logprob.go 1.87 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
package common

import (
	"math"
	"sort"

	"github.com/ollama/ollama/llm"
)

// TokenDecoderFunc is a function that converts token IDs to text.
type TokenDecoderFunc func(tokenID int) string

// CalculateLogprobs converts raw logits to log probabilities and finds top K tokens.
// It uses numerically stable softmax to compute log probabilities.
func CalculateLogprobs(logits []float32, selectedToken int, topK int, decoder TokenDecoderFunc) []llm.Logprob {
	if len(logits) == 0 {
		return nil
	}

	// Step 1: Convert logits to log probabilities using numerically stable softmax
	maxLogit := logits[0]
	for _, logit := range logits[1:] {
		if logit > maxLogit {
			maxLogit = logit
		}
	}

	var sumExp float64
	for _, logit := range logits {
		sumExp += math.Exp(float64(logit - maxLogit))
	}
	logSumExp := float32(math.Log(sumExp))

	logProbs := make([]float32, len(logits))
	for i, logit := range logits {
		logProbs[i] = (logit - maxLogit) - logSumExp
	}

	// Step 2: Get selected token's information
	selectedLogprob := logProbs[selectedToken]
	selectedText := decoder(selectedToken)

	result := llm.Logprob{
		TokenLogprob: llm.TokenLogprob{
			Token:   selectedText,
			Logprob: float64(selectedLogprob),
		},
	}

	// Step 3: If topK requested, find the top K tokens
	if topK > 0 {
		type tokenLogprobPair struct {
			tokenID int
			logprob float32
		}

		pairs := make([]tokenLogprobPair, len(logProbs))
		for i, lp := range logProbs {
			pairs[i] = tokenLogprobPair{tokenID: i, logprob: lp}
		}

		sort.Slice(pairs, func(i, j int) bool {
			return pairs[i].logprob > pairs[j].logprob
		})

		k := min(topK, len(pairs))
		topLogprobs := make([]llm.TokenLogprob, k)
		for i := range k {
			tokenText := decoder(pairs[i].tokenID)
			topLogprobs[i] = llm.TokenLogprob{
				Token:   tokenText,
				Logprob: float64(pairs[i].logprob),
			}
		}
		result.TopLogprobs = topLogprobs
	}

	return []llm.Logprob{result}
}