pooling.go 713 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
package pooling

import (
	"github.com/ollama/ollama/ml"
)

type Type uint32

const (
	TypeNone Type = iota
	TypeMean
	TypeCLS
	TypeLast
	TypeRank

	TypeUnknown     = 0xFFFFFFFE
	TypeUnspecified = 0xFFFFFFFF
)

func Pooling(ctx ml.Context, hiddenStates ml.Tensor, poolingType Type) ml.Tensor {
	switch poolingType {
	case TypeNone:
		return hiddenStates
	case TypeMean:
		hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
		return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
	case TypeCLS:
		return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
	case TypeLast:
		panic("not implemented")
	case TypeRank:
		panic("not implemented")
	default:
		panic("not implemented")
	}
}