pooling_test.go 1.71 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
package pooling_test

import (
	"bytes"
	"os"
	"slices"
	"testing"

	"github.com/google/go-cmp/cmp"
	"github.com/ollama/ollama/discover"
	fsggml "github.com/ollama/ollama/fs/ggml"
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/backend/ggml"
	"github.com/ollama/ollama/ml/nn/pooling"
)

func setup(tb testing.TB, n int) ml.Backend {
	tb.Helper()

	f, err := os.CreateTemp(tb.TempDir(), "*.bin")
	if err != nil {
		tb.Fatal(err)
	}
	defer f.Close()

	if err := fsggml.WriteGGUF(f, fsggml.KV{
		"general.architecture": "test",
		"test.block_count":     uint32(1),
	}, []*fsggml.Tensor{
		{Name: "blk.0.weight", Shape: []uint64{1}, WriterTo: bytes.NewBuffer(make([]byte, 4))},
	}); err != nil {
		tb.Fatal(err)
	}

	var gpuLayers ml.GPULayersList
	if gpus := discover.GetGPUInfo(); len(gpus) > 0 {
		gpuLayers = append(gpuLayers, ml.GPULayers{
			ID: gpus[0].ID,
			Layers: slices.Collect(func(yield func(int) bool) {
				for i := range n {
					if !yield(i) {
						return
					}
				}
			}),
		})
	}
	b, err := ggml.New(f.Name(), ml.BackendParams{AllocMemory: true, GPULayers: gpuLayers})
	if err != nil {
		tb.Fatal(err)
	}

	return b
}

func TestForward(t *testing.T) {
	cases := map[pooling.Type][]float32{
		pooling.TypeMean: {4, 5, 6, 7, 8, 9, 10, 11},
		pooling.TypeCLS:  {0, 1, 2, 3, 4, 5, 6, 7},
		pooling.TypeLast: {8, 9, 10, 11, 12, 13, 14, 15},
	}
	for typ, want := range cases {
		t.Run(typ.String(), func(t *testing.T) {
			b := setup(t, 99)
			defer b.Close()

			ctx := b.NewContext()
			defer ctx.Close()

			tt := ctx.Input().Arange(0, 16, 1, ml.DTypeF32).Reshape(ctx, 8, 2)
			tt = typ.Forward(ctx, tt)

			ctx.Forward(tt).Compute(tt)
			if diff := cmp.Diff(want, tt.Floats()); diff != "" {
				t.Error(diff)
			}
		})
	}
}