Unverified Commit 33ee7168 authored by Daniel Hiltgen's avatar Daniel Hiltgen Committed by GitHub
Browse files

Add experimental MLX backend and engine with imagegen support (#13648)



* WIP - MLX backend with gemma3

* MLX: add cmake and go tag build toggles

To build the new MLX backend code:
  cmake --preset MLX
  cmake --build --preset MLX --parallel
  cmake --install build --component MLX
  go build -tags mlx .

Note: the main.go entrypoint for the MLX engine will change in a follow up commit.

* add experimental image generation runtime

* add experimental image generation runtime

* MLX: wire up cuda build for linux

* MLX: get dependencies correct and dedup

This is still too large for a unified github artifact, but is now "correct" for the mlx_cuda_v13
directory.

* fix relative link bug in dedup

* Add darwin build and readme

* add go build tag for mlx dependent code and wire up build_darwin.sh

* lint cleanup

* macos: build mlx for x86

This will be CPU only.

* cuda build instructions and fix drift from mlx bump

* stale comment

* Delete agent helper doc

* Clean up readme.md

* Revise README for tokenizer clarity and details

Updated README to clarify tokenizer functionality and removed correctness section.

---------
Co-authored-by: default avatarjmorganca <jmorganca@gmail.com>
parent 34d0c55e
//go:build mlx
package cache
import "github.com/ollama/ollama/x/imagegen/mlx"
type Cache interface {
Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array)
Offset() int
Len() int
State() []*mlx.Array
}
type KVCache struct {
keys, values *mlx.Array
offset int
step int
}
func NewKVCache() *KVCache {
return &KVCache{step: 256}
}
func (c *KVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
prev := c.offset
shape := k.Shape()
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
// Grow buffer if needed
if c.keys == nil || (prev+seqLen) > int(c.keys.Shape()[2]) {
nSteps := (c.step + seqLen - 1) / c.step
newK := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dk}, k.Dtype())
newV := mlx.Zeros([]int32{B, H, int32(nSteps * c.step), Dv}, v.Dtype())
if c.keys != nil {
if prev%c.step != 0 {
c.keys = mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dk})
c.values = mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(prev), Dv})
}
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
} else {
c.keys, c.values = newK, newV
}
}
c.offset += seqLen
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dk})
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(prev), 0}, []int32{B, H, int32(c.offset), Dv})
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dk}),
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, int32(c.offset), Dv})
}
func (c *KVCache) State() []*mlx.Array {
if c.keys == nil {
return nil
}
return []*mlx.Array{c.keys, c.values}
}
func (c *KVCache) Offset() int { return c.offset }
func (c *KVCache) Len() int { return c.offset }
// RotatingKVCache implements sliding window attention with bounded memory
type RotatingKVCache struct {
keys, values *mlx.Array
offset int
maxSize int
step int
idx int
}
func NewRotatingKVCache(maxSize int) *RotatingKVCache {
return &RotatingKVCache{maxSize: maxSize, step: 256}
}
func (c *RotatingKVCache) Update(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
if seqLen > 1 {
return c.updateConcat(k, v, seqLen)
}
return c.updateInPlace(k, v)
}
func (c *RotatingKVCache) updateInPlace(k, v *mlx.Array) (*mlx.Array, *mlx.Array) {
shape := k.Shape()
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
// Grow buffer if not yet at max
if c.keys == nil || (c.idx >= int(c.keys.Shape()[2]) && int(c.keys.Shape()[2]) < c.maxSize) {
var cap int
if c.keys != nil {
cap = int(c.keys.Shape()[2])
}
newSize := min(c.step, c.maxSize-cap)
newK := mlx.Zeros([]int32{B, H, int32(newSize), Dk}, k.Dtype())
newV := mlx.Zeros([]int32{B, H, int32(newSize), Dv}, v.Dtype())
if c.keys != nil {
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, newK}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, newV}, 2)
} else {
c.keys, c.values = newK, newV
}
}
// Rotate when hitting max
if c.idx >= c.maxSize {
c.idx = 0
}
c.keys = mlx.SliceUpdateInplace(c.keys, k, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dk})
c.values = mlx.SliceUpdateInplace(c.values, v, []int32{0, 0, int32(c.idx), 0}, []int32{B, H, int32(c.idx + 1), Dv})
c.offset++
c.idx++
validLen := int32(min(c.offset, c.maxSize))
return mlx.Slice(c.keys, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dk}),
mlx.Slice(c.values, []int32{0, 0, 0, 0}, []int32{B, H, validLen, Dv})
}
func (c *RotatingKVCache) updateConcat(k, v *mlx.Array, seqLen int) (*mlx.Array, *mlx.Array) {
shape := k.Shape()
B, H, Dk := shape[0], shape[1], shape[3]
Dv := v.Shape()[3]
if c.keys == nil {
c.keys, c.values = k, v
} else {
c.keys = mlx.Concatenate([]*mlx.Array{c.keys, k}, 2)
c.values = mlx.Concatenate([]*mlx.Array{c.values, v}, 2)
}
c.offset += seqLen
// Trim to max_size to maintain sliding window
cap := int(c.keys.Shape()[2])
if trim := cap - c.maxSize; trim > 0 {
c.keys = mlx.Slice(c.keys, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dk})
c.values = mlx.Slice(c.values, []int32{0, 0, int32(trim), 0}, []int32{B, H, int32(cap), Dv})
}
c.idx = int(c.keys.Shape()[2])
return c.keys, c.values
}
func (c *RotatingKVCache) State() []*mlx.Array {
if c.keys == nil {
return nil
}
return []*mlx.Array{c.keys, c.values}
}
func (c *RotatingKVCache) Offset() int { return c.offset }
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
//go:build mlx
package cache
import "github.com/ollama/ollama/x/imagegen/mlx"
// StepCache caches layer outputs across diffusion denoising steps.
// Based on DeepCache (CVPR 2024) and Learning-to-Cache (NeurIPS 2024):
// shallow layers change little between consecutive steps, so we can
// cache their outputs and skip recomputation on non-refresh steps.
//
// Supports both single-stream (Z-Image) and dual-stream (Qwen-Image) architectures:
// - Single-stream: use Get/Set for the single output per layer
// - Dual-stream: use Get/Set for stream 1 (imgH), Get2/Set2 for stream 2 (txtH)
//
// Usage (single-stream):
//
// cache := NewStepCache(15) // cache first 15 layers
// for step := 0; step < numSteps; step++ {
// refresh := cache.ShouldRefresh(step, 3) // refresh every 3 steps
// for i, layer := range layers {
// if i < 15 && !refresh && cache.Get(i) != nil {
// output = cache.Get(i) // reuse cached
// } else {
// output = layer.Forward(input)
// if i < 15 && refresh {
// cache.Set(i, output)
// }
// }
// }
// }
// cache.Free() // cleanup when done
//
// Usage (dual-stream):
//
// cache := NewStepCache(15)
// for step := 0; step < numSteps; step++ {
// refresh := cache.ShouldRefresh(step, 3)
// for i, layer := range layers {
// if i < 15 && !refresh && cache.Get(i) != nil {
// imgH, txtH = cache.Get(i), cache.Get2(i)
// } else {
// imgH, txtH = layer.Forward(imgH, txtH, ...)
// if i < 15 && refresh {
// cache.Set(i, imgH)
// cache.Set2(i, txtH)
// }
// }
// }
// }
type StepCache struct {
layers []*mlx.Array // cached layer outputs (stream 1)
layers2 []*mlx.Array // cached layer outputs (stream 2, for dual-stream models)
constant *mlx.Array // optional constant (e.g., text embeddings)
}
// NewStepCache creates a cache for the given number of layers.
func NewStepCache(numLayers int) *StepCache {
return &StepCache{
layers: make([]*mlx.Array, numLayers),
layers2: make([]*mlx.Array, numLayers),
}
}
// ShouldRefresh returns true if the cache should be refreshed at this step.
// Refresh happens on step 0, interval, 2*interval, etc.
func (c *StepCache) ShouldRefresh(step, interval int) bool {
return step%interval == 0
}
// Get returns the cached output for a layer, or nil if not cached.
func (c *StepCache) Get(layer int) *mlx.Array {
if layer < len(c.layers) {
return c.layers[layer]
}
return nil
}
// Set stores a layer output (stream 1), freeing any previous value.
func (c *StepCache) Set(layer int, arr *mlx.Array) {
if layer < len(c.layers) {
if c.layers[layer] != nil {
c.layers[layer].Free()
}
c.layers[layer] = arr
}
}
// Get2 returns the cached output for a layer (stream 2), or nil if not cached.
// Used for dual-stream architectures like Qwen-Image.
func (c *StepCache) Get2(layer int) *mlx.Array {
if layer < len(c.layers2) {
return c.layers2[layer]
}
return nil
}
// Set2 stores a layer output (stream 2), freeing any previous value.
// Used for dual-stream architectures like Qwen-Image.
func (c *StepCache) Set2(layer int, arr *mlx.Array) {
if layer < len(c.layers2) {
if c.layers2[layer] != nil {
c.layers2[layer].Free()
}
c.layers2[layer] = arr
}
}
// GetConstant returns the cached constant value.
func (c *StepCache) GetConstant() *mlx.Array {
return c.constant
}
// SetConstant stores a constant value, freeing any previous value.
func (c *StepCache) SetConstant(arr *mlx.Array) {
if c.constant != nil {
c.constant.Free()
}
c.constant = arr
}
// Arrays returns all non-nil cached arrays (for pool.Keep).
func (c *StepCache) Arrays() []*mlx.Array {
var result []*mlx.Array
if c.constant != nil {
result = append(result, c.constant)
}
for _, arr := range c.layers {
if arr != nil {
result = append(result, arr)
}
}
for _, arr := range c.layers2 {
if arr != nil {
result = append(result, arr)
}
}
return result
}
// Free releases all cached arrays. Call when generation completes.
func (c *StepCache) Free() {
if c.constant != nil {
c.constant.Free()
c.constant = nil
}
for i, arr := range c.layers {
if arr != nil {
arr.Free()
c.layers[i] = nil
}
}
for i, arr := range c.layers2 {
if arr != nil {
arr.Free()
c.layers2[i] = nil
}
}
}
// NumLayers returns the number of layers this cache can store.
func (c *StepCache) NumLayers() int {
return len(c.layers)
}
//go:build mlx
package main
import (
"context"
"fmt"
"time"
"unicode/utf8"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// Dedicated stream for generation (like mlx-lm's generation_stream)
var generationStream *mlx.Stream
// utf8Streamer buffers decoded text and emits only complete UTF-8 characters.
// This handles cases where tokenizers output partial multi-byte sequences.
type utf8Streamer struct {
buffer []byte
}
// Write adds decoded text to the buffer and returns complete UTF-8 characters.
func (s *utf8Streamer) Write(text string) string {
s.buffer = append(s.buffer, text...)
// Find the last position that ends with a complete UTF-8 character
validLen := 0
for i := 0; i < len(s.buffer); {
r, size := utf8.DecodeRune(s.buffer[i:])
if r == utf8.RuneError && size == 1 {
// Invalid or incomplete UTF-8 sequence at this position
// Check if it could be a valid start of a multi-byte sequence
if len(s.buffer)-i < 4 {
// Might be incomplete, keep it in buffer
break
}
// Definitely invalid, skip this byte
i++
validLen = i
} else {
i += size
validLen = i
}
}
if validLen == 0 {
return ""
}
result := string(s.buffer[:validLen])
s.buffer = s.buffer[validLen:]
return result
}
// Flush returns any remaining buffered bytes (may be incomplete UTF-8).
func (s *utf8Streamer) Flush() string {
if len(s.buffer) == 0 {
return ""
}
result := string(s.buffer)
s.buffer = nil
return result
}
func init() {
generationStream = mlx.NewStream()
}
// withStream runs fn with the generation stream as default
func withStream(fn func()) {
orig := mlx.GetDefaultStream()
mlx.SetDefaultStream(generationStream)
fn()
mlx.SetDefaultStream(orig)
}
type Model interface {
Tokenizer() *tokenizer.Tokenizer
VocabSize() int32
NewCache(maxSeqLen int32) []cache.Cache
Forward(input *mlx.Array, caches []cache.Cache) *mlx.Array
}
// ChatModel is an optional interface for models that support chat formatting
type ChatModel interface {
FormatPrompt(prompt string) string
}
// MultimodalModel is for models that support image input
type MultimodalModel interface {
Model
FormatPromptWithImage(prompt string) string
ExpandImageTokens(tokens []int32) []int32
ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array
ImageSize() int32 // Returns expected image size for preprocessing
}
// ImageLoader loads and preprocesses an image for multimodal models
// Returns nil if path is empty
type ImageLoader func(path string, imageSize int32) (*mlx.Array, error)
type input struct {
Prompt string
Image *mlx.Array // Optional preprocessed image for multimodal models
MaxTokens int
Temperature float32
TopP float32
TopK int
WiredLimitGB int // Metal wired memory limit in GB (default 32)
}
type output struct {
Text string
Done bool
PrefillTokSec float64
GenTokSec float64
}
// Decoder wraps model + cache for autoregressive generation.
type Decoder struct {
model Model
caches []cache.Cache
vocabSize int32
temp float32
topK int
topP float32
token *mlx.Array // Current token (kept across pools)
oldCacheState []*mlx.Array // Preallocated slice for old cache state
image *mlx.Array // Optional image for multimodal prefill
}
func NewDecoder(m Model, temp float32, topK int, topP float32) *Decoder {
caches := m.NewCache(0)
return &Decoder{
model: m,
caches: caches,
vocabSize: m.VocabSize(),
temp: temp,
topK: topK,
topP: topP,
oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
}
}
// SetImage sets the image for multimodal prefill (call before prefill)
func (d *Decoder) SetImage(img *mlx.Array) {
d.image = img
}
func (d *Decoder) prefill(inputIDs []int32) int {
processed := 0
// Track old cache state to free after each chunk
var oldCacheState []*mlx.Array
// For multimodal models with an image, we need to process all tokens together
// in the first forward pass so the image embeddings can be inserted properly.
// Skip chunking for multimodal prefill.
isMultimodal := d.image != nil
// Process all-but-1 tokens in chunks, eval cache state for memory management
// Skip chunking for multimodal - process everything in the final step
if !isMultimodal {
for len(inputIDs) > 1 {
chunkSize := min(2048, len(inputIDs)-1)
if chunkSize <= 0 {
break
}
chunk := inputIDs[:chunkSize]
// Save old cache state before forward
oldCacheState = oldCacheState[:0]
for _, c := range d.caches {
oldCacheState = append(oldCacheState, c.State()...)
}
var cacheState []*mlx.Array
withStream(func() {
x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
d.model.Forward(x, d.caches)
for _, c := range d.caches {
cacheState = append(cacheState, c.State()...)
}
})
mlx.Eval(cacheState...)
// Free old cache state
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
inputIDs = inputIDs[chunkSize:]
processed += chunkSize
}
}
// Save old cache state before final step
oldCacheState = oldCacheState[:0]
for _, c := range d.caches {
oldCacheState = append(oldCacheState, c.State()...)
}
// Final token + sampling (or all tokens for multimodal)
withStream(func() {
x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
mlx.Eval(x) // Materialize before any other evals
var logits *mlx.Array
// Use ForwardWithImage if we have an image and model supports it
if d.image != nil {
if mm, ok := d.model.(MultimodalModel); ok {
logits = mm.ForwardWithImage(x, d.image, d.caches)
d.image = nil // Only use image for first forward
} else {
logits = d.model.Forward(x, d.caches)
}
} else {
logits = d.model.Forward(x, d.caches)
}
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
})
// Keep cache state (token auto-kept by AsyncEval)
for _, c := range d.caches {
mlx.Keep(c.State()...)
}
mlx.AsyncEval(d.token)
// Free old cache state from before final step
for _, arr := range oldCacheState {
if arr != nil {
arr.Free()
}
}
mlx.ClearCache()
return processed + len(inputIDs)
}
func (d *Decoder) step() int32 {
prevToken := d.token
// Save old cache state (reuse preallocated slice)
d.oldCacheState = d.oldCacheState[:0]
for _, c := range d.caches {
d.oldCacheState = append(d.oldCacheState, c.State()...)
}
withStream(func() {
logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
d.token = sample(logits, d.temp, d.topK, d.topP, d.vocabSize)
})
// Keep token and new cache state so they survive cleanup
mlx.Keep(d.token)
for _, c := range d.caches {
mlx.Keep(c.State()...)
}
mlx.AsyncEval(d.token)
// Sync on previous token (GPU already working on next step)
val := prevToken.ItemInt32()
// Free old token and old cache state
prevToken.Free()
for _, arr := range d.oldCacheState {
arr.Free()
}
return val
}
func generate(ctx context.Context, m Model, in input, cb func(output)) error {
mlx.EnableCompile()
wiredLimit := in.WiredLimitGB
if wiredLimit <= 0 {
wiredLimit = 32 // default 32GB
}
mlx.MetalSetWiredLimit(uint64(wiredLimit) << 30)
temp := in.Temperature
if temp < 0 {
temp = 0.7
}
tok := m.Tokenizer()
dec := NewDecoder(m, temp, in.TopK, in.TopP)
// Apply chat template - use image template if we have an image
prompt := in.Prompt
var tokens []int32
if mm, ok := m.(MultimodalModel); ok && in.Image != nil {
prompt = mm.FormatPromptWithImage(prompt)
tokens = tok.Encode(prompt, true)
tokens = mm.ExpandImageTokens(tokens) // Expand <start_of_image> to 256 image tokens
dec.SetImage(in.Image)
} else if cm, ok := m.(ChatModel); ok {
prompt = cm.FormatPrompt(prompt)
tokens = tok.Encode(prompt, true)
} else {
tokens = tok.Encode(prompt, true)
}
prefillStart := time.Now()
prefillTokens := dec.prefill(tokens)
// Prefill measurement should include time to first token (like mlx-lm)
// Step() waits for prefill to complete and returns first token
firstToken := dec.step()
prefillTokSec := float64(prefillTokens) / time.Since(prefillStart).Seconds()
genStart := time.Now()
maxTokens := max(in.MaxTokens, 100)
var genTokens int
// UTF-8 streamer to handle partial multi-byte characters
streamer := &utf8Streamer{}
// Handle first token
genTokens++
if tok.IsEOS(firstToken) {
cb(output{Done: true, PrefillTokSec: prefillTokSec, GenTokSec: 0})
return nil
}
if text := streamer.Write(tok.Decode([]int32{firstToken})); text != "" {
cb(output{Text: text})
}
for n := 1; n < maxTokens; n++ {
if ctx.Err() != nil {
return ctx.Err()
}
token := dec.step()
genTokens++
if tok.IsEOS(token) {
break
}
if text := streamer.Write(tok.Decode([]int32{token})); text != "" {
cb(output{Text: text})
}
if n%256 == 0 {
mlx.ClearCache()
}
}
// Flush any remaining buffered bytes
if text := streamer.Flush(); text != "" {
cb(output{Text: text})
}
fmt.Printf("\nPeak memory: %.2fGB\n", float64(mlx.MetalGetPeakMemory())/(1<<30))
cb(output{Done: true, PrefillTokSec: prefillTokSec,
GenTokSec: float64(genTokens) / time.Since(genStart).Seconds()})
return nil
}
//go:build mlx
package main
import (
"fmt"
"image"
"image/png"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// saveImageArray saves an MLX array as a PNG image.
// Expected format: [B, C, H, W] with values in [0, 1] range and C=3 (RGB).
func saveImageArray(arr *mlx.Array, path string) error {
img, err := arrayToImage(arr)
if err != nil {
return err
}
return savePNG(img, path)
}
func savePNG(img *image.RGBA, path string) error {
if filepath.Ext(path) != ".png" {
path = path + ".png"
}
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
return png.Encode(f, img)
}
func arrayToImage(arr *mlx.Array) (*image.RGBA, error) {
shape := arr.Shape()
if len(shape) != 4 {
return nil, fmt.Errorf("expected 4D array [B, C, H, W], got %v", shape)
}
// Transform to [H, W, C] for image conversion
img := mlx.Squeeze(arr, 0)
arr.Free()
img = mlx.Transpose(img, 1, 2, 0)
img = mlx.Contiguous(img)
mlx.Eval(img)
imgShape := img.Shape()
H := int(imgShape[0])
W := int(imgShape[1])
C := int(imgShape[2])
if C != 3 {
img.Free()
return nil, fmt.Errorf("expected 3 channels (RGB), got %d", C)
}
// Copy to CPU and free GPU memory
data := img.Data()
img.Free()
// Write directly to Pix slice (faster than SetRGBA)
goImg := image.NewRGBA(image.Rect(0, 0, W, H))
pix := goImg.Pix
for y := 0; y < H; y++ {
for x := 0; x < W; x++ {
srcIdx := (y*W + x) * C
dstIdx := (y*W + x) * 4
pix[dstIdx+0] = uint8(clampF(data[srcIdx+0]*255+0.5, 0, 255))
pix[dstIdx+1] = uint8(clampF(data[srcIdx+1]*255+0.5, 0, 255))
pix[dstIdx+2] = uint8(clampF(data[srcIdx+2]*255+0.5, 0, 255))
pix[dstIdx+3] = 255
}
}
return goImg, nil
}
func clampF(v, min, max float32) float32 {
if v < min {
return min
}
if v > max {
return max
}
return v
}
//go:build mlx
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"log"
"os"
"path/filepath"
"runtime/pprof"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
"github.com/ollama/ollama/x/imagegen/models/llama"
"github.com/ollama/ollama/x/imagegen/models/qwen_image"
"github.com/ollama/ollama/x/imagegen/models/qwen_image_edit"
"github.com/ollama/ollama/x/imagegen/models/zimage"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
// stringSlice is a flag type that accumulates multiple values
type stringSlice []string
func (s *stringSlice) String() string {
return fmt.Sprintf("%v", *s)
}
func (s *stringSlice) Set(value string) error {
*s = append(*s, value)
return nil
}
func main() {
modelPath := flag.String("model", "", "Model directory")
prompt := flag.String("prompt", "Hello", "Prompt")
// Text generation params
maxTokens := flag.Int("max-tokens", 100, "Max tokens")
temperature := flag.Float64("temperature", 0.7, "Temperature")
topP := flag.Float64("top-p", 0.9, "Top-p sampling")
topK := flag.Int("top-k", 40, "Top-k sampling")
imagePath := flag.String("image", "", "Image path for multimodal models")
// Image generation params
width := flag.Int("width", 1024, "Image width")
height := flag.Int("height", 1024, "Image height")
steps := flag.Int("steps", 9, "Denoising steps")
seed := flag.Int64("seed", 42, "Random seed")
out := flag.String("output", "output.png", "Output path")
// Utility flags
listTensors := flag.Bool("list", false, "List tensors only")
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
layerCache := flag.Bool("layer-cache", false, "Enable layer caching for faster diffusion (Z-Image, Qwen-Image). Not compatible with CFG/negative prompts.")
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
// Legacy mode flags
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
var inputImages stringSlice
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
cfgScale := flag.Float64("cfg-scale", 4.0, "CFG scale for image editing")
flag.Parse()
if *modelPath == "" {
flag.Usage()
return
}
// CPU profiling
if *cpuProfile != "" {
f, err := os.Create(*cpuProfile)
if err != nil {
log.Fatal(err)
}
defer f.Close()
if err := pprof.StartCPUProfile(f); err != nil {
log.Fatal(err)
}
defer pprof.StopCPUProfile()
}
var err error
// Handle legacy mode flags that aren't unified yet
switch {
case *zimageFlag:
m := &zimage.Model{}
if loadErr := m.Load(*modelPath); loadErr != nil {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(&zimage.GenerateConfig{
Prompt: *prompt,
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
CapturePath: *gpuCapture,
LayerCache: *layerCache,
})
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImage:
m, loadErr := qwen_image.LoadPersistent(*modelPath)
if loadErr != nil {
log.Fatal(loadErr)
}
var img *mlx.Array
img, err = m.GenerateFromConfig(&qwen_image.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: int32(*width),
Height: int32(*height),
Steps: *steps,
Seed: *seed,
LayerCache: *layerCache,
})
if err == nil {
err = saveImageArray(img, *out)
}
case *qwenImageEdit:
if len(inputImages) == 0 {
log.Fatal("qwen-image-edit requires at least one -input-image")
}
m, loadErr := qwen_image_edit.LoadPersistent(*modelPath)
if loadErr != nil {
log.Fatal(loadErr)
}
// For image editing, use 0 for dimensions to auto-detect from input image
// unless explicitly overridden from defaults
editWidth := int32(0)
editHeight := int32(0)
if *width != 1024 {
editWidth = int32(*width)
}
if *height != 1024 {
editHeight = int32(*height)
}
cfg := &qwen_image_edit.GenerateConfig{
Prompt: *prompt,
NegativePrompt: *negativePrompt,
CFGScale: float32(*cfgScale),
Width: editWidth,
Height: editHeight,
Steps: *steps,
Seed: *seed,
}
var img *mlx.Array
img, err = m.EditFromConfig(inputImages, cfg)
if err == nil {
err = saveImageArray(img, *out)
}
case *listTensors:
err = listModelTensors(*modelPath)
default:
// llm path
m, err := load(*modelPath)
if err != nil {
log.Fatal(err)
}
// Load image if provided and model supports it
var image *mlx.Array
if *imagePath != "" {
if mm, ok := m.(interface{ ImageSize() int32 }); ok {
image, err = gemma3.ProcessImage(*imagePath, mm.ImageSize())
if err != nil {
log.Fatal("load image:", err)
}
} else {
log.Fatal("model does not support image input")
}
}
err = generate(context.Background(), m, input{
Prompt: *prompt,
Image: image,
MaxTokens: *maxTokens,
Temperature: float32(*temperature),
TopP: float32(*topP),
TopK: *topK,
WiredLimitGB: *wiredLimitGB,
}, func(out output) {
if out.Text != "" {
fmt.Print(out.Text)
}
if out.Done {
fmt.Printf("\n\n[prefill: %.1f tok/s, gen: %.1f tok/s]\n", out.PrefillTokSec, out.GenTokSec)
}
})
}
if err != nil {
log.Fatal(err)
}
}
func listModelTensors(modelPath string) error {
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return err
}
for _, name := range weights.ListTensors() {
info, _ := weights.GetTensorInfo(name)
fmt.Printf("%s: %v (%s)\n", name, info.Shape, info.Dtype)
}
return nil
}
// loadModel builds and evaluates a model using the common load pattern.
// Release safetensors BEFORE eval - lazy arrays have captured their data,
// and this reduces peak memory by ~6GB (matches mlx-lm behavior).
func loadModel[T Model](build func() T, cleanup func()) T {
m := build()
weights := mlx.Collect(m)
cleanup()
mlx.Eval(weights...)
return m
}
func load(modelPath string) (Model, error) {
kind, err := detectModelKind(modelPath)
if err != nil {
return nil, fmt.Errorf("detect model kind: %w", err)
}
switch kind {
case "gpt_oss":
return gpt_oss.Load(modelPath)
case "gemma3":
return gemma3.Load(modelPath)
case "gemma3_text":
return gemma3.LoadText(modelPath)
default:
return llama.Load(modelPath)
}
}
func detectModelKind(modelPath string) (string, error) {
indexPath := filepath.Join(modelPath, "model_index.json")
if _, err := os.Stat(indexPath); err == nil {
data, err := os.ReadFile(indexPath)
if err != nil {
return "zimage", nil
}
var index struct {
ClassName string `json:"_class_name"`
}
if err := json.Unmarshal(data, &index); err == nil {
switch index.ClassName {
case "FluxPipeline", "ZImagePipeline":
return "zimage", nil
}
}
return "zimage", nil
}
configPath := filepath.Join(modelPath, "config.json")
data, err := os.ReadFile(configPath)
if err != nil {
return "", fmt.Errorf("no config.json or model_index.json found: %w", err)
}
var cfg struct {
ModelType string `json:"model_type"`
}
if err := json.Unmarshal(data, &cfg); err != nil {
return "", fmt.Errorf("parse config.json: %w", err)
}
return cfg.ModelType, nil
}
//go:build mlx
package main
import "github.com/ollama/ollama/x/imagegen/mlx"
// sampleTopK samples from top-k logits using global random state
func sampleTopK(scaledLogits *mlx.Array, k int) *mlx.Array {
neg := mlx.Neg(scaledLogits)
indices := mlx.Argpartition(neg, k-1, -1)
topKIdx := mlx.Slice(indices, []int32{0}, []int32{int32(k)})
values := mlx.TakeAlongAxis(scaledLogits, topKIdx, -1)
sampled := mlx.RandomCategorical(values, -1, 1)
return mlx.Take(topKIdx, sampled, -1)
}
// sampleTopP samples using nucleus sampling with global random state
func sampleTopP(scaledLogits *mlx.Array, p float32, vocabSize int32) *mlx.Array {
sorted := mlx.Argsort(mlx.Neg(scaledLogits), -1)
sortedLogits := mlx.TakeAlongAxis(scaledLogits, sorted, -1)
probs := mlx.Softmax(sortedLogits, -1)
cumProbs := mlx.Cumsum(probs, -1)
mask := mlx.LessScalar(cumProbs, p)
negInf := mlx.FullDtype(float32(-1e9), scaledLogits.Dtype(), vocabSize)
masked := mlx.Where(mask, sortedLogits, negInf)
sampled := mlx.RandomCategorical(masked, -1, 1)
return mlx.Take(sorted, sampled, -1)
}
// sample samples from logits at the last position
func sample(logits *mlx.Array, temp float32, topK int, topP float32, vocab int32) *mlx.Array {
// Get last position logits: [1, L, vocab] -> [vocab]
shape := logits.Shape()
seqLen := shape[1]
lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocab})
lastLogits = mlx.Reshape(lastLogits, vocab)
if temp == 0 {
return mlx.Argmax(lastLogits, -1, false)
}
scaled := mlx.DivScalar(lastLogits, temp)
if topK > 0 && topK < int(vocab) {
return sampleTopK(scaled, topK)
}
if topP > 0 && topP < 1.0 {
return sampleTopP(scaled, topP, vocab)
}
return mlx.RandomCategorical(scaled, -1, 1)
}
# MLX Memory Management
| This package will get consolidated with `x/ml/backend/mlx` in the future.
## Automatic Tracking
All arrays are automatically tracked when created. On `Eval()`, non-kept arrays are freed.
### API
```go
result := mlx.Matmul(x, w) // arrays automatically tracked
mlx.Eval(result) // free non-kept, eval result (auto-kept)
```
### Key Functions
- `mlx.Eval(outputs...)` - free non-kept arrays, then evaluate (outputs auto-kept)
- `mlx.AsyncEval(outputs...)` - async version of Eval (outputs auto-kept)
- `mlx.Keep(arrays...)` - mark arrays to survive cleanup (for weights, caches)
- `array.Free()` - mark array for cleanup on next Eval
### Loop Pattern
```go
for step := 0; step < maxTokens; step++ {
logits := model.Forward(token, caches)
oldToken := token
token = sample(logits)
// Keep cache state across iterations
for _, c := range caches {
mlx.Keep(c.State()...)
}
oldToken.Free() // mark for cleanup
mlx.AsyncEval(token) // frees old, evals new
}
```
### Notes
- `Eval()` and `AsyncEval()` auto-keep their outputs
- `Free()` marks for cleanup - actual free happens during next Eval
- Use `Keep()` for weights and cache state that must survive multiple Eval cycles
- Arrays created inside compiled closures are managed by MLX, not tracked
//go:build mlx
package mlx
/*
#include "mlx/c/mlx.h"
#include <stdlib.h>
// Forward declaration for Go callback
extern int goClosureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);
// Destructor for payload (Go handle)
extern void goClosureDestructor(void* payload);
*/
import "C"
import (
"runtime/cgo"
"sync"
"unsafe"
)
// inClosureCallback is set to true during closure callback execution.
var inClosureCallback bool
var closureCallbackMu sync.Mutex
// InClosureCallback returns true if we're currently executing inside a closure callback.
func InClosureCallback() bool {
closureCallbackMu.Lock()
defer closureCallbackMu.Unlock()
return inClosureCallback
}
// CompiledFunc is a compiled MLX function that can be called efficiently.
// All intermediate arrays during execution stay inside MLX - only inputs
// and outputs cross the Go boundary.
type CompiledFunc struct {
closure C.mlx_closure
compiled C.mlx_closure
}
// ClosureFunc is the signature for functions that can be compiled.
// It takes a slice of input arrays and returns a slice of output arrays.
type ClosureFunc func(inputs []*Array) []*Array
// Compile compiles a Go function into an optimized MLX closure.
// The function is traced once during compilation, then subsequent calls
// run the optimized graph without creating Go intermediate arrays.
//
// Example:
//
// compiled := mlx.Compile(func(inputs []*mlx.Array) []*mlx.Array {
// a, b := inputs[0], inputs[1]
// c := mlx.Add(a, b)
// d := mlx.Mul(c, c)
// return []*mlx.Array{d}
// })
// defer compiled.Free()
//
// result := compiled.Call(x, y)[0]
func Compile(fn ClosureFunc) *CompiledFunc {
return CompileShapeless(fn, false)
}
// CompileShapeless compiles with optional shapeless mode.
// If shapeless=true, the function works for any input shape after tracing.
func CompileShapeless(fn ClosureFunc, shapeless bool) *CompiledFunc {
// Create a cgo.Handle to prevent the Go function from being GC'd
handle := cgo.NewHandle(fn)
// Create the closure from the Go callback
closure := C.mlx_closure_new_func_payload(
(*[0]byte)(C.goClosureCallback),
unsafe.Pointer(handle),
(*[0]byte)(C.goClosureDestructor),
)
// Compile the closure
compiled := C.mlx_closure_new()
C.mlx_compile(&compiled, closure, C.bool(shapeless))
return &CompiledFunc{
closure: closure,
compiled: compiled,
}
}
// Call invokes the compiled function with the given inputs.
func (cf *CompiledFunc) Call(inputs ...*Array) []*Array {
// Pack inputs into vector
inputVec := C.mlx_vector_array_new()
for _, arr := range inputs {
C.mlx_vector_array_append_value(inputVec, arr.c)
}
// Apply compiled closure
outputVec := C.mlx_vector_array_new()
C.mlx_closure_apply(&outputVec, cf.compiled, inputVec)
C.mlx_vector_array_free(inputVec)
// Unpack outputs
numOutputs := int(C.mlx_vector_array_size(outputVec))
outputs := make([]*Array, numOutputs)
for i := 0; i < numOutputs; i++ {
var arr C.mlx_array
C.mlx_vector_array_get(&arr, outputVec, C.size_t(i))
outputs[i] = newArray(arr)
}
C.mlx_vector_array_free(outputVec)
return outputs
}
// CallEval invokes the compiled function and evaluates the results.
func (cf *CompiledFunc) CallEval(inputs ...*Array) []*Array {
outputs := cf.Call(inputs...)
Eval(outputs...)
return outputs
}
// Free releases the compiled function resources.
func (cf *CompiledFunc) Free() {
C.mlx_closure_free(cf.compiled)
C.mlx_closure_free(cf.closure)
}
// borrowArray wraps a C array WITHOUT setting up GC cleanup.
// Use this for arrays we don't own (e.g., borrowed references in callbacks).
func borrowArray(array C.mlx_array) *Array {
return &Array{c: array}
}
//export goClosureCallback
func goClosureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) C.int {
// Set flag to disable AddCleanup during callback
closureCallbackMu.Lock()
inClosureCallback = true
closureCallbackMu.Unlock()
defer func() {
closureCallbackMu.Lock()
inClosureCallback = false
closureCallbackMu.Unlock()
}()
// Recover the Go function from the handle
handle := cgo.Handle(payload)
fn := handle.Value().(ClosureFunc)
// Convert input vector to Go slice - use borrowArray since MLX owns these
numInputs := int(C.mlx_vector_array_size(input))
inputs := make([]*Array, numInputs)
for i := 0; i < numInputs; i++ {
var arr C.mlx_array
C.mlx_vector_array_get(&arr, input, C.size_t(i))
inputs[i] = borrowArray(arr) // Don't set up cleanup - MLX owns these
}
// Call the Go function
outputs := fn(inputs)
// Build output vector
*res = C.mlx_vector_array_new()
for _, arr := range outputs {
C.mlx_vector_array_append_value(*res, arr.c)
}
return 0
}
//export goClosureDestructor
func goClosureDestructor(payload unsafe.Pointer) {
handle := cgo.Handle(payload)
handle.Delete()
}
//go:build mlx
package mlx
/*
#cgo CFLAGS: -O3 -I${SRCDIR}/../../../build/_deps/mlx-c-src
#cgo LDFLAGS: -L${SRCDIR}/../../../build/lib/ollama/ -lmlxc -Wl,-rpath,${SRCDIR}/../../../build/lib/ollama/
#cgo darwin LDFLAGS: -lc++ -framework Metal -framework Foundation -framework Accelerate
#cgo linux LDFLAGS: -lstdc++ -lcuda -lcudart -lnvrtc
#include "mlx/c/mlx.h"
#include <stdlib.h>
#include <stdint.h>
// Cached default GPU stream for all ops
static mlx_stream _default_stream = {0};
static mlx_stream _cpu_stream = {0};
static inline mlx_stream default_stream() {
if (_default_stream.ctx == NULL) {
_default_stream = mlx_default_gpu_stream_new();
}
return _default_stream;
}
static inline void set_default_stream(mlx_stream s) {
_default_stream = s;
}
// CPU stream for file loading (Load primitive only runs on CPU)
static inline mlx_stream cpu_stream() {
if (_cpu_stream.ctx == NULL) {
_cpu_stream = mlx_default_cpu_stream_new();
}
return _cpu_stream;
}
// CGO noescape/nocallback hints to reduce CGO overhead
// noescape: pointers won't escape, no heap allocation needed
// nocallback: function won't call back into Go
#cgo noescape mlx_add
#cgo nocallback mlx_add
#cgo noescape mlx_subtract
#cgo nocallback mlx_subtract
#cgo noescape mlx_multiply
#cgo nocallback mlx_multiply
#cgo noescape mlx_divide
#cgo nocallback mlx_divide
#cgo noescape mlx_negative
#cgo nocallback mlx_negative
#cgo noescape mlx_abs
#cgo nocallback mlx_abs
#cgo noescape mlx_exp
#cgo nocallback mlx_exp
#cgo noescape mlx_log
#cgo nocallback mlx_log
#cgo noescape mlx_sqrt
#cgo nocallback mlx_sqrt
#cgo noescape mlx_rsqrt
#cgo nocallback mlx_rsqrt
#cgo noescape mlx_square
#cgo nocallback mlx_square
#cgo noescape mlx_power
#cgo nocallback mlx_power
#cgo noescape mlx_erf
#cgo nocallback mlx_erf
#cgo noescape mlx_sigmoid
#cgo nocallback mlx_sigmoid
#cgo noescape mlx_tanh
#cgo nocallback mlx_tanh
#cgo noescape mlx_sin
#cgo nocallback mlx_sin
#cgo noescape mlx_cos
#cgo nocallback mlx_cos
#cgo noescape mlx_maximum
#cgo nocallback mlx_maximum
#cgo noescape mlx_minimum
#cgo nocallback mlx_minimum
#cgo noescape mlx_clip
#cgo nocallback mlx_clip
#cgo noescape mlx_sum
#cgo nocallback mlx_sum
#cgo noescape mlx_sum_axis
#cgo nocallback mlx_sum_axis
#cgo noescape mlx_mean
#cgo nocallback mlx_mean
#cgo noescape mlx_mean_axis
#cgo nocallback mlx_mean_axis
#cgo noescape mlx_var_axis
#cgo nocallback mlx_var_axis
#cgo noescape mlx_argmax
#cgo nocallback mlx_argmax
#cgo noescape mlx_argmax_axis
#cgo nocallback mlx_argmax_axis
#cgo noescape mlx_softmax_axis
#cgo nocallback mlx_softmax_axis
#cgo noescape mlx_cumsum
#cgo nocallback mlx_cumsum
#cgo noescape mlx_matmul
#cgo nocallback mlx_matmul
#cgo noescape mlx_addmm
#cgo nocallback mlx_addmm
#cgo noescape mlx_gather_mm
#cgo nocallback mlx_gather_mm
#cgo noescape mlx_gather_qmm
#cgo nocallback mlx_gather_qmm
#cgo noescape mlx_reshape
#cgo nocallback mlx_reshape
#cgo noescape mlx_transpose_axes
#cgo nocallback mlx_transpose_axes
#cgo noescape mlx_expand_dims
#cgo nocallback mlx_expand_dims
#cgo noescape mlx_squeeze_axis
#cgo nocallback mlx_squeeze_axis
#cgo noescape mlx_flatten
#cgo nocallback mlx_flatten
#cgo noescape mlx_concatenate_axis
#cgo nocallback mlx_concatenate_axis
#cgo noescape mlx_slice
#cgo nocallback mlx_slice
#cgo noescape mlx_slice_update
#cgo nocallback mlx_slice_update
#cgo noescape mlx_as_strided
#cgo nocallback mlx_as_strided
#cgo noescape mlx_view
#cgo nocallback mlx_view
#cgo noescape mlx_contiguous
#cgo nocallback mlx_contiguous
#cgo noescape mlx_pad
#cgo nocallback mlx_pad
#cgo noescape mlx_tile
#cgo nocallback mlx_tile
#cgo noescape mlx_take_axis
#cgo nocallback mlx_take_axis
#cgo noescape mlx_take_along_axis
#cgo nocallback mlx_take_along_axis
#cgo noescape mlx_put_along_axis
#cgo nocallback mlx_put_along_axis
#cgo noescape mlx_where
#cgo nocallback mlx_where
#cgo noescape mlx_argsort_axis
#cgo nocallback mlx_argsort_axis
#cgo noescape mlx_argpartition_axis
#cgo nocallback mlx_argpartition_axis
#cgo noescape mlx_topk_axis
#cgo nocallback mlx_topk_axis
#cgo noescape mlx_less
#cgo nocallback mlx_less
#cgo noescape mlx_greater_equal
#cgo nocallback mlx_greater_equal
#cgo noescape mlx_logical_and
#cgo nocallback mlx_logical_and
#cgo noescape mlx_zeros
#cgo nocallback mlx_zeros
#cgo noescape mlx_zeros_like
#cgo nocallback mlx_zeros_like
#cgo noescape mlx_ones
#cgo nocallback mlx_ones
#cgo noescape mlx_full
#cgo nocallback mlx_full
#cgo noescape mlx_arange
#cgo nocallback mlx_arange
#cgo noescape mlx_linspace
#cgo nocallback mlx_linspace
#cgo noescape mlx_tri
#cgo nocallback mlx_tri
#cgo noescape mlx_astype
#cgo nocallback mlx_astype
#cgo noescape mlx_fast_rms_norm
#cgo nocallback mlx_fast_rms_norm
#cgo noescape mlx_fast_rope
#cgo nocallback mlx_fast_rope
#cgo noescape mlx_fast_scaled_dot_product_attention
#cgo nocallback mlx_fast_scaled_dot_product_attention
#cgo noescape mlx_conv2d
#cgo nocallback mlx_conv2d
#cgo noescape mlx_conv3d
#cgo nocallback mlx_conv3d
#cgo noescape mlx_random_key
#cgo nocallback mlx_random_key
#cgo noescape mlx_random_split
#cgo nocallback mlx_random_split
#cgo noescape mlx_random_categorical_num_samples
#cgo nocallback mlx_random_categorical_num_samples
#cgo noescape mlx_random_normal
#cgo nocallback mlx_random_normal
#cgo noescape mlx_random_uniform
#cgo nocallback mlx_random_uniform
#cgo noescape mlx_array_eval
#cgo nocallback mlx_array_eval
#cgo noescape mlx_eval
#cgo nocallback mlx_eval
#cgo noescape mlx_async_eval
#cgo nocallback mlx_async_eval
#cgo noescape mlx_synchronize
#cgo nocallback mlx_synchronize
#cgo noescape mlx_array_new
#cgo nocallback mlx_array_new
#cgo noescape mlx_array_new_data
#cgo nocallback mlx_array_new_data
#cgo noescape mlx_array_new_float
#cgo nocallback mlx_array_new_float
#cgo noescape mlx_array_free
#cgo nocallback mlx_array_free
#cgo noescape mlx_array_size
#cgo nocallback mlx_array_size
#cgo noescape mlx_array_ndim
#cgo nocallback mlx_array_ndim
#cgo noescape mlx_array_dim
#cgo nocallback mlx_array_dim
#cgo noescape mlx_array_dtype
#cgo nocallback mlx_array_dtype
#cgo noescape mlx_array_item_int32
#cgo nocallback mlx_array_item_int32
#cgo noescape mlx_vector_array_new_data
#cgo nocallback mlx_vector_array_new_data
#cgo noescape mlx_vector_array_free
#cgo nocallback mlx_vector_array_free
#cgo noescape mlx_array_new_int
#cgo nocallback mlx_array_new_int
#cgo noescape mlx_stream_new_device
#cgo nocallback mlx_stream_new_device
#cgo noescape mlx_get_default_stream
#cgo nocallback mlx_get_default_stream
#cgo noescape mlx_set_default_stream
#cgo nocallback mlx_set_default_stream
*/
import "C"
import (
"fmt"
"reflect"
"runtime"
"sync"
"sync/atomic"
"time"
"unsafe"
)
// Dtype represents MLX data types
type Dtype int
const (
DtypeBool Dtype = C.MLX_BOOL
DtypeUint8 Dtype = C.MLX_UINT8
DtypeUint16 Dtype = C.MLX_UINT16
DtypeUint32 Dtype = C.MLX_UINT32
DtypeUint64 Dtype = C.MLX_UINT64
DtypeInt8 Dtype = C.MLX_INT8
DtypeInt16 Dtype = C.MLX_INT16
DtypeInt32 Dtype = C.MLX_INT32
DtypeInt64 Dtype = C.MLX_INT64
DtypeFloat16 Dtype = C.MLX_FLOAT16
DtypeFloat32 Dtype = C.MLX_FLOAT32
DtypeFloat64 Dtype = C.MLX_FLOAT64
DtypeBFloat16 Dtype = C.MLX_BFLOAT16
DtypeComplex64 Dtype = C.MLX_COMPLEX64
)
// String implements fmt.Stringer for Dtype
func (d Dtype) String() string {
switch d {
case DtypeBool:
return "bool"
case DtypeUint8:
return "u8"
case DtypeUint16:
return "u16"
case DtypeUint32:
return "u32"
case DtypeUint64:
return "u64"
case DtypeInt8:
return "i8"
case DtypeInt16:
return "i16"
case DtypeInt32:
return "i32"
case DtypeInt64:
return "i64"
case DtypeFloat16:
return "f16"
case DtypeFloat32:
return "f32"
case DtypeFloat64:
return "f64"
case DtypeBFloat16:
return "bf16"
case DtypeComplex64:
return "c64"
default:
return "unknown"
}
}
// Memory Management:
//
// All arrays are automatically tracked for cleanup. On Eval(), non-kept arrays are freed.
//
// x := mlx.Matmul(input, weight) // x is tracked for cleanup
// mlx.Keep(x) // mark x as persistent
// mlx.Eval(x) // eval + free non-kept arrays
//
// Use Keep() for arrays that should persist (weights, caches).
// Use Free() to mark a kept array for cleanup on next Eval().
//
// Note: Not goroutine-safe. Use from a single goroutine.
// Array wraps an MLX array handle.
// Arrays are freed via Eval() cleanup (deterministic) or GC (fallback).
type Array struct {
c C.mlx_array
freed bool // Prevents double-free
kept bool // If true, survives Eval() cleanup
}
// arrays tracks all live arrays. On Eval(), non-kept arrays are freed.
// Not goroutine-safe.
var arrays = make([]*Array, 0, 4096)
// evalHandles is a pre-allocated slice for passing arrays to MLX eval.
var evalHandles = make([]C.mlx_array, 0, 64)
// arrayPool reduces allocations for intermediate arrays
var arrayPool = sync.Pool{
New: func() any { return &Array{} },
}
func newArray(array C.mlx_array) *Array {
// In compiled closures, MLX manages memory - skip Go tracking
if InClosureCallback() {
return &Array{c: array}
}
// Use pooled Array struct for efficiency
a := arrayPool.Get().(*Array)
a.c = array
a.freed = false
a.kept = false
// Track in global list
arrays = append(arrays, a)
return a
}
// Collect uses reflection to find all *Array fields in a struct (recursively).
// Use this to automatically gather model weights, cache state, etc.
func Collect(v any) []*Array {
var arrays []*Array
seen := make(map[uintptr]bool)
collect(reflect.ValueOf(v), &arrays, seen)
return arrays
}
func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
if !v.IsValid() {
return
}
// Handle pointers
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return
}
// Avoid infinite loops
ptr := v.Pointer()
if seen[ptr] {
return
}
seen[ptr] = true
// Check if it's *Array
if arr, ok := v.Interface().(*Array); ok {
if arr != nil && arr.c.ctx != nil {
*arrays = append(*arrays, arr)
}
return
}
collect(v.Elem(), arrays, seen)
return
}
// Handle structs
if v.Kind() == reflect.Struct {
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if field.CanInterface() {
collect(field, arrays, seen)
}
}
return
}
// Handle slices
if v.Kind() == reflect.Slice {
for i := 0; i < v.Len(); i++ {
collect(v.Index(i), arrays, seen)
}
return
}
// Handle maps
if v.Kind() == reflect.Map {
for _, key := range v.MapKeys() {
collect(v.MapIndex(key), arrays, seen)
}
return
}
// Handle interfaces
if v.Kind() == reflect.Interface {
if !v.IsNil() {
collect(v.Elem(), arrays, seen)
}
return
}
}
// FreeStruct releases all *Array fields in a struct (recursively).
// Use this to free model weights when unloading a model.
func FreeStruct(v any) {
for _, arr := range Collect(v) {
arr.Free()
}
}
// Keep marks arrays to persist across Eval() cleanup.
// Kept arrays will NOT be freed when Eval() runs cleanup.
func Keep(arrays ...*Array) {
for _, a := range arrays {
if a != nil {
a.kept = true
}
}
}
// cleanup frees non-kept arrays and compacts the live array list.
// Returns number of arrays freed.
func cleanup() int {
freed := 0
n := 0
for _, a := range arrays {
if a.kept {
arrays[n] = a
n++
} else if a.c.ctx != nil && !a.freed {
C.mlx_array_free(a.c)
a.c.ctx = nil
arrayPool.Put(a)
freed++
}
}
arrays = arrays[:n]
return freed
}
// DebugArrays prints summary info about all tracked arrays.
func DebugArrays() {
var totalBytes int64
var keptCount, unkeptCount int
for _, a := range arrays {
if a.kept {
keptCount++
} else {
unkeptCount++
}
totalBytes += a.Nbytes()
}
fmt.Printf("[DEBUG] Arrays: %d kept, %d unkept, %.2f GB total\n",
keptCount, unkeptCount, float64(totalBytes)/(1024*1024*1024))
}
// DebugArraysVerbose prints detailed info about all tracked arrays, sorted by size.
func DebugArraysVerbose(topN int) {
type arrayInfo struct {
shape []int32
dtype Dtype
bytes int64
kept bool
}
var infos []arrayInfo
var totalBytes int64
for _, a := range arrays {
bytes := a.Nbytes()
infos = append(infos, arrayInfo{
shape: a.Shape(),
dtype: a.Dtype(),
bytes: bytes,
kept: a.kept,
})
totalBytes += bytes
}
// Sort by size descending
for i := 0; i < len(infos)-1; i++ {
for j := i + 1; j < len(infos); j++ {
if infos[j].bytes > infos[i].bytes {
infos[i], infos[j] = infos[j], infos[i]
}
}
}
fmt.Printf("[DEBUG] %d arrays, %.2f GB total:\n", len(infos), float64(totalBytes)/(1024*1024*1024))
for i, info := range infos {
if i >= topN {
break
}
keptStr := ""
if info.kept {
keptStr = " [kept]"
}
fmt.Printf(" %3d. %8.2f MB %v %v%s\n",
i+1, float64(info.bytes)/(1024*1024), info.shape, info.dtype, keptStr)
}
}
// Eval synchronously evaluates arrays and cleans up non-kept arrays.
// Outputs are automatically kept (survive cleanup). Returns them for chaining.
func Eval(outputs ...*Array) []*Array {
// Keep outputs so cleanup doesn't free them
for _, o := range outputs {
if o != nil {
o.kept = true
}
}
// Cleanup non-kept arrays
cleanup()
// Then evaluate
if len(outputs) > 0 {
evalHandles = evalHandles[:0]
for _, o := range outputs {
if o != nil {
evalHandles = append(evalHandles, o.c)
}
}
if len(evalHandles) > 0 {
vec := C.mlx_vector_array_new_data(&evalHandles[0], C.size_t(len(evalHandles)))
C.mlx_eval(vec)
C.mlx_vector_array_free(vec)
}
}
return outputs
}
// AsyncEval dispatches async evaluation and cleans up non-kept arrays.
// Outputs are automatically kept (survive cleanup).
func AsyncEval(outputs ...*Array) {
// Keep outputs so cleanup doesn't free them
for _, o := range outputs {
if o != nil {
o.kept = true
}
}
// Cleanup non-kept arrays
cleanup()
// Then dispatch async eval
if len(outputs) > 0 {
evalHandles = evalHandles[:0]
for _, o := range outputs {
if o != nil {
evalHandles = append(evalHandles, o.c)
}
}
if len(evalHandles) > 0 {
vec := C.mlx_vector_array_new_data(&evalHandles[0], C.size_t(len(evalHandles)))
C.mlx_async_eval(vec)
C.mlx_vector_array_free(vec)
}
}
}
// Sync waits for all async operations to complete (no cleanup).
func Sync() {
C.mlx_synchronize(C.default_stream())
}
// Free marks this array for cleanup on the next Eval().
// The array is not immediately freed - cleanup happens during Eval().
//
// Pattern for loops:
//
// oldLatents.Free() // mark for cleanup
// mlx.Eval(newLatents) // frees old, evals new
func (a *Array) Free() {
if a != nil {
a.kept = false
}
}
// Eval evaluates this single array and runs cleanup.
func (a *Array) Eval() *Array {
Eval(a)
return a
}
// Valid returns true if the array hasn't been freed.
func (a *Array) Valid() bool {
return a != nil && a.c.ctx != nil
}
func int32ToCInt(s []int32) *C.int {
if len(s) == 0 {
return nil
}
return (*C.int)(unsafe.Pointer(&s[0]))
}
// NewArray creates a new MLX array from float32 data
func NewArray(data []float32, shape []int32) *Array {
handle := C.mlx_array_new_data(
unsafe.Pointer(&data[0]),
int32ToCInt(shape),
C.int(len(shape)),
C.MLX_FLOAT32,
)
return newArray(handle)
}
// NewArrayInt32 creates a new MLX array from int32 data
func NewArrayInt32(data []int32, shape []int32) *Array {
handle := C.mlx_array_new_data(
unsafe.Pointer(&data[0]),
int32ToCInt(shape),
C.int(len(shape)),
C.MLX_INT32,
)
return newArray(handle)
}
// NewArrayFloat32 creates a new float32 array from data
func NewArrayFloat32(data []float32, shape []int32) *Array {
return NewArray(data, shape)
}
// Zeros creates an array of zeros with optional dtype (default float32)
func Zeros(shape []int32, dtype ...Dtype) *Array {
res := C.mlx_array_new()
dt := DtypeFloat32
if len(dtype) > 0 {
dt = dtype[0]
}
C.mlx_zeros(&res, int32ToCInt(shape), C.size_t(len(shape)), C.mlx_dtype(dt), C.default_stream())
return newArray(res)
}
// ZerosLike creates a zeros array with the same dtype as a.
// If shape is provided, uses that shape; otherwise uses a's shape.
func ZerosLike(a *Array, shape ...int32) *Array {
res := C.mlx_array_new()
if len(shape) == 0 {
C.mlx_zeros_like(&res, a.c, C.default_stream())
} else {
dtype := a.Dtype()
C.mlx_zeros(&res, int32ToCInt(shape), C.size_t(len(shape)), C.mlx_dtype(dtype), C.default_stream())
}
return newArray(res)
}
// Ones creates an array of ones
func Ones(shape ...int32) *Array {
res := C.mlx_array_new()
C.mlx_ones(&res, int32ToCInt(shape), C.size_t(len(shape)), C.MLX_FLOAT32, C.default_stream())
return newArray(res)
}
// Full creates an array filled with a value
func Full(value float32, shape ...int32) *Array {
vals := C.mlx_array_new_float(C.float(value))
res := C.mlx_array_new()
C.mlx_full(&res, int32ToCInt(shape), C.size_t(len(shape)), vals, C.MLX_FLOAT32, C.default_stream())
C.mlx_array_free(vals)
return newArray(res)
}
// Arange creates a range of values
func Arange(start, stop, step float32) *Array {
res := C.mlx_array_new()
C.mlx_arange(&res, C.double(start), C.double(stop), C.double(step), C.MLX_FLOAT32, C.default_stream())
return newArray(res)
}
// Linspace creates evenly spaced values
func Linspace(start, stop float32, steps int32) *Array {
res := C.mlx_array_new()
C.mlx_linspace(&res, C.double(start), C.double(stop), C.int(steps), C.MLX_FLOAT32, C.default_stream())
return newArray(res)
}
// ============ Math Operations ============
// Add adds two arrays element-wise
func Add(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_add(&res, a.c, b.c, C.default_stream())
return newArray(res)
}
// AddRaw is like Add - kept for API compatibility (now identical to Add)
func AddRaw(a, b *Array) *Array {
return Add(a, b)
}
// Sub subtracts two arrays element-wise
func Sub(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_subtract(&res, a.c, b.c, C.default_stream())
return newArray(res)
}
// Mul multiplies two arrays element-wise
func Mul(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_multiply(&res, a.c, b.c, C.default_stream())
return newArray(res)
}
// Div divides two arrays element-wise
func Div(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_divide(&res, a.c, b.c, C.default_stream())
return newArray(res)
}
// Matmul performs matrix multiplication
func Matmul(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_matmul(&res, a.c, b.c, C.default_stream())
return newArray(res)
}
// AddMM computes: result = beta*c + alpha*(a @ b)
// This fuses bias addition with matmul into a single op.
func AddMM(c, a, b *Array, alpha, beta float32) *Array {
res := C.mlx_array_new()
C.mlx_addmm(&res, c.c, a.c, b.c, C.float(alpha), C.float(beta), C.default_stream())
return newArray(res)
}
// Linear performs matrix multiplication: a @ weight
func Linear(a, weight *Array) *Array {
return Matmul(a, weight)
}
// Sqrt computes element-wise square root
func Sqrt(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_sqrt(&res, a.c, C.default_stream())
return newArray(res)
}
// RSqrt computes element-wise reciprocal square root
func RSqrt(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_rsqrt(&res, a.c, C.default_stream())
return newArray(res)
}
// Erf computes element-wise error function
func Erf(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_erf(&res, a.c, C.default_stream())
return newArray(res)
}
// Exp computes element-wise exponential
func Exp(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_exp(&res, a.c, C.default_stream())
return newArray(res)
}
// Log computes element-wise natural logarithm
func Log(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_log(&res, a.c, C.default_stream())
return newArray(res)
}
// Sin computes element-wise sine
func Sin(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_sin(&res, a.c, C.default_stream())
return newArray(res)
}
// Cos computes element-wise cosine
func Cos(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_cos(&res, a.c, C.default_stream())
return newArray(res)
}
// Neg negates the array
func Neg(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_negative(&res, a.c, C.default_stream())
return newArray(res)
}
// Abs computes element-wise absolute value
func Abs(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_abs(&res, a.c, C.default_stream())
return newArray(res)
}
// Square computes element-wise square
func Square(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_square(&res, a.c, C.default_stream())
return newArray(res)
}
// Pow raises a to the power of b element-wise
func Pow(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_power(&res, a.c, b.c, C.default_stream())
return newArray(res)
}
// Max computes element-wise maximum
func Max(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_maximum(&res, a.c, b.c, C.default_stream())
return newArray(res)
}
// Min computes element-wise minimum
func Min(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_minimum(&res, a.c, b.c, C.default_stream())
return newArray(res)
}
// scalarWithDtype creates a scalar array matching the dtype of a (critical for graph fusion!)
func scalarWithDtype(s float32, a *Array) C.mlx_array {
// Create float32 scalar, then cast to match input dtype
f32 := C.mlx_array_new_float(C.float(s))
dtype := a.Dtype()
if dtype == DtypeFloat32 {
return f32 // No cast needed
}
// Cast to match input dtype
casted := C.mlx_array_new()
C.mlx_astype(&casted, f32, C.mlx_dtype(dtype), C.default_stream())
C.mlx_array_free(f32)
return casted
}
// AddScalar adds a scalar to an array (matches dtype for graph fusion)
func AddScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
res := C.mlx_array_new()
C.mlx_add(&res, a.c, scalar, C.default_stream())
C.mlx_array_free(scalar)
return newArray(res)
}
// MulScalar multiplies an array by a scalar (matches dtype for graph fusion)
func MulScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
res := C.mlx_array_new()
C.mlx_multiply(&res, a.c, scalar, C.default_stream())
C.mlx_array_free(scalar)
return newArray(res)
}
// DivScalar divides an array by a scalar (matches dtype for graph fusion)
func DivScalar(a *Array, s float32) *Array {
scalar := scalarWithDtype(s, a)
res := C.mlx_array_new()
C.mlx_divide(&res, a.c, scalar, C.default_stream())
C.mlx_array_free(scalar)
return newArray(res)
}
// DivScalarInt divides an int array by an int scalar (regular division, may return float)
func DivScalarInt(a *Array, s int32) *Array {
scalar := C.mlx_array_new_int(C.int(s))
res := C.mlx_array_new()
C.mlx_divide(&res, a.c, scalar, C.default_stream())
C.mlx_array_free(scalar)
return newArray(res)
}
// FloorDivideScalar performs integer floor division (a // s), preserving int dtype
func FloorDivideScalar(a *Array, s int32) *Array {
scalar := C.mlx_array_new_int(C.int(s))
res := C.mlx_array_new()
C.mlx_floor_divide(&res, a.c, scalar, C.default_stream())
C.mlx_array_free(scalar)
return newArray(res)
}
// ============ Reduction Operations ============
// Sum reduces along an axis
func Sum(a *Array, axis int, keepdims bool) *Array {
res := C.mlx_array_new()
C.mlx_sum_axis(&res, a.c, C.int(axis), C._Bool(keepdims), C.default_stream())
return newArray(res)
}
// SumAll reduces the entire array to a scalar
func SumAll(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_sum(&res, a.c, false, C.default_stream())
return newArray(res)
}
// Mean reduces along an axis
func Mean(a *Array, axis int, keepdims bool) *Array {
res := C.mlx_array_new()
C.mlx_mean_axis(&res, a.c, C.int(axis), C._Bool(keepdims), C.default_stream())
return newArray(res)
}
// MeanAll reduces the entire array to a scalar
func MeanAll(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_mean(&res, a.c, false, C.default_stream())
return newArray(res)
}
// Var computes variance along an axis
func Var(a *Array, axis int, keepdims bool) *Array {
res := C.mlx_array_new()
C.mlx_var_axis(&res, a.c, C.int(axis), C._Bool(keepdims), 0, C.default_stream())
return newArray(res)
}
// Argmax returns indices of maximum values along an axis
func Argmax(a *Array, axis int, keepdims bool) *Array {
res := C.mlx_array_new()
C.mlx_argmax_axis(&res, a.c, C.int(axis), C._Bool(keepdims), C.default_stream())
return newArray(res)
}
// ArgmaxAll returns the index of the maximum element (flattened).
// Triggers cleanup of non-kept arrays.
func ArgmaxAll(a *Array) int32 {
cleanup()
// Flatten, then argmax with keepdims=false
flat := C.mlx_array_new()
C.mlx_flatten(&flat, a.c, 0, -1, C.default_stream())
res := C.mlx_array_new()
C.mlx_argmax(&res, flat, false, C.default_stream())
C.mlx_array_eval(res)
var val C.int32_t
C.mlx_array_item_int32(&val, res)
C.mlx_array_free(flat)
C.mlx_array_free(res)
return int32(val)
}
// Reshape reshapes the array
func Reshape(a *Array, shape ...int32) *Array {
res := C.mlx_array_new()
C.mlx_reshape(&res, a.c, int32ToCInt(shape), C.size_t(len(shape)), C.default_stream())
return newArray(res)
}
// Transpose permutes the dimensions
func Transpose(a *Array, axes ...int) *Array {
cAxes := make([]C.int, len(axes))
for i, ax := range axes {
cAxes[i] = C.int(ax)
}
res := C.mlx_array_new()
C.mlx_transpose_axes(&res, a.c, &cAxes[0], C.size_t(len(axes)), C.default_stream())
return newArray(res)
}
// AsStrided creates a view with custom strides. Useful for fusing reshape+transpose.
func AsStrided(a *Array, shape []int32, strides []int64, offset int64) *Array {
cShape := make([]C.int, len(shape))
for i, s := range shape {
cShape[i] = C.int(s)
}
cStrides := make([]C.int64_t, len(strides))
for i, s := range strides {
cStrides[i] = C.int64_t(s)
}
res := C.mlx_array_new()
C.mlx_as_strided(&res, a.c, &cShape[0], C.size_t(len(shape)), &cStrides[0], C.size_t(len(strides)), C.size_t(offset), C.default_stream())
return newArray(res)
}
// ExpandDims adds a dimension at the specified axis
func ExpandDims(a *Array, axis int) *Array {
res := C.mlx_array_new()
C.mlx_expand_dims(&res, a.c, C.int(axis), C.default_stream())
return newArray(res)
}
// Squeeze removes a dimension at the specified axis
func Squeeze(a *Array, axis int) *Array {
res := C.mlx_array_new()
C.mlx_squeeze_axis(&res, a.c, C.int(axis), C.default_stream())
return newArray(res)
}
// Flatten flattens the array to 1D
func Flatten(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_flatten(&res, a.c, 0, -1, C.default_stream())
return newArray(res)
}
// FlattenRange flattens consecutive axes from start_axis to end_axis (intermediates)
func FlattenRange(a *Array, startAxis, endAxis int) *Array {
res := C.mlx_array_new()
C.mlx_flatten(&res, a.c, C.int(startAxis), C.int(endAxis), C.default_stream())
return newArray(res)
}
// View reinterprets the array with a new dtype (no data copy)
func View(a *Array, dtype int) *Array {
res := C.mlx_array_new()
C.mlx_view(&res, a.c, C.mlx_dtype(dtype), C.default_stream())
return newArray(res)
}
// Contiguous returns a contiguous copy of the array
func Contiguous(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_contiguous(&res, a.c, true, C.default_stream())
return newArray(res)
}
// Clip clips values to [min, max]. Pass nil for no bound on that side.
func Clip(a *Array, aMin, aMax *Array) *Array {
res := C.mlx_array_new()
var minH, maxH C.mlx_array
if aMin != nil {
minH = aMin.c
}
if aMax != nil {
maxH = aMax.c
}
C.mlx_clip(&res, a.c, minH, maxH, C.default_stream())
return newArray(res)
}
// ClipScalar clips array values using scalar bounds (matches dtype for graph fusion)
// Pass math.NaN() or set hasMin/hasMax to false for unbounded
func ClipScalar(a *Array, minVal, maxVal float32, hasMin, hasMax bool) *Array {
var minArr, maxArr C.mlx_array
if hasMin {
minArr = scalarWithDtype(minVal, a)
}
if hasMax {
maxArr = scalarWithDtype(maxVal, a)
}
res := C.mlx_array_new()
C.mlx_clip(&res, a.c, minArr, maxArr, C.default_stream())
if hasMin {
C.mlx_array_free(minArr)
}
if hasMax {
C.mlx_array_free(maxArr)
}
return newArray(res)
}
// GreaterEqual returns element-wise a >= b
func GreaterEqual(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_greater_equal(&res, a.c, b.c, C.default_stream())
return newArray(res)
}
// LessArray returns element-wise a < b
func LessArray(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_less(&res, a.c, b.c, C.default_stream())
return newArray(res)
}
// LogicalAnd returns element-wise a && b
func LogicalAnd(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_logical_and(&res, a.c, b.c, C.default_stream())
return newArray(res)
}
// AllClose returns true if all elements of a and b are within tolerance.
// Uses rtol (relative tolerance) and atol (absolute tolerance):
// |a - b| <= atol + rtol * |b|
func AllClose(a, b *Array, rtol, atol float64) *Array {
res := C.mlx_array_new()
C.mlx_allclose(&res, a.c, b.c, C.double(rtol), C.double(atol), C.bool(false), C.default_stream())
return newArray(res)
}
// AllCloseEqualNaN is like AllClose but treats NaN as equal to NaN.
func AllCloseEqualNaN(a, b *Array, rtol, atol float64) *Array {
res := C.mlx_array_new()
C.mlx_allclose(&res, a.c, b.c, C.double(rtol), C.double(atol), C.bool(true), C.default_stream())
return newArray(res)
}
// ArrayEqual returns true if arrays have same shape and all elements are equal.
func ArrayEqual(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_array_equal(&res, a.c, b.c, C.bool(false), C.default_stream())
return newArray(res)
}
// ArrayEqualNaN is like ArrayEqual but treats NaN as equal to NaN.
func ArrayEqualNaN(a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_array_equal(&res, a.c, b.c, C.bool(true), C.default_stream())
return newArray(res)
}
// IsClose returns element-wise bool array indicating if values are within tolerance.
// |a - b| <= atol + rtol * |b|
func IsClose(a, b *Array, rtol, atol float64) *Array {
res := C.mlx_array_new()
C.mlx_isclose(&res, a.c, b.c, C.double(rtol), C.double(atol), C.bool(false), C.default_stream())
return newArray(res)
}
// IsCloseEqualNaN is like IsClose but treats NaN as equal to NaN.
func IsCloseEqualNaN(a, b *Array, rtol, atol float64) *Array {
res := C.mlx_array_new()
C.mlx_isclose(&res, a.c, b.c, C.double(rtol), C.double(atol), C.bool(true), C.default_stream())
return newArray(res)
}
// ReduceMax reduces array to max value over all dimensions.
func ReduceMax(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_max(&res, a.c, C.bool(false), C.default_stream())
return newArray(res)
}
// ArangeInt creates an array with values from start to stop with step and specified dtype
func ArangeInt(start, stop, step int32, dtype Dtype) *Array {
res := C.mlx_array_new()
C.mlx_arange(&res, C.double(start), C.double(stop), C.double(step), C.mlx_dtype(dtype), C.default_stream())
return newArray(res)
}
// Concatenate concatenates arrays along an axis
func Concatenate(arrays []*Array, axis int) *Array {
handles := make([]C.mlx_array, len(arrays))
for i, arr := range arrays {
handles[i] = arr.c
}
vec := C.mlx_vector_array_new_data(&handles[0], C.size_t(len(handles)))
res := C.mlx_array_new()
C.mlx_concatenate_axis(&res, vec, C.int(axis), C.default_stream())
C.mlx_vector_array_free(vec)
return newArray(res)
}
// Concat is a convenience function to concatenate two arrays
func Concat(a, b *Array, axis int) *Array {
return Concatenate([]*Array{a, b}, axis)
}
// Slice slices the array
func Slice(a *Array, start, stop []int32) *Array {
n := len(start)
cStart := make([]C.int, n)
cStop := make([]C.int, n)
cStrides := make([]C.int, n)
for i := 0; i < n; i++ {
cStart[i] = C.int(start[i])
cStop[i] = C.int(stop[i])
cStrides[i] = 1 // Default stride of 1
}
res := C.mlx_array_new()
C.mlx_slice(&res, a.c, &cStart[0], C.size_t(n), &cStop[0], C.size_t(n), &cStrides[0], C.size_t(n), C.default_stream())
return newArray(res)
}
// SliceStride slices with start:stop:stride like Python a[start:stop:stride]
func SliceStride(a *Array, start, stop, strides []int32) *Array {
cStart := make([]C.int, len(start))
cStop := make([]C.int, len(stop))
cStrides := make([]C.int, len(strides))
for i := range start {
cStart[i] = C.int(start[i])
cStop[i] = C.int(stop[i])
cStrides[i] = C.int(strides[i])
}
res := C.mlx_array_new()
C.mlx_slice(&res, a.c, &cStart[0], C.size_t(len(start)), &cStop[0], C.size_t(len(stop)), &cStrides[0], C.size_t(len(strides)), C.default_stream())
return newArray(res)
}
// Tile repeats the array along each dimension
func Tile(a *Array, reps []int32) *Array {
res := C.mlx_array_new()
C.mlx_tile(&res, a.c, int32ToCInt(reps), C.size_t(len(reps)), C.default_stream())
return newArray(res)
}
// BroadcastTo broadcasts an array to a given shape
func BroadcastTo(a *Array, shape []int32) *Array {
res := C.mlx_array_new()
C.mlx_broadcast_to(&res, a.c, int32ToCInt(shape), C.size_t(len(shape)), C.default_stream())
return newArray(res)
}
// ============ Neural Network Operations ============
// Softmax computes softmax along an axis
func Softmax(a *Array, axis int) *Array {
res := C.mlx_array_new()
C.mlx_softmax_axis(&res, a.c, C.int(axis), false, C.default_stream())
return newArray(res)
}
// Take gathers elements along an axis using indices
func Take(a *Array, indices *Array, axis int) *Array {
res := C.mlx_array_new()
C.mlx_take_axis(&res, a.c, indices.c, C.int(axis), C.default_stream())
return newArray(res)
}
// Argsort returns indices that would sort the array along an axis
func Argsort(a *Array, axis int) *Array {
res := C.mlx_array_new()
C.mlx_argsort_axis(&res, a.c, C.int(axis), C.default_stream())
return newArray(res)
}
// Sigmoid computes element-wise sigmoid
func Sigmoid(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_sigmoid(&res, a.c, C.default_stream())
return newArray(res)
}
// ReLU computes element-wise ReLU: max(0, x)
func ReLU(a *Array) *Array {
// ReLU = maximum(x, 0) - mlx-c doesn't have mlx_relu, but we can use maximum
zero := C.mlx_array_new_float(0.0)
res := C.mlx_array_new()
C.mlx_maximum(&res, a.c, zero, C.default_stream())
C.mlx_array_free(zero)
return newArray(res)
}
// SiLU computes element-wise SiLU (Swish): x * sigmoid(x)
func SiLU(a *Array) *Array {
// SiLU = x * sigmoid(x)
sig := C.mlx_array_new()
C.mlx_sigmoid(&sig, a.c, C.default_stream())
res := C.mlx_array_new()
C.mlx_multiply(&res, a.c, sig, C.default_stream())
C.mlx_array_free(sig)
return newArray(res)
}
// GELU computes element-wise GELU (Gaussian Error Linear Unit)
// GELU(x) = x * 0.5 * (1 + erf(x / sqrt(2)))
func GELU(a *Array) *Array {
sqrt2 := C.mlx_array_new_float(1.4142135623730951)
scaled := C.mlx_array_new()
C.mlx_divide(&scaled, a.c, sqrt2, C.default_stream())
erfd := C.mlx_array_new()
C.mlx_erf(&erfd, scaled, C.default_stream())
one := C.mlx_array_new_float(1.0)
erfdPlusOne := C.mlx_array_new()
C.mlx_add(&erfdPlusOne, erfd, one, C.default_stream())
half := C.mlx_array_new_float(0.5)
halfErfdPlusOne := C.mlx_array_new()
C.mlx_multiply(&halfErfdPlusOne, half, erfdPlusOne, C.default_stream())
res := C.mlx_array_new()
C.mlx_multiply(&res, a.c, halfErfdPlusOne, C.default_stream())
C.mlx_array_free(sqrt2)
C.mlx_array_free(scaled)
C.mlx_array_free(erfd)
C.mlx_array_free(one)
C.mlx_array_free(erfdPlusOne)
C.mlx_array_free(half)
C.mlx_array_free(halfErfdPlusOne)
return newArray(res)
}
// Tanh computes element-wise tanh
func Tanh(a *Array) *Array {
res := C.mlx_array_new()
C.mlx_tanh(&res, a.c, C.default_stream())
return newArray(res)
}
// RMSNorm computes RMS normalization using mlx.fast
func RMSNorm(x, weight *Array, eps float32) *Array {
res := C.mlx_array_new()
C.mlx_fast_rms_norm(&res, x.c, weight.c, C.float(eps), C.default_stream())
return newArray(res)
}
// RMSNormNoWeight applies RMS normalization without a weight
// x * rsqrt(mean(x^2) + eps)
// Uses mlx_fast_rms_norm with ones weight for f32 accumulation precision
func RMSNormNoWeight(x *Array, eps float32) *Array {
// Create weight of ones matching last dimension
lastDim := x.Shape()[len(x.Shape())-1]
ones := AsType(Full(1.0, lastDim), x.Dtype())
return RMSNorm(x, ones, eps)
}
// RoPE applies rotary position embeddings using mlx.fast
func RoPE(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
res := C.mlx_array_new()
optBase := C.mlx_optional_float{value: C.float(base), has_value: true}
C.mlx_fast_rope(&res, x.c, C.int(dims), C._Bool(traditional), optBase, C.float(scale), C.int(offset), C.mlx_array{}, C.default_stream())
return newArray(res)
}
// RoPEWithFreqs applies rotary position embeddings with custom frequencies (for YaRN)
// freqs is required - use RoPE() if you don't have custom frequencies
func RoPEWithFreqs(x, freqs *Array, dims int, traditional bool, scale float32, offset int) *Array {
res := C.mlx_array_new()
optBase := C.mlx_optional_float{has_value: false} // No base when using freqs
C.mlx_fast_rope(&res, x.c, C.int(dims), C._Bool(traditional), optBase, C.float(scale), C.int(offset), freqs.c, C.default_stream())
return newArray(res)
}
// ============ Indexing ============
// EmbeddingLookup performs embedding lookup (gathers from table)
// table: [vocab_size, hidden_size], indices: [batch, seq_len]
// returns: [batch, seq_len, hidden_size]
func EmbeddingLookup(table, indices *Array) *Array {
return Take(table, indices, 0)
}
// Gather gathers elements using indices - simplified to use take axis 0
func Gather(a, indices *Array) *Array {
return Take(a, indices, 0)
}
// ============ Array Properties ============
// Ndim returns the number of dimensions
func (a *Array) Ndim() int {
return int(C.mlx_array_ndim(a.c))
}
// Size returns the total number of elements
func (a *Array) Size() int {
return int(C.mlx_array_size(a.c))
}
// IsContiguous returns whether the array's data is contiguous in memory.
// Non-contiguous arrays (e.g., from SliceStride) must call Contiguous() before Data().
func (a *Array) IsContiguous() bool {
var res C.bool
C._mlx_array_is_contiguous(&res, a.c)
return bool(res)
}
// Dim returns the size of a dimension
func (a *Array) Dim(axis int) int32 {
return int32(C.mlx_array_dim(a.c, C.int(axis)))
}
// Shape returns the shape as a slice
func (a *Array) Shape() []int32 {
ndim := a.Ndim()
shape := make([]int32, ndim)
for i := 0; i < ndim; i++ {
shape[i] = a.Dim(i)
}
return shape
}
// IsValid returns true if the array hasn't been freed
func (a *Array) IsValid() bool {
return a != nil && a.c.ctx != nil
}
// Dtype returns the data type
func (a *Array) Dtype() Dtype {
return Dtype(C.mlx_array_dtype(a.c))
}
// Nbytes returns the total size in bytes
func (a *Array) Nbytes() int64 {
return int64(a.Size()) * a.Dtype().ItemSize()
}
// ItemSize returns the size in bytes of one element for this dtype
func (d Dtype) ItemSize() int64 {
switch d {
case DtypeBool, DtypeUint8, DtypeInt8:
return 1
case DtypeUint16, DtypeInt16, DtypeFloat16, DtypeBFloat16:
return 2
case DtypeUint32, DtypeInt32, DtypeFloat32:
return 4
case DtypeUint64, DtypeInt64, DtypeFloat64, DtypeComplex64:
return 8
default:
return 4
}
}
// ============ Data Access ============
// Data copies the float32 data out of the array.
// Note: For non-contiguous arrays (e.g., from SliceStride), call Contiguous() first.
// Note: Arrays of other dtypes (bf16, f16, etc) are automatically converted to float32.
// Note: Triggers cleanup of non-kept arrays.
func (a *Array) Data() []float32 {
cleanup()
size := a.Size()
if size == 0 {
return nil
}
arr := a
if a.Dtype() != DtypeFloat32 {
arr = AsType(a, DtypeFloat32)
arr.Eval()
// Cast array will be cleaned up on next Eval
}
ptr := C.mlx_array_data_float32(arr.c)
if ptr == nil {
return nil
}
data := make([]float32, size)
copy(data, unsafe.Slice((*float32)(unsafe.Pointer(ptr)), size))
return data
}
// Item returns the scalar value from a 0-dimensional array.
// Converts to float32 if necessary. Triggers cleanup.
func (a *Array) Item() float32 {
data := a.Data() // Data() calls cleanup()
if len(data) == 0 {
return 0
}
return data[0]
}
// DataInt32 copies the int32 data out of the array.
// Note: For non-contiguous arrays (e.g., from SliceStride), call Contiguous() first.
// Note: Triggers cleanup of non-kept arrays.
func (a *Array) DataInt32() []int32 {
cleanup()
size := a.Size()
if size == 0 {
return nil
}
ptr := C.mlx_array_data_int32(a.c)
if ptr == nil {
return nil
}
data := make([]int32, size)
copy(data, unsafe.Slice((*int32)(unsafe.Pointer(ptr)), size))
return data
}
// ItemInt32 gets a single scalar value efficiently (no array copy).
// Note: Triggers cleanup of non-kept arrays.
func (a *Array) ItemInt32() int32 {
cleanup()
var val C.int32_t
C.mlx_array_item_int32(&val, a.c)
return int32(val)
}
// ============ Utility ============
// String returns a string representation
func (a *Array) String() string {
shape := a.Shape()
size := a.Size()
if size <= 20 {
data := a.Data()
return fmt.Sprintf("Array(shape=%v, data=%v)", shape, data)
}
return fmt.Sprintf("Array(shape=%v, size=%d)", shape, size)
}
// ============ Safetensors Support ============
// NewArrayFromBytes creates an array from raw bytes (for safetensors)
func NewArrayFromBytes(data []byte, shape []int32, dtype Dtype) *Array {
cData := unsafe.Pointer(&data[0])
intShape := make([]C.int, len(shape))
for i, s := range shape {
intShape[i] = C.int(s)
}
handle := C.mlx_array_new_data(cData, &intShape[0], C.int(len(shape)), C.mlx_dtype(dtype))
return newArray(handle)
}
// ============ Device Control ============
// SetDefaultDeviceGPU sets the default device to GPU (Metal)
func SetDefaultDeviceGPU() {
dev := C.mlx_device_new_type(C.MLX_GPU, 0)
C.mlx_set_default_device(dev)
C.mlx_device_free(dev)
}
// SetDefaultDeviceCPU sets the default device to CPU
func SetDefaultDeviceCPU() {
dev := C.mlx_device_new_type(C.MLX_CPU, 0)
C.mlx_set_default_device(dev)
C.mlx_device_free(dev)
}
// MetalIsAvailable returns true if Metal GPU is available
func MetalIsAvailable() bool {
var available C._Bool
C.mlx_metal_is_available(&available)
return bool(available)
}
// MetalStartCapture starts a GPU trace capture to the given file path.
// The path must not already exist. Run with MTL_CAPTURE_ENABLED=1 env var.
// Open the resulting .gputrace file in Xcode for analysis.
func MetalStartCapture(path string) {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
C.mlx_metal_start_capture(cPath)
}
// MetalStopCapture stops the current GPU trace capture.
func MetalStopCapture() {
C.mlx_metal_stop_capture()
}
// GPUIsAvailable returns true if any GPU (Metal or CUDA) is available
func GPUIsAvailable() bool {
// On Linux with CUDA build, GPU is available
// On macOS, check Metal availability
if MetalIsAvailable() {
return true
}
// CUDA is available if we compiled with CUDA support (Linux)
return runtime.GOOS == "linux"
}
// GetDefaultDeviceType returns the current default device (0=CPU, 1=GPU)
func GetDefaultDeviceType() int {
var dev C.mlx_device
C.mlx_get_default_device(&dev)
var devType C.mlx_device_type
C.mlx_device_get_type(&devType, dev)
C.mlx_device_free(dev)
return int(devType)
}
// Synchronize waits for all GPU operations to complete
func Synchronize() {
C.mlx_synchronize(C.default_stream())
}
// ScaledDotProductAttention computes optimized attention using GPU kernel
// Q, K, V should be [batch, heads, seq, head_dim]
func ScaledDotProductAttention(q, k, v *Array, scale float32, causalMask bool) *Array {
res := C.mlx_array_new()
maskMode := "" // empty string for no mask
if causalMask {
maskMode = "causal"
}
cMaskMode := C.CString(maskMode)
defer C.free(unsafe.Pointer(cMaskMode))
C.mlx_fast_scaled_dot_product_attention(&res, q.c, k.c, v.c, C.float(scale), cMaskMode, C.mlx_array{}, C.mlx_array{}, C.default_stream())
return newArray(res)
}
// ScaledDotProductAttentionWithSinks computes attention with sinks support
// maskMode: "causal", "sliding_window", or "" for none
// mask: optional attention mask array (nil for none)
// sinks: attention sinks array (nil for none)
func ScaledDotProductAttentionWithSinks(q, k, v *Array, scale float32, maskMode string, mask, sinks *Array) *Array {
res := C.mlx_array_new()
cMaskMode := C.CString(maskMode)
defer C.free(unsafe.Pointer(cMaskMode))
var maskH, sinksH C.mlx_array
if mask != nil {
maskH = mask.c
}
if sinks != nil {
sinksH = sinks.c
}
C.mlx_fast_scaled_dot_product_attention(&res, q.c, k.c, v.c, C.float(scale), cMaskMode, maskH, sinksH, C.default_stream())
return newArray(res)
}
// ============ Native Safetensors Loading ============
// SafetensorsFile represents a loaded safetensors file
type SafetensorsFile struct {
arrays C.mlx_map_string_to_array
metadata C.mlx_map_string_to_string
}
// LoadSafetensorsNative loads a safetensors file using MLX's optimized loader
// Note: Uses CPU stream because Load primitive only runs on CPU
func LoadSafetensorsNative(path string) (*SafetensorsFile, error) {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
var arrays C.mlx_map_string_to_array
var metadata C.mlx_map_string_to_string
if C.mlx_load_safetensors(&arrays, &metadata, cPath, C.cpu_stream()) != 0 {
return nil, fmt.Errorf("failed to load safetensors: %s", path)
}
return &SafetensorsFile{arrays: arrays, metadata: metadata}, nil
}
// Get retrieves a tensor by name
func (s *SafetensorsFile) Get(name string) *Array {
cName := C.CString(name)
defer C.free(unsafe.Pointer(cName))
var arr C.mlx_array
if C.mlx_map_string_to_array_get(&arr, s.arrays, cName) != 0 {
return nil
}
if arr.ctx == nil {
return nil
}
return newArray(arr)
}
// Set replaces a tensor in the map (like Python's weights[k] = v)
func (s *SafetensorsFile) Set(name string, arr *Array) {
cName := C.CString(name)
defer C.free(unsafe.Pointer(cName))
C.mlx_map_string_to_array_insert(s.arrays, cName, arr.c)
}
// Count returns the number of tensors (not directly available, would need iterator)
func (s *SafetensorsFile) Count() int {
// mlx-c doesn't have a direct count - would need to iterate
return 0
}
// Free releases the safetensors file
func (s *SafetensorsFile) Free() {
C.mlx_map_string_to_array_free(s.arrays)
C.mlx_map_string_to_string_free(s.metadata)
}
// ============ NPY Loading ============
// LoadNpy loads a numpy array from an npy file
// Note: Uses CPU stream because Load primitive only runs on CPU
func LoadNpy(path string) (*Array, error) {
cPath := C.CString(path)
defer C.free(unsafe.Pointer(cPath))
var arr C.mlx_array
if C.mlx_load(&arr, cPath, C.cpu_stream()) != 0 {
return nil, fmt.Errorf("failed to load npy: %s", path)
}
if arr.ctx == nil {
return nil, fmt.Errorf("failed to load npy: %s", path)
}
return newArray(arr), nil
}
// ============ Slice Update ============
// SliceUpdate updates a slice of the array with new values
func SliceUpdate(a, update *Array, start, stop []int32) *Array {
n := len(start)
cStart := make([]C.int, n)
cStop := make([]C.int, n)
cStrides := make([]C.int, n)
for i := 0; i < n; i++ {
cStart[i] = C.int(start[i])
cStop[i] = C.int(stop[i])
cStrides[i] = 1 // Default stride of 1
}
res := C.mlx_array_new()
C.mlx_slice_update(&res, a.c, update.c, &cStart[0], C.size_t(n), &cStop[0], C.size_t(n), &cStrides[0], C.size_t(n), C.default_stream())
return newArray(res)
}
// SliceUpdateInplace updates a slice and returns a new array.
// Note: Despite the name, this is NOT in-place - MLX arrays are immutable.
// The caller must use the returned value.
func SliceUpdateInplace(a, update *Array, start, stop []int32) *Array {
return SliceUpdate(a, update, start, stop)
}
// ============ Optimized Operations ============
// SampleArgmax gets the last logit position and returns argmax (fused operation)
func SampleArgmax(logits *Array) int32 {
result := Argmax(logits, -1, false)
return result.ItemInt32()
}
// ArgmaxKeepArray returns argmax as an Array (for pipelining, no sync)
// This is like mlx-lm's sampler that returns y as an array, not .item()
func ArgmaxKeepArray(logits *Array) *Array {
// For greedy decoding: logits shape is [1, 1, vocab]
// We want argmax over vocab dimension, return shape []
return Argmax(logits, -1, false)
}
// RandomState is the global PRNG state, analogous to mx.random.state in Python.
// It's a slice containing a single key array. Random functions use and update this state.
//
// Thread safety: Protected by randomStateMu, mimicking Python's GIL behavior.
// All random functions that use global state acquire this lock.
var RandomState = []*Array{nil}
var randomStateMu sync.Mutex
func init() {
// Lock main goroutine to OS thread for CUDA context stability.
// CUDA contexts are bound to threads; Go can migrate goroutines between threads.
runtime.LockOSThread()
RandomState[0] = RandomKey(uint64(time.Now().UnixMilli()))
Keep(RandomState[0]) // Global state should persist
}
// RandomKey creates a PRNG key from a seed
func RandomKey(seed uint64) *Array {
var res C.mlx_array
C.mlx_random_key(&res, C.uint64_t(seed))
return newArray(res)
}
// RandomSplit splits a PRNG key into two new keys
func RandomSplit(key *Array) (*Array, *Array) {
var key1, key2 C.mlx_array
C.mlx_random_split(&key1, &key2, key.c, C.default_stream())
return newArray(key1), newArray(key2)
}
// RandomCategoricalWithKey samples from categorical distribution using provided key.
func RandomCategoricalWithKey(logits, key *Array, axis int, numSamples int) *Array {
res := C.mlx_array_new()
C.mlx_random_categorical_num_samples(&res, logits.c, C.int(axis), C.int(numSamples), key.c, C.default_stream())
return newArray(res)
}
// RandomCategorical samples using global RandomState.
// For simple scripts - production code should use RandomCategoricalWithKey with explicit key management.
func RandomCategorical(logits *Array, axis int, numSamples int) *Array {
randomStateMu.Lock()
oldKey := RandomState[0]
key1, key2 := RandomSplit(oldKey)
Keep(key1) // key1 becomes the new global state
oldKey.Free()
RandomState[0] = key1
randomStateMu.Unlock()
return RandomCategoricalWithKey(logits, key2, axis, numSamples)
}
// RandomNormal creates a random normal (Gaussian) tensor
func RandomNormal(shape []int32, seed uint64) *Array {
key := RandomKey(seed)
res := C.mlx_array_new()
C.mlx_random_normal(&res, int32ToCInt(shape), C.size_t(len(shape)), C.MLX_FLOAT32, 0.0, 1.0, key.c, C.default_stream())
return newArray(res)
}
// RandomUniform generates uniform random values in [0, 1) with the given shape
func RandomUniform(shape []int32, seed uint64) *Array {
key := RandomKey(seed)
low := C.mlx_array_new_float(0.0)
high := C.mlx_array_new_float(1.0)
res := C.mlx_array_new()
C.mlx_random_uniform(&res, low, high, int32ToCInt(shape), C.size_t(len(shape)), C.MLX_FLOAT32, key.c, C.default_stream())
C.mlx_array_free(low)
C.mlx_array_free(high)
return newArray(res)
}
// Conv2d performs 2D convolution
// input: [N, H, W, C], weight: [O, kH, kW, C] (MLX uses NHWC layout)
// Returns: [N, H', W', O]
func Conv2d(input, weight *Array, stride, padding int32) *Array {
res := C.mlx_array_new()
C.mlx_conv2d(&res, input.c, weight.c, C.int(stride), C.int(stride), C.int(padding), C.int(padding), 1, 1, 1, C.default_stream())
return newArray(res)
}
// Conv3d performs 3D convolution
// input: [N, D, H, W, C], weight: [O, kD, kH, kW, C] (MLX uses NDHWC layout)
// Returns: [N, D', H', W', O]
func Conv3d(input, weight *Array, strideD, strideH, strideW, padD, padH, padW int32) *Array {
res := C.mlx_array_new()
C.mlx_conv3d(&res, input.c, weight.c, C.int(strideD), C.int(strideH), C.int(strideW), C.int(padD), C.int(padH), C.int(padW), 1, 1, 1, 1, C.default_stream())
return newArray(res)
}
// ============ Compilation Control ============
// EnableCompile enables global compilation/graph fusion
func EnableCompile() {
C.mlx_enable_compile()
}
// DisableCompile disables global compilation
func DisableCompile() {
C.mlx_disable_compile()
}
// SetCompileMode sets the compile mode
// 0=disabled, 1=no_simplify, 2=no_fuse, 3=enabled
func SetCompileMode(mode int) {
C.mlx_set_compile_mode(C.mlx_compile_mode(mode))
}
// ============ Stream Control ============
// Stream represents an MLX execution stream
type Stream struct {
c C.mlx_stream
}
// NewStream creates a new execution stream on the default device
func NewStream() *Stream {
var dev C.mlx_device
C.mlx_get_default_device(&dev)
stream := C.mlx_stream_new_device(dev)
C.mlx_device_free(dev)
return &Stream{c: stream}
}
// Free releases the stream
func (s *Stream) Free() {
if s.c.ctx != nil {
C.mlx_stream_free(s.c)
s.c.ctx = nil
}
}
// SetDefaultStream sets the default stream for operations
func SetDefaultStream(s *Stream) {
C.mlx_set_default_stream(s.c)
C.set_default_stream(s.c) // Also update our cached stream
}
// GetDefaultStream returns the current default stream
func GetDefaultStream() *Stream {
var stream C.mlx_stream
var dev C.mlx_device
C.mlx_get_default_device(&dev)
C.mlx_get_default_stream(&stream, dev)
C.mlx_device_free(dev)
return &Stream{c: stream}
}
// SynchronizeStream waits for all operations on the stream to complete
func SynchronizeStream(s *Stream) {
C.mlx_synchronize(s.c)
}
// ============ Metal Memory Control ============
// MetalGetCacheMemory returns the current cache memory usage in bytes
func MetalGetCacheMemory() uint64 {
var size C.size_t
C.mlx_get_cache_memory(&size)
return uint64(size)
}
// MetalGetPeakMemory returns the peak memory usage in bytes
func MetalGetPeakMemory() uint64 {
var size C.size_t
C.mlx_get_peak_memory(&size)
return uint64(size)
}
// MetalResetPeakMemory resets the peak memory counter
func MetalResetPeakMemory() {
C.mlx_reset_peak_memory()
}
// MetalSetWiredLimit sets the wired memory limit and returns the previous limit
// This keeps tensors pinned in GPU memory for faster access
func MetalSetWiredLimit(limit uint64) uint64 {
var prev C.size_t
C.mlx_set_wired_limit(&prev, C.size_t(limit))
return uint64(prev)
}
// MetalGetActiveMemory returns the current active memory usage in bytes
func MetalGetActiveMemory() uint64 {
var size C.size_t
C.mlx_get_active_memory(&size)
return uint64(size)
}
// ClearCache clears the MLX memory cache
func ClearCache() {
C.mlx_clear_cache()
}
// SetCacheLimit sets the free cache limit in bytes
// Setting to 0 disables caching (useful for memory-constrained generation)
// Returns the previous cache limit
func SetCacheLimit(limit uint64) uint64 {
var prev C.size_t
C.mlx_set_cache_limit(&prev, C.size_t(limit))
return uint64(prev)
}
// SetMemoryLimit sets the overall memory limit in bytes
// This is a guideline for maximum memory during graph evaluation.
// When Metal is available, defaults to 1.5x the max recommended working set.
// Returns the previous memory limit
func SetMemoryLimit(limit uint64) uint64 {
var prev C.size_t
C.mlx_set_memory_limit(&prev, C.size_t(limit))
return uint64(prev)
}
// GetMemoryLimit returns the current memory limit in bytes
func GetMemoryLimit() uint64 {
var size C.size_t
C.mlx_get_memory_limit(&size)
return uint64(size)
}
// ============ MoE Operations ============
// GatherMM performs gather matrix multiplication for MoE
// a: input, b: weight matrices
// lhsIndices, rhsIndices: optional expert selection indices (nil for none)
func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *Array {
var lhs, rhs C.mlx_array
if lhsIndices != nil {
lhs = lhsIndices.c
}
if rhsIndices != nil {
rhs = rhsIndices.c
}
res := C.mlx_array_new()
C.mlx_gather_mm(&res, a.c, b.c, lhs, rhs, C._Bool(sortedIndices), C.default_stream())
return newArray(res)
}
// GatherQMM performs quantized gather matrix multiplication for MoE
// Used for MXFP4 and other quantized MoE inference
func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, transpose bool, groupSize, bits int, mode string, sortedIndices bool) *Array {
var b, lhs, rhs C.mlx_array
if biases != nil {
b = biases.c
}
if lhsIndices != nil {
lhs = lhsIndices.c
}
if rhsIndices != nil {
rhs = rhsIndices.c
}
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
res := C.mlx_array_new()
C.mlx_gather_qmm(&res, x.c, w.c, scales.c, b, lhs, rhs, C._Bool(transpose), optGroupSize, optBits, cMode, C._Bool(sortedIndices), C.default_stream())
return newArray(res)
}
// ============ Quantization ============
// Quantize quantizes weights to specified bits per element.
// Returns (quantized_weights, scales, biases).
// groupSize: number of elements quantized together (default 64)
// bits: bits per element, 2, 4, or 8 (default 4)
// mode: "affine" (default) or "mxfp4"
func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
res := C.mlx_vector_array_new()
C.mlx_quantize(&res, w.c, optGroupSize, optBits, cMode, C.default_stream())
// Result is a vector of 3 arrays: [weights, scales, biases]
var w0, w1, w2 C.mlx_array
C.mlx_vector_array_get(&w0, res, 0)
C.mlx_vector_array_get(&w1, res, 1)
C.mlx_vector_array_get(&w2, res, 2)
C.mlx_vector_array_free(res)
return newArray(w0), newArray(w1), newArray(w2)
}
// Dequantize reconstructs weights from quantized form.
// groupSize: number of elements quantized together (default 64)
// bits: bits per element, 2, 4, or 8 (default 4)
// mode: "affine" (default) or "mxfp4"
func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
optDtype := C.mlx_optional_dtype{has_value: false}
var b C.mlx_array
if biases != nil {
b = biases.c
}
res := C.mlx_array_new()
C.mlx_dequantize(&res, w.c, scales.c, b, optGroupSize, optBits, cMode, optDtype, C.default_stream())
return newArray(res)
}
// QuantizedMatmul performs matrix multiplication with quantized weights.
// x: input tensor [batch..., in_features]
// w: quantized weights
// scales, biases: from Quantize
// transpose: if true, compute x @ w.T (typical for Linear layers)
// groupSize, bits, mode: must match what was used in Quantize
func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int, mode string) *Array {
cMode := C.CString(mode)
defer C.free(unsafe.Pointer(cMode))
optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
var b C.mlx_array
if biases != nil {
b = biases.c
}
res := C.mlx_array_new()
C.mlx_quantized_matmul(&res, x.c, w.c, scales.c, b, C._Bool(transpose), optGroupSize, optBits, cMode, C.default_stream())
return newArray(res)
}
// ============ Sorting and Top-K ============
// TopK returns the k largest elements along an axis
func TopK(a *Array, k int, axis int) *Array {
res := C.mlx_array_new()
C.mlx_topk_axis(&res, a.c, C.int(k), C.int(axis), C.default_stream())
return newArray(res)
}
// Argpartition returns indices for partial sort (k-th smallest first)
func Argpartition(a *Array, kth int, axis int) *Array {
res := C.mlx_array_new()
C.mlx_argpartition_axis(&res, a.c, C.int(kth), C.int(axis), C.default_stream())
return newArray(res)
}
// TakeAlongAxis takes elements from array using indices along axis
func TakeAlongAxis(a, indices *Array, axis int) *Array {
res := C.mlx_array_new()
C.mlx_take_along_axis(&res, a.c, indices.c, C.int(axis), C.default_stream())
return newArray(res)
}
// PutAlongAxis puts values into array at indices along axis
func PutAlongAxis(a, indices, values *Array, axis int) *Array {
res := C.mlx_array_new()
C.mlx_put_along_axis(&res, a.c, indices.c, values.c, C.int(axis), C.default_stream())
return newArray(res)
}
// Cumsum computes cumulative sum along an axis
func Cumsum(a *Array, axis int) *Array {
res := C.mlx_array_new()
C.mlx_cumsum(&res, a.c, C.int(axis), false, false, C.default_stream())
return newArray(res)
}
// Where selects elements: condition ? a : b
func Where(condition, a, b *Array) *Array {
res := C.mlx_array_new()
C.mlx_where(&res, condition.c, a.c, b.c, C.default_stream())
return newArray(res)
}
// LessScalar returns element-wise a < scalar
func LessScalar(a *Array, s float32) *Array {
scalar := C.mlx_array_new_float(C.float(s))
res := C.mlx_array_new()
C.mlx_less(&res, a.c, scalar, C.default_stream())
C.mlx_array_free(scalar)
return newArray(res)
}
// FullDtype creates an array filled with a value with specific dtype
func FullDtype(value float32, dtype Dtype, shape ...int32) *Array {
intShape := make([]C.int, len(shape))
for i, s := range shape {
intShape[i] = C.int(s)
}
vals := C.mlx_array_new_float(C.float(value))
res := C.mlx_array_new()
C.mlx_full(&res, &intShape[0], C.size_t(len(shape)), vals, C.mlx_dtype(dtype), C.default_stream())
C.mlx_array_free(vals)
return newArray(res)
}
// AsType casts an array to a different dtype
func AsType(a *Array, dtype Dtype) *Array {
res := C.mlx_array_new()
C.mlx_astype(&res, a.c, C.mlx_dtype(dtype), C.default_stream())
return newArray(res)
}
// ToBFloat16 casts an array to bfloat16
func ToBFloat16(a *Array) *Array {
return AsType(a, DtypeBFloat16)
}
// ============ VibeVoice Helper Functions ============
// NewScalarArray creates a true 0-dimensional scalar array from a float32 value
func NewScalarArray(value float32) *Array {
return newArray(C.mlx_array_new_float(C.float(value)))
}
// Global random seed counter for RandN
var randnSeedCounter uint64 = uint64(time.Now().UnixNano())
// RandN creates an array of random samples from a standard normal distribution
func RandN(shape []int32) *Array {
// Use incrementing seed for unique random values each call
seed := atomic.AddUint64(&randnSeedCounter, 1)
return RandomNormal(shape, seed)
}
// Pad pads an array with zeros
// paddings: [before_0, after_0, before_1, after_1, ...] for each dimension
func Pad(a *Array, paddings []int32) *Array {
numAxes := len(paddings) / 2
// Convert to low/high pairs
lowPad := make([]C.int, numAxes)
highPad := make([]C.int, numAxes)
for i := 0; i < numAxes; i++ {
lowPad[i] = C.int(paddings[i*2])
highPad[i] = C.int(paddings[i*2+1])
}
zero := C.mlx_array_new_float(0.0)
res := C.mlx_array_new()
// mlx_pad takes axes, low, high arrays
axes := make([]C.int, numAxes)
for i := 0; i < numAxes; i++ {
axes[i] = C.int(i)
}
cMode := C.CString("constant")
defer C.free(unsafe.Pointer(cMode))
C.mlx_pad(&res, a.c, &axes[0], C.size_t(numAxes), &lowPad[0], C.size_t(numAxes), &highPad[0], C.size_t(numAxes), zero, cMode, C.default_stream())
C.mlx_array_free(zero)
return newArray(res)
}
// Conv1d performs 1D convolution
// x: [B, L, Cin], weight: [Cout, K, Cin] (MLX uses NLC layout)
// bias: optional (nil for no bias)
func Conv1d(x, weight *Array, bias *Array, stride int32) *Array {
res := C.mlx_array_new()
C.mlx_conv1d(&res, x.c, weight.c, C.int(stride), C.int(0), C.int(1), 1, C.default_stream())
// Apply bias if provided
if bias != nil {
biased := C.mlx_array_new()
C.mlx_add(&biased, res, bias.c, C.default_stream())
C.mlx_array_free(res)
return newArray(biased)
}
return newArray(res)
}
// ConvTranspose1d performs transposed 1D convolution
// x: [B, L, Cin], weight: [Cout, K, Cin] (MLX uses NLC layout)
// bias: optional (nil for no bias)
func ConvTranspose1d(x, weight *Array, bias *Array, stride int32) *Array {
res := C.mlx_array_new()
// stride, padding, dilation, output_padding, groups
C.mlx_conv_transpose1d(&res, x.c, weight.c, C.int(stride), 0, 1, 0, 1, C.default_stream())
// Apply bias if provided
if bias != nil {
biased := C.mlx_array_new()
C.mlx_add(&biased, res, bias.c, C.default_stream())
C.mlx_array_free(res)
return newArray(biased)
}
return newArray(res)
}
// DepthwiseConv1d performs depthwise 1D convolution (groups=Cin)
// x: [B, L, C], weight: [1, K, C] (groups = C)
// bias: optional (nil for no bias)
func DepthwiseConv1d(x, weight *Array, bias *Array) *Array {
// Get number of input channels for groups
shape := x.Shape()
groups := int(shape[len(shape)-1])
res := C.mlx_array_new()
C.mlx_conv1d(&res, x.c, weight.c, 1, 0, 1, C.int(groups), C.default_stream())
// Apply bias if provided
if bias != nil {
biased := C.mlx_array_new()
C.mlx_add(&biased, res, bias.c, C.default_stream())
C.mlx_array_free(res)
return newArray(biased)
}
return newArray(res)
}
// SliceAxis extracts a slice along a specific axis
func SliceAxis(a *Array, axis int, start, stop int32) *Array {
shape := a.Shape()
// Build start and stop indices for all dimensions
starts := make([]int32, len(shape))
stops := make([]int32, len(shape))
for i := range shape {
if i == axis {
starts[i] = start
stops[i] = stop
} else {
starts[i] = 0
stops[i] = shape[i]
}
}
return Slice(a, starts, stops)
}
// Tri creates a lower triangular matrix
func Tri(n, m int32, k int) *Array {
res := C.mlx_array_new()
C.mlx_tri(&res, C.int(n), C.int(m), C.int(k), C.MLX_FLOAT32, C.default_stream())
return newArray(res)
}
//go:build mlx
package mlx
import (
"fmt"
"testing"
)
// TestBasicCleanup verifies non-kept arrays are freed and kept arrays survive.
func TestBasicCleanup(t *testing.T) {
weight := NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2})
Keep(weight)
weight.Eval()
intermediate := NewArrayFloat32([]float32{1, 1}, []int32{1, 2})
result := Matmul(intermediate, weight)
Keep(result)
// Before eval: intermediate should be valid
if !intermediate.Valid() {
t.Fatal("intermediate should be valid before Eval")
}
Eval(result)
// After eval: intermediate should be freed
if intermediate.Valid() {
t.Fatal("intermediate should be freed after Eval")
}
// Result should have correct values
data := result.Data()
if data[0] != 4 || data[1] != 6 {
t.Errorf("expected [4, 6], got %v", data)
}
// Weight should survive
if !weight.Valid() {
t.Error("weight was freed")
}
}
// TestKeptSurvives verifies kept arrays are not freed.
func TestKeptSurvives(t *testing.T) {
a := NewArrayFloat32([]float32{1, 2}, []int32{2})
b := NewArrayFloat32([]float32{3, 4}, []int32{2})
result := Add(a, b)
Keep(result)
Eval(result)
if !result.Valid() {
t.Error("kept result was freed")
}
data := result.Data()
if data[0] != 4 || data[1] != 6 {
t.Errorf("expected [4, 6], got %v", data)
}
}
// TestEvalAutoKeeps verifies Eval automatically keeps its outputs.
func TestEvalAutoKeeps(t *testing.T) {
a := NewArrayFloat32([]float32{1, 2}, []int32{2})
b := NewArrayFloat32([]float32{3, 4}, []int32{2})
result := Add(a, b)
// Don't call Keep(result) - Eval should auto-keep it
Eval(result)
// Result should survive (auto-kept by Eval)
if !result.Valid() {
t.Error("Eval output was freed - should be auto-kept")
}
// Inputs should be freed (not kept)
if a.Valid() {
t.Error("input 'a' should be freed")
}
if b.Valid() {
t.Error("input 'b' should be freed")
}
// Verify data is correct
data := result.Data()
if data[0] != 4 || data[1] != 6 {
t.Errorf("expected [4, 6], got %v", data)
}
}
// TestWeightsSurvive verifies kept arrays survive multiple Eval cycles.
func TestWeightsSurvive(t *testing.T) {
weight := NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2})
Keep(weight)
weight.Eval()
for i := 0; i < 5; i++ {
x := NewArrayFloat32([]float32{1, 1}, []int32{1, 2})
result := Matmul(x, weight)
Keep(result)
Eval(result)
}
if !weight.Valid() {
t.Error("weight was freed after multiple iterations")
}
}
// TestAsyncEvalCleanup verifies AsyncEval cleans up and dispatches.
func TestAsyncEvalCleanup(t *testing.T) {
weight := NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2}) // Identity matrix
Keep(weight)
weight.Eval()
// First async step
x1 := NewArrayFloat32([]float32{1, 2}, []int32{1, 2})
result1 := Matmul(x1, weight)
Keep(result1)
AsyncEval(result1)
// Second async step
x2 := NewArrayFloat32([]float32{3, 4}, []int32{1, 2})
result2 := Matmul(x2, weight)
Keep(result2)
AsyncEval(result2)
// Sync and verify results
result1.Eval()
d1 := result1.Data()
if d1[0] != 1 || d1[1] != 2 {
t.Errorf("result1: expected [1, 2], got %v", d1)
}
result2.Eval()
d2 := result2.Data()
if d2[0] != 3 || d2[1] != 4 {
t.Errorf("result2: expected [3, 4], got %v", d2)
}
if !weight.Valid() {
t.Error("weight was freed during async")
}
}
// TestMultiOutput verifies multiple kept arrays survive.
func TestMultiOutput(t *testing.T) {
a := NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2})
sum := Add(a, a)
prod := Mul(a, a)
Keep(sum, prod)
Eval(sum, prod)
// Both kept arrays should be valid
if !sum.Valid() || !prod.Valid() {
t.Error("kept arrays should survive cleanup")
}
// Verify values
sumData := sum.Data()
prodData := prod.Data()
if sumData[0] != 2 || prodData[0] != 1 {
t.Errorf("unexpected results: sum=%v prod=%v", sumData, prodData)
}
}
// TestChaining verifies output from one step can be used in next.
func TestChaining(t *testing.T) {
weight := NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2})
Keep(weight)
weight.Eval()
// First step
x := NewArrayFloat32([]float32{1, 2}, []int32{1, 2})
out1 := Matmul(x, weight)
Keep(out1)
AsyncEval(out1)
// Second step uses output of first
out2 := Add(out1, out1)
Keep(out2)
Eval(out2)
// out1 should survive (was kept)
if !out1.Valid() {
t.Error("out1 was freed but used by second step")
}
// Final result should be correct
data := out2.Data()
if data[0] != 2 || data[1] != 4 {
t.Errorf("expected [2, 4], got %v", data)
}
}
// TestGenerationLoop simulates the LLM generation pattern with cache.
func TestGenerationLoop(t *testing.T) {
weight := NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2})
Keep(weight)
weight.Eval()
// Simulate cache - starts as zeros
cache := NewArrayFloat32([]float32{0, 0}, []int32{1, 2})
Keep(cache)
cache.Eval()
var lastToken *Array
// Simulate 5 generation steps
for step := 0; step < 5; step++ {
oldCache := cache
// Simulate forward pass
input := NewArrayFloat32([]float32{float32(step + 1), float32(step + 2)}, []int32{1, 2})
output := Matmul(input, weight)
// Simulate cache update
newCache := Add(output, cache)
// Mark what survives
Keep(output, newCache)
if step < 4 {
AsyncEval(output, newCache)
} else {
Eval(output, newCache)
}
// Free old cache, update references
oldCache.Free()
lastToken = output
cache = newCache
}
// Token output should be valid
if !lastToken.Valid() {
t.Error("token output was freed")
}
// Cache should be valid
if !cache.Valid() {
t.Error("cache was freed")
}
// Weight should survive all iterations
if !weight.Valid() {
t.Error("weight was freed")
}
}
// BenchmarkCleanupOnly isolates cleanup cost without MLX ops.
func BenchmarkCleanupOnly(b *testing.B) {
// Pre-create weight
weight := NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2})
Keep(weight)
weight.Eval()
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Create 100 arrays - minimal ops
arrays := make([]*Array, 100)
for j := range arrays {
arrays[j] = NewArrayFloat32([]float32{1, 2}, []int32{1, 2})
}
Keep(arrays[0])
Eval() // Just cleanup
}
}
// BenchmarkNewArrayOnly measures array creation overhead.
func BenchmarkNewArrayOnly(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2})
}
}
// BenchmarkCGOCallOverhead measures raw CGO call cost.
func BenchmarkCGOCallOverhead(b *testing.B) {
arr := NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2})
Keep(arr)
arr.Eval()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = arr.Ndim() // Simple CGO call
}
}
// BenchmarkCleanup_50 measures cleanup with 50 arrays.
func BenchmarkCleanup_50(b *testing.B) {
benchCleanup(b, 50)
}
// BenchmarkCleanup_500 measures cleanup with 500 arrays (LLM scale).
func BenchmarkCleanup_500(b *testing.B) {
benchCleanup(b, 500)
}
// BenchmarkCleanup_1000 measures cleanup with 1000 arrays.
func BenchmarkCleanup_1000(b *testing.B) {
benchCleanup(b, 1000)
}
func benchCleanup(b *testing.B, numArrays int) {
weight := NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2})
Keep(weight)
weight.Eval()
b.ResetTimer()
for i := 0; i < b.N; i++ {
x := NewArrayFloat32([]float32{1, 2}, []int32{1, 2})
for j := 0; j < numArrays; j++ {
x = Add(x, x)
}
result := Matmul(x, weight)
Keep(result)
Eval(result)
}
}
// BenchmarkGenerationLoop_10 simulates 10 token generation steps.
func BenchmarkGenerationLoop_10(b *testing.B) {
benchGenerationLoop(b, 10)
}
// BenchmarkGenerationLoop_100 simulates 100 token generation steps.
func BenchmarkGenerationLoop_100(b *testing.B) {
benchGenerationLoop(b, 100)
}
func benchGenerationLoop(b *testing.B, steps int) {
weight := NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2})
Keep(weight)
weight.Eval()
b.ResetTimer()
for i := 0; i < b.N; i++ {
cache := NewArrayFloat32([]float32{0, 0}, []int32{1, 2})
Keep(cache)
cache.Eval()
for step := 0; step < steps; step++ {
oldCache := cache
input := NewArrayFloat32([]float32{1, 2}, []int32{1, 2})
output := Matmul(input, weight)
newCache := Add(output, cache)
Keep(output, newCache)
if step < steps-1 {
AsyncEval(output, newCache)
} else {
Eval(output, newCache)
}
oldCache.Free()
cache = newCache
}
}
}
// BenchmarkLLMForward simulates a realistic LLM forward pass with ~500 ops.
func BenchmarkLLMForward(b *testing.B) {
// Simulate weights for 32 layers
numLayers := 32
weights := make([]*Array, numLayers*4) // q, k, v, o per layer
for i := range weights {
weights[i] = NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2})
}
Keep(weights...)
Eval(weights...)
b.ResetTimer()
for i := 0; i < b.N; i++ {
x := NewArrayFloat32([]float32{1, 2}, []int32{1, 2})
// Simulate 32 transformer layers
for layer := 0; layer < numLayers; layer++ {
// Attention block (simplified)
q := Matmul(x, weights[layer*4])
k := Matmul(x, weights[layer*4+1])
v := Matmul(x, weights[layer*4+2])
attn := Matmul(Softmax(Matmul(q, Transpose(k, 1, 0)), -1), v)
attnOut := Matmul(attn, weights[layer*4+3])
// Residual + layernorm (simplified)
x = Add(x, attnOut)
x = RMSNormNoWeight(x, 1e-5)
// FFN (simplified as single matmul)
ffn := Matmul(x, weights[layer*4])
ffn = SiLU(ffn)
x = Add(x, ffn)
}
Keep(x)
Eval(x)
}
}
// ============ Compile Tests ============
// gelu implements GELU activation: x * 0.5 * (1 + erf(x / sqrt(2)))
func gelu(x *Array) *Array {
sqrt2 := NewScalarArray(1.4142135623730951)
half := NewScalarArray(0.5)
one := NewScalarArray(1.0)
scaled := Div(x, sqrt2)
erfd := Erf(scaled)
return Mul(Mul(x, half), Add(one, erfd))
}
// TestCompileBasic verifies compiled function produces correct output.
func TestCompileBasic(t *testing.T) {
x := NewArrayFloat32([]float32{-1, 0, 1, 2}, []int32{4})
Keep(x)
x.Eval()
// Uncompiled
expected := gelu(x)
Keep(expected)
Eval(expected)
// Compiled
compiled := Compile(func(inputs []*Array) []*Array {
return []*Array{gelu(inputs[0])}
})
defer compiled.Free()
result := compiled.Call(x)[0]
Keep(result)
Eval(result)
// Compare with tolerance
expData := expected.Data()
resData := result.Data()
for i := range expData {
diff := expData[i] - resData[i]
if diff < 0 {
diff = -diff
}
if diff > 1e-5 {
t.Errorf("mismatch at %d: expected %f, got %f (diff=%e)", i, expData[i], resData[i], diff)
}
}
}
// TestCompileMultipleInputs verifies compiled function with multiple inputs.
func TestCompileMultipleInputs(t *testing.T) {
a := NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{4})
b := NewArrayFloat32([]float32{5, 6, 7, 8}, []int32{4})
Keep(a, b)
Eval(a, b)
compiled := Compile(func(inputs []*Array) []*Array {
sum := Add(inputs[0], inputs[1])
prod := Mul(inputs[0], inputs[1])
return []*Array{sum, prod}
})
defer compiled.Free()
outputs := compiled.Call(a, b)
Keep(outputs...)
Eval(outputs...)
sumData := outputs[0].Data()
prodData := outputs[1].Data()
if sumData[0] != 6 || prodData[0] != 5 {
t.Errorf("unexpected: sum[0]=%f, prod[0]=%f", sumData[0], prodData[0])
}
}
// TestCompileReuse verifies compiled function can be called multiple times.
func TestCompileReuse(t *testing.T) {
compiled := Compile(func(inputs []*Array) []*Array {
return []*Array{Add(inputs[0], inputs[0])}
})
defer compiled.Free()
for i := 0; i < 5; i++ {
x := NewArrayFloat32([]float32{float32(i)}, []int32{1})
Keep(x)
x.Eval()
result := compiled.Call(x)[0]
Keep(result)
Eval(result)
data := result.Data()
expected := float32(i * 2)
if data[0] != expected {
t.Errorf("iteration %d: expected %f, got %f", i, expected, data[0])
}
}
}
// BenchmarkGELUUncompiled benchmarks uncompiled GELU.
func BenchmarkGELUUncompiled(b *testing.B) {
x := RandomNormal([]int32{1000, 1024}, 42)
Keep(x)
x.Eval()
b.ResetTimer()
for i := 0; i < b.N; i++ {
y := x
for j := 0; j < 10; j++ {
y = gelu(y)
}
Keep(y)
Eval(y)
}
}
// BenchmarkGELUCompiled benchmarks compiled GELU.
func BenchmarkGELUCompiled(b *testing.B) {
x := RandomNormal([]int32{1000, 1024}, 42)
Keep(x)
x.Eval()
compiled := Compile(func(inputs []*Array) []*Array {
y := inputs[0]
for j := 0; j < 10; j++ {
y = gelu(y)
}
return []*Array{y}
})
defer compiled.Free()
b.ResetTimer()
for i := 0; i < b.N; i++ {
result := compiled.Call(x)
Keep(result[0])
Eval(result[0])
}
}
// TestCompileNoMemoryLeak verifies compiled functions don't leak memory.
func TestCompileNoMemoryLeak(t *testing.T) {
x := RandomNormal([]int32{100, 100}, 42)
Keep(x)
x.Eval()
compiled := Compile(func(inputs []*Array) []*Array {
y := inputs[0]
for j := 0; j < 5; j++ {
y = gelu(y)
}
return []*Array{y}
})
defer compiled.Free()
// Warmup to establish baseline
for i := 0; i < 10; i++ {
result := compiled.Call(x)
Keep(result[0])
Eval(result[0])
result[0].Free()
}
MetalResetPeakMemory()
initialMem := MetalGetActiveMemory()
for i := 0; i < 100; i++ {
result := compiled.Call(x)
Keep(result[0])
Eval(result[0])
result[0].Free()
}
Eval() // Final cleanup
finalMem := MetalGetActiveMemory()
peakMem := MetalGetPeakMemory()
// Memory should not grow significantly (allow 10MB slack for caching)
growth := int64(finalMem) - int64(initialMem)
if growth > 10*1024*1024 {
t.Errorf("memory grew by %d bytes over 100 iterations", growth)
}
t.Logf("memory: initial=%dMB, final=%dMB, peak=%dMB, growth=%dKB",
initialMem/(1<<20), finalMem/(1<<20), peakMem/(1<<20), growth/1024)
}
// TestCompileWithRandomState verifies compiled function can capture and update random state.
func TestCompileWithRandomState(t *testing.T) {
// Simulate logits for sampling
logits := NewArrayFloat32([]float32{0.1, 0.2, 0.3, 0.4}, []int32{1, 4})
Keep(logits)
logits.Eval()
// Initial random key
key := RandomKey(42)
Keep(key)
// Compile a sampling function that splits the key
compiled := Compile(func(inputs []*Array) []*Array {
logits := inputs[0]
keyIn := inputs[1]
// Split key: one for sampling, one for next iteration
key1, key2 := RandomSplit(keyIn)
// Sample from logits
sample := RandomCategoricalWithKey(logits, key2, -1, 1)
return []*Array{sample, key1}
})
defer compiled.Free()
// Run multiple sampling steps
samples := make([]int32, 10)
for i := 0; i < 10; i++ {
outputs := compiled.Call(logits, key)
Keep(outputs...)
Eval(outputs...)
samples[i] = outputs[0].ItemInt32()
key.Free()
key = outputs[1]
}
// Verify we got valid samples (0-3)
for i, s := range samples {
if s < 0 || s > 3 {
t.Errorf("sample %d out of range: %d", i, s)
}
}
t.Logf("samples: %v", samples)
// Verify samples aren't all the same (randomness works)
allSame := true
for i := 1; i < len(samples); i++ {
if samples[i] != samples[0] {
allSame = false
break
}
}
if allSame {
t.Error("all samples are the same - random state may not be updating")
}
}
// swiGLU implements the GPT-OSS custom SwiGLU activation.
func swiGLU(gate, up *Array, alpha, limit float32) *Array {
gateClipped := ClipScalar(gate, 0, limit, false, true)
upClipped := ClipScalar(up, -limit, limit, true, true)
gluScaled := MulScalar(gateClipped, alpha)
sig := Sigmoid(gluScaled)
outGlu := Mul(gateClipped, sig)
return Mul(outGlu, AddScalar(upClipped, 1.0))
}
// TestCompileSwiGLU verifies compiled SwiGLU produces correct output.
func TestCompileSwiGLU(t *testing.T) {
gate := NewArrayFloat32([]float32{-1, 0, 1, 2, 5, 10}, []int32{6})
up := NewArrayFloat32([]float32{-5, -1, 0, 1, 5, 10}, []int32{6})
Keep(gate, up)
Eval(gate, up)
const alpha float32 = 1.702
const limit float32 = 7.0
// Uncompiled
expected := swiGLU(gate, up, alpha, limit)
Keep(expected)
Eval(expected)
// Compiled
compiled := Compile(func(inputs []*Array) []*Array {
return []*Array{swiGLU(inputs[0], inputs[1], alpha, limit)}
})
defer compiled.Free()
result := compiled.Call(gate, up)[0]
Keep(result)
Eval(result)
// Compare
expData := expected.Data()
resData := result.Data()
for i := range expData {
diff := expData[i] - resData[i]
if diff < 0 {
diff = -diff
}
if diff > 1e-5 {
t.Errorf("mismatch at %d: expected %f, got %f", i, expData[i], resData[i])
}
}
t.Logf("SwiGLU results: %v", resData)
}
// BenchmarkSwiGLUUncompiled benchmarks uncompiled SwiGLU.
func BenchmarkSwiGLUUncompiled(b *testing.B) {
gate := RandomNormal([]int32{1, 2880}, 42)
up := RandomNormal([]int32{1, 2880}, 43)
Keep(gate, up)
Eval(gate, up)
const alpha float32 = 1.702
const limit float32 = 7.0
b.ResetTimer()
for i := 0; i < b.N; i++ {
result := swiGLU(gate, up, alpha, limit)
Keep(result)
Eval(result)
}
}
// BenchmarkSwiGLUCompiled benchmarks compiled SwiGLU.
func BenchmarkSwiGLUCompiled(b *testing.B) {
gate := RandomNormal([]int32{1, 2880}, 42)
up := RandomNormal([]int32{1, 2880}, 43)
Keep(gate, up)
Eval(gate, up)
const alpha float32 = 1.702
const limit float32 = 7.0
compiled := Compile(func(inputs []*Array) []*Array {
return []*Array{swiGLU(inputs[0], inputs[1], alpha, limit)}
})
defer compiled.Free()
b.ResetTimer()
for i := 0; i < b.N; i++ {
result := compiled.Call(gate, up)
Keep(result[0])
Eval(result[0])
}
}
// BenchmarkSwiGLU10xUncompiled benchmarks 10 chained SwiGLU ops uncompiled.
func BenchmarkSwiGLU10xUncompiled(b *testing.B) {
x := RandomNormal([]int32{1, 2880}, 42)
Keep(x)
x.Eval()
const alpha float32 = 1.702
const limit float32 = 7.0
b.ResetTimer()
for i := 0; i < b.N; i++ {
y := x
for j := 0; j < 10; j++ {
y = swiGLU(y, y, alpha, limit)
}
Keep(y)
Eval(y)
}
}
// BenchmarkSwiGLU10xCompiled benchmarks 10 chained SwiGLU ops compiled.
func BenchmarkSwiGLU10xCompiled(b *testing.B) {
x := RandomNormal([]int32{1, 2880}, 42)
Keep(x)
x.Eval()
const alpha float32 = 1.702
const limit float32 = 7.0
compiled := Compile(func(inputs []*Array) []*Array {
y := inputs[0]
for j := 0; j < 10; j++ {
y = swiGLU(y, y, alpha, limit)
}
return []*Array{y}
})
defer compiled.Free()
b.ResetTimer()
for i := 0; i < b.N; i++ {
result := compiled.Call(x)
Keep(result[0])
Eval(result[0])
}
}
// ============ Sampler Benchmarks ============
// sampleTopK implements top-k sampling
func sampleTopK(logits, key *Array, k int) (*Array, *Array) {
neg := Neg(logits)
indices := Argpartition(neg, k-1, -1)
topK := Slice(indices, []int32{0}, []int32{int32(k)})
values := TakeAlongAxis(logits, topK, -1)
key1, key2 := RandomSplit(key)
sampled := RandomCategoricalWithKey(values, key2, -1, 1)
return Take(topK, sampled, -1), key1
}
// sampleTopP implements top-p (nucleus) sampling
func sampleTopP(logits, key *Array, p float32, vocabSize int32) (*Array, *Array) {
sorted := Argsort(Neg(logits), -1)
sortedLogits := TakeAlongAxis(logits, sorted, -1)
probs := Softmax(sortedLogits, -1)
cumProbs := Cumsum(probs, -1)
mask := LessScalar(cumProbs, p)
negInf := FullDtype(float32(-1e9), logits.Dtype(), vocabSize)
masked := Where(mask, sortedLogits, negInf)
key1, key2 := RandomSplit(key)
sampled := RandomCategoricalWithKey(masked, key2, -1, 1)
return Take(sorted, sampled, -1), key1
}
// BenchmarkSampleTopKUncompiled benchmarks uncompiled top-k sampling.
func BenchmarkSampleTopKUncompiled(b *testing.B) {
vocabSize := int32(32000)
logits := RandomNormal([]int32{vocabSize}, 42)
key := RandomKey(42)
Keep(logits, key)
Eval(logits, key)
b.ResetTimer()
for i := 0; i < b.N; i++ {
var token *Array
token, key = sampleTopK(logits, key, 40)
Keep(token, key)
Eval(token)
}
}
// BenchmarkSampleTopKCompiled benchmarks compiled top-k sampling.
func BenchmarkSampleTopKCompiled(b *testing.B) {
vocabSize := int32(32000)
logits := RandomNormal([]int32{vocabSize}, 42)
key := RandomKey(42)
Keep(logits, key)
Eval(logits, key)
compiled := Compile(func(inputs []*Array) []*Array {
token, newKey := sampleTopK(inputs[0], inputs[1], 40)
return []*Array{token, newKey}
})
defer compiled.Free()
b.ResetTimer()
for i := 0; i < b.N; i++ {
outputs := compiled.Call(logits, key)
Keep(outputs...)
Eval(outputs[0])
key = outputs[1]
}
}
// BenchmarkSampleTopPUncompiled benchmarks uncompiled top-p sampling.
func BenchmarkSampleTopPUncompiled(b *testing.B) {
vocabSize := int32(32000)
logits := RandomNormal([]int32{vocabSize}, 42)
key := RandomKey(42)
Keep(logits, key)
Eval(logits, key)
b.ResetTimer()
for i := 0; i < b.N; i++ {
var token *Array
token, key = sampleTopP(logits, key, 0.9, vocabSize)
Keep(token, key)
Eval(token)
}
}
// BenchmarkSampleTopPCompiled benchmarks compiled top-p sampling.
func BenchmarkSampleTopPCompiled(b *testing.B) {
vocabSize := int32(32000)
logits := RandomNormal([]int32{vocabSize}, 42)
key := RandomKey(42)
Keep(logits, key)
Eval(logits, key)
compiled := Compile(func(inputs []*Array) []*Array {
token, newKey := sampleTopP(inputs[0], inputs[1], 0.9, vocabSize)
return []*Array{token, newKey}
})
defer compiled.Free()
b.ResetTimer()
for i := 0; i < b.N; i++ {
outputs := compiled.Call(logits, key)
Keep(outputs...)
Eval(outputs[0])
key = outputs[1]
}
}
// TestCompiledSamplerMemoryStable verifies compiled samplers don't leak memory.
func TestCompiledSamplerMemoryStable(t *testing.T) {
vocabSize := int32(32000)
logits := RandomNormal([]int32{vocabSize}, 42)
key := RandomKey(42)
Keep(logits, key)
Eval(logits, key)
compiledTopK := Compile(func(inputs []*Array) []*Array {
token, newKey := sampleTopK(inputs[0], inputs[1], 40)
return []*Array{token, newKey}
})
defer compiledTopK.Free()
compiledTopP := Compile(func(inputs []*Array) []*Array {
token, newKey := sampleTopP(inputs[0], inputs[1], 0.9, vocabSize)
return []*Array{token, newKey}
})
defer compiledTopP.Free()
// Warmup
for i := 0; i < 10; i++ {
out := compiledTopK.Call(logits, key)
Keep(out...)
Eval(out[0])
out[0].Free()
key = out[1]
}
MetalResetPeakMemory()
initialMem := MetalGetActiveMemory()
// Run 500 iterations of each sampler
for i := 0; i < 500; i++ {
// TopK
out := compiledTopK.Call(logits, key)
Keep(out...)
Eval(out[0])
out[0].Free()
key = out[1]
// TopP
out = compiledTopP.Call(logits, key)
Keep(out...)
Eval(out[0])
out[0].Free()
key = out[1]
}
Eval() // Final cleanup
finalMem := MetalGetActiveMemory()
peakMem := MetalGetPeakMemory()
growth := int64(finalMem) - int64(initialMem)
t.Logf("memory: initial=%dMB, final=%dMB, peak=%dMB, growth=%dKB",
initialMem/(1<<20), finalMem/(1<<20), peakMem/(1<<20), growth/1024)
// Memory should stay bounded (allow 20MB for caching overhead)
if growth > 20*1024*1024 {
t.Errorf("memory grew by %d bytes over 1000 sampler calls - possible leak!", growth)
}
}
// BenchmarkSimpleOps measures simple ops with cleanup
func BenchmarkSimpleOps(b *testing.B) {
weight := NewArrayFloat32([]float32{1, 0, 0, 1}, []int32{2, 2})
Keep(weight)
weight.Eval()
b.ResetTimer()
for i := 0; i < b.N; i++ {
x := NewArrayFloat32([]float32{1, 2}, []int32{1, 2})
result := Matmul(x, weight)
Keep(result)
AsyncEval(result)
result.Eval()
}
}
// BenchmarkLayerLike measures layer-like ops (~15 ops)
func BenchmarkLayerLike(b *testing.B) {
hidden := int32(256)
w := Ones(hidden, hidden)
Keep(w)
w.Eval()
b.ResetTimer()
for i := 0; i < b.N; i++ {
x := Ones(1, hidden)
// Simulate attention-like ops with proper shapes
h := Matmul(x, w) // [1, 256] @ [256, 256] = [1, 256]
h = Add(h, Matmul(h, w)) // residual
h = Mul(h, Sigmoid(Matmul(h, w))) // gating
h = Matmul(h, w) // output projection
h = Add(x, RMSNormNoWeight(h, 1e-5)) // residual + norm
Keep(h)
AsyncEval(h)
Eval(h)
}
}
// BenchmarkManyOps measures with increasing op counts
func BenchmarkManyOps(b *testing.B) {
w := Ones(64, 64)
Keep(w)
w.Eval()
for _, numOps := range []int{10, 50, 100, 500, 1000} {
b.Run(fmt.Sprintf("ops_%d", numOps), func(b *testing.B) {
for i := 0; i < b.N; i++ {
x := Ones(1, 64)
for j := 0; j < numOps; j++ {
x = Add(x, Matmul(x, w))
}
Keep(x)
AsyncEval(x)
Eval(x)
}
})
}
}
// BenchmarkLLMScale measures at LLM-realistic scale (~1348 arrays)
func BenchmarkLLMScale(b *testing.B) {
// Simulate Qwen-like model: 24 layers, each with ~56 ops = 1344 arrays
numLayers := 24
opsPerLayer := 56
// Create weights
hidden := int32(64)
weights := make([]*Array, numLayers*4)
for i := range weights {
weights[i] = Ones(hidden, hidden)
}
Keep(weights...)
Eval(weights...)
b.ResetTimer()
for i := 0; i < b.N; i++ {
x := Ones(1, hidden)
for layer := 0; layer < numLayers; layer++ {
for op := 0; op < opsPerLayer/4; op++ {
x = Add(x, Matmul(x, weights[layer*4]))
x = Mul(x, Sigmoid(x))
}
}
Keep(x)
AsyncEval(x)
Eval(x)
}
}
// BenchmarkArrayFreeLoop measures the cost of freeing N arrays
func BenchmarkArrayFreeLoop(b *testing.B) {
for _, count := range []int{100, 500, 1000, 1500} {
b.Run(fmt.Sprintf("arrays_%d", count), func(b *testing.B) {
for i := 0; i < b.N; i++ {
b.StopTimer()
arrays := make([]*Array, count)
for j := 0; j < count; j++ {
arrays[j] = NewArrayFloat32([]float32{1, 2, 3, 4}, []int32{2, 2})
}
b.StartTimer()
// Cleanup all arrays
Eval()
}
})
}
}
// BenchmarkCleanupIsolated measures just cleanup time
func BenchmarkCleanupIsolated(b *testing.B) {
w := NewArrayFloat32([]float32{1}, []int32{1, 1})
Keep(w)
w.Eval()
for _, count := range []int{100, 500, 1000, 1500} {
b.Run(fmt.Sprintf("arrays_%d", count), func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
x := NewArrayFloat32([]float32{1}, []int32{1})
for j := 0; j < count; j++ {
x = Add(x, x)
}
Keep(x)
b.StartTimer()
Eval() // Just cleanup
}
})
}
}
// TestMemoryStable verifies that cleanup doesn't cause unbounded memory growth.
func TestMemoryStable(t *testing.T) {
if testing.Short() {
t.Skip("skipping memory test in short mode")
}
// Create realistic-sized arrays (like KV cache)
batchSize := int32(1)
numHeads := int32(8)
seqLen := int32(256)
headDim := int32(64)
cacheShape := []int32{batchSize, numHeads, seqLen, headDim}
cacheSize := batchSize * numHeads * seqLen * headDim * 4 // float32 = 4 bytes
// Initial cache
keys := Zeros(cacheShape, DtypeFloat32)
values := Zeros(cacheShape, DtypeFloat32)
Keep(keys, values)
Eval(keys, values)
// Warmup
for i := 0; i < 5; i++ {
oldKeys, oldValues := keys, values
newKeys := Add(keys, keys)
newValues := Add(values, values)
Keep(newKeys, newValues)
Eval(newKeys, newValues)
oldKeys.Free()
oldValues.Free()
keys, values = newKeys, newValues
}
MetalResetPeakMemory()
initialMem := MetalGetActiveMemory()
// Run 100 steps
for step := 0; step < 100; step++ {
oldKeys, oldValues := keys, values
newKeys := Add(keys, keys)
newValues := Add(values, values)
Keep(newKeys, newValues)
Eval(newKeys, newValues)
oldKeys.Free()
oldValues.Free()
keys, values = newKeys, newValues
}
Eval() // Final cleanup
finalMem := MetalGetActiveMemory()
peakMem := MetalGetPeakMemory()
growth := int64(finalMem) - int64(initialMem)
expectedMaxGrowth := int64(cacheSize * 4 * 10)
t.Logf("cache size: %d bytes", cacheSize*2)
t.Logf("memory: initial=%dMB, final=%dMB, peak=%dMB, growth=%dKB",
initialMem/(1<<20), finalMem/(1<<20), peakMem/(1<<20), growth/1024)
if growth > expectedMaxGrowth {
t.Errorf("memory grew by %d bytes over 100 steps (expected max %d) - possible leak",
growth, expectedMaxGrowth)
}
}
//go:build mlx
package gemma3
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// TextConfig holds configuration for the text model
type TextConfig struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
HeadDim int32 `json:"head_dim"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
SlidingWindow int32 `json:"sliding_window"`
SlidingWindowPattern int32 `json:"sliding_window_pattern"`
// Computed fields
Scale float32 `json:"-"`
}
// TextModel is the Gemma 3 text-only model
type TextModel struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []*DecoderLayer `weight:"model.layers"`
Norm *nn.RMSNorm `weight:"model.norm"`
Output *nn.Linear `weight:"-"` // Tied to EmbedTokens, set manually
// Precomputed (1 + weight) for Gemma-style RMSNorm to avoid allocation per forward
NormScaled *mlx.Array `weight:"-"`
tok *tokenizer.Tokenizer
*TextConfig
}
// DecoderLayer is a single transformer block
type DecoderLayer struct {
InputNorm *nn.RMSNorm `weight:"input_layernorm"`
Attention *Attention
PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
PreFFNorm *nn.RMSNorm `weight:"pre_feedforward_layernorm"`
MLP *MLP
PostFFNorm *nn.RMSNorm `weight:"post_feedforward_layernorm"`
// Precomputed (1 + weight) for Gemma-style RMSNorm
InputNormScaled *mlx.Array `weight:"-"`
PostAttnNormScaled *mlx.Array `weight:"-"`
PreFFNormScaled *mlx.Array `weight:"-"`
PostFFNormScaled *mlx.Array `weight:"-"`
// Whether this layer uses sliding window attention
IsSliding bool
LayerIdx int32
}
// Attention implements Gemma 3 attention with Q/K normalization
type Attention struct {
QProj *nn.Linear `weight:"self_attn.q_proj"`
KProj *nn.Linear `weight:"self_attn.k_proj"`
VProj *nn.Linear `weight:"self_attn.v_proj"`
OProj *nn.Linear `weight:"self_attn.o_proj"`
QNorm *nn.RMSNorm `weight:"self_attn.q_norm"`
KNorm *nn.RMSNorm `weight:"self_attn.k_norm"`
// Precomputed (1 + weight) for Gemma-style RMSNorm
QNormScaled *mlx.Array `weight:"-"`
KNormScaled *mlx.Array `weight:"-"`
}
// MLP is the feed-forward network with GELU activation
type MLP struct {
GateProj *nn.Linear `weight:"mlp.gate_proj"`
UpProj *nn.Linear `weight:"mlp.up_proj"`
DownProj *nn.Linear `weight:"mlp.down_proj"`
}
// LoadText loads the text-only Gemma 3 model
func LoadText(modelPath string) (*TextModel, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg TextConfig
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
// Compute scale
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
// Set defaults if not specified
if cfg.RopeTheta == 0 {
cfg.RopeTheta = 1000000
}
if cfg.RopeLocalBaseFreq == 0 {
cfg.RopeLocalBaseFreq = 10000
}
if cfg.RMSNormEps == 0 {
cfg.RMSNormEps = 1e-6
}
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("load tokenizer: %w", err)
}
m := &TextModel{
Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
TextConfig: &cfg,
tok: tok,
}
// Initialize layer metadata
for i := range m.Layers {
m.Layers[i] = &DecoderLayer{
LayerIdx: int32(i),
IsSliding: isLayerSliding(int32(i), cfg.SlidingWindowPattern),
}
}
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
// Tied embeddings for output
m.Output = nn.NewLinear(m.EmbedTokens.Weight, nil)
mlx.Eval(mlx.Collect(m)...)
weights.ReleaseAll()
// Precompute (1 + weight) for Gemma-style RMSNorm to avoid per-forward allocation
precomputeGemmaScaledWeights(m)
return m, nil
}
// precomputeGemmaScaledWeights computes (1 + weight) for all RMSNorm layers
// This avoids creating temporary arrays on every forward pass
func precomputeGemmaScaledWeights(m *TextModel) {
m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
for _, layer := range m.Layers {
layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
}
// Eval all the precomputed weights
var scaled []*mlx.Array
scaled = append(scaled, m.NormScaled)
for _, layer := range m.Layers {
scaled = append(scaled, layer.InputNormScaled, layer.PostAttnNormScaled,
layer.PreFFNormScaled, layer.PostFFNormScaled,
layer.Attention.QNormScaled, layer.Attention.KNormScaled)
}
mlx.Eval(scaled...)
}
// isLayerSliding determines if a layer uses sliding window attention
// Pattern N means: layers 0 to N-1 sliding, N full, N+1 to 2N-1 sliding, 2N full, etc.
func isLayerSliding(layerIdx, pattern int32) bool {
if pattern <= 0 {
return false // No sliding window
}
// Layer is full attention if (layerIdx + 1) % pattern == 0
return (layerIdx+1)%pattern != 0
}
// Forward runs the text model forward pass
func (m *TextModel) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := tokens.Shape()[0], tokens.Shape()[1]
// Get embeddings and scale by sqrt(hidden_size)
h := m.EmbedTokens.Forward(tokens)
h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize))))
for i, layer := range m.Layers {
h = layer.Forward(h, caches[i], B, L, m.TextConfig)
}
// Final norm and output projection
return m.Output.Forward(mlx.RMSNorm(h, m.NormScaled, m.RMSNormEps))
}
// Forward runs a decoder layer
func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
// Pre-attention norm (use precomputed scaled weight)
normed := mlx.RMSNorm(x, l.InputNormScaled, cfg.RMSNormEps)
// Attention
attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg)
// Post-attention norm and residual
attnOut = mlx.RMSNorm(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
h := mlx.Add(x, attnOut)
// Pre-FFN norm
normed = mlx.RMSNorm(h, l.PreFFNormScaled, cfg.RMSNormEps)
// MLP
mlpOut := l.MLP.Forward(normed)
// Post-FFN norm and residual
mlpOut = mlx.RMSNorm(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
return mlx.Add(h, mlpOut)
}
// Forward runs attention with Q/K normalization
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
// Reshape to [B, num_heads, L, head_dim]
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
// Q/K normalization after reshaping (use precomputed scaled weight)
q = mlx.RMSNorm(q, a.QNormScaled, cfg.RMSNormEps)
k = mlx.RMSNorm(k, a.KNormScaled, cfg.RMSNormEps)
// Apply RoPE with appropriate theta
ropeTheta := cfg.RopeTheta
if isSliding {
ropeTheta = cfg.RopeLocalBaseFreq
}
q = mlx.RoPE(q, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
k = mlx.RoPE(k, int(cfg.HeadDim), false, ropeTheta, 1.0, c.Offset())
// Update cache
k, v = c.Update(k, v, int(L))
// Repeat K/V for GQA if needed
repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
if repeatFactor > 1 {
k = nn.RepeatKV(k, repeatFactor)
v = nn.RepeatKV(v, repeatFactor)
}
// Attention
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
// compiledGeluApprox is a singleton compiled GELU function shared across all layers
var compiledGeluApprox *mlx.CompiledFunc
// getCompiledGeluApprox returns the compiled GELU function, creating it once if needed
func getCompiledGeluApprox() *mlx.CompiledFunc {
if compiledGeluApprox == nil {
compiledGeluApprox = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
return []*mlx.Array{geluApproxImpl(inputs[0])}
}, true)
}
return compiledGeluApprox
}
// Forward runs the MLP with GELU approximation (tanh variant)
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
gate := getCompiledGeluApprox().Call(m.GateProj.Forward(x))[0]
return m.DownProj.Forward(mlx.Mul(gate, m.UpProj.Forward(x)))
}
// geluApproxImpl computes GELU using the tanh approximation (gelu_pytorch_tanh):
// 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
func geluApproxImpl(x *mlx.Array) *mlx.Array {
// Constants
const sqrt2OverPi = 0.7978845608028654 // sqrt(2/pi)
const coeff = 0.044715
// x^3
x3 := mlx.Mul(mlx.Mul(x, x), x)
// x + 0.044715 * x^3
inner := mlx.Add(x, mlx.MulScalar(x3, coeff))
// sqrt(2/pi) * (x + 0.044715 * x^3)
scaled := mlx.MulScalar(inner, sqrt2OverPi)
// tanh(...)
tanh := mlx.Tanh(scaled)
// 1 + tanh(...)
onePlusTanh := mlx.AddScalar(tanh, 1.0)
// 0.5 * x * (1 + tanh(...))
return mlx.Mul(mlx.MulScalar(x, 0.5), onePlusTanh)
}
// gemmaRMSNorm applies Gemma-style RMS normalization: x * rsqrt(mean(x^2) + eps) * (1 + weight)
// Uses mlx.RMSNorm fast kernel with pre-computed (1 + weight)
func gemmaRMSNorm(x, weight *mlx.Array, eps float32) *mlx.Array {
// Gemma uses (1 + weight) instead of weight
scaledWeight := mlx.AddScalar(weight, 1.0)
return mlx.RMSNorm(x, scaledWeight, eps)
}
// Interface methods
func (m *TextModel) NumLayers() int { return len(m.Layers) }
func (m *TextModel) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
func (m *TextModel) VocabSize() int32 { return m.TextConfig.VocabSize }
// Tokenizer returns the tokenizer wrapped to add BOS and apply chat template
func (m *TextModel) Tokenizer() *tokenizer.Tokenizer {
return m.tok
}
// FormatPrompt applies the Gemma 3 chat template to a prompt
func (m *TextModel) FormatPrompt(prompt string) string {
// Gemma 3 chat format: <start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
}
func (m *TextModel) NewCache(maxSeqLen int32) []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
if m.Layers[i].IsSliding {
// Use rotating cache for sliding window layers
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
} else {
// Use regular cache for global attention layers
caches[i] = cache.NewKVCache()
}
}
return caches
}
// Config holds config for the full multimodal model
type Config struct {
TextConfig TextConfig `json:"text_config"`
VisionConfig VisionConfig `json:"vision_config"`
// Image token config (from config.json)
BOITokenIndex int32 `json:"boi_token_index"` // <start_of_image> = 255999
EOITokenIndex int32 `json:"eoi_token_index"` // <end_of_image> = 256000
ImageTokenIndex int32 `json:"image_token_index"` // <image_soft_token> = 262144
MMTokensPerImage int32 `json:"mm_tokens_per_image"` // 256
}
// Model is the full Gemma 3 multimodal model
type Model struct {
VisionTower *VisionTower `weight:"vision_tower"`
Projector *MultiModalProjector `weight:"multi_modal_projector"`
TextModel *TextModel `weight:"language_model"`
Config *Config
tok *tokenizer.Tokenizer
}
// Load loads the full multimodal Gemma 3 model
func Load(modelPath string) (*Model, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
// Set defaults for text config (multimodal config often has incomplete text_config)
// These defaults match transformers.Gemma3TextConfig defaults
tc := &cfg.TextConfig
if tc.HeadDim == 0 {
tc.HeadDim = 256 // Gemma 3 uses head_dim=256
}
if tc.NumAttentionHeads == 0 {
// Gemma 3 4B uses 8 attention heads (cannot infer from hidden_size/head_dim)
tc.NumAttentionHeads = 8
}
if tc.NumKeyValueHeads == 0 {
// Gemma 3 4B uses 4 KV heads (GQA with 2:1 ratio)
tc.NumKeyValueHeads = 4
}
if tc.VocabSize == 0 {
tc.VocabSize = 262208 // Gemma 3 vocab size (not 262144!)
}
if tc.RopeTheta == 0 {
tc.RopeTheta = 1000000
}
if tc.RopeLocalBaseFreq == 0 {
tc.RopeLocalBaseFreq = 10000
}
if tc.RMSNormEps == 0 {
tc.RMSNormEps = 1e-6
}
if tc.SlidingWindowPattern == 0 {
tc.SlidingWindowPattern = 6
}
if tc.MaxPositionEmbeddings == 0 {
tc.MaxPositionEmbeddings = 131072 // Gemma 3 4B default
}
// Compute text model scale
tc.Scale = float32(1.0 / math.Sqrt(float64(tc.HeadDim)))
// Set defaults for image token config
if cfg.BOITokenIndex == 0 {
cfg.BOITokenIndex = 255999 // <start_of_image>
}
if cfg.EOITokenIndex == 0 {
cfg.EOITokenIndex = 256000 // <end_of_image>
}
if cfg.ImageTokenIndex == 0 {
cfg.ImageTokenIndex = 262144 // <image_soft_token>
}
if cfg.MMTokensPerImage == 0 {
cfg.MMTokensPerImage = 256
}
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("load tokenizer: %w", err)
}
m := &Model{
VisionTower: &VisionTower{
Embeddings: &VisionEmbeddings{},
Encoder: make([]*VisionEncoderLayer, cfg.VisionConfig.NumHiddenLayers),
Config: &cfg.VisionConfig,
},
Projector: &MultiModalProjector{},
TextModel: &TextModel{
Layers: make([]*DecoderLayer, cfg.TextConfig.NumHiddenLayers),
TextConfig: &cfg.TextConfig,
},
Config: &cfg,
tok: tok,
}
// Initialize text layer metadata
for i := range m.TextModel.Layers {
m.TextModel.Layers[i] = &DecoderLayer{
LayerIdx: int32(i),
IsSliding: isLayerSliding(int32(i), cfg.TextConfig.SlidingWindowPattern),
}
}
// Initialize vision encoder layers
for i := range m.VisionTower.Encoder {
m.VisionTower.Encoder[i] = &VisionEncoderLayer{}
}
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
// Tied embeddings for text output
m.TextModel.Output = nn.NewLinear(m.TextModel.EmbedTokens.Weight, nil)
m.TextModel.tok = tok
mlx.Eval(mlx.Collect(m)...)
weights.ReleaseAll()
// Precompute (1 + weight) for Gemma-style RMSNorm
precomputeGemmaScaledWeights(m.TextModel)
// Precompute projector's scaled weight
m.Projector.SoftEmbNormScaled = mlx.AddScalar(m.Projector.SoftEmbNorm.Weight, 1.0)
mlx.Eval(m.Projector.SoftEmbNormScaled)
return m, nil
}
// Forward runs the text-only forward pass
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
return m.TextModel.Forward(tokens, caches)
}
// ForwardWithImage runs the multimodal forward pass
// tokens: [B, L] input token IDs (with image placeholder tokens)
// image: [B, H, W, C] preprocessed image tensor
func (m *Model) ForwardWithImage(tokens *mlx.Array, image *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := tokens.Shape()[0], tokens.Shape()[1]
cfg := m.Config.TextConfig
// Find image token position FIRST before any eval that might free tokens
imageStartPos := int32(-1)
if image != nil && B == 1 {
tokenData := tokens.DataInt32() // This evals tokens
for i, t := range tokenData {
if t == m.Config.ImageTokenIndex {
imageStartPos = int32(i)
break
}
}
}
// Get text embeddings and scale
h := m.TextModel.EmbedTokens.Forward(tokens)
h = mlx.MulScalar(h, float32(math.Sqrt(float64(cfg.HiddenSize))))
// Process image if provided
if image != nil && imageStartPos >= 0 {
// Vision tower: [B, H, W, C] -> [B, num_patches, vision_hidden]
visionFeatures := m.VisionTower.Forward(image)
// Project to text space: [B, num_patches, vision_hidden] -> [B, 256, text_hidden]
imageEmbeds := m.Projector.Forward(visionFeatures, cfg.RMSNormEps)
// Eval h and imageEmbeds together so neither gets freed
mlx.Eval(h, imageEmbeds)
// Cast imageEmbeds to match text embeddings dtype (bf16)
if imageEmbeds.Dtype() != h.Dtype() {
imageEmbeds = mlx.AsType(imageEmbeds, h.Dtype())
mlx.Eval(imageEmbeds)
}
// Insert image embeddings at the known position
h = m.insertImageEmbeddingsAt(h, imageEmbeds, imageStartPos)
}
// Run through text model layers
for i, layer := range m.TextModel.Layers {
h = layer.Forward(h, caches[i], B, L, m.TextModel.TextConfig)
}
// Final norm and output projection
return m.TextModel.Output.Forward(mlx.RMSNorm(h, m.TextModel.NormScaled, cfg.RMSNormEps))
}
// insertImageEmbeddingsAt replaces image placeholder tokens with actual image embeddings
// at a known position (to avoid re-scanning tokens after eval)
// textEmbeds: [B, L, hidden_size] text embeddings
// imageEmbeds: [B, 256, hidden_size] image embeddings from projector
// startPos: starting position of image tokens in the sequence
func (m *Model) insertImageEmbeddingsAt(textEmbeds, imageEmbeds *mlx.Array, startPos int32) *mlx.Array {
numImageTokens := imageEmbeds.Shape()[1]
L := textEmbeds.Shape()[1]
// Split text embeddings: [0:startPos] + imageEmbeds + [startPos+256:L]
afterStart := startPos + numImageTokens
// Slice before image tokens: textEmbeds[:, 0:startPos, :]
before := mlx.SliceAxis(textEmbeds, 1, 0, startPos)
// Slice after image tokens: textEmbeds[:, startPos+256:L, :]
after := mlx.SliceAxis(textEmbeds, 1, afterStart, L)
// Concatenate: before + imageEmbeds + after along axis 1
return mlx.Concatenate([]*mlx.Array{before, imageEmbeds, after}, 1)
}
// Interface methods for Model
func (m *Model) NumLayers() int { return len(m.TextModel.Layers) }
func (m *Model) MaxContextLength() int32 { return m.Config.TextConfig.MaxPositionEmbeddings }
func (m *Model) VocabSize() int32 { return m.Config.TextConfig.VocabSize }
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache { return m.TextModel.NewCache(maxSeqLen) }
func (m *Model) ImageSize() int32 { return m.Config.VisionConfig.ImageSize }
// FormatPrompt applies the Gemma 3 multimodal chat template
func (m *Model) FormatPrompt(prompt string) string {
return fmt.Sprintf("<start_of_turn>user\n%s<end_of_turn>\n<start_of_turn>model\n", prompt)
}
// FormatPromptWithImage applies the Gemma 3 multimodal chat template with image
func (m *Model) FormatPromptWithImage(prompt string) string {
return fmt.Sprintf("<start_of_turn>user\n<start_of_image>%s<end_of_turn>\n<start_of_turn>model\n", prompt)
}
// ExpandImageTokens expands <start_of_image> into 256 image placeholder tokens
// Input tokens containing boi_token (255999) are expanded to:
// boi_token + 256 * image_token + eoi_token
func (m *Model) ExpandImageTokens(tokens []int32) []int32 {
result := make([]int32, 0, len(tokens)+int(m.Config.MMTokensPerImage)+1)
for _, t := range tokens {
if t == m.Config.BOITokenIndex {
// Expand: boi + 256 * image_token + eoi
result = append(result, m.Config.BOITokenIndex)
for i := int32(0); i < m.Config.MMTokensPerImage; i++ {
result = append(result, m.Config.ImageTokenIndex)
}
result = append(result, m.Config.EOITokenIndex)
} else {
result = append(result, t)
}
}
return result
}
//go:build mlx
package gemma3
import (
"fmt"
"image"
_ "image/jpeg"
_ "image/png"
"os"
"github.com/ollama/ollama/x/imagegen/mlx"
"golang.org/x/image/draw"
)
// ProcessImage loads and preprocesses an image for the vision tower
// Returns [1, H, W, C] tensor in NHWC format normalized for SigLIP
func ProcessImage(path string, imageSize int32) (*mlx.Array, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open image: %w", err)
}
defer f.Close()
img, _, err := image.Decode(f)
if err != nil {
return nil, fmt.Errorf("decode image: %w", err)
}
return ProcessImageData(img, imageSize)
}
// ProcessImageData preprocesses an image.Image for the vision tower
func ProcessImageData(img image.Image, imageSize int32) (*mlx.Array, error) {
// Resize to target size using bilinear interpolation
resized := image.NewRGBA(image.Rect(0, 0, int(imageSize), int(imageSize)))
draw.BiLinear.Scale(resized, resized.Bounds(), img, img.Bounds(), draw.Over, nil)
// Convert to float32 array [H, W, C] and normalize
// SigLIP normalization: (pixel / 255.0 - 0.5) / 0.5 = pixel / 127.5 - 1.0
data := make([]float32, imageSize*imageSize*3)
idx := 0
for y := int32(0); y < imageSize; y++ {
for x := int32(0); x < imageSize; x++ {
r, g, b, _ := resized.At(int(x), int(y)).RGBA()
// RGBA returns 16-bit values, convert to 8-bit
data[idx] = float32(r>>8)/127.5 - 1.0
data[idx+1] = float32(g>>8)/127.5 - 1.0
data[idx+2] = float32(b>>8)/127.5 - 1.0
idx += 3
}
}
// Create MLX array [1, H, W, C] for NHWC layout
arr := mlx.NewArrayFloat32(data, []int32{1, imageSize, imageSize, 3})
mlx.Eval(arr) // Materialize to prevent use-after-free
return arr, nil
}
//go:build mlx
package gemma3
import (
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
)
// MultiModalProjector projects vision features to text embedding space
type MultiModalProjector struct {
// mm_input_projection_weight: [vision_hidden, text_hidden]
InputProjection *mlx.Array `weight:"mm_input_projection_weight"`
SoftEmbNorm *nn.RMSNorm `weight:"mm_soft_emb_norm"`
// Precomputed (1 + weight) for Gemma-style RMSNorm
SoftEmbNormScaled *mlx.Array `weight:"-"`
}
// Forward projects vision features to text space
// Input: [B, num_patches, vision_hidden] (e.g., [1, 4096, 1152])
// Output: [B, num_image_tokens, text_hidden] (e.g., [1, 256, 2560])
func (p *MultiModalProjector) Forward(visionFeatures *mlx.Array, eps float32) *mlx.Array {
// Average pool 4x4: [B, 4096, 1152] -> [B, 256, 1152]
// 4096 patches = 64x64 grid, pool to 16x16 = 256 tokens
B := visionFeatures.Shape()[0]
visionHidden := visionFeatures.Shape()[2]
// Reshape to [B, 64, 64, hidden]
gridSize := int32(64) // sqrt(4096)
pooledSize := int32(16) // 64/4
h := mlx.Reshape(visionFeatures, B, gridSize, gridSize, visionHidden)
// Reshape to [B, 16, 4, 16, 4, hidden] for 4x4 pooling
h = mlx.Reshape(h, B, pooledSize, 4, pooledSize, 4, visionHidden)
// Average over pooling dimensions (axes 2 and 4)
h = mlx.Mean(h, 4, false)
h = mlx.Mean(h, 2, false)
// h is now [B, 16, 16, hidden], reshape to [B, 256, hidden]
numTokens := pooledSize * pooledSize
h = mlx.Reshape(h, B, numTokens, visionHidden)
// Apply Gemma-style RMS norm (use precomputed 1 + weight)
h = mlx.RMSNorm(h, p.SoftEmbNormScaled, eps)
// Project to text space: [B, 256, vision_hidden] @ [vision_hidden, text_hidden]
return mlx.Linear(h, p.InputProjection)
}
//go:build mlx
package gemma3
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
)
// VisionConfig holds configuration for the SigLIP vision tower
type VisionConfig struct {
HiddenSize int32 `json:"hidden_size"`
ImageSize int32 `json:"image_size"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
PatchSize int32 `json:"patch_size"`
}
// VisionTower is the SigLIP vision encoder
type VisionTower struct {
Embeddings *VisionEmbeddings `weight:"vision_model.embeddings"`
Encoder []*VisionEncoderLayer `weight:"vision_model.encoder.layers"`
PostLayerNorm *nn.LayerNorm `weight:"vision_model.post_layernorm"`
Config *VisionConfig
}
// VisionEmbeddings handles patch and position embeddings
type VisionEmbeddings struct {
// PatchWeight: [O, C, kH, kW] from PyTorch, transposed to [O, kH, kW, C] for MLX
PatchWeight *mlx.Array `weight:"patch_embedding.weight"`
PatchBias *mlx.Array `weight:"patch_embedding.bias"`
PosEmbed *nn.Embedding `weight:"position_embedding"`
}
// VisionEncoderLayer is a single transformer encoder layer
type VisionEncoderLayer struct {
LayerNorm1 *nn.LayerNorm `weight:"layer_norm1"`
Attention *VisionAttention `weight:"self_attn"`
LayerNorm2 *nn.LayerNorm `weight:"layer_norm2"`
MLP *VisionMLP `weight:"mlp"`
}
// VisionAttention implements multi-head self-attention
type VisionAttention struct {
QProj *nn.Linear `weight:"q_proj"`
KProj *nn.Linear `weight:"k_proj"`
VProj *nn.Linear `weight:"v_proj"`
OutProj *nn.Linear `weight:"out_proj"`
}
// VisionMLP is the feed-forward network
type VisionMLP struct {
FC1 *nn.Linear `weight:"fc1"`
FC2 *nn.Linear `weight:"fc2"`
}
// Forward runs the vision tower on preprocessed images
// Input: [B, H, W, C] normalized image tensor (NHWC layout for MLX)
// Output: [B, num_patches, hidden_size]
func (v *VisionTower) Forward(x *mlx.Array) *mlx.Array {
// Patch embedding conv: input [B, H, W, C], weight [O, kH, kW, C] -> [B, grid, grid, O]
// Weight comes as [O, C, kH, kW] from PyTorch, transpose to [O, kH, kW, C]
weight := mlx.Transpose(v.Embeddings.PatchWeight, 0, 2, 3, 1)
h := mlx.Conv2d(x, weight, v.Config.PatchSize, 0) // stride=patch_size, no padding
// Add bias: [O] -> [1, 1, 1, O] for broadcasting
bias := mlx.Reshape(v.Embeddings.PatchBias, 1, 1, 1, v.Embeddings.PatchBias.Shape()[0])
h = mlx.Add(h, bias)
// h is [B, grid, grid, hidden], flatten to [B, num_patches, hidden]
B := h.Shape()[0]
gridH, gridW := h.Shape()[1], h.Shape()[2]
hidden := h.Shape()[3]
numPatches := gridH * gridW
h = mlx.Reshape(h, B, numPatches, hidden)
// Add position embeddings
posIds := mlx.ArangeInt(0, numPatches, 1, mlx.DtypeInt32)
posEmbed := v.Embeddings.PosEmbed.Forward(posIds)
h = mlx.Add(h, posEmbed)
// Encoder layers
headDim := float32(v.Config.HiddenSize / v.Config.NumAttentionHeads)
scale := float32(1.0 / math.Sqrt(float64(headDim)))
for _, layer := range v.Encoder {
h = layer.Forward(h, v.Config, scale)
}
// Final layer norm
h = v.PostLayerNorm.Forward(h)
return h
}
// Forward runs a vision encoder layer
func (l *VisionEncoderLayer) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
// Pre-norm attention
h := l.LayerNorm1.Forward(x)
h = l.Attention.Forward(h, cfg, scale)
x = mlx.Add(x, h)
// Pre-norm MLP
h = l.LayerNorm2.Forward(x)
h = l.MLP.Forward(h)
return mlx.Add(x, h)
}
// Forward runs multi-head self-attention
func (a *VisionAttention) Forward(x *mlx.Array, cfg *VisionConfig, scale float32) *mlx.Array {
B, L := x.Shape()[0], x.Shape()[1]
headDim := cfg.HiddenSize / cfg.NumAttentionHeads
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
// Reshape to [B, num_heads, L, head_dim]
q = mlx.Transpose(mlx.Reshape(q, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
k = mlx.Transpose(mlx.Reshape(k, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
v = mlx.Transpose(mlx.Reshape(v, B, L, cfg.NumAttentionHeads, headDim), 0, 2, 1, 3)
// Scaled dot-product attention (no causal mask for vision)
out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
// Reshape back: [B, num_heads, L, head_dim] -> [B, L, hidden]
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.HiddenSize)
return a.OutProj.Forward(out)
}
// Forward runs the MLP with GELU activation
func (m *VisionMLP) Forward(x *mlx.Array) *mlx.Array {
h := mlx.GELU(m.FC1.Forward(x))
return m.FC2.Forward(h)
}
//go:build mlx
package gpt_oss
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// RopeScaling holds YaRN or other RoPE scaling configuration
type RopeScaling struct {
RopeType string `json:"rope_type"`
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeddings int32 `json:"original_max_position_embeddings"`
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
}
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
HeadDim int32 `json:"head_dim"`
SlidingWindow int32 `json:"sliding_window"`
NumLocalExperts int32 `json:"num_local_experts"`
NumExpertsPerTok int32 `json:"num_experts_per_tok"`
LayerTypes []string `json:"layer_types"`
SwiGLULimit float32 `json:"swiglu_limit"`
RopeScaling *RopeScaling `json:"rope_scaling"`
Scale float32 `json:"-"` // computed: 1/sqrt(HeadDim)
}
type Attention struct {
QProj *nn.Linear `weight:"self_attn.q_proj"`
KProj *nn.Linear `weight:"self_attn.k_proj"`
VProj *nn.Linear `weight:"self_attn.v_proj"`
OProj *nn.Linear `weight:"self_attn.o_proj"`
Sinks *mlx.Array `weight:"self_attn.sinks,optional"`
YarnFreqs *mlx.Array // computed
YarnMscale float32
}
// swiGLU applies the GPT-OSS custom SwiGLU activation.
// Formula: (gate * sigmoid(alpha * gate)) * (up + 1)
// with clipping: gate to [None, limit], up to [-limit, limit]
func swiGLU(gate, up *mlx.Array, alpha, limit float32) *mlx.Array {
// Clip gate to [None, limit]
gateClipped := mlx.ClipScalar(gate, 0, limit, false, true)
// Clip up to [-limit, limit]
upClipped := mlx.ClipScalar(up, -limit, limit, true, true)
// glu_scaled = alpha * gate_clipped
gluScaled := mlx.MulScalar(gateClipped, alpha)
// sig = sigmoid(glu_scaled)
sig := mlx.Sigmoid(gluScaled)
// out_glu = gate_clipped * sig
outGlu := mlx.Mul(gateClipped, sig)
// result = out_glu * (up_clipped + 1)
return mlx.Mul(outGlu, mlx.AddScalar(upClipped, 1.0))
}
// compiledSwiGLU is a singleton compiled SwiGLU function shared across all layers
var compiledSwiGLU *mlx.CompiledFunc
// getCompiledSwiGLU returns the compiled SwiGLU function, creating it once if needed
func getCompiledSwiGLU() *mlx.CompiledFunc {
if compiledSwiGLU == nil {
const alpha float32 = 1.702
const limit float32 = 7.0
compiledSwiGLU = mlx.CompileShapeless(func(inputs []*mlx.Array) []*mlx.Array {
return []*mlx.Array{swiGLU(inputs[0], inputs[1], alpha, limit)}
}, true) // shapeless=true so it works for any input size
}
return compiledSwiGLU
}
// ComputeYarnFreqs computes YaRN-modified RoPE frequencies
// Based on mlx-lm's YarnRoPE implementation
func ComputeYarnFreqs(dims int32, base, scalingFactor float32, origMaxPos int32, betaFast, betaSlow float32) (*mlx.Array, float32) {
// yarn_find_correction_dim
yarnFindCorrectionDim := func(numRotations float64) float64 {
return float64(dims) * math.Log(float64(origMaxPos)/(numRotations*2*math.Pi)) / (2 * math.Log(float64(base)))
}
// yarn_find_correction_range
low := int(math.Floor(yarnFindCorrectionDim(float64(betaFast))))
high := int(math.Ceil(yarnFindCorrectionDim(float64(betaSlow))))
if low < 0 {
low = 0
}
if high > int(dims)-1 {
high = int(dims) - 1
}
// yarn_get_mscale
yarnGetMscale := func(scale, mscale float64) float64 {
if scale <= 1 {
return 1.0
}
return 0.1*mscale*math.Log(scale) + 1.0
}
mscale := float32(yarnGetMscale(float64(scalingFactor), 1.0) / yarnGetMscale(float64(scalingFactor), 0.0))
// Compute frequencies
// freq_extra = base ** (arange(0, dims, 2) / dims)
// freq_inter = scaling_factor * freq_extra
halfDims := dims / 2
freqData := make([]float32, halfDims)
for i := int32(0); i < halfDims; i++ {
exp := float64(2*i) / float64(dims)
freqExtra := math.Pow(float64(base), exp)
freqInter := float64(scalingFactor) * freqExtra
// linear ramp mask
var freqMask float64
if low == high {
freqMask = 0.0
} else {
t := (float64(i) - float64(low)) / float64(high-low)
if t < 0 {
t = 0
}
if t > 1 {
t = 1
}
freqMask = 1.0 - t
}
// Combined frequency: (inter * extra) / (inter * mask + extra * (1 - mask))
freqData[i] = float32((freqInter * freqExtra) / (freqInter*freqMask + freqExtra*(1-freqMask)))
}
return mlx.NewArray(freqData, []int32{halfDims}), mscale
}
// initYarn initializes YaRN RoPE if configured
func (a *Attention) initYarn(cfg *Config) {
a.YarnMscale = 1.0
if cfg.RopeScaling != nil && cfg.RopeScaling.RopeType == "yarn" {
a.YarnFreqs, a.YarnMscale = ComputeYarnFreqs(
cfg.HeadDim,
cfg.RopeTheta,
cfg.RopeScaling.Factor,
cfg.RopeScaling.OriginalMaxPositionEmbeddings,
cfg.RopeScaling.BetaFast,
cfg.RopeScaling.BetaSlow,
)
}
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
// Reshape via AsStrided: [B, L, n_heads * head_dim] -> [B, n_heads, L, head_dim]
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
offset := 0
if c != nil {
offset = c.Offset()
}
if a.YarnFreqs != nil {
if a.YarnMscale != 1.0 {
q = mlx.MulScalar(q, a.YarnMscale)
}
q = mlx.RoPEWithFreqs(q, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset)
k = mlx.RoPEWithFreqs(k, a.YarnFreqs, int(cfg.HeadDim), false, 1.0, offset)
} else {
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
}
if c != nil {
k, v = c.Update(k, v, int(L))
}
out := mlx.ScaledDotProductAttentionWithSinks(q, k, v, cfg.Scale, maskMode, mask, a.Sinks)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
// CreateSlidingWindowMask creates a causal mask with sliding window
// Mirrors mlx-lm's create_causal_mask with window_size
func CreateSlidingWindowMask(seqLen, queryStart, keyStart, keyLen, windowSize int) *mlx.Array {
// Build mask aligned to actual cache length (may be rotated)
// rinds covers existing keys: [keyStart, keyStart+keyLen)
// linds covers new queries: [queryStart, queryStart+seqLen)
rinds := mlx.Arange(float32(keyStart), float32(keyStart+keyLen), 1) // [keyLen]
linds := mlx.Arange(float32(queryStart), float32(queryStart+seqLen), 1) // [seqLen]
linds = mlx.ExpandDims(linds, 1) // [seqLen, 1]
rinds = mlx.ExpandDims(rinds, 0) // [1, keyLen]
causalMask := mlx.GreaterEqual(linds, rinds) // [seqLen, keyLen]
windowLimit := mlx.AddScalar(rinds, float32(windowSize))
windowMask := mlx.LessArray(linds, windowLimit) // [seqLen, keyLen]
return mlx.LogicalAnd(causalMask, windowMask)
}
// MoE represents the Mixture of Experts SwiGLU layer with quantized experts.
type MoE struct {
Router *nn.Linear `weight:"mlp.router"`
TopK int32
HiddenSize int32
GroupSize int
Bits int
// Expert weights (loaded manually via sanitizeExpertWeights)
GateBlocks, GateScales, GateBias *mlx.Array
UpBlocks, UpScales, UpBias *mlx.Array
DownBlocks, DownScales, DownBias *mlx.Array
}
func (moe *MoE) Forward(x *mlx.Array, B, L int32) *mlx.Array {
logits := moe.Router.Forward(x)
neg := mlx.Neg(logits)
part := mlx.Argpartition(neg, int(moe.TopK)-1, -1)
topKIdx := mlx.Slice(part, []int32{0, 0, 0}, []int32{B, L, moe.TopK})
topKVal := mlx.TakeAlongAxis(logits, topKIdx, -1)
weights := mlx.Softmax(topKVal, -1)
xFlat := mlx.Reshape(x, B*L, 1, 1, moe.HiddenSize)
idxFlat := mlx.Reshape(topKIdx, B*L, moe.TopK)
doSort := B*L >= 64
var invOrder *mlx.Array
sorted := false
n := B * L * moe.TopK
if doSort {
idxAll := mlx.Flatten(idxFlat)
order := mlx.Argsort(idxAll, 0)
invOrder = mlx.Argsort(order, 0)
xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, moe.TopK), 0), 1)
idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
sorted = true
}
gate := mlx.GatherQMM(xFlat, moe.GateBlocks, moe.GateScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
up := mlx.GatherQMM(xFlat, moe.UpBlocks, moe.UpScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
if moe.GateBias != nil {
gate = mlx.Add(gate, mlx.ExpandDims(mlx.Take(moe.GateBias, idxFlat, 0), 2))
}
if moe.UpBias != nil {
up = mlx.Add(up, mlx.ExpandDims(mlx.Take(moe.UpBias, idxFlat, 0), 2))
}
hidden := getCompiledSwiGLU().Call(gate, up)[0]
down := mlx.GatherQMM(hidden, moe.DownBlocks, moe.DownScales, nil, nil, idxFlat, true, moe.GroupSize, moe.Bits, "mxfp4", sorted)
if moe.DownBias != nil {
down = mlx.Add(down, mlx.ExpandDims(mlx.Take(moe.DownBias, idxFlat, 0), 2))
}
if doSort {
down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, moe.TopK, moe.HiddenSize)
} else {
down = mlx.Squeeze(down, 2)
}
ewFlat := mlx.Reshape(weights, B*L, moe.TopK, 1)
return mlx.Reshape(mlx.Sum(mlx.Mul(down, ewFlat), 1, false), B, L, moe.HiddenSize)
}
type Block struct {
Attention *Attention
MLP *MoE
InputNorm *nn.RMSNorm `weight:"input_layernorm"`
PostAttnNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
LayerType string // "sliding_attention" or "full_attention"
}
func (b *Block) Forward(x *mlx.Array, c cache.Cache, B, L int32, mask *mlx.Array, maskMode string, cfg *Config) *mlx.Array {
h := mlx.Add(x, b.Attention.Forward(b.InputNorm.Forward(x, cfg.RMSNormEps), c, B, L, mask, maskMode, cfg))
return mlx.Add(h, b.MLP.Forward(b.PostAttnNorm.Forward(h, cfg.RMSNormEps), B, L))
}
type Model struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []*Block `weight:"-"` // loaded manually due to MoE sanitization
Norm *nn.RMSNorm `weight:"model.norm"`
LMHead *nn.Linear `weight:"lm_head"`
tok *tokenizer.Tokenizer
*Config
}
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
func (m *Model) NumLayers() int { return len(m.Layers) }
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
func (m *Model) NewCache(int32) []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i, layer := range m.Layers {
if layer.LayerType == "sliding_attention" && m.SlidingWindow > 0 {
caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
} else {
caches[i] = cache.NewKVCache()
}
}
return caches
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := tokens.Shape()[0], tokens.Shape()[1]
x := m.EmbedTokens.Forward(tokens)
// Find representative cache indices for sliding window attention
var swaIdx int = -1
for i, layer := range m.Layers {
if layer.LayerType == "sliding_attention" {
swaIdx = i
break
}
}
// Create masks once at model level
var fullMask, swaMask *mlx.Array
var fullMaskMode, swaMaskMode string
if L > 1 {
fullMaskMode = "causal"
if swaIdx >= 0 && m.SlidingWindow > 0 && caches != nil {
c := caches[swaIdx]
offset := c.Offset()
windowSize := int(m.SlidingWindow)
cacheLen := min(int(L), windowSize)
if offset > 0 {
cacheLen = min(c.Len()+int(L), windowSize)
}
if int(L) > windowSize {
swaMask = CreateSlidingWindowMask(int(L), offset, offset+int(L)-cacheLen, cacheLen, windowSize)
} else {
swaMaskMode = "causal"
}
} else {
swaMaskMode = "causal"
}
}
for i, layer := range m.Layers {
var c cache.Cache
if caches != nil {
c = caches[i]
}
mask, maskMode := fullMask, fullMaskMode
if layer.LayerType == "sliding_attention" {
mask, maskMode = swaMask, swaMaskMode
}
x = layer.Forward(x, c, B, L, mask, maskMode, m.Config)
}
return m.LMHead.Forward(m.Norm.Forward(x, m.RMSNormEps))
}
// sanitizeExpertWeights splits merged gate_up weights into separate gate/up arrays.
// MXFP4 quantized weights require contiguous memory - strided views give wrong results.
func sanitizeExpertWeights(weights *safetensors.ModelWeights, prefix string) (moe *MoE) {
gateUpBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_blocks")
gateUpScales, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_scales")
gateUpBias, _ := weights.GetTensor(prefix + ".mlp.experts.gate_up_proj_bias")
downBlocks, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_blocks")
downScales, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_scales")
downBias, _ := weights.GetTensor(prefix + ".mlp.experts.down_proj_bias")
moe = &MoE{GroupSize: 32, Bits: 4, DownScales: downScales, DownBias: downBias}
if gateUpBlocks != nil {
gub := mlx.FlattenRange(mlx.View(gateUpBlocks, int(mlx.DtypeUint32)), -2, -1)
s := gub.Shape()
moe.GateBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
moe.UpBlocks = mlx.Contiguous(mlx.SliceStride(gub, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
}
if gateUpScales != nil {
s := gateUpScales.Shape()
moe.GateScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 0, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
moe.UpScales = mlx.Contiguous(mlx.SliceStride(gateUpScales, []int32{0, 1, 0}, []int32{s[0], s[1], s[2]}, []int32{1, 2, 1}))
}
if gateUpBias != nil {
s := gateUpBias.Shape()
moe.GateBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 0}, []int32{s[0], s[1]}, []int32{1, 2}))
moe.UpBias = mlx.Contiguous(mlx.SliceStride(gateUpBias, []int32{0, 1}, []int32{s[0], s[1]}, []int32{1, 2}))
}
if downBlocks != nil {
moe.DownBlocks = mlx.FlattenRange(mlx.View(downBlocks, int(mlx.DtypeUint32)), -2, -1)
}
return moe
}
func Load(modelPath string) (*Model, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("load tokenizer: %w", err)
}
m := &Model{
Layers: make([]*Block, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
// Load simple weights via struct tags
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
// Load layers with custom MoE handling
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
prefix := fmt.Sprintf("model.layers.%d", i)
layer := &Block{}
if err := safetensors.LoadModule(layer, weights, prefix); err != nil {
return nil, fmt.Errorf("layer %d: %w", i, err)
}
// Initialize attention YaRN
layer.Attention.initYarn(&cfg)
// Load MoE with weight sanitization
moe := sanitizeExpertWeights(weights, prefix)
moe.Router = layer.MLP.Router // Router was loaded by LoadModule
moe.TopK = cfg.NumExpertsPerTok
moe.HiddenSize = cfg.HiddenSize
layer.MLP = moe
// Set layer type
layer.LayerType = "full_attention"
if int(i) < len(cfg.LayerTypes) {
layer.LayerType = cfg.LayerTypes[i]
}
m.Layers[i] = layer
}
// Release safetensors BEFORE eval - lazy arrays have captured data,
// this reduces peak memory by freeing mmap during materialization
weights.ReleaseAll()
mlx.Eval(mlx.Collect(m)...)
return m, nil
}
func (m *Model) MaxContextLength() int32 {
if m.RopeScaling != nil && m.RopeScaling.OriginalMaxPositionEmbeddings > 0 {
return m.RopeScaling.OriginalMaxPositionEmbeddings
}
return 131072
}
//go:build mlx
package llama
import (
"encoding/json"
"fmt"
"math"
"os"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
type Config struct {
HiddenSize int32 `json:"hidden_size"`
NumHiddenLayers int32 `json:"num_hidden_layers"`
IntermediateSize int32 `json:"intermediate_size"`
NumAttentionHeads int32 `json:"num_attention_heads"`
NumKeyValueHeads int32 `json:"num_key_value_heads"`
VocabSize int32 `json:"vocab_size"`
RMSNormEps float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
HeadDim int32 `json:"-"`
Scale float32 `json:"-"`
}
type Model struct {
EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
Layers []*Layer `weight:"model.layers"`
Norm *nn.RMSNorm `weight:"model.norm"`
Output *nn.Linear `weight:"lm_head,optional"`
tok *tokenizer.Tokenizer
*Config
}
type Layer struct {
Attention *Attention
MLP *MLP
AttentionNorm *nn.RMSNorm `weight:"input_layernorm"`
MLPNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
}
type Attention struct {
QProj *nn.Linear `weight:"self_attn.q_proj"`
KProj *nn.Linear `weight:"self_attn.k_proj"`
VProj *nn.Linear `weight:"self_attn.v_proj"`
OProj *nn.Linear `weight:"self_attn.o_proj"`
}
type MLP struct {
GateProj *nn.Linear `weight:"mlp.gate_proj"`
UpProj *nn.Linear `weight:"mlp.up_proj"`
DownProj *nn.Linear `weight:"mlp.down_proj"`
}
func Load(modelPath string) (*Model, error) {
data, err := os.ReadFile(filepath.Join(modelPath, "config.json"))
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
weights, err := safetensors.LoadModelWeights(modelPath)
if err != nil {
return nil, fmt.Errorf("load weights: %w", err)
}
tok, err := tokenizer.Load(filepath.Join(modelPath, "tokenizer.json"))
if err != nil {
return nil, fmt.Errorf("load tokenizer: %w", err)
}
m := &Model{
Layers: make([]*Layer, cfg.NumHiddenLayers),
Config: &cfg,
tok: tok,
}
if err := safetensors.LoadModule(m, weights, ""); err != nil {
return nil, err
}
m.Output = nn.NewLinear(m.EmbedTokens.Weight, nil)
mlx.Eval(mlx.Collect(m)...)
weights.ReleaseAll()
return m, nil
}
func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
B, L := tokens.Shape()[0], tokens.Shape()[1]
h := m.EmbedTokens.Forward(tokens)
for i, layer := range m.Layers {
h = layer.Forward(h, caches[i], B, L, m.Config)
}
return m.Output.Forward(m.Norm.Forward(h, m.RMSNormEps))
}
func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
}
func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
q := a.QProj.Forward(x)
k := a.KProj.Forward(x)
v := a.VProj.Forward(x)
q = mlx.AsStrided(q, []int32{B, cfg.NumAttentionHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumAttentionHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumAttentionHeads * cfg.HeadDim), 1}, 0)
k = mlx.AsStrided(k, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
v = mlx.AsStrided(v, []int32{B, cfg.NumKeyValueHeads, L, cfg.HeadDim},
[]int64{int64(L * cfg.NumKeyValueHeads * cfg.HeadDim), int64(cfg.HeadDim), int64(cfg.NumKeyValueHeads * cfg.HeadDim), 1}, 0)
q = mlx.RoPE(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
k = mlx.RoPE(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, c.Offset())
k, v = c.Update(k, v, int(L))
out := mlx.ScaledDotProductAttention(q, k, v, cfg.Scale, L > 1)
out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
return a.OProj.Forward(out)
}
func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
}
// Interface methods
func (m *Model) NumLayers() int { return len(m.Layers) }
func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
caches := make([]cache.Cache, len(m.Layers))
for i := range caches {
caches[i] = cache.NewKVCache()
}
return caches
}
//go:build mlx
package qwen_image
import (
"os"
"testing"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// TestPipelineOutput runs the full pipeline (integration test).
// Skips if model weights not found. Requires ~50GB VRAM.
func TestPipelineOutput(t *testing.T) {
modelPath := "../../../weights/Qwen-Image-2512"
if _, err := os.Stat(modelPath); os.IsNotExist(err) {
t.Skip("Skipping: model weights not found at " + modelPath)
}
// Load model
pm, err := LoadPersistent(modelPath)
if err != nil {
t.Skipf("Skipping: failed to load model: %v", err)
}
// Run 2-step pipeline (minimum for stable scheduler)
cfg := &GenerateConfig{
Prompt: "a cat",
Width: 256,
Height: 256,
Steps: 2,
Seed: 42,
}
output, err := pm.GenerateFromConfig(cfg)
if err != nil {
t.Fatalf("Pipeline failed: %v", err)
}
mlx.Eval(output)
// Verify output shape [1, C, H, W]
shape := output.Shape()
if len(shape) != 4 {
t.Errorf("Expected 4D output, got %v", shape)
}
if shape[0] != 1 || shape[1] != 3 || shape[2] != cfg.Height || shape[3] != cfg.Width {
t.Errorf("Shape mismatch: got %v, expected [1, 3, %d, %d]", shape, cfg.Height, cfg.Width)
}
// Verify values in expected range [0, 1]
data := output.Data()
minVal, maxVal := float32(1.0), float32(0.0)
for _, v := range data {
if v < minVal {
minVal = v
}
if v > maxVal {
maxVal = v
}
}
t.Logf("Output range: [%.4f, %.4f]", minVal, maxVal)
if minVal < -0.1 || maxVal > 1.1 {
t.Errorf("Output values out of range: [%.4f, %.4f]", minVal, maxVal)
}
}
//go:build mlx
package qwen_image
import (
"errors"
"fmt"
"math"
"path/filepath"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// Qwen25VLConfig holds Qwen2.5-VL configuration
type Qwen25VLConfig struct {
// Text model config
HiddenSize int32 `json:"hidden_size"` // 3584
NumHiddenLayers int32 `json:"num_hidden_layers"` // 28
IntermediateSize int32 `json:"intermediate_size"` // 18944
NumAttentionHeads int32 `json:"num_attention_heads"` // 28
NumKeyValueHeads int32 `json:"num_key_value_heads"` // 4
VocabSize int32 `json:"vocab_size"` // 152064
RMSNormEps float32 `json:"rms_norm_eps"` // 1e-6
RopeTheta float32 `json:"rope_theta"` // 1000000
HeadDim int32 // Calculated: HiddenSize / NumAttentionHeads
MRoPESection []int32 // [16, 24, 24] for temporal, height, width
// Vision config
VisionHiddenSize int32 `json:"vision_hidden_size"` // 1280
VisionNumLayers int32 `json:"vision_num_layers"` // 32
VisionNumHeads int32 `json:"vision_num_heads"` // 16
VisionIntermSize int32 `json:"vision_intermediate"` // 3420
VisionPatchSize int32 `json:"vision_patch_size"` // 14
VisionOutHiddenSize int32 `json:"vision_out_hidden"` // 3584
VisionSpatialMerge int32 `json:"vision_spatial_merge"` // 2
VisionWindowSize int32 `json:"vision_window_size"` // 112
VisionFullAttIdx []int32 // [7, 15, 23, 31]
// Special tokens
ImageTokenID int32 // 151655
VisionStartTokenID int32 // 151652
VisionEndTokenID int32 // 151653
}
// defaultQwen25VLConfig returns default config
func defaultQwen25VLConfig() *Qwen25VLConfig {
cfg := &Qwen25VLConfig{
// Text
HiddenSize: 3584,
NumHiddenLayers: 28,
IntermediateSize: 18944,
NumAttentionHeads: 28,
NumKeyValueHeads: 4,
VocabSize: 152064,
RMSNormEps: 1e-6,
RopeTheta: 1000000,
MRoPESection: []int32{16, 24, 24},
// Vision
VisionHiddenSize: 1280,
VisionNumLayers: 32,
VisionNumHeads: 16,
VisionIntermSize: 3420,
VisionPatchSize: 14,
VisionOutHiddenSize: 3584,
VisionSpatialMerge: 2,
VisionWindowSize: 112,
VisionFullAttIdx: []int32{7, 15, 23, 31},
// Special tokens
ImageTokenID: 151655,
VisionStartTokenID: 151652,
VisionEndTokenID: 153653,
}
cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
return cfg
}
// Qwen25VL is the Qwen2.5-VL vision-language encoder
type Qwen25VL struct {
Config *Qwen25VLConfig
// Text model
Embedding *mlx.Array
Blocks []*VLTextBlock
FinalNorm *mlx.Array
// Vision tower (optional - nil for text-only models)
VisionPatchEmbed *VisionPatchEmbed
VisionBlocks []*VisionBlock
VisionMerger *VisionMerger
HasVision bool // True if vision tower is loaded
}
// LoadTextOnly loads only the text encoder components (skips vision tower)
// Use this for text-to-image generation where vision components are not needed
func (m *Qwen25VL) LoadTextOnly(path string) error {
return m.load(path, false)
}
// Load loads the vision-language encoder from a directory
// Vision components are loaded if weights exist
func (m *Qwen25VL) Load(path string) error {
return m.load(path, true)
}
// load is the internal loading function
func (m *Qwen25VL) load(path string, loadVision bool) error {
fmt.Println("Loading Qwen2.5-VL encoder...")
cfg := defaultQwen25VLConfig()
m.Config = cfg
weights, err := safetensors.LoadModelWeights(path)
if err != nil {
return fmt.Errorf("weights: %w", err)
}
// Bulk load all weights as bf16
fmt.Print(" Loading weights as bf16... ")
if err := weights.Load(mlx.DtypeBFloat16); err != nil {
return fmt.Errorf("failed to load weights: %w", err)
}
fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
// Load text embedding
fmt.Print(" Loading text embeddings... ")
embedding, err := weights.Get("model.embed_tokens.weight")
if err != nil {
return err
}
m.Embedding = embedding
fmt.Printf("✓ [%v]\n", embedding.Shape())
// Load text blocks
m.Blocks = make([]*VLTextBlock, cfg.NumHiddenLayers)
for i := int32(0); i < cfg.NumHiddenLayers; i++ {
fmt.Printf("\r Loading text blocks... %d/%d", i+1, cfg.NumHiddenLayers)
block, err := newVLTextBlock(weights, int(i), cfg)
if err != nil {
return fmt.Errorf("failed to load text block %d: %w", i, err)
}
m.Blocks[i] = block
}
fmt.Printf("\r Loading text blocks... ✓ [%d blocks] \n", cfg.NumHiddenLayers)
// Load final norm
fmt.Print(" Loading final norm... ")
finalNorm, err := weights.Get("model.norm.weight")
if err != nil {
return err
}
m.FinalNorm = finalNorm
fmt.Println("✓")
// Try to load vision tower (optional)
m.HasVision = false
if loadVision {
if _, err := weights.Get("visual.patch_embed.proj.weight"); err == nil {
fmt.Print(" Loading vision patch embed... ")
m.VisionPatchEmbed, err = newVisionPatchEmbed(weights, cfg)
if err != nil {
return fmt.Errorf("vision patch embed: %w", err)
}
fmt.Println("✓")
m.VisionBlocks = make([]*VisionBlock, cfg.VisionNumLayers)
for i := int32(0); i < cfg.VisionNumLayers; i++ {
fmt.Printf("\r Loading vision blocks... %d/%d", i+1, cfg.VisionNumLayers)
block, err := newVisionBlock(weights, int(i), cfg)
if err != nil {
return fmt.Errorf("failed to load vision block %d: %w", i, err)
}
m.VisionBlocks[i] = block
}
fmt.Printf("\r Loading vision blocks... ✓ [%d blocks] \n", cfg.VisionNumLayers)
fmt.Print(" Loading vision merger... ")
m.VisionMerger, err = newVisionMerger(weights, cfg)
if err != nil {
return fmt.Errorf("vision merger: %w", err)
}
fmt.Println("✓")
m.HasVision = true
} else {
fmt.Println(" (No vision tower - text-only mode)")
}
} else {
fmt.Println(" (Skipping vision tower)")
}
weights.ReleaseAll()
return nil
}
// EncodePrompt encodes a text prompt for image generation (text-only mode)
// Uses the Qwen-Image template and drops the first 34 tokens (system prefix)
func (m *Qwen25VL) EncodePrompt(tok *tokenizer.Tokenizer, prompt string) *mlx.Array {
cfg := m.Config
// Template from Python: prompt_template_encode (for image generation)
template := "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n%s<|im_end|>\n<|im_start|>assistant\n"
formattedPrompt := fmt.Sprintf(template, prompt)
// Tokenize
tokens := tok.Encode(formattedPrompt, false)
// Create token array
seqLen := int32(len(tokens))
tokenArr := mlx.NewArrayInt32(tokens, []int32{1, seqLen})
// Get text embeddings
textEmbed := mlx.EmbeddingLookup(m.Embedding, tokenArr)
// Compute RoPE
cossin := m.computeTextRoPE(seqLen, 1)
// Forward through ALL text blocks
x := textEmbed
for _, block := range m.Blocks {
x = block.Forward(x, cossin)
}
// Apply final norm
x = mlx.RMSNorm(x, m.FinalNorm, cfg.RMSNormEps)
// Drop first 34 tokens (system prefix)
// prompt_template_encode_start_idx = 34
dropIdx := int32(34)
if x.Shape()[1] > dropIdx {
x = mlx.Slice(x, []int32{0, dropIdx, 0}, []int32{1, x.Shape()[1], cfg.HiddenSize})
}
return x
}
// EncodePromptWithImage encodes a text prompt with an image
// Returns: embeddings [B, L, hidden_size], mask [B, L], error
func (m *Qwen25VL) EncodePromptWithImage(tok *tokenizer.Tokenizer, prompt string, image *mlx.Array) (*mlx.Array, *mlx.Array, error) {
if !m.HasVision {
return nil, nil, errors.New("EncodePromptWithImage called on text-only model")
}
cfg := m.Config
// Template from Python diffusers pipeline: prompt_template_encode
// Python's _get_qwen_prompt_embeds adds "Picture 1: " before vision tokens
template := "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\nPicture 1: <|vision_start|><|image_pad|><|vision_end|>%s<|im_end|>\n<|im_start|>assistant\n"
formattedPrompt := fmt.Sprintf(template, prompt)
// Tokenize
tokens := tok.Encode(formattedPrompt, false)
// Process vision if image provided
var visionEmbeddings *mlx.Array
var numImageTokens int32
var visionH, visionW int32 // Grid dims in patches (before spatial merge)
if image != nil {
visionEmbeddings = m.encodeVision(image)
numImageTokens = visionEmbeddings.Shape()[1]
// Get original grid dimensions from image shape
imgShape := image.Shape()
visionH = imgShape[2] / cfg.VisionPatchSize // Height in patches
visionW = imgShape[3] / cfg.VisionPatchSize // Width in patches
}
// Find image token position and expand
expandedTokens := make([]int32, 0, len(tokens)+int(numImageTokens))
imageTokenPos := int32(-1)
textAfterCount := int32(0)
for i, t := range tokens {
if t == cfg.ImageTokenID {
imageTokenPos = int32(len(expandedTokens))
// Insert placeholder tokens for image
for j := int32(0); j < numImageTokens; j++ {
expandedTokens = append(expandedTokens, cfg.ImageTokenID)
}
// Count remaining tokens after image
textAfterCount = int32(len(tokens) - i - 1)
} else {
expandedTokens = append(expandedTokens, t)
}
}
// Create token array
seqLen := int32(len(expandedTokens))
tokenArr := mlx.NewArrayInt32(expandedTokens, []int32{1, seqLen})
// Get text embeddings
textEmbed := mlx.EmbeddingLookup(m.Embedding, tokenArr) // [1, L, hidden]
// Replace image token embeddings with vision embeddings
if visionEmbeddings != nil && imageTokenPos >= 0 {
// Split, replace, concat
before := mlx.Slice(textEmbed, []int32{0, 0, 0}, []int32{1, imageTokenPos, cfg.HiddenSize})
after := mlx.Slice(textEmbed, []int32{0, imageTokenPos + numImageTokens, 0}, []int32{1, seqLen, cfg.HiddenSize})
textEmbed = mlx.Concatenate([]*mlx.Array{before, visionEmbeddings, after}, 1)
}
// Compute RoPE - use multimodal RoPE when image is present
var cossin [2]*mlx.Array
if image != nil && imageTokenPos >= 0 {
cossin = m.ComputeMultimodalRoPE(imageTokenPos, visionH, visionW, textAfterCount, cfg.VisionSpatialMerge)
} else {
cossin = m.computeTextRoPE(seqLen, 1)
}
// Forward through ALL text blocks
// Python uses hidden_states[-1] (LAST layer output, not second-to-last!)
x := textEmbed
for _, block := range m.Blocks {
x = block.Forward(x, cossin)
}
// Apply final norm (Python DOES apply this for the output)
x = mlx.RMSNorm(x, m.FinalNorm, cfg.RMSNormEps)
// Drop first N tokens (system prefix)
// prompt_template_encode_start_idx = 64
dropIdx := int32(64)
if x.Shape()[1] > dropIdx {
x = mlx.Slice(x, []int32{0, dropIdx, 0}, []int32{1, x.Shape()[1], cfg.HiddenSize})
}
// Create attention mask (all ones for now)
mask := mlx.Ones(1, x.Shape()[1])
return x, mask, nil
}
// EncodeVision encodes an image through the vision tower (exported for testing)
// image: [B, C, H, W] normalized image tensor
// Returns: [B, num_tokens, hidden_size] vision embeddings
func (m *Qwen25VL) EncodeVision(image *mlx.Array) *mlx.Array {
return m.encodeVision(image)
}
// VisionRegion describes where vision embeddings are inserted in the sequence
type VisionRegion struct {
StartPos int32 // Position in sequence where vision tokens start
NumTokens int32 // Number of vision tokens
GridH int32 // Vision grid height (in patches, after spatial merge)
GridW int32 // Vision grid width (in patches, after spatial merge)
}
// EncodePromptWithImages encodes a text prompt with multiple images
// Returns: embeddings [B, L, hidden_size], mask [B, L], regions []VisionRegion, error
func (m *Qwen25VL) EncodePromptWithImages(tok *tokenizer.Tokenizer, prompt string, images []*mlx.Array) (*mlx.Array, *mlx.Array, []VisionRegion, error) {
if !m.HasVision {
return nil, nil, nil, errors.New("EncodePromptWithImages called on text-only model")
}
if len(images) == 0 {
return nil, nil, nil, errors.New("EncodePromptWithImages called with no images")
}
cfg := m.Config
// Build image prompt prefix: "Picture 1: <vision>...Picture N: <vision>..."
imgPromptTemplate := "Picture %d: <|vision_start|><|image_pad|><|vision_end|>"
imgPrompt := ""
for i := range images {
imgPrompt += fmt.Sprintf(imgPromptTemplate, i+1)
}
// Template from Python diffusers pipeline: prompt_template_encode
template := "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n%s%s<|im_end|>\n<|im_start|>assistant\n"
formattedPrompt := fmt.Sprintf(template, imgPrompt, prompt)
// Tokenize
tokens := tok.Encode(formattedPrompt, false)
// Process each image through vision tower
visionEmbeddings := make([]*mlx.Array, len(images))
numImageTokens := make([]int32, len(images))
visionGridH := make([]int32, len(images))
visionGridW := make([]int32, len(images))
for i, image := range images {
visionEmbeddings[i] = m.encodeVision(image)
numImageTokens[i] = visionEmbeddings[i].Shape()[1]
// Get original grid dimensions from image shape
imgShape := image.Shape()
visionH := imgShape[2] / cfg.VisionPatchSize // Height in patches
visionW := imgShape[3] / cfg.VisionPatchSize // Width in patches
// After spatial merge, grid is halved
visionGridH[i] = visionH / cfg.VisionSpatialMerge
visionGridW[i] = visionW / cfg.VisionSpatialMerge
}
// Find all image token positions and expand tokens
expandedTokens := make([]int32, 0, len(tokens)+int(sum(numImageTokens)))
imagePositions := make([]int32, 0, len(images)) // Start position for each image's tokens
imageIdx := 0
for _, t := range tokens {
if t == cfg.ImageTokenID {
if imageIdx < len(images) {
imagePositions = append(imagePositions, int32(len(expandedTokens)))
// Insert placeholder tokens for this image
for j := int32(0); j < numImageTokens[imageIdx]; j++ {
expandedTokens = append(expandedTokens, cfg.ImageTokenID)
}
imageIdx++
}
} else {
expandedTokens = append(expandedTokens, t)
}
}
// Create token array
seqLen := int32(len(expandedTokens))
tokenArr := mlx.NewArrayInt32(expandedTokens, []int32{1, seqLen})
// Get text embeddings
textEmbed := mlx.EmbeddingLookup(m.Embedding, tokenArr) // [1, L, hidden]
// Replace image token embeddings with vision embeddings
// Build list of segments to concatenate
segments := make([]*mlx.Array, 0, len(images)*2+1)
regions := make([]VisionRegion, len(images))
lastEnd := int32(0)
for i, imgPos := range imagePositions {
// Text segment before this image
if imgPos > lastEnd {
segments = append(segments, mlx.Slice(textEmbed, []int32{0, lastEnd, 0}, []int32{1, imgPos, cfg.HiddenSize}))
}
// Vision embeddings for this image
segments = append(segments, visionEmbeddings[i])
regions[i] = VisionRegion{
StartPos: imgPos,
NumTokens: numImageTokens[i],
GridH: visionGridH[i],
GridW: visionGridW[i],
}
lastEnd = imgPos + numImageTokens[i]
}
// Remaining text after last image
if lastEnd < seqLen {
segments = append(segments, mlx.Slice(textEmbed, []int32{0, lastEnd, 0}, []int32{1, seqLen, cfg.HiddenSize}))
}
// Concatenate all segments
textEmbed = mlx.Concatenate(segments, 1)
// Compute RoPE - use multimodal RoPE for multiple images
cossin, err := m.ComputeMultiImageRoPE(imagePositions, visionGridH, visionGridW, numImageTokens, seqLen)
if err != nil {
return nil, nil, nil, fmt.Errorf("computing RoPE: %w", err)
}
// Forward through ALL text blocks
x := textEmbed
for _, block := range m.Blocks {
x = block.Forward(x, cossin)
}
// Apply final norm
x = mlx.RMSNorm(x, m.FinalNorm, cfg.RMSNormEps)
// Drop first N tokens (system prefix)
// prompt_template_encode_start_idx = 64
dropIdx := int32(64)
if x.Shape()[1] > dropIdx {
x = mlx.Slice(x, []int32{0, dropIdx, 0}, []int32{1, x.Shape()[1], cfg.HiddenSize})
// Adjust region positions
for i := range regions {
regions[i].StartPos -= dropIdx
}
}
// Create attention mask (all ones)
mask := mlx.Ones(1, x.Shape()[1])
return x, mask, regions, nil
}
// sum returns the sum of int32 slice
func sum(arr []int32) int32 {
var s int32
for _, v := range arr {
s += v
}
return s
}
// EncodeTextOnly encodes text tokens through all text blocks (exported for testing)
// tokens: array of token IDs
// Returns: [B, L, hidden_size] text embeddings after all blocks
func (m *Qwen25VL) EncodeTextOnly(tokens []int32) *mlx.Array {
seqLen := int32(len(tokens))
tokenArr := mlx.NewArrayInt32(tokens, []int32{1, seqLen})
// Get text embeddings
textEmbed := mlx.EmbeddingLookup(m.Embedding, tokenArr) // [1, L, hidden]
// Compute RoPE
cossin := m.computeTextRoPE(seqLen, 1)
// Forward through ALL text blocks (unlike Encode which stops at second-to-last)
x := textEmbed
for _, block := range m.Blocks {
x = block.Forward(x, cossin)
}
// Apply final norm
x = mlx.RMSNorm(x, m.FinalNorm, m.Config.RMSNormEps)
return x
}
// encodeVision encodes an image through the vision tower
// image: [B, C, H, W] normalized image tensor
// Returns: [B, num_tokens, hidden_size] vision embeddings
func (m *Qwen25VL) encodeVision(image *mlx.Array) *mlx.Array {
cfg := m.Config
// Calculate grid dimensions from image
imgShape := image.Shape()
imgH := imgShape[2]
imgW := imgShape[3]
pH := imgH / cfg.VisionPatchSize // grid height in patches
pW := imgW / cfg.VisionPatchSize // grid width in patches
// Patch embed
x := m.VisionPatchEmbed.Forward(image)
mlx.Eval(x)
// Get window reordering info
winInfo := m.getWindowInfo(pH, pW)
// Compute vision RoPE embeddings (already in 2x2-block order)
posEmb := m.computeVisionRoPE(pH, pW)
shape := x.Shape()
B := shape[0]
L := shape[1] // num patches = pH * pW
D := shape[2]
spatialMergeUnit := winInfo.SpatialMergeUnit
spatialMerge := cfg.VisionSpatialMerge
// Convert patch embed from row-major to 2x2-block order
// Row-major: (0,0), (0,1), (0,2), ..., (1,0), (1,1), ...
// 2x2-block: (0,0), (0,1), (1,0), (1,1), (0,2), (0,3), (1,2), (1,3), ...
llmGridH := pH / spatialMerge
llmGridW := pW / spatialMerge
blockReorderIdx := make([]int32, L)
idx := int32(0)
for hBlock := int32(0); hBlock < llmGridH; hBlock++ {
for wBlock := int32(0); wBlock < llmGridW; wBlock++ {
for dh := int32(0); dh < spatialMerge; dh++ {
for dw := int32(0); dw < spatialMerge; dw++ {
h := hBlock*spatialMerge + dh
w := wBlock*spatialMerge + dw
rowMajorIdx := h*pW + w
blockReorderIdx[idx] = rowMajorIdx
idx++
}
}
}
}
blockIdxArr := mlx.NewArrayInt32(blockReorderIdx, []int32{L})
x = mlx.Take(x, blockIdxArr, 1) // Reorder patches to 2x2-block order
// Window reorder hidden states and RoPE before blocks
// Python: reshape to [L/4, 4, D], reorder dim 0, reshape back
// Reshape x: [B, L, D] -> [B, L/4, 4, D]
x = mlx.Reshape(x, B, L/spatialMergeUnit, spatialMergeUnit, D)
// Reorder using window index
winIdxArr := mlx.NewArrayInt32(winInfo.WindowIndex, []int32{int32(len(winInfo.WindowIndex))})
x = mlx.Take(x, winIdxArr, 1) // Take along axis 1
// Reshape back: [B, L/4, 4, D] -> [B, L, D]
x = mlx.Reshape(x, B, L, D)
// Similarly reorder RoPE: [L, headDim] -> [L/4, 4, headDim] -> reorder -> [L, headDim]
cosShape := posEmb[0].Shape()
ropeL := cosShape[0]
ropeD := cosShape[1]
cos := mlx.Reshape(posEmb[0], ropeL/spatialMergeUnit, spatialMergeUnit, ropeD)
sin := mlx.Reshape(posEmb[1], ropeL/spatialMergeUnit, spatialMergeUnit, ropeD)
cos = mlx.Take(cos, winIdxArr, 0)
sin = mlx.Take(sin, winIdxArr, 0)
cos = mlx.Reshape(cos, ropeL, ropeD)
sin = mlx.Reshape(sin, ropeL, ropeD)
posEmb = [2]*mlx.Array{cos, sin}
// Materialize to prevent freeing during block evaluations
mlx.Eval(x, posEmb[0], posEmb[1])
// Full sequence cu_seqlens for full attention blocks
cuSeqlensFull := []int32{0, L}
// Vision blocks - use window attention except at full attention indices
for i, block := range m.VisionBlocks {
useFullAttention := false
for _, idx := range cfg.VisionFullAttIdx {
if int32(i) == idx {
useFullAttention = true
break
}
}
var cuSeqlens []int32
if useFullAttention {
cuSeqlens = cuSeqlensFull
} else {
cuSeqlens = winInfo.CuWindowSeqlens
}
x = block.Forward(x, posEmb, cuSeqlens)
}
// Spatial merge (2x2 -> 1)
x = m.VisionMerger.ForwardWithDims(x, pH, pW)
// Reverse window reorder after merger
revIdxArr := mlx.NewArrayInt32(winInfo.ReverseIndex, []int32{int32(len(winInfo.ReverseIndex))})
x = mlx.Take(x, revIdxArr, 1)
return x
}
// WindowInfo holds window reordering and attention boundary info
type WindowInfo struct {
WindowIndex []int32 // Reordering indices for merged tokens
ReverseIndex []int32 // Reverse reordering indices
CuWindowSeqlens []int32 // Cumulative window boundaries in UNMERGED sequence
SpatialMergeUnit int32 // Number of patches per merged token (4 = 2x2)
}
// getWindowInfo computes window reordering indices and attention boundaries
// pH, pW: patch grid dimensions before 2x2 merge
func (m *Qwen25VL) getWindowInfo(pH, pW int32) *WindowInfo {
cfg := m.Config
spatialMergeUnit := cfg.VisionSpatialMerge * cfg.VisionSpatialMerge // 4
// After 2x2 merge
llmGridH := pH / cfg.VisionSpatialMerge
llmGridW := pW / cfg.VisionSpatialMerge
numTokens := llmGridH * llmGridW
// Window size in merged tokens
// window_size=112, spatial_merge_size=2, patch_size=14
// vit_merger_window_size = 112 / 2 / 14 = 4
vitMergerWindowSize := cfg.VisionWindowSize / cfg.VisionSpatialMerge / cfg.VisionPatchSize
// Calculate padding and number of windows
padH := vitMergerWindowSize - llmGridH%vitMergerWindowSize
if padH == vitMergerWindowSize {
padH = 0
}
padW := vitMergerWindowSize - llmGridW%vitMergerWindowSize
if padW == vitMergerWindowSize {
padW = 0
}
numWindowsH := (llmGridH + padH) / vitMergerWindowSize
numWindowsW := (llmGridW + padW) / vitMergerWindowSize
// Create padded grid with -1 for padding
paddedH := llmGridH + padH
paddedW := llmGridW + padW
grid := make([]int32, paddedH*paddedW)
for i := range grid {
grid[i] = -1
}
for h := int32(0); h < llmGridH; h++ {
for w := int32(0); w < llmGridW; w++ {
grid[h*paddedW+w] = h*llmGridW + w
}
}
// Reorder into windows and track window sizes
windowIndex := make([]int32, 0, numTokens)
windowSizes := make([]int32, 0, numWindowsH*numWindowsW)
ws := vitMergerWindowSize
for wh := int32(0); wh < numWindowsH; wh++ {
for ww := int32(0); ww < numWindowsW; ww++ {
windowStart := len(windowIndex)
// Extract window
for h := int32(0); h < ws; h++ {
for w := int32(0); w < ws; w++ {
idx := (wh*ws+h)*paddedW + (ww*ws + w)
if grid[idx] >= 0 {
windowIndex = append(windowIndex, grid[idx])
}
}
}
windowSize := int32(len(windowIndex) - windowStart)
windowSizes = append(windowSizes, windowSize)
}
}
// Create reverse index (argsort of windowIndex)
reverseIndex := make([]int32, numTokens)
for i, idx := range windowIndex {
reverseIndex[idx] = int32(i)
}
// Compute cumulative sequence lengths in UNMERGED sequence
// Each merged token corresponds to spatialMergeUnit patches
cuWindowSeqlens := make([]int32, len(windowSizes)+1)
cuWindowSeqlens[0] = 0
for i, size := range windowSizes {
cuWindowSeqlens[i+1] = cuWindowSeqlens[i] + size*spatialMergeUnit
}
return &WindowInfo{
WindowIndex: windowIndex,
ReverseIndex: reverseIndex,
CuWindowSeqlens: cuWindowSeqlens,
SpatialMergeUnit: spatialMergeUnit,
}
}
// ComputeMultiImageRoPE computes M-RoPE for combined text + multiple vision regions + text sequences
// This extends ComputeMultimodalRoPE to handle N images instead of just one.
//
// Parameters:
// - imagePositions: starting position of each image's tokens in the sequence
// - visionGridH, visionGridW: grid dimensions for each image (after spatial merge)
// - numImageTokens: number of tokens for each image
// - totalLen: total sequence length
func (m *Qwen25VL) ComputeMultiImageRoPE(imagePositions []int32, visionGridH, visionGridW, numImageTokens []int32, totalLen int32) ([2]*mlx.Array, error) {
numImages := len(imagePositions)
// Build 3D position IDs: [3, 1, totalLen]
// Dimension 0: temporal, Dimension 1: height, Dimension 2: width
posIDs := make([]float32, 3*totalLen)
// Process sequence in order
stIdx := int32(0) // Running text position counter
seqIdx := int32(0)
for i := 0; i < numImages; i++ {
imgPos := imagePositions[i]
gridH := visionGridH[i]
gridW := visionGridW[i]
numTokens := numImageTokens[i]
// Text segment before this image
for seqIdx < imgPos {
posIDs[0*totalLen+seqIdx] = float32(stIdx)
posIDs[1*totalLen+seqIdx] = float32(stIdx)
posIDs[2*totalLen+seqIdx] = float32(stIdx)
stIdx++
seqIdx++
}
// Vision tokens for this image
// Python uses stIdx as base offset for all position dimensions
for h := int32(0); h < gridH; h++ {
for w := int32(0); w < gridW; w++ {
posIDs[0*totalLen+seqIdx] = float32(stIdx) // temporal: constant = stIdx
posIDs[1*totalLen+seqIdx] = float32(stIdx + h) // height: stIdx + row_index
posIDs[2*totalLen+seqIdx] = float32(stIdx + w) // width: stIdx + col_index
seqIdx++
}
}
// Verify we processed the expected number of tokens
if seqIdx != imgPos+numTokens {
return [2]*mlx.Array{}, fmt.Errorf("mismatch: processed %d but expected %d tokens for image %d", seqIdx-imgPos, numTokens, i)
}
// Update stIdx for next text segment: max(temporal, height, width) + 1
maxVisionPos := stIdx // temporal max
if stIdx+gridH-1 > maxVisionPos {
maxVisionPos = stIdx + gridH - 1
}
if stIdx+gridW-1 > maxVisionPos {
maxVisionPos = stIdx + gridW - 1
}
stIdx = maxVisionPos + 1
}
// Text after last image
for seqIdx < totalLen {
posIDs[0*totalLen+seqIdx] = float32(stIdx)
posIDs[1*totalLen+seqIdx] = float32(stIdx)
posIDs[2*totalLen+seqIdx] = float32(stIdx)
stIdx++
seqIdx++
}
posIDsArr := mlx.NewArray(posIDs, []int32{3, 1, totalLen})
return m.computeRoPEFromPositions(posIDsArr, totalLen, 1), nil
}
// computeTextRoPE computes M-RoPE for text-only sequences
func (m *Qwen25VL) computeTextRoPE(L, B int32) [2]*mlx.Array {
// For text-only, all 3 dims use same positions [0, 1, 2, ..., L-1]
posArr := make([]float32, L*3)
for d := 0; d < 3; d++ {
for i := int32(0); i < L; i++ {
posArr[int32(d)*L+i] = float32(i)
}
}
posIDs := mlx.NewArray(posArr, []int32{3, 1, L})
posIDs = mlx.Tile(posIDs, []int32{1, B, 1})
return m.computeRoPEFromPositions(posIDs, L, B)
}
// ComputeMultimodalRoPE computes M-RoPE for combined text + vision + text sequences
// This matches Python's get_rope_index behavior exactly.
// Exported for testing.
//
// Python pattern discovered from testing:
//
// Vision row 1: temporal=stIdx, height=stIdx, width=[stIdx, stIdx+1, ..., stIdx+gridW-1]
// Vision row 2: temporal=stIdx, height=stIdx+1, width=[stIdx, stIdx+1, ..., stIdx+gridW-1]
// Text after: temporal=stIdx+1+i, height=stIdx+gridH+i, width=stIdx+gridW+i
func (m *Qwen25VL) ComputeMultimodalRoPE(textBefore, visionH, visionW, textAfter int32, spatialMerge int32) [2]*mlx.Array {
// Vision grid after spatial merge
llmGridH := visionH / spatialMerge
llmGridW := visionW / spatialMerge
visionLen := llmGridH * llmGridW
totalLen := textBefore + visionLen + textAfter
// Build 3D position IDs: [3, 1, totalLen]
// Dimension 0: temporal, Dimension 1: height, Dimension 2: width
posIDs := make([]float32, 3*totalLen)
// Text before vision: all dims same [0, 1, 2, ..., textBefore-1]
for d := 0; d < 3; d++ {
for i := int32(0); i < textBefore; i++ {
posIDs[int32(d)*totalLen+i] = float32(i)
}
}
// Vision tokens: 3D grid positions
// Python uses stIdx (textBefore) as base offset for all position dimensions
stIdx := textBefore
for h := int32(0); h < llmGridH; h++ {
for w := int32(0); w < llmGridW; w++ {
idx := stIdx + h*llmGridW + w
posIDs[0*totalLen+idx] = float32(stIdx) // temporal: constant = stIdx
posIDs[1*totalLen+idx] = float32(stIdx + h) // height: stIdx + row_index
posIDs[2*totalLen+idx] = float32(stIdx + w) // width: stIdx + col_index
}
}
// Text after vision: ALL dimensions continue from max(temporal, height, width) + 1
// max is max(stIdx, stIdx+llmGridH-1, stIdx+llmGridW-1) = stIdx + max(0, llmGridH-1, llmGridW-1)
// Then st_idx = max + 1
maxVisionPos := stIdx // temporal max
if stIdx+llmGridH-1 > maxVisionPos {
maxVisionPos = stIdx + llmGridH - 1
}
if stIdx+llmGridW-1 > maxVisionPos {
maxVisionPos = stIdx + llmGridW - 1
}
textAfterStart := maxVisionPos + 1
for i := int32(0); i < textAfter; i++ {
seqIdx := textBefore + visionLen + i
posIDs[0*totalLen+seqIdx] = float32(textAfterStart + i) // temporal
posIDs[1*totalLen+seqIdx] = float32(textAfterStart + i) // height
posIDs[2*totalLen+seqIdx] = float32(textAfterStart + i) // width
}
posIDsArr := mlx.NewArray(posIDs, []int32{3, 1, totalLen})
return m.computeRoPEFromPositions(posIDsArr, totalLen, 1)
}
// computeRoPEFromPositions computes cos/sin from 3D position IDs
// posIDs: [3, B, L] where dim 0 is temporal, 1 is height, 2 is width
func (m *Qwen25VL) computeRoPEFromPositions(posIDs *mlx.Array, L, B int32) [2]*mlx.Array {
cfg := m.Config
half := cfg.HeadDim / 2
// Compute inv_freq
invFreqArr := make([]float32, half)
for i := int32(0); i < half; i++ {
invFreqArr[i] = float32(1.0 / math.Pow(float64(cfg.RopeTheta), 2.0*float64(i)/float64(cfg.HeadDim)))
}
invFreq := mlx.NewArray(invFreqArr, []int32{half})
// Process each position dimension
var cosAll, sinAll []*mlx.Array
for d := int32(0); d < 3; d++ {
// Get positions for this dimension: [B, L]
pos := mlx.Slice(posIDs, []int32{d, 0, 0}, []int32{d + 1, B, L})
pos = mlx.Squeeze(pos, 0) // [B, L]
posExp := mlx.ExpandDims(pos, 2) // [B, L, 1]
invFreqExp := mlx.Reshape(invFreq, 1, 1, half) // [1, 1, half]
freqs := mlx.Mul(posExp, invFreqExp) // [B, L, half]
emb := mlx.Tile(freqs, []int32{1, 1, 2}) // [B, L, D]
cosAll = append(cosAll, mlx.ExpandDims(mlx.Cos(emb), 0))
sinAll = append(sinAll, mlx.ExpandDims(mlx.Sin(emb), 0))
}
cos := mlx.Concatenate(cosAll, 0) // [3, B, L, D]
sin := mlx.Concatenate(sinAll, 0)
return [2]*mlx.Array{cos, sin}
}
// computeVisionRoPE computes RoPE embeddings for vision patches
// pH, pW: grid dimensions in patches
// Returns: [2]*mlx.Array containing (cos, sin) each of shape [numPatches, headDim]
func (m *Qwen25VL) computeVisionRoPE(pH, pW int32) [2]*mlx.Array {
cfg := m.Config
headDim := cfg.VisionHiddenSize / cfg.VisionNumHeads // 80 for 1280/16
halfDim := headDim / 2 // 40
quarterDim := halfDim / 2 // 20
spatialMerge := cfg.VisionSpatialMerge // 2
// Python Qwen2_5_VisionRotaryEmbedding uses dim=head_dim/2=40
// inv_freq = 1.0 / (theta ** (arange(0, dim, 2) / dim)) -> 20 elements
theta := float64(10000.0)
invFreqArr := make([]float32, quarterDim)
for i := int32(0); i < quarterDim; i++ {
invFreqArr[i] = float32(1.0 / math.Pow(theta, float64(2*i)/float64(halfDim)))
}
invFreq := mlx.NewArray(invFreqArr, []int32{quarterDim})
// Create position IDs matching Python's 2x2 block ordering:
// Python does: reshape(h//2, 2, w//2, 2), permute(0, 2, 1, 3), flatten
// This groups patches by 2x2 merged token blocks
numPatches := pH * pW
hPosArr := make([]float32, numPatches)
wPosArr := make([]float32, numPatches)
// Number of merged token blocks
llmGridH := pH / spatialMerge
llmGridW := pW / spatialMerge
idx := int32(0)
for hBlock := int32(0); hBlock < llmGridH; hBlock++ {
for wBlock := int32(0); wBlock < llmGridW; wBlock++ {
// Within each 2x2 block: (0,0), (0,1), (1,0), (1,1)
for dh := int32(0); dh < spatialMerge; dh++ {
for dw := int32(0); dw < spatialMerge; dw++ {
h := hBlock*spatialMerge + dh
w := wBlock*spatialMerge + dw
hPosArr[idx] = float32(h)
wPosArr[idx] = float32(w)
idx++
}
}
}
}
hPos := mlx.NewArray(hPosArr, []int32{numPatches, 1})
wPos := mlx.NewArray(wPosArr, []int32{numPatches, 1})
invFreqExp := mlx.Reshape(invFreq, 1, quarterDim)
// Compute freqs: [numPatches, quarterDim] for each of h and w
hFreqs := mlx.Mul(hPos, invFreqExp) // [L, 20]
wFreqs := mlx.Mul(wPos, invFreqExp) // [L, 20]
// Concatenate h and w freqs: [numPatches, halfDim] = [L, 40]
freqs := mlx.Concatenate([]*mlx.Array{hFreqs, wFreqs}, 1)
// Double for cos/sin application: [L, 40] -> [L, 80] = [L, headDim]
emb := mlx.Concatenate([]*mlx.Array{freqs, freqs}, 1)
cos := mlx.Cos(emb)
sin := mlx.Sin(emb)
return [2]*mlx.Array{cos, sin}
}
// VLTextBlock is a single Qwen2.5 transformer block (for VL model)
type VLTextBlock struct {
Attention *VLTextAttention
MLP *VLTextMLP
InputLayerNorm *mlx.Array
PostAttnLayerNorm *mlx.Array
NormEps float32
}
// newVLTextBlock creates a text block
func newVLTextBlock(weights *safetensors.ModelWeights, layerIdx int, cfg *Qwen25VLConfig) (*VLTextBlock, error) {
prefix := fmt.Sprintf("model.layers.%d", layerIdx)
inputNorm, err := weights.Get(prefix + ".input_layernorm.weight")
if err != nil {
return nil, err
}
postAttnNorm, err := weights.Get(prefix + ".post_attention_layernorm.weight")
if err != nil {
return nil, err
}
attention, err := newVLTextAttention(weights, prefix, cfg)
if err != nil {
return nil, err
}
mlpLayer, err := newVLTextMLP(weights, prefix)
if err != nil {
return nil, err
}
return &VLTextBlock{
Attention: attention,
MLP: mlpLayer,
InputLayerNorm: inputNorm,
PostAttnLayerNorm: postAttnNorm,
NormEps: cfg.RMSNormEps,
}, nil
}
// Forward applies the block
func (tb *VLTextBlock) Forward(x *mlx.Array, cossin [2]*mlx.Array) *mlx.Array {
h := mlx.RMSNorm(x, tb.InputLayerNorm, tb.NormEps)
attnOut := tb.Attention.Forward(h, cossin)
x = mlx.Add(x, attnOut)
h = mlx.RMSNorm(x, tb.PostAttnLayerNorm, tb.NormEps)
mlpOut := tb.MLP.Forward(h)
x = mlx.Add(x, mlpOut)
return x
}
// VLTextAttention implements Qwen2.5 attention with M-RoPE
type VLTextAttention struct {
QProj *mlx.Array
KProj *mlx.Array
VProj *mlx.Array
OProj *mlx.Array
QBias *mlx.Array
KBias *mlx.Array
VBias *mlx.Array
NHeads int32
NKVHeads int32
HeadDim int32
Scale float32
MRoPESection []int32
}
// newVLTextAttention creates a text attention layer
func newVLTextAttention(weights *safetensors.ModelWeights, prefix string, cfg *Qwen25VLConfig) (*VLTextAttention, error) {
qProj, err := weights.Get(prefix + ".self_attn.q_proj.weight")
if err != nil {
return nil, err
}
kProj, err := weights.Get(prefix + ".self_attn.k_proj.weight")
if err != nil {
return nil, err
}
vProj, err := weights.Get(prefix + ".self_attn.v_proj.weight")
if err != nil {
return nil, err
}
oProj, err := weights.Get(prefix + ".self_attn.o_proj.weight")
if err != nil {
return nil, err
}
qBias, _ := weights.Get(prefix + ".self_attn.q_proj.bias")
kBias, _ := weights.Get(prefix + ".self_attn.k_proj.bias")
vBias, _ := weights.Get(prefix + ".self_attn.v_proj.bias")
return &VLTextAttention{
QProj: mlx.Transpose(qProj, 1, 0),
KProj: mlx.Transpose(kProj, 1, 0),
VProj: mlx.Transpose(vProj, 1, 0),
OProj: mlx.Transpose(oProj, 1, 0),
QBias: qBias,
KBias: kBias,
VBias: vBias,
NHeads: cfg.NumAttentionHeads,
NKVHeads: cfg.NumKeyValueHeads,
HeadDim: cfg.HeadDim,
Scale: float32(1.0 / math.Sqrt(float64(cfg.HeadDim))),
MRoPESection: cfg.MRoPESection,
}, nil
}
// Forward computes attention
func (attn *VLTextAttention) Forward(x *mlx.Array, cossin [2]*mlx.Array) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
q := mlx.Linear(x, attn.QProj)
if attn.QBias != nil {
q = mlx.Add(q, attn.QBias)
}
k := mlx.Linear(x, attn.KProj)
if attn.KBias != nil {
k = mlx.Add(k, attn.KBias)
}
v := mlx.Linear(x, attn.VProj)
if attn.VBias != nil {
v = mlx.Add(v, attn.VBias)
}
q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
q = mlx.Transpose(q, 0, 2, 1, 3)
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
// Apply M-RoPE
if cossin[0] != nil && cossin[1] != nil {
q = applyMRoPE(q, cossin[0], cossin[1], attn.MRoPESection)
k = applyMRoPE(k, cossin[0], cossin[1], attn.MRoPESection)
}
// Repeat KV for GQA
if attn.NKVHeads < attn.NHeads {
repeats := attn.NHeads / attn.NKVHeads
k = repeatKV(k, repeats)
v = repeatKV(v, repeats)
}
out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, true)
out = mlx.Transpose(out, 0, 2, 1, 3)
out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
return mlx.Linear(out, attn.OProj)
}
// applyMRoPE applies Multi-Resolution RoPE
func applyMRoPE(x *mlx.Array, cos, sin *mlx.Array, section []int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
L := shape[2]
D := shape[3]
half := D / 2
fullSection := make([]int32, len(section))
for i, s := range section {
fullSection[i] = s * 2
}
var cosParts, sinParts []*mlx.Array
offset := int32(0)
for i, size := range fullSection {
posDim := int32(i % 3)
cosSection := mlx.Slice(cos, []int32{posDim, 0, 0, offset}, []int32{posDim + 1, B, L, offset + size})
sinSection := mlx.Slice(sin, []int32{posDim, 0, 0, offset}, []int32{posDim + 1, B, L, offset + size})
cosSection = mlx.Squeeze(cosSection, 0)
sinSection = mlx.Squeeze(sinSection, 0)
cosParts = append(cosParts, cosSection)
sinParts = append(sinParts, sinSection)
offset += size
}
cosFlat := mlx.Concatenate(cosParts, 2)
sinFlat := mlx.Concatenate(sinParts, 2)
cosFlat = mlx.Reshape(cosFlat, B, 1, L, D)
sinFlat = mlx.Reshape(sinFlat, B, 1, L, D)
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, H, L, half})
x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, H, L, D})
negX2 := mlx.MulScalar(x2, -1)
rotatedX := mlx.Concatenate([]*mlx.Array{negX2, x1}, 3)
return mlx.Add(mlx.Mul(x, cosFlat), mlx.Mul(rotatedX, sinFlat))
}
// repeatKV repeats key/value heads for GQA
func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
if repeats == 1 {
return x
}
shape := x.Shape()
x = mlx.ExpandDims(x, 2)
x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
}
// VLTextMLP implements Qwen2.5 SwiGLU MLP
type VLTextMLP struct {
GateProj *mlx.Array
UpProj *mlx.Array
DownProj *mlx.Array
}
// newVLTextMLP creates a text MLP layer
func newVLTextMLP(weights *safetensors.ModelWeights, prefix string) (*VLTextMLP, error) {
gateProj, err := weights.Get(prefix + ".mlp.gate_proj.weight")
if err != nil {
return nil, err
}
upProj, err := weights.Get(prefix + ".mlp.up_proj.weight")
if err != nil {
return nil, err
}
downProj, err := weights.Get(prefix + ".mlp.down_proj.weight")
if err != nil {
return nil, err
}
return &VLTextMLP{
GateProj: mlx.Transpose(gateProj, 1, 0),
UpProj: mlx.Transpose(upProj, 1, 0),
DownProj: mlx.Transpose(downProj, 1, 0),
}, nil
}
// Forward applies the SwiGLU MLP
func (mlp *VLTextMLP) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.Linear(x, mlp.GateProj)
gate = mlx.SiLU(gate)
up := mlx.Linear(x, mlp.UpProj)
h := mlx.Mul(gate, up)
return mlx.Linear(h, mlp.DownProj)
}
// VisionPatchEmbed embeds image patches
type VisionPatchEmbed struct {
ProjWeight *mlx.Array
ProjBias *mlx.Array
PatchSize int32
}
// newVisionPatchEmbed creates a vision patch embed layer
func newVisionPatchEmbed(weights *safetensors.ModelWeights, cfg *Qwen25VLConfig) (*VisionPatchEmbed, error) {
projWeight, err := weights.Get("visual.patch_embed.proj.weight")
if err != nil {
return nil, err
}
projBias, _ := weights.Get("visual.patch_embed.proj.bias")
return &VisionPatchEmbed{
ProjWeight: projWeight,
ProjBias: projBias,
PatchSize: cfg.VisionPatchSize,
}, nil
}
// Forward embeds patches from an image
// image: [B, C, H, W]
// Returns: [B, num_patches, hidden_size]
func (pe *VisionPatchEmbed) Forward(image *mlx.Array) *mlx.Array {
// Qwen2.5-VL uses 3D conv for patch embedding to support video
// Weight shape is [O, I, kT, kH, kW] e.g. [1280, 3, 2, 14, 14]
// For single image, we duplicate the frame to match temporal_patch_size
wShape := pe.ProjWeight.Shape()
if len(wShape) == 5 {
// 3D convolution case
temporalPatchSize := wShape[2] // kT from weight shape
// Add temporal dimension: [B, C, H, W] -> [B, C, 1, H, W]
image = mlx.ExpandDims(image, 2)
// Duplicate frame to match temporal_patch_size (Python does this for single images)
// [B, C, 1, H, W] -> [B, C, T, H, W] where T = temporal_patch_size
if temporalPatchSize > 1 {
image = mlx.Tile(image, []int32{1, 1, temporalPatchSize, 1, 1})
}
// Convert to channels-last: [B, C, T, H, W] -> [B, T, H, W, C]
image = mlx.Transpose(image, 0, 2, 3, 4, 1)
// Weight is [O, I, kT, kH, kW] - keep as-is since patches are now in [I, kT, kH, kW] order
// (extractPatches3DStrided transposes each patch to [C, T, H, W] to match Python)
// Apply 3D conv using manual patch extraction
// Strides: (temporal_patch_size, patch_size, patch_size)
x := conv3DStrided(image, pe.ProjWeight, temporalPatchSize, pe.PatchSize, pe.PatchSize)
if pe.ProjBias != nil {
outC := pe.ProjBias.Dim(0)
bias := mlx.Reshape(pe.ProjBias, 1, 1, 1, 1, outC)
x = mlx.Add(x, bias)
}
// x is [B, T', H', W', C], squeeze T' and flatten spatial
shape := x.Shape()
// T' should be 1 for single image (since we used stride=temporal_patch_size)
x = mlx.Reshape(x, shape[0], shape[2]*shape[3], shape[4])
return x
}
// Original 2D case (fallback)
// Convert to channels-last for Conv2d
image = mlx.Transpose(image, 0, 2, 3, 1) // [B, H, W, C]
// Apply conv with stride=patch_size using manual strided convolution
weight := mlx.Transpose(pe.ProjWeight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
x := conv2DStrided(image, weight, pe.PatchSize)
if pe.ProjBias != nil {
bias := mlx.Reshape(pe.ProjBias, 1, 1, 1, pe.ProjBias.Dim(0))
x = mlx.Add(x, bias)
}
// Flatten patches: [B, pH, pW, C] -> [B, pH*pW, C]
shape := x.Shape()
x = mlx.Reshape(x, shape[0], shape[1]*shape[2], shape[3])
return x
}
// VisionBlock is a single vision transformer block
type VisionBlock struct {
Norm1 *mlx.Array
Norm2 *mlx.Array
Attention *VisionAttention
MLP *VisionMLP
}
// newVisionBlock creates a vision block
func newVisionBlock(weights *safetensors.ModelWeights, layerIdx int, cfg *Qwen25VLConfig) (*VisionBlock, error) {
prefix := fmt.Sprintf("visual.blocks.%d", layerIdx)
norm1, err := weights.Get(prefix + ".norm1.weight")
if err != nil {
return nil, err
}
norm2, err := weights.Get(prefix + ".norm2.weight")
if err != nil {
return nil, err
}
attention, err := newVisionAttention(weights, prefix, cfg)
if err != nil {
return nil, err
}
mlpLayer, err := newVisionMLP(weights, prefix, cfg)
if err != nil {
return nil, err
}
return &VisionBlock{
Norm1: norm1,
Norm2: norm2,
Attention: attention,
MLP: mlpLayer,
}, nil
}
// Forward applies the vision block
// posEmb: [2]*mlx.Array containing (cos, sin) for RoPE, can be nil
// cuSeqlens: cumulative sequence lengths for window attention
func (vb *VisionBlock) Forward(x *mlx.Array, posEmb [2]*mlx.Array, cuSeqlens []int32) *mlx.Array {
// Python uses RMSNorm, not LayerNorm!
h := mlx.RMSNormNoWeight(x, 1e-6)
h = mlx.Mul(h, vb.Norm1)
attnOut := vb.Attention.Forward(h, posEmb, cuSeqlens)
x = mlx.Add(x, attnOut)
h = mlx.RMSNormNoWeight(x, 1e-6)
h = mlx.Mul(h, vb.Norm2)
mlpOut := vb.MLP.Forward(h)
x = mlx.Add(x, mlpOut)
return x
}
// VisionAttention implements vision attention
type VisionAttention struct {
QKVProj *mlx.Array
QKVBias *mlx.Array
OutProj *mlx.Array
OutBias *mlx.Array
NHeads int32
HeadDim int32
Scale float32
}
// newVisionAttention creates a vision attention layer
func newVisionAttention(weights *safetensors.ModelWeights, prefix string, cfg *Qwen25VLConfig) (*VisionAttention, error) {
qkvProj, err := weights.Get(prefix + ".attn.qkv.weight")
if err != nil {
return nil, err
}
qkvBias, _ := weights.Get(prefix + ".attn.qkv.bias")
outProj, err := weights.Get(prefix + ".attn.proj.weight")
if err != nil {
return nil, err
}
outBias, _ := weights.Get(prefix + ".attn.proj.bias")
headDim := cfg.VisionHiddenSize / cfg.VisionNumHeads
return &VisionAttention{
QKVProj: mlx.Transpose(qkvProj, 1, 0),
QKVBias: qkvBias,
OutProj: mlx.Transpose(outProj, 1, 0),
OutBias: outBias,
NHeads: cfg.VisionNumHeads,
HeadDim: headDim,
Scale: float32(1.0 / math.Sqrt(float64(headDim))),
}, nil
}
// Forward applies vision attention with optional RoPE and window attention
// posEmb: [2]*mlx.Array containing (cos, sin) for RoPE, can be nil
// cuSeqlens: cumulative sequence lengths for window boundaries
func (attn *VisionAttention) Forward(x *mlx.Array, posEmb [2]*mlx.Array, cuSeqlens []int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
qkv := mlx.Linear(x, attn.QKVProj)
if attn.QKVBias != nil {
qkv = mlx.Add(qkv, attn.QKVBias)
}
// Split into Q, K, V
qkv = mlx.Reshape(qkv, B, L, 3, attn.NHeads, attn.HeadDim)
q := mlx.Slice(qkv, []int32{0, 0, 0, 0, 0}, []int32{B, L, 1, attn.NHeads, attn.HeadDim})
k := mlx.Slice(qkv, []int32{0, 0, 1, 0, 0}, []int32{B, L, 2, attn.NHeads, attn.HeadDim})
v := mlx.Slice(qkv, []int32{0, 0, 2, 0, 0}, []int32{B, L, 3, attn.NHeads, attn.HeadDim})
q = mlx.Squeeze(q, 2) // [B, L, H, D]
k = mlx.Squeeze(k, 2)
v = mlx.Squeeze(v, 2)
// Apply RoPE if position embeddings provided
if posEmb[0] != nil && posEmb[1] != nil {
q, k = applyVisionRoPE(q, k, posEmb[0], posEmb[1])
}
q = mlx.Transpose(q, 0, 2, 1, 3) // [B, H, L, D]
k = mlx.Transpose(k, 0, 2, 1, 3)
v = mlx.Transpose(v, 0, 2, 1, 3)
var out *mlx.Array
// Check if we need window attention (more than 1 window)
numWindows := len(cuSeqlens) - 1
if numWindows <= 1 {
// Full attention - single window covering entire sequence
out = mlx.ScaledDotProductAttention(q, k, v, attn.Scale, false)
} else {
// Window attention - process each window separately
attnOutputs := make([]*mlx.Array, numWindows)
for w := 0; w < numWindows; w++ {
start := cuSeqlens[w]
end := cuSeqlens[w+1]
// Slice Q, K, V for this window: [B, H, winLen, D]
qWin := mlx.Slice(q, []int32{0, 0, start, 0}, []int32{B, attn.NHeads, end, attn.HeadDim})
kWin := mlx.Slice(k, []int32{0, 0, start, 0}, []int32{B, attn.NHeads, end, attn.HeadDim})
vWin := mlx.Slice(v, []int32{0, 0, start, 0}, []int32{B, attn.NHeads, end, attn.HeadDim})
// Compute attention for this window
attnWin := mlx.ScaledDotProductAttention(qWin, kWin, vWin, attn.Scale, false)
attnOutputs[w] = attnWin
}
// Concatenate all window outputs along sequence dimension
out = mlx.Concatenate(attnOutputs, 2)
}
out = mlx.Transpose(out, 0, 2, 1, 3) // [B, L, H, D]
out = mlx.Reshape(out, B, L, D)
out = mlx.Linear(out, attn.OutProj)
if attn.OutBias != nil {
out = mlx.Add(out, attn.OutBias)
}
return out
}
// applyVisionRoPE applies rotary position embedding to Q and K for vision
// q, k: [B, L, H, D], cos, sin: [L, D] (already doubled: D = head_dim)
// Returns: rotated q, k with same shape
// Note: Python does this computation in float32 for numerical stability
func applyVisionRoPE(q, k, cos, sin *mlx.Array) (*mlx.Array, *mlx.Array) {
// Convert to float32 for numerical stability (matches Python)
origDtype := q.Dtype()
q = mlx.AsType(q, mlx.DtypeFloat32)
k = mlx.AsType(k, mlx.DtypeFloat32)
cos = mlx.AsType(cos, mlx.DtypeFloat32)
sin = mlx.AsType(sin, mlx.DtypeFloat32)
// Expand cos/sin to match q/k shape: [L, D] -> [1, L, 1, D]
cos = mlx.ExpandDims(cos, 0)
cos = mlx.ExpandDims(cos, 2)
sin = mlx.ExpandDims(sin, 0)
sin = mlx.ExpandDims(sin, 2)
// rotate_half: split last dim in half and swap with negation
// q_rot = q * cos + rotate_half(q) * sin
qRotated := rotateHalf(q)
kRotated := rotateHalf(k)
qOut := mlx.Add(mlx.Mul(q, cos), mlx.Mul(qRotated, sin))
kOut := mlx.Add(mlx.Mul(k, cos), mlx.Mul(kRotated, sin))
// Convert back to original dtype
qOut = mlx.AsType(qOut, origDtype)
kOut = mlx.AsType(kOut, origDtype)
return qOut, kOut
}
// rotateHalf rotates the last dimension by splitting in half and swapping with negation
// x: [..., D] -> split to [..., D/2] and [..., D/2], then concat(-x2, x1)
func rotateHalf(x *mlx.Array) *mlx.Array {
shape := x.Shape()
lastDim := shape[len(shape)-1]
halfDim := lastDim / 2
// Split into two halves
x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], halfDim})
x2 := mlx.Slice(x, []int32{0, 0, 0, halfDim}, []int32{shape[0], shape[1], shape[2], lastDim})
// Negate x2 and concatenate
x2Neg := mlx.MulScalar(x2, -1.0)
return mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3)
}
// VisionMLP implements vision SwiGLU MLP
type VisionMLP struct {
GateProj *mlx.Array
GateProjBias *mlx.Array
UpProj *mlx.Array
UpProjBias *mlx.Array
DownProj *mlx.Array
DownProjBias *mlx.Array
}
// newVisionMLP creates a vision MLP layer
func newVisionMLP(weights *safetensors.ModelWeights, prefix string, cfg *Qwen25VLConfig) (*VisionMLP, error) {
gateProj, err := weights.Get(prefix + ".mlp.gate_proj.weight")
if err != nil {
return nil, err
}
gateProjBias, _ := weights.Get(prefix + ".mlp.gate_proj.bias")
upProj, err := weights.Get(prefix + ".mlp.up_proj.weight")
if err != nil {
return nil, err
}
upProjBias, _ := weights.Get(prefix + ".mlp.up_proj.bias")
downProj, err := weights.Get(prefix + ".mlp.down_proj.weight")
if err != nil {
return nil, err
}
downProjBias, _ := weights.Get(prefix + ".mlp.down_proj.bias")
return &VisionMLP{
GateProj: mlx.Transpose(gateProj, 1, 0),
GateProjBias: gateProjBias,
UpProj: mlx.Transpose(upProj, 1, 0),
UpProjBias: upProjBias,
DownProj: mlx.Transpose(downProj, 1, 0),
DownProjBias: downProjBias,
}, nil
}
// Forward applies the vision SwiGLU MLP
func (m *VisionMLP) Forward(x *mlx.Array) *mlx.Array {
gate := mlx.Linear(x, m.GateProj)
if m.GateProjBias != nil {
gate = mlx.Add(gate, m.GateProjBias)
}
gate = mlx.SiLU(gate)
up := mlx.Linear(x, m.UpProj)
if m.UpProjBias != nil {
up = mlx.Add(up, m.UpProjBias)
}
h := mlx.Mul(gate, up)
h = mlx.Linear(h, m.DownProj)
if m.DownProjBias != nil {
h = mlx.Add(h, m.DownProjBias)
}
return h
}
// VisionMerger merges spatial patches (2x2 -> 1)
type VisionMerger struct {
MLP0Weight *mlx.Array
MLP0Bias *mlx.Array
MLP2Weight *mlx.Array
MLP2Bias *mlx.Array
LNWeight *mlx.Array
}
// newVisionMerger creates a vision merger
func newVisionMerger(weights *safetensors.ModelWeights, cfg *Qwen25VLConfig) (*VisionMerger, error) {
mlp0Weight, err := weights.Get("visual.merger.mlp.0.weight")
if err != nil {
return nil, err
}
mlp0Bias, _ := weights.Get("visual.merger.mlp.0.bias")
mlp2Weight, err := weights.Get("visual.merger.mlp.2.weight")
if err != nil {
return nil, err
}
mlp2Bias, _ := weights.Get("visual.merger.mlp.2.bias")
lnWeight, _ := weights.Get("visual.merger.ln_q.weight")
return &VisionMerger{
MLP0Weight: mlx.Transpose(mlp0Weight, 1, 0),
MLP0Bias: mlp0Bias,
MLP2Weight: mlx.Transpose(mlp2Weight, 1, 0),
MLP2Bias: mlp2Bias,
LNWeight: lnWeight,
}, nil
}
// Forward merges 2x2 patches into 1 (assumes square grid - use ForwardWithDims for non-square)
func (m *VisionMerger) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
L := shape[1]
side := int32(math.Sqrt(float64(L)))
return m.ForwardWithDims(x, side, side)
}
// ForwardWithDims merges 2x2 patches into 1 with explicit grid dimensions
// After window reordering, consecutive 4 patches form a 2x2 block, so we just
// reshape [B, L, D] -> [B, L/4, 4*D] without 2D spatial rearrangement.
func (m *VisionMerger) ForwardWithDims(x *mlx.Array, pH, pW int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
L := shape[1]
D := shape[2]
// RMSNorm BEFORE merge (applied to each token with D dimensions)
// Python: ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
if m.LNWeight != nil {
x = mlx.RMSNormNoWeight(x, 1e-6)
x = mlx.Mul(x, m.LNWeight)
}
// After window reordering, consecutive 4 patches belong to a 2x2 block
// Just reshape to [B, L/4, 4*D] - no spatial rearrangement needed
newL := L / 4
x = mlx.Reshape(x, B, newL, 4*D)
// MLP
h := mlx.Linear(x, m.MLP0Weight)
if m.MLP0Bias != nil {
h = mlx.Add(h, m.MLP0Bias)
}
h = mlx.GELU(h)
h = mlx.Linear(h, m.MLP2Weight)
if m.MLP2Bias != nil {
h = mlx.Add(h, m.MLP2Bias)
}
return h
}
// LoadQwen25VLFromPath loads the encoder from path
func LoadQwen25VLFromPath(path string) (*Qwen25VL, error) {
m := &Qwen25VL{}
if err := m.Load(filepath.Join(path, "text_encoder")); err != nil {
return nil, err
}
return m, nil
}
// conv2DStrided applies conv with stride > 1 using manual patch extraction
// x: [B, H, W, C] (channels-last), weight: [O, kH, kW, I]
func conv2DStrided(x, weight *mlx.Array, stride int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
wShape := weight.Shape()
Cout := wShape[0]
kH := wShape[1]
kW := wShape[2]
outH := (H - kH) / stride + 1
outW := (W - kW) / stride + 1
patches := extractPatches2DStrided(x, kH, kW, stride)
wFlat := mlx.Reshape(weight, Cout, -1)
patches = mlx.Reshape(patches, B*outH*outW, -1)
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
return mlx.Reshape(out, B, outH, outW, Cout)
}
// conv3DStrided applies 3D conv with strides using manual patch extraction
// x: [B, T, H, W, C] (channels-last), weight: [O, I, kT, kH, kW] (PyTorch format)
// strideT, strideH, strideW are the strides for each dimension
// Patches are extracted in [C, T, H, W] order to match Python's preprocessing
func conv3DStrided(x, weight *mlx.Array, strideT, strideH, strideW int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
wShape := weight.Shape()
Cout := wShape[0]
// I := wShape[1]
kT := wShape[2]
kH := wShape[3]
kW := wShape[4]
// For temporal: if T < kT, we need to repeat frames temporally
// For single image with T=1 and kT=2, we duplicate the frame to T=kT
// Python Qwen2.5-VL duplicates the frame, not zero-pads
if T < kT {
// Tile along T dimension: [B, T, H, W, C] -> [B, kT, H, W, C]
x = mlx.Tile(x, []int32{1, kT, 1, 1, 1})
T = kT
}
outT := (T - kT) / strideT + 1
outH := (H - kH) / strideH + 1
outW := (W - kW) / strideW + 1
// Extract 3D patches in [C, T, H, W] order to match Python
patches := extractPatches3DStrided(x, kT, kH, kW, strideT, strideH, strideW)
// patches shape: [B, outT, outH, outW, C*kT*kH*kW]
// Weight is [O, I, kT, kH, kW] - flatten to [O, I*kT*kH*kW] to match patch order [C, T, H, W]
wFlat := mlx.Reshape(weight, Cout, -1) // [Cout, I*kT*kH*kW]
patches = mlx.Reshape(patches, B*outT*outH*outW, C*kT*kH*kW)
out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
return mlx.Reshape(out, B, outT, outH, outW, Cout)
}
// extractPatches3DStrided extracts 3D patches with given strides
// Returns patches with values in [C, T, H, W] order to match Python's preprocessing
func extractPatches3DStrided(x *mlx.Array, kT, kH, kW, strideT, strideH, strideW int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
T := shape[1]
H := shape[2]
W := shape[3]
C := shape[4]
outT := (T - kT) / strideT + 1
outH := (H - kH) / strideH + 1
outW := (W - kW) / strideW + 1
numPatches := outT * outH * outW
patches := make([]*mlx.Array, numPatches)
idx := 0
for t := int32(0); t < outT; t++ {
for i := int32(0); i < outH; i++ {
for j := int32(0); j < outW; j++ {
startT := t * strideT
startH := i * strideH
startW := j * strideW
// Extract patch: [B, kT, kH, kW, C]
patch := mlx.Slice(x,
[]int32{0, startT, startH, startW, 0},
[]int32{B, startT + kT, startH + kH, startW + kW, C})
// Transpose from [B, T, H, W, C] to [B, C, T, H, W] to match Python's order
patch = mlx.Transpose(patch, 0, 4, 1, 2, 3)
// Flatten to [B, C*T*H*W]
patch = mlx.Reshape(patch, B, C*kT*kH*kW)
patches[idx] = patch
idx++
}
}
}
for i := range patches {
patches[i] = mlx.ExpandDims(patches[i], 1)
}
stacked := mlx.Concatenate(patches, 1)
return mlx.Reshape(stacked, B, outT, outH, outW, C*kT*kH*kW)
}
// extractPatches2DStrided extracts patches with given stride
func extractPatches2DStrided(x *mlx.Array, kH, kW, stride int32) *mlx.Array {
shape := x.Shape()
B := shape[0]
H := shape[1]
W := shape[2]
C := shape[3]
outH := (H - kH) / stride + 1
outW := (W - kW) / stride + 1
patches := make([]*mlx.Array, outH*outW)
idx := 0
for i := int32(0); i < outH; i++ {
for j := int32(0); j < outW; j++ {
startH := i * stride
startW := j * stride
patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
patch = mlx.Reshape(patch, B, kH*kW*C)
patches[idx] = patch
idx++
}
}
for i := range patches {
patches[i] = mlx.ExpandDims(patches[i], 1)
}
stacked := mlx.Concatenate(patches, 1)
return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
}
//go:build mlx
// Package qwen_image implements the Qwen-Image diffusion transformer model.
package qwen_image
import (
"context"
"fmt"
"path/filepath"
"time"
"github.com/ollama/ollama/x/imagegen/cache"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
)
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
NegativePrompt string // Empty = no CFG
CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
Width int32 // Image width (default: 1024)
Height int32 // Image height (default: 1024)
Steps int // Denoising steps (default: 30)
Seed int64 // Random seed
Progress ProgressFunc // Optional progress callback
// Layer caching (DeepCache/Learning-to-Cache speedup)
LayerCache bool // Enable layer caching (default: false)
CacheInterval int // Refresh cache every N steps (default: 3)
CacheLayers int // Number of shallow layers to cache (default: 25)
}
// ProgressFunc is called during generation with step progress.
type ProgressFunc func(step, totalSteps int)
// Model represents a Qwen-Image diffusion model.
type Model struct {
ModelPath string
Tokenizer *tokenizer.Tokenizer
TextEncoder *Qwen25VL
Transformer *Transformer
VAEDecoder *VAEDecoder
}
// Load loads the Qwen-Image model from a directory.
func (m *Model) Load(modelPath string) error {
fmt.Println("Loading Qwen-Image model...")
start := time.Now()
if mlx.GPUIsAvailable() {
mlx.SetDefaultDeviceGPU()
mlx.EnableCompile()
}
m.ModelPath = modelPath
// Load tokenizer
fmt.Print(" Loading tokenizer... ")
tokenizerPath := filepath.Join(modelPath, "tokenizer")
tok, err := tokenizer.Load(tokenizerPath)
if err != nil {
return fmt.Errorf("tokenizer: %w", err)
}
m.Tokenizer = tok
fmt.Println("✓")
// Load text encoder (Qwen2.5-VL in text-only mode - skip vision tower for efficiency)
m.TextEncoder = &Qwen25VL{}
if err := m.TextEncoder.LoadTextOnly(filepath.Join(modelPath, "text_encoder")); err != nil {
return fmt.Errorf("text encoder: %w", err)
}
mlx.Eval(mlx.Collect(m.TextEncoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load transformer
m.Transformer = &Transformer{}
if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
return fmt.Errorf("transformer: %w", err)
}
mlx.Eval(mlx.Collect(m.Transformer)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
// Load VAE decoder
m.VAEDecoder = &VAEDecoder{}
if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
return fmt.Errorf("VAE decoder: %w", err)
}
mlx.Eval(mlx.Collect(m.VAEDecoder)...)
fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
mem := mlx.MetalGetActiveMemory()
peak := mlx.MetalGetPeakMemory()
fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
time.Since(start).Seconds(),
float64(mem)/(1024*1024*1024),
float64(peak)/(1024*1024*1024))
return nil
}
// Generate creates an image from a prompt.
func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
})
}
// GenerateWithProgress creates an image with progress callback.
func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateWithCFG creates an image with classifier-free guidance.
func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress ProgressFunc) (*mlx.Array, error) {
return m.GenerateFromConfig(&GenerateConfig{
Prompt: prompt,
NegativePrompt: negativePrompt,
CFGScale: cfgScale,
Width: width,
Height: height,
Steps: steps,
Seed: seed,
Progress: progress,
})
}
// GenerateFromConfig generates an image using the unified config struct.
func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
start := time.Now()
result, err := m.generate(cfg)
if err != nil {
return nil, err
}
if cfg.NegativePrompt != "" {
fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
} else {
fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
}
return result, nil
}
// GenerateImage implements model.ImageModel interface.
func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
return m.Generate(prompt, width, height, steps, seed)
}
// generate is the internal denoising pipeline.
func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
// Apply defaults
if cfg.Width <= 0 {
cfg.Width = 1024
}
if cfg.Height <= 0 {
cfg.Height = 1024
}
if cfg.Steps <= 0 {
cfg.Steps = 30
}
if cfg.CFGScale <= 0 {
cfg.CFGScale = 4.0
}
if cfg.CacheInterval <= 0 {
cfg.CacheInterval = 3
}
if cfg.CacheLayers <= 0 {
cfg.CacheLayers = 25 // ~42% of 60 layers (similar ratio to Z-Image's 15/38)
}
useCFG := cfg.NegativePrompt != ""
tcfg := m.Transformer.Config
latentH := cfg.Height / 8
latentW := cfg.Width / 8
pH := latentH / tcfg.PatchSize
pW := latentW / tcfg.PatchSize
imgSeqLen := pH * pW
// Text encoding
var posEmb, negEmb *mlx.Array
{
posEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
if useCFG {
negEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt)
mlx.Keep(posEmb, negEmb)
mlx.Eval(posEmb, negEmb)
} else {
mlx.Keep(posEmb)
mlx.Eval(posEmb)
}
}
// Pad sequences to same length for CFG
txtLen := posEmb.Shape()[1]
if useCFG {
negLen := negEmb.Shape()[1]
if negLen > txtLen {
txtLen = negLen
}
if posEmb.Shape()[1] < txtLen {
posEmb = padSequence(posEmb, txtLen)
}
if negEmb.Shape()[1] < txtLen {
negEmb = padSequence(negEmb, txtLen)
}
mlx.Keep(posEmb, negEmb)
}
// Scheduler
scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
// Init latents [B, C, T, H, W]
var latents *mlx.Array
{
latents = scheduler.InitNoise([]int32{1, tcfg.OutChannels, 1, latentH, latentW}, cfg.Seed)
mlx.Eval(latents)
}
// RoPE cache
var ropeCache *RoPECache
{
ropeCache = PrepareRoPE(pH, pW, txtLen, tcfg.AxesDimsRope)
mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
mlx.Eval(ropeCache.ImgFreqs)
}
// Layer cache for DeepCache/Learning-to-Cache speedup
var stepCache *cache.StepCache
if cfg.LayerCache {
stepCache = cache.NewStepCache(cfg.CacheLayers)
fmt.Printf(" Layer caching: %d layers, refresh every %d steps\n", cfg.CacheLayers, cfg.CacheInterval)
}
// Denoising loop
for i := 0; i < cfg.Steps; i++ {
stepStart := time.Now()
if cfg.Progress != nil {
cfg.Progress(i+1, cfg.Steps)
}
t := scheduler.Timesteps[i]
timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
// Squeeze temporal dim: [B, C, T, H, W] -> [B, C, H, W]
latents2D := mlx.Squeeze(latents, 2)
patches := PackLatents(latents2D, tcfg.PatchSize)
var output *mlx.Array
if useCFG {
// True CFG: run twice and combine with norm rescaling
// Note: layer caching with CFG is not supported yet (would need 2 caches)
posOutput := m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
negOutput := m.Transformer.Forward(patches, negEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
diff := mlx.Sub(posOutput, negOutput)
scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
combPred := mlx.Add(negOutput, scaledDiff)
// Norm rescaling: rescale combined prediction to match conditional prediction's norm
condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posOutput), -1, true))
combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
output = mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
} else if stepCache != nil {
output = m.Transformer.ForwardWithCache(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs,
stepCache, i, cfg.CacheInterval, cfg.CacheLayers)
} else {
output = m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
}
noisePred := UnpackLatents(output, latentH, latentW, tcfg.PatchSize)
oldLatents := latents
latents = scheduler.Step(noisePred, latents, i)
// Keep cached arrays alive across cleanup
if stepCache != nil {
mlx.Keep(stepCache.Arrays()...)
}
mlx.Eval(latents)
oldLatents.Free()
activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds(), activeMem, peakMem)
}
// Free denoising temporaries before VAE decode
posEmb.Free()
if negEmb != nil {
negEmb.Free()
}
ropeCache.ImgFreqs.Free()
ropeCache.TxtFreqs.Free()
if stepCache != nil {
stepCache.Free()
}
// VAE decode (Decode manages its own pools for staged memory)
decoded := m.VAEDecoder.Decode(latents)
latents.Free()
// Post-process: squeeze temporal dim and rescale to [0, 1]
{
decoded = mlx.Squeeze(decoded, 2)
decoded = mlx.AddScalar(decoded, 1.0)
decoded = mlx.DivScalar(decoded, 2.0)
mlx.Eval(decoded)
}
fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
return decoded, nil
}
// padSequence pads a sequence tensor to the target length with zeros
func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
shape := x.Shape()
currentLen := shape[1]
if currentLen >= targetLen {
return x
}
padLen := targetLen - currentLen
// Pad on sequence dimension (axis 1)
return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
}
// LoadPersistent is an alias for backward compatibility.
// Use m := &Model{}; m.Load(path) instead.
func LoadPersistent(modelPath string) (*Model, error) {
m := &Model{}
if err := m.Load(modelPath); err != nil {
return nil, err
}
return m, nil
}
//go:build mlx
package qwen_image
import (
"math"
"github.com/ollama/ollama/x/imagegen/mlx"
)
// SchedulerConfig holds FlowMatchEulerDiscreteScheduler configuration
type SchedulerConfig struct {
NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
BaseShift float32 `json:"base_shift"` // 0.5
MaxShift float32 `json:"max_shift"` // 0.9
BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
MaxImageSeqLen int32 `json:"max_image_seq_len"` // 8192
ShiftTerminal float32 `json:"shift_terminal"` // 0.02
UseDynamicShift bool `json:"use_dynamic_shifting"` // true
}
// DefaultSchedulerConfig returns config for FlowMatchEulerDiscreteScheduler
func DefaultSchedulerConfig() *SchedulerConfig {
return &SchedulerConfig{
NumTrainTimesteps: 1000,
BaseShift: 0.5,
MaxShift: 0.9, // Matches scheduler_config.json
BaseImageSeqLen: 256,
MaxImageSeqLen: 8192,
ShiftTerminal: 0.02,
UseDynamicShift: true,
}
}
// FlowMatchScheduler implements the Flow Match Euler discrete scheduler
type FlowMatchScheduler struct {
Config *SchedulerConfig
Timesteps []float32
Sigmas []float32
NumSteps int
}
// NewFlowMatchScheduler creates a new scheduler
func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
return &FlowMatchScheduler{
Config: cfg,
}
}
// CalculateShift computes the dynamic shift based on image sequence length
// This matches Python's calculate_shift function
func CalculateShift(imageSeqLen int32, baseSeqLen int32, maxSeqLen int32, baseShift float32, maxShift float32) float32 {
m := (maxShift - baseShift) / float32(maxSeqLen-baseSeqLen)
b := baseShift - m*float32(baseSeqLen)
mu := float32(imageSeqLen)*m + b
return mu
}
// SetTimesteps sets up the scheduler for the given number of inference steps
// Matches Python diffusers FlowMatchEulerDiscreteScheduler behavior:
// 1. Create sigmas from sigma_max to sigma_min (linspace)
// 2. Apply time_shift with mu (if dynamic shifting)
// 3. Apply stretch_shift_to_terminal to make final value = shift_terminal
func (s *FlowMatchScheduler) SetTimesteps(numSteps int, imageSeqLen int32) {
s.NumSteps = numSteps
// Calculate mu for dynamic shifting
var mu float32
if s.Config.UseDynamicShift {
mu = CalculateShift(
imageSeqLen,
s.Config.BaseImageSeqLen,
s.Config.MaxImageSeqLen,
s.Config.BaseShift,
s.Config.MaxShift,
)
}
// Step 1: Create sigmas from 1.0 to 1/num_steps
// Python (pipeline_qwenimage.py:639):
// sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
// This gives sigmas from 1.0 to 1/30 = 0.033 for 30 steps
sigmas := make([]float32, numSteps)
sigmaMax := float32(1.0)
sigmaMin := 1.0 / float32(numSteps) // 1/30 = 0.033 for 30 steps
if numSteps == 1 {
sigmas[0] = sigmaMax
} else {
for i := 0; i < numSteps; i++ {
sigmas[i] = sigmaMax + float32(i)*(sigmaMin-sigmaMax)/float32(numSteps-1)
}
}
// Step 2: Apply time shift if using dynamic shifting
if s.Config.UseDynamicShift && mu != 0 {
for i := range sigmas {
sigmas[i] = s.timeShift(mu, sigmas[i])
}
}
// Step 3: Apply stretch_shift_to_terminal
if s.Config.ShiftTerminal > 0 {
sigmas = s.stretchShiftToTerminal(sigmas)
}
// Step 4: Append terminal sigma (0) and store
// Note: Python's scheduler.timesteps are sigmas*1000, but the pipeline divides by 1000
// before passing to transformer. We skip both steps and just use sigmas directly.
s.Sigmas = make([]float32, numSteps+1)
s.Timesteps = make([]float32, numSteps+1)
for i := 0; i < numSteps; i++ {
s.Sigmas[i] = sigmas[i]
s.Timesteps[i] = sigmas[i]
}
s.Sigmas[numSteps] = 0.0
s.Timesteps[numSteps] = 0.0
}
// stretchShiftToTerminal stretches and shifts the timestep schedule
// so the final value equals shift_terminal (matches Python behavior)
func (s *FlowMatchScheduler) stretchShiftToTerminal(sigmas []float32) []float32 {
if len(sigmas) == 0 {
return sigmas
}
// one_minus_z = 1 - t
// scale_factor = one_minus_z[-1] / (1 - shift_terminal)
// stretched_t = 1 - (one_minus_z / scale_factor)
lastSigma := sigmas[len(sigmas)-1]
scaleFactor := (1.0 - lastSigma) / (1.0 - s.Config.ShiftTerminal)
// Handle edge case: if scaleFactor is 0 or near 0, skip stretch
// This happens when lastSigma ≈ 1.0 (e.g., single step with timeshift)
if scaleFactor < 1e-6 {
return sigmas
}
result := make([]float32, len(sigmas))
for i, t := range sigmas {
oneMinusZ := 1.0 - t
result[i] = 1.0 - (oneMinusZ / scaleFactor)
}
return result
}
// timeShift applies the dynamic time shift (exponential)
// exp(mu) / (exp(mu) + (1/t - 1))
func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
if t <= 0 {
return 0
}
expMu := float32(math.Exp(float64(mu)))
return expMu / (expMu + (1.0/t - 1.0))
}
// Step performs one denoising step
// modelOutput: predicted velocity from the transformer
// sample: current noisy sample
// timestepIdx: current timestep index
func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
// Get current and next sigma
sigma := s.Sigmas[timestepIdx]
sigmaNext := s.Sigmas[timestepIdx+1]
// Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
dt := sigmaNext - sigma
// Upcast to float32 to avoid precision issues (matches Python diffusers)
sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
modelOutputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
scaledOutput := mlx.MulScalar(modelOutputF32, dt)
result := mlx.Add(sampleF32, scaledOutput)
// Cast back to original dtype
return mlx.ToBFloat16(result)
}
// GetTimestep returns the timestep value at the given index
func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
if idx < len(s.Timesteps) {
return s.Timesteps[idx]
}
return 0.0
}
// InitNoise creates initial noise for sampling in unpacked format [B, C, T, H, W]
func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
return mlx.RandomNormal(shape, uint64(seed))
}
// InitNoisePacked creates initial noise directly in packed format [B, L, C*4]
// This matches how Python diffusers generates noise - directly in packed space.
// Generating in unpacked format and then packing produces different spatial
// correlation structure, which affects model output quality.
func (s *FlowMatchScheduler) InitNoisePacked(batchSize, seqLen, channels int32, seed int64) *mlx.Array {
shape := []int32{batchSize, seqLen, channels}
return mlx.RandomNormal(shape, uint64(seed))
}
// GetLatentShape returns the latent shape for a given image size
// For qwen_image: VAE downscale is 8x (spatial), latent has 16 channels
func GetLatentShape(batchSize, height, width int32) []int32 {
latentH := height / 8
latentW := width / 8
return []int32{batchSize, 16, 1, latentH, latentW} // [B, C, T, H, W]
}
// GetPatchedLatentShape returns the patchified latent shape
// After patchification: [B, L, C*patch_size^2] where L = H/2 * W/2
func GetPatchedLatentShape(batchSize, height, width, patchSize int32) []int32 {
latentH := height / 8
latentW := width / 8
pH := latentH / patchSize
pW := latentW / patchSize
inChannels := int32(64) // 16 * patch_size^2
return []int32{batchSize, pH * pW, inChannels}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment