compile.go 4.74 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
//go:build mlx

package mlx

/*
#include "mlx/c/mlx.h"
#include <stdlib.h>

// Forward declaration for Go callback
extern int goClosureCallback(mlx_vector_array* res, mlx_vector_array input, void* payload);

// Destructor for payload (Go handle)
extern void goClosureDestructor(void* payload);
*/
import "C"
import (
	"runtime/cgo"
	"sync"
	"unsafe"
)

// inClosureCallback is set to true during closure callback execution.
var inClosureCallback bool
var closureCallbackMu sync.Mutex

// InClosureCallback returns true if we're currently executing inside a closure callback.
func InClosureCallback() bool {
	closureCallbackMu.Lock()
	defer closureCallbackMu.Unlock()
	return inClosureCallback
}

// CompiledFunc is a compiled MLX function that can be called efficiently.
// All intermediate arrays during execution stay inside MLX - only inputs
// and outputs cross the Go boundary.
type CompiledFunc struct {
	closure  C.mlx_closure
	compiled C.mlx_closure
}

// ClosureFunc is the signature for functions that can be compiled.
// It takes a slice of input arrays and returns a slice of output arrays.
type ClosureFunc func(inputs []*Array) []*Array

// Compile compiles a Go function into an optimized MLX closure.
// The function is traced once during compilation, then subsequent calls
// run the optimized graph without creating Go intermediate arrays.
//
// Example:
//
//	compiled := mlx.Compile(func(inputs []*mlx.Array) []*mlx.Array {
//	    a, b := inputs[0], inputs[1]
//	    c := mlx.Add(a, b)
//	    d := mlx.Mul(c, c)
//	    return []*mlx.Array{d}
//	})
//	defer compiled.Free()
//
//	result := compiled.Call(x, y)[0]
func Compile(fn ClosureFunc) *CompiledFunc {
	return CompileShapeless(fn, false)
}

// CompileShapeless compiles with optional shapeless mode.
// If shapeless=true, the function works for any input shape after tracing.
func CompileShapeless(fn ClosureFunc, shapeless bool) *CompiledFunc {
	// Create a cgo.Handle to prevent the Go function from being GC'd
	handle := cgo.NewHandle(fn)

	// Create the closure from the Go callback
	closure := C.mlx_closure_new_func_payload(
		(*[0]byte)(C.goClosureCallback),
		unsafe.Pointer(handle),
		(*[0]byte)(C.goClosureDestructor),
	)

	// Compile the closure
	compiled := C.mlx_closure_new()
	C.mlx_compile(&compiled, closure, C.bool(shapeless))

	return &CompiledFunc{
		closure:  closure,
		compiled: compiled,
	}
}

// Call invokes the compiled function with the given inputs.
func (cf *CompiledFunc) Call(inputs ...*Array) []*Array {
	// Pack inputs into vector
	inputVec := C.mlx_vector_array_new()
	for _, arr := range inputs {
		C.mlx_vector_array_append_value(inputVec, arr.c)
	}

	// Apply compiled closure
	outputVec := C.mlx_vector_array_new()
	C.mlx_closure_apply(&outputVec, cf.compiled, inputVec)
	C.mlx_vector_array_free(inputVec)

	// Unpack outputs
	numOutputs := int(C.mlx_vector_array_size(outputVec))
	outputs := make([]*Array, numOutputs)
	for i := 0; i < numOutputs; i++ {
		var arr C.mlx_array
		C.mlx_vector_array_get(&arr, outputVec, C.size_t(i))
		outputs[i] = newArray(arr)
	}
	C.mlx_vector_array_free(outputVec)

	return outputs
}

// CallEval invokes the compiled function and evaluates the results.
func (cf *CompiledFunc) CallEval(inputs ...*Array) []*Array {
	outputs := cf.Call(inputs...)
	Eval(outputs...)
	return outputs
}

// Free releases the compiled function resources.
func (cf *CompiledFunc) Free() {
	C.mlx_closure_free(cf.compiled)
	C.mlx_closure_free(cf.closure)
}

// borrowArray wraps a C array WITHOUT setting up GC cleanup.
// Use this for arrays we don't own (e.g., borrowed references in callbacks).
func borrowArray(array C.mlx_array) *Array {
	return &Array{c: array}
}

//export goClosureCallback
func goClosureCallback(res *C.mlx_vector_array, input C.mlx_vector_array, payload unsafe.Pointer) C.int {
	// Set flag to disable AddCleanup during callback
	closureCallbackMu.Lock()
	inClosureCallback = true
	closureCallbackMu.Unlock()
	defer func() {
		closureCallbackMu.Lock()
		inClosureCallback = false
		closureCallbackMu.Unlock()
	}()

	// Recover the Go function from the handle
	handle := cgo.Handle(payload)
	fn := handle.Value().(ClosureFunc)

	// Convert input vector to Go slice - use borrowArray since MLX owns these
	numInputs := int(C.mlx_vector_array_size(input))
	inputs := make([]*Array, numInputs)
	for i := 0; i < numInputs; i++ {
		var arr C.mlx_array
		C.mlx_vector_array_get(&arr, input, C.size_t(i))
		inputs[i] = borrowArray(arr) // Don't set up cleanup - MLX owns these
	}

	// Call the Go function
	outputs := fn(inputs)

	// Build output vector
	*res = C.mlx_vector_array_new()
	for _, arr := range outputs {
		C.mlx_vector_array_append_value(*res, arr.c)
	}

	return 0
}

//export goClosureDestructor
func goClosureDestructor(payload unsafe.Pointer) {
	handle := cgo.Handle(payload)
	handle.Delete()
}