attention.go 2.71 KB
Newer Older
1
2
3
4
5
package nn

import (
	"fmt"

6
	"github.com/ollama/ollama/kvcache"
7
8
9
10
11
12
13
14
	"github.com/ollama/ollama/ml"
)

// Attention implements scaled dot-product attention for transformer models:
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
//
// Parameters:
//   - ctx: Context for tensor operations
15
16
17
//   - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
//   - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
//   - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
18
//   - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
19
//   - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
20
21
22
23
//
// Returns:
//
//	Attention output with shape [d_v, heads, seq_len_q]
24
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
Grace's avatar
Grace committed
25
	return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache)
26
27
28
}

func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
Grace's avatar
Grace committed
29
30
31
32
	return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache)
}

func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
33
	ctx.Forward(query)
34
35
36
37
	if key != nil && value != nil {
		if query.Dim(0) != key.Dim(0) {
			panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
		}
38

39
40
41
		if key.Dim(1) != value.Dim(1) {
			panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
		}
42

43
44
45
		if key.Dim(2) != value.Dim(2) {
			panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
		}
46

47
		ctx.Forward(key, value)
48
49
50
51
52
		if cache != nil {
			cache.Put(ctx, key, value)
		}
	} else if cache == nil {
		panic("key & value tensors must be provided if cache is nil")
53
54
	}

55
56
57
	var mask ml.Tensor
	if cache != nil {
		key, value, mask = cache.Get(ctx)
58
59
	}

60
61
62
	if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
		cacheConfigApplied := cache != nil
		return sdpa.ScaledDotProductAttention(ctx, key, value, mask, sinks, vmla, scale, cacheConfigApplied)
63
	} else {
64
65
66
67
		query = query.Permute(ctx, 0, 2, 1, 3)
		key = key.Permute(ctx, 0, 2, 1, 3)
		value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)

68
69
70
71
72
73
74
75
76
		kq := key.MulmatFullPrec(ctx, query)

		kq = kq.Scale(ctx, scale)
		if mask != nil {
			kq = kq.Add(ctx, mask)
		}
		kq = kq.Softmax(ctx)

		kqv := value.Mulmat(ctx, kq)
Grace's avatar
Grace committed
77
78
79
80
81

		if vmla != nil {
			kqv = vmla.Mulmat(ctx, kqv)
		}

82
83
84
		return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
	}
}