Commit 4e9f7a9d authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Simplify rotation multiplication code

parent 1eb4fcdd
...@@ -34,49 +34,25 @@ def rot_matmul( ...@@ -34,49 +34,25 @@ def rot_matmul(
Returns: Returns:
The product ab The product ab
""" """
row_1 = torch.stack( def row_mul(i):
[ return torch.stack(
a[..., 0, 0] * b[..., 0, 0] [
+ a[..., 0, 1] * b[..., 1, 0] a[..., i, 0] * b[..., 0, 0]
+ a[..., 0, 2] * b[..., 2, 0], + a[..., i, 1] * b[..., 1, 0]
a[..., 0, 0] * b[..., 0, 1] + a[..., i, 2] * b[..., 2, 0],
+ a[..., 0, 1] * b[..., 1, 1] a[..., i, 0] * b[..., 0, 1]
+ a[..., 0, 2] * b[..., 2, 1], + a[..., i, 1] * b[..., 1, 1]
a[..., 0, 0] * b[..., 0, 2] + a[..., i, 2] * b[..., 2, 1],
+ a[..., 0, 1] * b[..., 1, 2] a[..., i, 0] * b[..., 0, 2]
+ a[..., 0, 2] * b[..., 2, 2], + a[..., i, 1] * b[..., 1, 2]
], + a[..., i, 2] * b[..., 2, 2],
dim=-1, ],
) dim=-1,
row_2 = torch.stack( )
[
a[..., 1, 0] * b[..., 0, 0]
+ a[..., 1, 1] * b[..., 1, 0]
+ a[..., 1, 2] * b[..., 2, 0],
a[..., 1, 0] * b[..., 0, 1]
+ a[..., 1, 1] * b[..., 1, 1]
+ a[..., 1, 2] * b[..., 2, 1],
a[..., 1, 0] * b[..., 0, 2]
+ a[..., 1, 1] * b[..., 1, 2]
+ a[..., 1, 2] * b[..., 2, 2],
],
dim=-1,
)
row_3 = torch.stack(
[
a[..., 2, 0] * b[..., 0, 0]
+ a[..., 2, 1] * b[..., 1, 0]
+ a[..., 2, 2] * b[..., 2, 0],
a[..., 2, 0] * b[..., 0, 1]
+ a[..., 2, 1] * b[..., 1, 1]
+ a[..., 2, 2] * b[..., 2, 1],
a[..., 2, 0] * b[..., 0, 2]
+ a[..., 2, 1] * b[..., 1, 2]
+ a[..., 2, 2] * b[..., 2, 2],
],
dim=-1,
)
row_1 = row_mul(0)
row_2 = row_mul(1)
row_3 = row_mul(2)
return torch.stack([row_1, row_2, row_3], dim=-2) return torch.stack([row_1, row_2, row_3], dim=-2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment