transforms.go 2.69 KB
Newer Older
1
2
3
package sample

import (
4
	"container/heap"
5
6
7
8
	"math"
	"slices"
)

9
10
11
12
// tokenHeap implements heap.Interface and holds tokens as a min-heap to track k largest elements
type tokenHeap []token

func (h tokenHeap) Len() int           { return len(h) }
ParthSareen's avatar
ParthSareen committed
13
func (h tokenHeap) Less(i, j int) bool { return h[i].value < h[j].value }
14
15
16
17
18
19
20
21
22
23
24
25
26
27
func (h tokenHeap) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }

func (h *tokenHeap) Push(x any) {
	*h = append(*h, x.(token))
}

func (h *tokenHeap) Pop() any {
	old := *h
	n := len(old)
	x := old[n-1]
	*h = old[0 : n-1]
	return x
}

28
29
30
31
32
33
34
35
36
37
38
39
// temperature applies scaling and softmax to the logits
func temperature(ts []token, temp float32) []token {
	// Find max logit for numerical stability
	maxLogit := float32(math.Inf(-1))
	for _, t := range ts {
		if t.value > maxLogit {
			maxLogit = t.value
		}
	}

	// Apply temperature and compute exp(x - max)
	temp = max(temp, 1e-7)
40
41
	var sum float32
	for i, v := range ts {
42
		ts[i].value = float32(math.Exp(float64((v.value - maxLogit) / temp)))
43
		sum += ts[i].value
44
45
	}

46
	// Normalize
47
48
	for i := range ts {
		ts[i].value /= sum
49
50
	}

51
	return ts
52
53
}

54
// topK limits the number of tokens considered to the k highest logits
55
func topK(ts []token, k int) []token {
56
	if k >= len(ts) {
57
		sortLogits(ts)
58
59
		return ts
	}
60

61
62
63
64
65
66
	// Initialize min-heap with first k elements
	h := make(tokenHeap, k)
	copy(h, ts[:k])
	heap.Init(&h)

	// Process remaining elements
67
	for i := k; i < len(ts); i++ {
68
69
70
		if ts[i].value > h[0].value {
			heap.Pop(&h)
			heap.Push(&h, ts[i])
71
		}
72
73
	}

74
	// Convert heap to sorted slice in descending order
ParthSareen's avatar
ParthSareen committed
75
	result := make([]token, len(h))
76
77
78
	for i := k - 1; i >= 0; i-- {
		result[i] = heap.Pop(&h).(token)
	}
79

80
	return result
81
82
83
}

// topP limits tokens to those with cumulative probability p
84
func topP(ts []token, p float32) []token {
85
86
	if p == 1.0 {
		return ts
87
88
	}

89
90
91
92
93
94
95
	// Find cutoff index where cumulative sum exceeds p
	var sum float32
	for i, t := range ts {
		sum += t.value
		if sum > float32(p) {
			ts = ts[:i+1]
			return ts
96
97
98
		}
	}

99
	return ts
100
101
}

102
// minP limits tokens to those with cumulative probability p
103
func minP(ts []token, p float32) []token {
104
105
106
	if p == 1.0 {
		return ts
	}
107

108
109
110
111
112
	maxProb := float32(math.Inf(-1))
	for _, token := range ts {
		if token.value > maxProb {
			maxProb = token.value
		}
113
114
	}

115
	threshold := maxProb * float32(p)
116

117
118
119
120
121
	// Filter tokens in-place
	validTokens := ts[:0]
	for i, token := range ts {
		if token.value >= threshold {
			validTokens = append(validTokens, ts[i])
122
123
124
		}
	}

125
126
127
	ts = validTokens
	return ts
}
128

ParthSareen's avatar
ParthSareen committed
129
130
131
132
133
// sortLogits sorts the tokens in descending order of logits
func sortLogits(ts []token) {
	slices.SortFunc(ts, func(a, b token) int {
		switch {
		case a.value < b.value:
134
			return 1
ParthSareen's avatar
ParthSareen committed
135
136
137
138
		case a.value > b.value:
			return -1
		default:
			return 0
139
140
		}
	})
141
}