Unverified Commit 2dfb7441 authored by Jeffrey Morgan's avatar Jeffrey Morgan Committed by GitHub
Browse files

model: fix rotary embeddings for ministral 3 (#13432)

parent 1eb5e759
...@@ -29,24 +29,13 @@ type TextOptions struct { ...@@ -29,24 +29,13 @@ type TextOptions struct {
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor { func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
var ropeOpts []func(*rope.Options) var ropeOpts []func(*rope.Options)
if o.ropeType == "yarn" { if o.ropeType == "yarn" {
getMscale := func(scale, mscale float64) float64 {
if scale <= 1.0 {
return 1.0
}
return 0.1*mscale*math.Log(scale) + 1.0
}
var attnFactor float32
if o.ropeMscale != 0 && o.ropeMscaleAllDim != 0 { if o.ropeMscale != 0 && o.ropeMscaleAllDim != 0 {
attnFactor = float32(getMscale(float64(o.ropeScale), float64(o.ropeMscale)) / getMscale(float64(o.ropeScale), float64(o.ropeMscaleAllDim))) ropeOpts = append(ropeOpts, rope.WithAttentionFactor(1.0/float32(0.1*math.Log(float64(o.ropeScale))+1.0)))
} else {
attnFactor = float32(getMscale(float64(o.ropeScale), 1))
} }
ropeOpts = append(ropeOpts, ropeOpts = append(ropeOpts,
rope.WithOriginalContextLength(o.ropeOrigPosEmbeddings), rope.WithOriginalContextLength(o.ropeOrigPosEmbeddings),
rope.WithExtrapolationFactor(o.ropeExtrapolation), rope.WithExtrapolationFactor(o.ropeExtrapolation),
rope.WithAttentionFactor(attnFactor),
rope.WithBetaFast(o.ropeBetaFast), rope.WithBetaFast(o.ropeBetaFast),
rope.WithBetaSlow(o.ropeBetaSlow), rope.WithBetaSlow(o.ropeBetaSlow),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment