"vscode:/vscode.git/clone" did not exist on "e36dfaa8de2d9a9fa67eeed5ce64fd5949916c99"
Commit d773b7d6 authored by Jesse Gross's avatar Jesse Gross Committed by Jesse Gross
Browse files

backend: API to support full precision matmul

Most tensor backends try to optimize performance by using a lower
precision for matmuls. However, some operations (such as kq) on
some models are sensitive to this and require full precision.
parent 4d4463b2
......@@ -66,6 +66,7 @@ type Tensor interface {
Add(ctx Context, t2 Tensor) Tensor
Mul(ctx Context, t2 Tensor) Tensor
Mulmat(ctx Context, t2 Tensor) Tensor
MulmatFullPrec(ctx Context, t2 Tensor) Tensor
Softmax(ctx Context) Tensor
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
......
......@@ -421,6 +421,15 @@ func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
mul := C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t)
C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
return &Tensor{
t: mul,
}
}
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
tt := (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
if b != nil {
......
......@@ -80,7 +80,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := k.Mulmat(ctx, q)
kq := k.MulmatFullPrec(ctx, q)
kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
kq = kq.Softmax(ctx)
......
......@@ -37,7 +37,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, mas
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scores := key.Mulmat(ctx, query)
scores := key.MulmatFullPrec(ctx, query)
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
if mask != nil {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment