transforms.go 4.36 KB
Newer Older
1
2
3
4
5
6
7
package sample

import (
	"math"
	"slices"
)

8
9
10
11
12
func softmax(ts []logit) []logit {
	var sum float32
	for i, v := range ts {
		ts[i].value = float32(math.Exp(float64(v.value)))
		sum += ts[i].value
13
14
	}

15
16
	for i := range ts {
		ts[i].value /= sum
17
18
	}

19
	return ts
20
21
}

22
23
24
25
func temperature(ti []logit, t float32) []logit {
	if t == 1 {
		return ti
	}
26

27
28
29
30
31
32
33
	temp := max(t, 1e-7)
	maxLogit := float32(math.Inf(-1))
	for _, token := range ti {
		if token.value > maxLogit {
			maxLogit = token.value
		}
	}
34
35

	// subtracting max logit to avoid under/overflow
36
37
	for i := range ti {
		ti[i].value = (ti[i].value - maxLogit) / temp
38
39
	}

40
	return ti
41
42
}

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
// siftDown maintains a min-heap property by recursively moving larger elements down the heap.
//
// The heap is represented as an array where for any node at index i:
// - Left child is at index 2i + 1
// - Right child is at index 2i + 2
// - Parent is at index (i-1)/2
//
// The function compares a node with its children and:
// 1. Finds the smallest value between the node and its children
// 2. If the node is not the smallest, swaps it with its smallest child
// 3. Continues this process down the affected path until the min-heap property is restored
func siftDown(data []logit, start, end int) {
	root := start
	for {
		child := 2*root + 1
		if child >= end {
			break
		}
		// Find smaller child (we want min heap)
		if child+1 < end && data[child+1].value < data[child].value {
			child++
		}
		// Exit if root is already smaller than children
		if data[root].value <= data[child].value {
			break
		}
		// Swap with smaller child and continue
		data[root], data[child] = data[child], data[root]
		root = child
	}
73
74
}

75
76
77
78
79
80
81
82
83
84
// topK limits the number of tokens considered to the k highest logits
func topK(ts []logit, k int) []logit {
	if k >= len(ts) {
		return ts
	}
	// Heapify + siftDown - O(nlog(k))
	// Build min-heap of first k elements
	heap := ts[:k]
	for i := k/2 - 1; i >= 0; i-- {
		siftDown(heap, i, k)
85
86
	}

87
88
89
90
91
92
	// Process remaining elements - if larger than heap root, replace root
	for i := k; i < len(ts); i++ {
		if ts[i].value > heap[0].value {
			heap[0] = ts[i]
			siftDown(heap, 0, k)
		}
93
94
	}

95
96
97
98
99
100
101
102
103
104
	slices.Reverse(heap)

	ts = heap
	return ts
}

// topP limits tokens to those with cumulative probability p
func topP(ts []logit, p float32) []logit {
	if p == 1.0 {
		return ts
105
106
	}

107
108
109
110
111
112
113
	// Find cutoff index where cumulative sum exceeds p
	var sum float32
	for i, t := range ts {
		sum += t.value
		if sum > float32(p) {
			ts = ts[:i+1]
			return ts
114
115
116
		}
	}

117
	return ts
118
119
}

120
121
122
123
124
// minP limits tokens to those with cumulative probability p
func minP(ts []logit, p float32) []logit {
	if p == 1.0 {
		return ts
	}
125

126
127
128
129
130
	maxProb := float32(math.Inf(-1))
	for _, token := range ts {
		if token.value > maxProb {
			maxProb = token.value
		}
131
132
	}

133
	threshold := maxProb * float32(p)
134

135
136
137
138
139
	// Filter tokens in-place
	validTokens := ts[:0]
	for i, token := range ts {
		if token.value >= threshold {
			validTokens = append(validTokens, ts[i])
140
141
142
		}
	}

143
144
145
	ts = validTokens
	return ts
}
146

147
148
149
150
151
152
// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584
// Conting sort implementation to sort tokens by logits
func sortLogits(tokens []logit) {
	if len(tokens) <= 1 {
		return
	}
153

154
155
156
157
158
159
160
	// Find max/min in a single pass
	minLogit, maxLogit := tokens[0].value, tokens[0].value
	for _, t := range tokens[1:] {
		if t.value < minLogit {
			minLogit = t.value
		} else if t.value > maxLogit {
			maxLogit = t.value
161
162
163
		}
	}

164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
	// Calculate scaling to map to uint32 range
	logitRange := maxLogit - minLogit
	if logitRange < 1e-6 {
		return // All values effectively equal
	}

	// Count frequencies directly from tokens
	const maxInt = (1 << 24) - 1 // Use 24 bits for good granularity
	var counts [256]int          // For first byte

	// First pass: count frequencies
	for _, t := range tokens {
		// Map to [0, maxInt] range
		score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)
		counts[score>>16]++
	}

	// Calculate offsets
	var offset int
	for i := range counts {
		count := counts[i]
		counts[i] = offset
		offset += count
	}

	// Second pass: place elements in correct position
	output := make([]logit, len(tokens))
	// Track current positions
	countsCopy := counts

	for i, t := range tokens {
		score := min(uint32((t.value-minLogit)*float32(maxInt)/logitRange), maxInt)

		pos := countsCopy[score>>16]
		countsCopy[score>>16]++
		output[len(tokens)-1-pos] = tokens[i]
	}

	copy(tokens, output)
203
}