backend.go 12.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
package ml

import (
	"fmt"
	"log/slog"
	"os"

	"github.com/ollama/ollama/fs"
)

type Backend interface {
	// Close frees all memory associated with this backend
	// Close()

	// Load(ctx context.Context, progress func(float32)) error

	// BackendMemory returns the memory allocations that were made for this model
	// BackendMemory() BackendMemory

	Config() fs.Config
	Get(name string) Tensor
	NewContext() Context
	// NewContextSize(size int) Context

	// Enumerate the devices available for inference via this backend
	// BackendDevices() []DeviceInfo
}

// BackendCacheConfig should be implemented by backends that need special output
// from the cache to meet specific requirements. It is frequently implemented in
// conjunction with ScaledDotProductAttention.
type BackendCacheConfig interface {
	CacheConfig() CacheConfig
}

// CacheConfig controls optimizations (mostly backend-specific) that may transform
// the output the cache to work better with specific kernels.
type CacheConfig struct {
	// CachePadding specifies the multiple for the number of tokens of cache history
	// that will be returned from cache Get for k, v and mask. The capacity of the
	// cache itself will also be increased to a multiple of this size if needed.
	CachePadding int

	// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
	// and return the permuted version via Get. This uses the cache copy operation
	// to avoid a Contiguous call on the permuted tensor.
	PermutedV bool

	// MaskDType specifies the data type for generating the mask. If unset it will
	// default to DTypeF32.
	MaskDType DType

	// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
	// Any position that does not correspond to an actual token will be filled with -Inf.
	MaskBatchPadding int
}

// BackendParams controls how the backend loads and executes models
type BackendParams struct {
	// AllocMemory causes the backend to allocate memory for the model. If
	// false, this is only being used for discovering the required amount of
	// memory and cannot load the model for running.
	AllocMemory bool

	// NumThreads sets the number of threads to use if running on the CPU
	NumThreads int

	// GPULayers is the set of layers to offload to GPUs
	GPULayers GPULayersList

	// FlashAttention indicates that we should use a fused flash attention kernel
	FlashAttention bool
}

var backends = make(map[string]func(string, BackendParams) (Backend, error))

func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
	if _, ok := backends[name]; ok {
		panic("backend: backend already registered")
	}

	backends[name] = f
}

func NewBackend(modelPath string, params BackendParams) (Backend, error) {
	be := os.Getenv("OLLAMA_BACKEND")
	if be == "" {
		be = "mlx"
		slog.Info("Defaulting to " + be + ". Set OLLAMA_BACKEND to override")
	}
	slog.Info("Loading new engine", "backend", be)
	if backend, ok := backends[be]; ok {
		return backend(modelPath, params)
	}

	return nil, fmt.Errorf("unsupported backend")
}

type Context interface {
	Empty(dtype DType, shape ...int) Tensor
	Zeros(dtype DType, shape ...int) Tensor
	// FromBytes(dtype DType, s []byte, shape ...int) Tensor
	FromFloats(s []float32, shape ...int) Tensor
	FromInts(s []int32, shape ...int) Tensor
	RandomNormal(shape []int, dtype DType, loc, scale float32, key Tensor) Tensor

	// Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
	Arange(start, stop, step float32, dtype DType) Tensor

	Forward(...Tensor) Context

	// SetBatchSize provides a hint on the batch size to optimize processing
	// Uses heuristics if not set
	// SetBatchSize(int)

	Compute(...Tensor)
	// ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun

	// Reserve is analogous to Compute but rather than executing a
	// graph, simply preallocates memory. Typically called with a
	// worst case graph to ensure all resources are available for
	// for future inference.
	// Reserve()

	// MaxGraphNodes() int
	Close()

	// Input returns a context appropriate for creating tensors that are
	// inputs to the model (which includes things like output locations)
	Input() Context

	// Layer returns a context appropriate for creating intermediate tensors
	Layer(int) Context

	// Load a tensor from "filename" safetensors file, and compare with the input tensor
	// Returns error if the shape is inconsistent, or similarity measures are below 99%
	CompareWith(filename string, tensors map[string]Tensor, abortOnError bool) error
}

type RoPEOptions struct {
	Base  *float32
	Freqs Tensor
}

func WithRoPEBase(base float32) func(*RoPEOptions) {
	return func(opts *RoPEOptions) {
		opts.Base = &base
	}
}

func WithRoPEFreqs(freqs Tensor) func(*RoPEOptions) {
	return func(opts *RoPEOptions) {
		opts.Freqs = freqs
	}
}

