ggml_test.go 2.43 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
package ggml

import (
	"errors"
	"os"
	"testing"

	"github.com/google/go-cmp/cmp"
	"github.com/ollama/ollama/fs/ggml"
	"github.com/ollama/ollama/ml"
)

func setup(tb testing.TB) ml.Context {
	tb.Helper()

	f, err := os.CreateTemp(tb.TempDir(), "*.bin")
	if err != nil {
		tb.Fatal(err)
	}
	defer f.Close()

	if err := ggml.WriteGGUF(f, ggml.KV{"general.architecture": "test"}, nil); err != nil {
		tb.Fatal(err)
	}

	b, err := ml.NewBackend(f.Name(), ml.BackendParams{})
	if err != nil {
		tb.Fatal(err)
	}

	ctx := b.NewContext().Input()

	tb.Cleanup(func() {
		ctx.Close()
		b.Close()
	})

	return ctx
}

func TestInferShape(t *testing.T) {
	cases := []struct {
		name  string
		input []int
		want  []int
		err   error
	}{
		{
			name:  "no inferred shape",
			input: []int{2, 3, 4},
			want:  []int{2, 3, 4},
		},
		{
			name:  "infer begin",
			input: []int{-1, 3, 4},
			want:  []int{2, 3, 4},
		},
		{
			name:  "infer mid",
			input: []int{2, -1, 4},
			want:  []int{2, 3, 4},
		},
		{
			name:  "infer end",
			input: []int{2, 3, -1},
			want:  []int{2, 3, 4},
		},
		{
			name:  "too many inferred dims",
			input: []int{-1, 3, -1},
			err:   errors.New("only one dimension can be inferred"),
		},
		{
			name:  "infer gather",
			input: []int{2, -1},
			want:  []int{2, 12},
		},
		{
			name:  "infer gather all",
			input: []int{-1},
			want:  []int{24},
		},
		{
			name:  "infer split",
			input: []int{2, -1, 3, 2},
			want:  []int{2, 2, 3, 2},
		},
		{
			name:  "indivisible infer",
			input: []int{2, -1, 2, 4},
			err:   errors.New("cannot infer dimension"),
		},
		{
			name:  "infer zero dim",
			input: []int{2, 0, 4},
			err:   errors.New("dimension cannot be zero"),
		},
	}

	ctx := setup(t)
	tensor, ok := ctx.Empty(ml.DTypeF32, 2, 3, 4).(*Tensor)
	if !ok {
		t.Fatal("expected *Tensor")
	}

	for _, tt := range cases {
		t.Run(tt.name, func(t *testing.T) {
			defer func() {
				if r := recover(); r == nil && tt.err == nil {
					// all good
				} else if r != nil && tt.err == nil {
					t.Errorf("unexpected panic: %v", r)
				} else if r == nil && tt.err != nil {
					t.Errorf("expected panic but did not get one: %v", tt.err)
				} else if errStr, ok := r.(string); ok && errStr != tt.err.Error() {
					t.Errorf("expected panic %q but got %q", tt.err.Error(), errStr)
				}
			}()

			inferShape(tensor, tt.input)
			if diff := cmp.Diff(tt.want, tt.input); diff != "" {
				t.Errorf("%s: shape mismatch (-want +got):\n%s", tt.name, diff)
			}
		})
	}
}