rope.go 751 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
// fast provides implementations of fast (fused) operations for increased performance.
package fast

import (
	"github.com/ollama/ollama/ml"
	"github.com/ollama/ollama/ml/nn/rope"
)

// fastRoPE is an interface for tensors that support fast rotary positional embedding.
type fastRoPE interface {
	RoPE(ctx ml.Context, positionIDs ml.Tensor, dim int, base, scale float32, options ...func(*rope.Options)) ml.Tensor
}

// RoPE applies rotary positional embedding to tensor `t`.
func RoPE(ctx ml.Context, t, positions ml.Tensor, dim int, base, scale float32, options ...func(*rope.Options)) ml.Tensor {
	if t, ok := t.(fastRoPE); ok {
		return t.RoPE(ctx, positions, dim, base, scale, options...)
	}

	panic("RoPE not implemented for this tensor type")
}