lazy.go 1.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
package gguf

import (
	"encoding/binary"
	"iter"
	"log/slog"
)

type lazy[T any] struct {
	count  uint64
	next   func() (T, bool)
	stop   func()
	values []T

	// successFunc is called when all values have been successfully read.
	successFunc func() error
}

func newLazy[T any](f *File, fn func() (T, error)) (*lazy[T], error) {
	it := lazy[T]{}
	if err := binary.Read(f.reader, binary.LittleEndian, &it.count); err != nil {
		return nil, err
	}

	it.values = make([]T, 0)
	it.next, it.stop = iter.Pull(func(yield func(T) bool) {
		for i := range it.count {
			t, err := fn()
			if err != nil {
				slog.Error("error reading tensor", "index", i, "error", err)
				return
			}

			it.values = append(it.values, t)
			if !yield(t) {
				break
			}
		}

		if it.successFunc != nil {
			it.successFunc()
		}
	})

	return &it, nil
}

func (g *lazy[T]) Values() iter.Seq[T] {
	return func(yield func(T) bool) {
		for _, v := range g.All() {
			if !yield(v) {
				break
			}
		}
	}
}

func (g *lazy[T]) All() iter.Seq2[int, T] {
	return func(yield func(int, T) bool) {
		for i := range int(g.count) {
			if i < len(g.values) {
				if !yield(i, g.values[i]) {
					break
				}
			} else {
				t, ok := g.next()
				if !ok {
					break
				}

				if !yield(i, t) {
					break
				}
			}
		}
	}
}

func (g *lazy[T]) rest() (collected bool) {
	for {
		_, ok := g.next()
		collected = collected || ok
		if !ok {
			break
		}
	}

	return collected
}