Commit 253aff64 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Further simplify rotation multiplication code

parent 4e9f7a9d
...@@ -50,10 +50,14 @@ def rot_matmul( ...@@ -50,10 +50,14 @@ def rot_matmul(
dim=-1, dim=-1,
) )
row_1 = row_mul(0) return torch.stack(
row_2 = row_mul(1) [
row_3 = row_mul(2) row_mul(0),
return torch.stack([row_1, row_2, row_3], dim=-2) row_mul(1),
row_mul(2),
],
dim=-2
)
def rot_vec_mul( def rot_vec_mul(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment