Commit cd41d73f authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add shape assertion to rotation code

parent 253aff64
......@@ -74,9 +74,7 @@ def rot_vec_mul(
Returns:
[*, 3] rotated coordinates
"""
x = t[..., 0]
y = t[..., 1]
z = t[..., 2]
x, y, z = torch.unbind(t, dim=-1)
return torch.stack(
[
r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment