transforms.go 2.8 KB
Newer Older
1
2
3
package sample

import (
4
	"container/heap"
5
6
7
8
	"math"
	"slices"
)

9
10
11
12
// tokenHeap implements heap.Interface and holds tokens as a min-heap to track k largest elements
type tokenHeap []token

func (h tokenHeap) Len() int           { return len(h) }
ParthSareen's avatar
ParthSareen committed
13
func (h tokenHeap) Less(i, j int) bool { return h[i].value < h[j].value }
14
15
16
17
18
19
20
21
22
23
24
25
26
27
func (h tokenHeap) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }

func (h *tokenHeap) Push(x any) {
	*h = append(*h, x.(token))
}

func (h *tokenHeap) Pop() any {
	old := *h
	n := len(old)
	x := old[n-1]
	*h = old[0 : n-1]
	return x
}

28
// temperature applies scaling to the logits
29
func temperature(ts []token, temp float32) []token {
30
31
32
33
34
35
36
37
38
39
	// Ensure temperature clipping near 0 to avoid numerical instability
	temp = max(temp, 1e-7)
	for i := range ts {
		ts[i].value = ts[i].value / temp
	}
	return ts
}

// softmax applies normalization to the logits
func softmax(ts []token) []token {
40
41
42
43
44
45
46
47
	// Find max logit for numerical stability
	maxLogit := float32(math.Inf(-1))
	for _, t := range ts {
		if t.value > maxLogit {
			maxLogit = t.value
		}
	}

48
	// Compute exp(x - max)
49
50
	var sum float32
	for i, v := range ts {
51
		ts[i].value = float32(math.Exp(float64(v.value - maxLogit)))
52
		sum += ts[i].value
53
54
	}

55
	// exp(x - max) / sum(exp(x - max))
56
57
	for i := range ts {
		ts[i].value /= sum
58
59
	}

60
	return ts
61
62
}

63
// topK limits the number of tokens considered to the k highest logits
64
func topK(ts []token, k int) []token {
ParthSareen's avatar
ParthSareen committed
65
66
67
68
69
70
71
72
73
74
75
	if k >= len(ts) || k <= 0 {
		slices.SortFunc(ts, func(a, b token) int {
			switch {
			case a.value < b.value:
				return 1
			case a.value > b.value:
				return -1
			default:
				return 0
			}
		})
76
77
		return ts
	}
78

79
80
81
82
83
84
	// Initialize min-heap with first k elements
	h := make(tokenHeap, k)
	copy(h, ts[:k])
	heap.Init(&h)

	// Process remaining elements
85
	for i := k; i < len(ts); i++ {
86
87
88
		if ts[i].value > h[0].value {
			heap.Pop(&h)
			heap.Push(&h, ts[i])
89
		}
90
91
	}

92
	// Convert heap to sorted slice in descending order
ParthSareen's avatar
ParthSareen committed
93
	result := make([]token, len(h))
94
95
96
	for i := k - 1; i >= 0; i-- {
		result[i] = heap.Pop(&h).(token)
	}
97

98
	return result
99
100
101
}

// topP limits tokens to those with cumulative probability p
102
func topP(ts []token, p float32) []token {
103
104
	if p == 1.0 {
		return ts
105
106
	}

107
108
109
110
111
112
113
	// Find cutoff index where cumulative sum exceeds p
	var sum float32
	for i, t := range ts {
		sum += t.value
		if sum > float32(p) {
			ts = ts[:i+1]
			return ts
114
115
116
		}
	}

117
	return ts
118
119
}

120
// minP limits tokens to those with cumulative probability p
121
func minP(ts []token, p float32) []token {
122
123
124
	if p == 1.0 {
		return ts
	}
125

126
127
128
129
130
	maxProb := float32(math.Inf(-1))
	for _, token := range ts {
		if token.value > maxProb {
			maxProb = token.value
		}
131
132
	}

133
	threshold := maxProb * float32(p)
134

135
136
137
138
139
	// Filter tokens in-place
	validTokens := ts[:0]
	for i, token := range ts {
		if token.value >= threshold {
			validTokens = append(validTokens, ts[i])
140
141
142
		}
	}

143
144
145
	ts = validTokens
	return ts
}