Commit 9d1de41b authored by Michael Yang's avatar Michael Yang
Browse files

clamp glu/linear

parent 9679520e
...@@ -189,11 +189,16 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates, one ml.Tensor, opts * ...@@ -189,11 +189,16 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates, one ml.Tensor, opts *
hiddenStates = hiddenStates.Reshape(ctx, 2, hiddenStates.Dim(0)/2, hiddenStates.Dim(1), hiddenStates.Dim(2)) hiddenStates = hiddenStates.Reshape(ctx, 2, hiddenStates.Dim(0)/2, hiddenStates.Dim(1), hiddenStates.Dim(2))
dimStride := []int{hiddenStates.Dim(0) / 2, hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), hiddenStates.Dim(2), hiddenStates.Stride(3), hiddenStates.Dim(3)} dimStride := []int{hiddenStates.Dim(0) / 2, hiddenStates.Stride(1), hiddenStates.Dim(1), hiddenStates.Stride(2), hiddenStates.Dim(2), hiddenStates.Stride(3), hiddenStates.Dim(3)}
hiddenStates = hiddenStates.View(ctx, 0, dimStride...).
Contiguous(ctx).
QuickGELU(ctx).
Mul(ctx, hiddenStates.View(ctx, hiddenStates.Stride(0), dimStride...).Add(ctx, one))
glu := hiddenStates.View(ctx, 0, dimStride...)
glu = glu.Contiguous(ctx)
glu = glu.Clamp(ctx, float32(math.Inf(-1)), 7.0)
glu = glu.QuickGELU(ctx)
linear := hiddenStates.View(ctx, hiddenStates.Stride(0), dimStride...)
linear = linear.Clamp(ctx, -7.0, 7.0)
hiddenStates = glu.Mul(ctx, linear.Add(ctx, one))
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3)) hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3))
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts) experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment