transforms.go 2.64 KB
Newer Older
1
2
3
package sample

import (
4
	"container/heap"
5
6
7
8
	"math"
	"slices"
)

9
10
11
12
// tokenHeap implements heap.Interface and holds tokens as a min-heap to track k largest elements
type tokenHeap []token

func (h tokenHeap) Len() int           { return len(h) }
ParthSareen's avatar
ParthSareen committed
13
func (h tokenHeap) Less(i, j int) bool { return h[i].value < h[j].value }
14
15
16
17
18
19
20
21
22
23
24
25
26
27
func (h tokenHeap) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }

func (h *tokenHeap) Push(x any) {
	*h = append(*h, x.(token))
}

func (h *tokenHeap) Pop() any {
	old := *h
	n := len(old)
	x := old[n-1]
	*h = old[0 : n-1]
	return x
}

28
// temperature applies scaling to the logits
29
func temperature(ts []token, temp float32) {
30
31
32
33
34
35
36
37
	// Ensure temperature clipping near 0 to avoid numerical instability
	temp = max(temp, 1e-7)
	for i := range ts {
		ts[i].value = ts[i].value / temp
	}
}

// softmax applies normalization to the logits
38
func softmax(ts []token) {
39
40
41
42
43
44
45
46
	// Find max logit for numerical stability
	maxLogit := float32(math.Inf(-1))
	for _, t := range ts {
		if t.value > maxLogit {
			maxLogit = t.value
		}
	}

47
	// Compute exp(x - max)
48
49
	var sum float32
	for i, v := range ts {
50
		ts[i].value = float32(math.Exp(float64(v.value - maxLogit)))
51
		sum += ts[i].value
52
53
	}

54
	// exp(x - max) / sum(exp(x - max))
55
56
	for i := range ts {
		ts[i].value /= sum
57
58
59
	}
}

60
// topK limits the number of tokens considered to the k highest logits
61
func topK(ts []token, k int) []token {
ParthSareen's avatar
ParthSareen committed
62
63
64
65
66
67
68
69
70
71
72
	if k >= len(ts) || k <= 0 {
		slices.SortFunc(ts, func(a, b token) int {
			switch {
			case a.value < b.value:
				return 1
			case a.value > b.value:
				return -1
			default:
				return 0
			}
		})
73
74
		return ts
	}
75

76
77
78
79
80
81
	// Initialize min-heap with first k elements
	h := make(tokenHeap, k)
	copy(h, ts[:k])
	heap.Init(&h)

	// Process remaining elements
82
	for i := k; i < len(ts); i++ {
83
84
85
		if ts[i].value > h[0].value {
			heap.Pop(&h)
			heap.Push(&h, ts[i])
86
		}
87
88
	}

89
	// Convert heap to sorted slice in descending order
ParthSareen's avatar
ParthSareen committed
90
	result := make([]token, len(h))
91
92
93
	for i := k - 1; i >= 0; i-- {
		result[i] = heap.Pop(&h).(token)
	}
94

95
	return result
96
97
98
}

// topP limits tokens to those with cumulative probability p
99
// requires ts to be sorted in descending order of probabilities
100
func topP(ts []token, p float32) []token {
101
102
	if p == 1.0 {
		return ts
103
104
	}

105
106
107
108
109
	// Find cutoff index where cumulative sum exceeds p
	var sum float32
	for i, t := range ts {
		sum += t.value
		if sum > float32(p) {
110
			return ts[:i+1]
111
112
113
		}
	}

114
	return ts
115
116
}

117
118
// minP filters tokens with probabilities >= p * max_prob
// requires ts to be sorted in descending order of probabilities
119
func minP(ts []token, p float32) []token {
120
	maxProb := ts[0].value
121

122
	threshold := maxProb * p
123

124
125
126
	for i, t := range ts {
		if t.value < threshold {
			return ts[:i]
127
128
		}
	}
129
130
	return ts
}