normalization.go 460 Bytes
Newer Older
Michael Yang's avatar
Michael Yang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
package nn

import (
	"github.com/ollama/ollama/ml"
)

type LayerNorm struct {
	Weight ml.Tensor `gguf:"weight"`
	Bias   ml.Tensor `gguf:"bias"`
}

func (m *LayerNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
	return t.LayerNorm(ctx, m.Weight, m.Bias, eps)
}

type RMSNorm struct {
	Weight ml.Tensor `gguf:"weight"`
}

func (m *RMSNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
	return t.RMSNorm(ctx, m.Weight, eps)
}