"tools/vscode:/vscode.git/clone" did not exist on "39b6343dcf6e5aa23bbfe1121113d0771d2bc8a2"
Commit 603ceefa authored by Michael Yang's avatar Michael Yang Committed by Michael Yang
Browse files

refactor rope

change to a flatter directory structure and group the options with the
function

update models to call rope in one place
parent e082d60a
......@@ -23,18 +23,18 @@ func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
}
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor {
return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin))
}
func (sa *VisionAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts VisionOptions) ml.Tensor {
query := sa.Query.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, query.Dim(1))
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
query = applyRotaryPositionEmbeddings(ctx, query, cos, sin)
key := sa.Key.Forward(ctx, hiddenStates)
key = key.Reshape(ctx, opts.headDim(), opts.numHeads, key.Dim(1))
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
key = applyRotaryPositionEmbeddings(ctx, key, cos, sin)
value := sa.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), opts.numHeads, value.Dim(1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment