sample.go 1.2 KB
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
package sample

import (
	"slices"

	"gonum.org/v1/gonum/floats"
	"gonum.org/v1/gonum/stat/sampleuv"
)

type Sampler interface {
	Sample([]float64) ([]float64, error)
}

type Temperature float64

func (s Temperature) Sample(t []float64) ([]float64, error) {
	floats.Div(t, slices.Repeat([]float64{float64(s)}, len(t)))
	return t, nil
}

type softmax struct{}

func Softmax() Sampler {
	return softmax{}
}

func (softmax) Sample(t []float64) ([]float64, error) {
	return t, nil
}

type TopK int

func (s TopK) Sample(t []float64) ([]float64, error) {
	return t, nil
}

type TopP float32

func (s TopP) Sample(t []float64) ([]float64, error) {
	return t, nil
}

type MinP float32

func (s MinP) Sample(t []float64) ([]float64, error) {
	return t, nil
}

type weighed struct{}

func Weighed() Sampler {
	return weighed{}
}

func (s weighed) Sample(t []float64) ([]float64, error) {
	w := sampleuv.NewWeighted(t, nil)
	if v, ok := w.Take(); ok {
		return []float64{float64(v)}, nil
	}

	return t, nil
}

func Sample(floats []float64, samplers ...Sampler) ([]float64, error) {
	var err error
	for _, sampler := range samplers {
		floats, err = sampler.Sample(floats)
		if err != nil {
			return nil, err
		}
	}

	return floats, nil
}