type Tensor interface {
	ToString() string
	RoPE(ctx Context, dims int, traditional bool, scale float32, offset int, options ...func(*RoPEOptions)) Tensor
	ScaledDotProductAttention(ctx Context, keys, values Tensor, scale float64, maskMode string, mask Tensor, sinks Tensor) Tensor
	TakeAxes(ctx Context, indicies Tensor, axes int) Tensor
	// TakeAxes(ctx Context, axes int, indicies ...int) Tensor

	Dim(n int) int
	Stride(n int) int

	Shape() []int
	DType() DType
	// Cast(ctx Context, dtype DType) Tensor

	// Bytes() []byte
	Floats() []float32
	Ints() []int32

	// FromBytes([]byte)
	// FromFloats([]float32)
	// FromInts([]int32)

	Add(ctx Context, t2 Tensor) Tensor
	Sub(ctx Context, t2 Tensor) Tensor
	// Mul(ctx Context, t2 Tensor) Tensor
	// Div(ctx Context, t2 Tensor) Tensor

	Max(ctx Context, axes []int, keepDims bool) Tensor
	Min(ctx Context, axes []int, keepDims bool) Tensor

	Matmul(ctx Context, a2 Tensor) Tensor
	// Mulmat(ctx Context, t2 Tensor) Tensor
	// MulmatFullPrec(ctx Context, t2 Tensor) Tensor
	// MulmatID(ctx Context, t2, ids Tensor) Tensor
	// AddID(ctx Context, t2, ids Tensor) Tensor

	Softmax(ctx Context) Tensor
	L2Norm(ctx Context, eps float32) Tensor
	LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
	RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
	Scale(ctx Context, s float64) Tensor
	// SumRows(ctx Context) Tensor

	AvgPool2D(ctx Context, k, s int, p float32) Tensor
	Conv2D(ctx Context, weight Tensor, stride0, stride1, padding0, padding1, dilation0, dilation1, groups int) Tensor
	Conv3D(ctx Context, weight Tensor, stride0, stride1, stride2, padding0, padding1, padding2, dilation0, dilation1, dilation2, groups int) Tensor

	// IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor

	// Sin(ctx Context) Tensor
	// Cos(ctx Context) Tensor
	// Tanh(ctx Context) Tensor
	GELU(ctx Context, up ...Tensor) Tensor
	// QuickGELU(ctx Context, up ...Tensor) Tensor
	// SILU(ctx Context, up ...Tensor) Tensor
	// RELU(ctx Context, up ...Tensor) Tensor
	// Sigmoid(ctx Context) Tensor

	// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
	// SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor

	Reshape(ctx Context, shape ...int) Tensor
	AsStrided(ctx Context, shape, strides []int, offset int) Tensor
	Transpose(ctx Context, shape ...int) Tensor
	Contiguous(ctx Context, allowColMajor bool) Tensor

	// Pad(ctx Context, shape ...int) Tensor

	// Stack(ctx Context, dim int, s ...Tensor) Tensor

	// Repeat repeats the tensor n times along dimension dim
	// Repeat(ctx Context, dim, n int) Tensor
	// Concat(ctx Context, t2 Tensor, dim int) Tensor
	// Rows(ctx Context, t2 Tensor) Tensor

	// TODO these probably aren't actually needed - false starts on trying to wire up cache
	// SliceUpdate(ctx Context, update Tensor, start, stop, strides []int) Tensor
	// SliceUpdateDynamic(ctx Context, update, start Tensor, axes []int) Tensor
	// PutAlongAxis(ctx Context, indicies, values Tensor, axis int) Tensor

	Scatter(ctx Context, indicies []Tensor, updates Tensor, axes []int) Tensor

	Copy(ctx Context, t2 Tensor) Tensor
	// Duplicate(ctx Context) Tensor

	// Slice(ctx Context, dim, low, high, step int) Tensor
	// Chunk(ctx Context, dim int, size int) []Tensor
	// ChunkSections(ctx Context, dim int, sections ...int) []Tensor

	// TopK(ctx Context, k int) Tensor
	// Argsort(ctx Context) Tensor
	// Mean(ctx Context) Tensor
	// Variance(ctx Context) Tensor
	// Stddev(ctx Context) Tensor
	// Sqr(ctx Context) Tensor
	// Sqrt(ctx Context) Tensor

	// Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
}

