embedding.go 241 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
package nn

import "github.com/ollama/ollama/x/ml"

type Embedding struct {
	Weight ml.Tensor `gguf:"weight"`
}

func (m *Embedding) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
	return m.Weight.TakeAxes(ctx, hiddenState, 0)
}