rope.go 1.43 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
package rope

import "github.com/ollama/ollama/x/ml"

// Options contains optional parameters for RoPE function
type Options struct {
	Type    int
	Factors ml.Tensor

	// YaRN options
	YaRN struct {
		OriginalContextLength int
		ExtrapolationFactor,
		AttentionFactor,
		BetaFast,
		BetaSlow float32
	}

	// MRoPE options
	MRoPE struct {
		Sections []int
	}
}

// WithTypeNeoX sets RoPE type to NeoX
func WithTypeNeoX() func(*Options) {
	return func(opts *Options) {
		opts.Type = 2
	}
}

// WithFactors sets custom rope factors
func WithFactors(factors ml.Tensor) func(*Options) {
	return func(opts *Options) {
		if factors != nil {
			opts.Factors = factors
		}
	}
}

// WithOriginalContextLength sets a custom context length
func WithOriginalContextLength(n int) func(*Options) {
	return func(opts *Options) {
		opts.YaRN.OriginalContextLength = n
	}
}

func WithExtrapolationFactor(extrapolationFactor float32) func(*Options) {
	return func(opts *Options) {
		opts.YaRN.ExtrapolationFactor = extrapolationFactor
	}
}

func WithAttentionFactor(attentionFactor float32) func(*Options) {
	return func(opts *Options) {
		opts.YaRN.AttentionFactor = attentionFactor
	}
}

func WithMRoPE(sections []int) func(*Options) {
	return func(opts *Options) {
		opts.Type |= 1 << 3
		opts.MRoPE.Sections = sections
	}
}

func WithInterleaveMRoPE(sections []int) func(*Options) {
	return func(opts *Options) {
		opts.Type |= 1<<3 | 1<<5
		opts.MRoPE.Sections = sections
	}
}