loader.go 4.51 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
//go:build mlx

package safetensors

import (
	"fmt"
	"reflect"
	"strings"

	"github.com/ollama/ollama/x/imagegen/mlx"
)

// LoadModule loads weights into a struct using reflection and struct tags.
//
// Struct tags use the format: `weight:"path[,optional]"`
//   - path: the weight name suffix (appended to prefix)
//   - optional: if present, missing weights don't cause errors
//   - "-": skip this field entirely
//   - no tag on struct pointer: recurse with current prefix
//   - no tag on *mlx.Array: skip (computed fields don't need loading)
//
// For slices of struct pointers, the loader iterates with .0, .1, .2... suffixes.
// The slice must be pre-allocated to the correct length.
//
// Example:
//
//	type Attention struct {
//	    QProj *nn.Linear  `weight:"self_attn.q_proj"`
//	    KProj *nn.Linear  `weight:"self_attn.k_proj"`
//	    Cache *mlx.Array  // no tag = skipped (computed field)
//	}
//
//	err := LoadModule(&attn, weights, "model.layers.0")
func LoadModule(dst any, weights *ModelWeights, prefix string) error {
	v := reflect.ValueOf(dst)
	if v.Kind() != reflect.Ptr || v.IsNil() {
		return fmt.Errorf("LoadModule: dst must be a non-nil pointer")
	}
	v = v.Elem()
	if v.Kind() != reflect.Struct {
		return fmt.Errorf("LoadModule: dst must be a pointer to struct, got %v", v.Kind())
	}

	var errs []string
	loadStruct(v, weights, prefix, &errs, false)

	if len(errs) > 0 {
		return fmt.Errorf("LoadModule: missing weights:\n  %s", strings.Join(errs, "\n  "))
	}
	return nil
}

// loadStruct recursively loads weights into a struct value.
func loadStruct(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string, parentOptional bool) {
	t := v.Type()

	for i := 0; i < t.NumField(); i++ {
		field := t.Field(i)
		fieldVal := v.Field(i)

		// Skip unexported fields
		if !fieldVal.CanSet() {
			continue
		}

		// Parse tag
		tag, hasTag := field.Tag.Lookup("weight")
		if tag == "-" {
			continue
		}

		// Parse tag options
		optional := parentOptional
		weightPath := tag
		if idx := strings.Index(tag, ","); idx != -1 {
			weightPath = tag[:idx]
			if strings.Contains(tag[idx+1:], "optional") {
				optional = true
			}
		}

		// Build full path
		fullPath := joinPath(prefix, weightPath)

		// For struct pointers without a tag, recurse with current prefix
		if !hasTag && fieldVal.Kind() == reflect.Ptr {
			elemType := fieldVal.Type().Elem()
			if elemType.Kind() == reflect.Struct && elemType != reflect.TypeOf(mlx.Array{}) {
				if fieldVal.IsNil() {
					fieldVal.Set(reflect.New(elemType))
				}
				loadStruct(fieldVal.Elem(), weights, prefix, errs, optional)
				continue
			}
		}

		// Handle by kind
		switch fieldVal.Kind() {
		case reflect.Ptr:
			elemType := fieldVal.Type().Elem()

			// *mlx.Array - load directly (but skip if no tag - computed fields)
			if fieldVal.Type() == reflect.TypeOf((*mlx.Array)(nil)) {
				if !hasTag {
					continue // no tag on *mlx.Array = computed field, skip
				}
				arr, err := weights.GetTensor(fullPath)
				if err != nil {
					if !optional {
						*errs = append(*errs, fullPath)
					}
					continue
				}
				fieldVal.Set(reflect.ValueOf(arr))
				continue
			}

			// Pointer to struct - allocate and recurse
			if elemType.Kind() == reflect.Struct {
				if optional && !hasWeightsWithPrefix(weights, fullPath) {
					continue
				}
				if fieldVal.IsNil() {
					fieldVal.Set(reflect.New(elemType))
				}
				loadStruct(fieldVal.Elem(), weights, fullPath, errs, optional)
			}

		case reflect.Slice:
			elemType := fieldVal.Type().Elem()
			if elemType.Kind() == reflect.Ptr && elemType.Elem().Kind() == reflect.Struct {
				loadSlice(fieldVal, weights, fullPath, errs)
			}
		}
	}
}

// hasWeightsWithPrefix checks if any weights exist with the given prefix.
func hasWeightsWithPrefix(weights *ModelWeights, prefix string) bool {
	for _, name := range weights.ListTensors() {
		if strings.HasPrefix(name, prefix+".") || name == prefix {
			return true
		}
	}
	return false
}

// loadSlice loads weights into each element of a slice of struct pointers.
func loadSlice(v reflect.Value, weights *ModelWeights, prefix string, errs *[]string) {
	elemStructType := v.Type().Elem().Elem()

	for i := 0; i < v.Len(); i++ {
		elem := v.Index(i)
		if elem.IsNil() {
			elem.Set(reflect.New(elemStructType))
		}
		loadStruct(elem.Elem(), weights, fmt.Sprintf("%s.%d", prefix, i), errs, false)
	}
}

// joinPath joins path segments with dots, handling empty segments.
func joinPath(prefix, suffix string) string {
	if prefix == "" {
		return suffix
	}
	if suffix == "" {
		return prefix
	}
	return prefix + "." + suffix
}