// ScaledDotProductAttention implements a fused attention
// operation equivalent to following code on a tensor named
// query:
//
// query = query.Permute(ctx, 0, 2, 1, 3)
// key = key.Permute(ctx, 0, 2, 1, 3)
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
//
// kq := key.MulmatFullPrec(ctx, query)
//
// kq = kq.Scale(ctx, scale)
//
//	if mask != nil {
//		kq = kq.Add(ctx, mask)
//	}
//
// kq = kq.Softmax(ctx)
//
// kqv := value.Mulmat(ctx, kq)
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
// type ScaledDotProductAttention interface {
// 	ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
// }

// type number interface {
// 	~int | ~int8 | ~int16 | ~int32 | ~int64 |
// 		~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
// 		~float32 | ~float64 |
// 		~complex64 | ~complex128
// }

// func mul[T number](s ...T) T {
// 	p := T(1)
// 	for _, v := range s {
// 		p *= v
// 	}

// 	return p
// }

// type DumpOptions func(*dumpOptions)

// // DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64.
// func DumpWithPrecision(n int) DumpOptions {
// 	return func(opts *dumpOptions) {
// 		opts.Precision = n
// 	}
// }

// // DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements
// // is less than or equal to this value, the entire tensor will be printed. Otherwise, only the
// // beginning and end of each dimension will be printed.
// func DumpWithThreshold(n int) DumpOptions {
// 	return func(opts *dumpOptions) {
// 		opts.Threshold = n
// 	}
// }

// // DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension.
// func DumpWithEdgeItems(n int) DumpOptions {
// 	return func(opts *dumpOptions) {
// 		opts.EdgeItems = n
// 	}
// }

// type dumpOptions struct {
// 	Precision, Threshold, EdgeItems int
// }

// func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string {
// 	opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3}
// 	for _, optsFunc := range optsFuncs {
// 		optsFunc(&opts)
// 	}

// 	if mul(t.Shape()...) <= opts.Threshold {
// 		opts.EdgeItems = math.MaxInt
// 	}

// 	switch t.DType() {
// 	case DTypeFloat32:
// 		return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string {
// 			return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
// 		})
// 	case DTypeFloat16: // TODO other types...
// 		f32 := ctx.Input().Empty(DTypeFloat32, t.Shape()...)
// 		f32 = t.Copy(ctx, f32)
// 		return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string {
// 			return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
// 		})
// 	case DTypeInt32:
// 		return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string {
// 			return strconv.FormatInt(int64(i), 10)
// 		})
// 	default:
// 		return "<unsupported>"
// 	}
// }

// func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
// 	if t.Bytes() == nil {
// 		ctx.Compute(t)
// 	}

// 	s := make(S, mul(t.Shape()...))
// 	if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
// 		panic(err)
// 	}

// 	shape := t.Shape()
// 	slices.Reverse(shape)

// 	var sb strings.Builder
// 	var f func([]int, int)
// 	f = func(dims []int, stride int) {
// 		prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
// 		sb.WriteString("[")
// 		defer func() { sb.WriteString("]") }()
// 		for i := 0; i < dims[0]; i++ {
// 			if i >= items && i < dims[0]-items {
// 				sb.WriteString("..., ")
// 				// skip to next printable element
// 				skip := dims[0] - 2*items
// 				if len(dims) > 1 {
// 					stride += mul(append(dims[1:], skip)...)
// 					fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
// 				}
// 				i += skip - 1
// 			} else if len(dims) > 1 {
// 				f(dims[1:], stride)
// 				stride += mul(dims[1:]...)
// 				if i < dims[0]-1 {
// 					fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
// 				}
// 			} else {
// 				text := fn(s[stride+i])
// 				if len(text) > 0 && text[0] != '-' {
// 					sb.WriteString(" ")
// 				}

// 				sb.WriteString(text)
// 				if i < dims[0]-1 {
// 					sb.WriteString(", ")
// 				}
// 			}
// 		}
// 	}
// 	f(shape, 0)

// 	return sb.String()
// }

type DType int

const (
	DTypeBool DType = iota
	DTypeUint8
	DTypeUint16
	DTypeUint32
	DTypeUint64
	DTypeInt8
	DTypeInt16
	DTypeInt32
	DTypeInt64
	DTypeFloat16
	DTypeFloat32
	DTypeFloat64
	DTypeBfloat16
	DTypeComplex64
)

type SamplingMode int

const (
	SamplingModeNearest SamplingMode = iota
	SamplingModeBilinear
)