Commit 253aff64 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Further simplify rotation multiplication code

parent 4e9f7a9d
......@@ -50,10 +50,14 @@ def rot_matmul(
dim=-1,
)
row_1 = row_mul(0)
row_2 = row_mul(1)
row_3 = row_mul(2)
return torch.stack([row_1, row_2, row_3], dim=-2)
return torch.stack(
[
row_mul(0),
row_mul(1),
row_mul(2),
],
dim=-2
)
def rot_vec_mul(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment