Commit 4e9f7a9d authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Simplify rotation multiplication code

parent 1eb4fcdd
......@@ -34,49 +34,25 @@ def rot_matmul(
Returns:
The product ab
"""
row_1 = torch.stack(
[
a[..., 0, 0] * b[..., 0, 0]
+ a[..., 0, 1] * b[..., 1, 0]
+ a[..., 0, 2] * b[..., 2, 0],
a[..., 0, 0] * b[..., 0, 1]
+ a[..., 0, 1] * b[..., 1, 1]
+ a[..., 0, 2] * b[..., 2, 1],
a[..., 0, 0] * b[..., 0, 2]
+ a[..., 0, 1] * b[..., 1, 2]
+ a[..., 0, 2] * b[..., 2, 2],
],
dim=-1,
)
row_2 = torch.stack(
[
a[..., 1, 0] * b[..., 0, 0]
+ a[..., 1, 1] * b[..., 1, 0]
+ a[..., 1, 2] * b[..., 2, 0],
a[..., 1, 0] * b[..., 0, 1]
+ a[..., 1, 1] * b[..., 1, 1]
+ a[..., 1, 2] * b[..., 2, 1],
a[..., 1, 0] * b[..., 0, 2]
+ a[..., 1, 1] * b[..., 1, 2]
+ a[..., 1, 2] * b[..., 2, 2],
],
dim=-1,
)
row_3 = torch.stack(
[
a[..., 2, 0] * b[..., 0, 0]
+ a[..., 2, 1] * b[..., 1, 0]
+ a[..., 2, 2] * b[..., 2, 0],
a[..., 2, 0] * b[..., 0, 1]
+ a[..., 2, 1] * b[..., 1, 1]
+ a[..., 2, 2] * b[..., 2, 1],
a[..., 2, 0] * b[..., 0, 2]
+ a[..., 2, 1] * b[..., 1, 2]
+ a[..., 2, 2] * b[..., 2, 2],
],
dim=-1,
)
def row_mul(i):
return torch.stack(
[
a[..., i, 0] * b[..., 0, 0]
+ a[..., i, 1] * b[..., 1, 0]
+ a[..., i, 2] * b[..., 2, 0],
a[..., i, 0] * b[..., 0, 1]
+ a[..., i, 1] * b[..., 1, 1]
+ a[..., i, 2] * b[..., 2, 1],
a[..., i, 0] * b[..., 0, 2]
+ a[..., i, 1] * b[..., 1, 2]
+ a[..., i, 2] * b[..., 2, 2],
],
dim=-1,
)
row_1 = row_mul(0)
row_2 = row_mul(1)
row_3 = row_mul(2)
return torch.stack([row_1, row_2, row_3], dim=-2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment