"research/deeplab/deprecated/segmentation_dataset.py" did not exist on "bdcfdd30306a7df694fb281cd24884769009d03e"
transforms.go 4.43 KB
Newer Older
1
2
3
package sample

import (
4
	"container/heap"
5
6
7
8
	"math"
	"slices"
)

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
// tokenHeap implements heap.Interface and holds tokens as a min-heap to track k largest elements
type tokenHeap []token

func (h tokenHeap) Len() int           { return len(h) }
func (h tokenHeap) Less(i, j int) bool { return h[i].value < h[j].value } // Use < for min-heap to track largest elements
func (h tokenHeap) Swap(i, j int)      { h[i], h[j] = h[j], h[i] }

func (h *tokenHeap) Push(x any) {
	*h = append(*h, x.(token))
}

func (h *tokenHeap) Pop() any {
	old := *h
	n := len(old)
	x := old[n-1]
	*h = old[0 : n-1]
	return x
}

28
29
30
31
32
33
34
35
36
37
38
39
// temperature applies scaling and softmax to the logits
func temperature(ts []token, temp float32) []token {
	// Find max logit for numerical stability
	maxLogit := float32(math.Inf(-1))
	for _, t := range ts {
		if t.value > maxLogit {
			maxLogit = t.value
		}
	}

	// Apply temperature and compute exp(x - max)
	temp = max(temp, 1e-7)
40
41
	var sum float32
	for i, v := range ts {
42
		ts[i].value = float32(math.Exp(float64((v.value - maxLogit) / temp)))
43
		sum += ts[i].value
44
45
	}

46
	// Normalize
47
48
	for i := range ts {
		ts[i].value /= sum
49
50
	}

51
	return ts
52
53
}

54
// topK limits the number of tokens considered to the k highest logits
55
func topK(ts []token, k int) []token {
56
	if k >= len(ts) {
57
		sortLogits(ts)
58
59
		return ts
	}
60

61
62
63
64
65
66
	// Initialize min-heap with first k elements
	h := make(tokenHeap, k)
	copy(h, ts[:k])
	heap.Init(&h)

	// Process remaining elements
67
	for i := k; i < len(ts); i++ {
68
69
70
		if ts[i].value > h[0].value {
			heap.Pop(&h)
			heap.Push(&h, ts[i])
71
		}
72
73
	}

74
75
76
77
78
	// Convert heap to sorted slice in descending order
	result := make([]token, k)
	for i := k - 1; i >= 0; i-- {
		result[i] = heap.Pop(&h).(token)
	}
79

80
	return result
81
82
83
}

// topP limits tokens to those with cumulative probability p
84
func topP(ts []token, p float32) []token {
85
86
	if p == 1.0 {
		return ts
87
88
	}

89
90
91
92
93
94
95
	// Find cutoff index where cumulative sum exceeds p
	var sum float32
	for i, t := range ts {
		sum += t.value
		if sum > float32(p) {
			ts = ts[:i+1]
			return ts
96
97
98
		}
	}

99
	return ts
100
101
}

102
// minP limits tokens to those with cumulative probability p
103
func minP(ts []token, p float32) []token {
104
105
106
	if p == 1.0 {
		return ts
	}
107

108
109
110
111
112
	maxProb := float32(math.Inf(-1))
	for _, token := range ts {
		if token.value > maxProb {
			maxProb = token.value
		}
113
114
	}

115
	threshold := maxProb * float32(p)
116

117
118
119
120
121
	// Filter tokens in-place
	validTokens := ts[:0]
	for i, token := range ts {
		if token.value >= threshold {
			validTokens = append(validTokens, ts[i])
122
123
124
		}
	}

125
126
127
	ts = validTokens
	return ts
}
128

129
130
131
132
// partialSortLogits uses quickselect to efficiently find and sort the top n tokens
func partialSortLogits(ts []token, n int) []token {
	if n >= len(ts) {
		n = len(ts)
133
	}
134

135
136
	left, right := 0, len(ts)-1
	target := n - 1
137

138
139
140
141
142
	// Quickselect algorithm to partition array around pivot
	for left < right {
		// Choose middle element as pivot and move it to the end
		pivot := left + (right-left)/2
		ts[pivot], ts[right] = ts[right], ts[pivot]
143

144
145
146
		// storeIndex tracks where to put next element greater than pivot
		storeIndex := left
		pivotValue := ts[right].value
147

148
149
150
151
152
153
154
155
		// Partition array into elements >= pivot and < pivot
		// Elements >= pivot go to the left side
		for i := left; i < right; i++ {
			if ts[i].value >= pivotValue {
				ts[storeIndex], ts[i] = ts[i], ts[storeIndex]
				storeIndex++
			}
		}
156

157
158
159
160
161
162
163
164
165
166
167
168
		// Move pivot to its final position
		ts[right], ts[storeIndex] = ts[storeIndex], ts[right]

		// If pivot is at target position, we're done
		// Otherwise recursively partition the half containing target
		if storeIndex == target {
			break
		} else if storeIndex < target {
			left = storeIndex + 1 // Target is in right half
		} else {
			right = storeIndex - 1 // Target is in left half
		}
169
170
	}

171
172
173
174
175
176
177
178
179
180
181
182
183
	// Sort just the top n elements in descending order
	slices.SortFunc(ts[:n], func(a, b token) int {
		if a.value > b.value {
			return -1
		}
		if a.value < b.value {
			return 1
		}
		return 0
	})

	return ts[:n]
}
184

185
186
187
188
189
190
191
// sortLogits uses partialSortLogits to efficiently sort tokens
// It sorts approximately sqrt(len(tokens)) elements which balances
// between having enough tokens for sampling while avoiding full sort
func sortLogits(ts []token) {
	// Use sqrt of token length as a heuristic for partial sort size
	// This provides a good balance between performance and having enough tokens
	n := int(math.Sqrt(float64(len(ts)))) + 1
192

193
194
195
196
197
198
	// Ensure we have at least 100 tokens and at most 1000
	switch {
	case n < 100:
		n = 100
	case n > 1000:
		n = 1000
199
200
	}

201
	partialSortLogits(ts, n)
202
}