Commit 7e986cfb authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

Avoid torch.square

Summary: Fix axis_angle conversions where I used torch.square which doesn't work with pytorch 1.4

Reviewed By: nikhilaravi

Differential Revision: D24451546

fbshipit-source-id: ba26f7dad5fa991f0a8f7d3d09ee7151163aecf4
parent c93c4dd7
......@@ -469,7 +469,7 @@ def axis_angle_to_quaternion(axis_angle):
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - torch.square(angles[small_angles]) / 48
0.5 - (angles[small_angles] * angles[small_angles]) / 48
)
quaternions = torch.cat(
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
......@@ -503,7 +503,7 @@ def quaternion_to_axis_angle(quaternions):
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - torch.square(angles[small_angles]) / 48
0.5 - (angles[small_angles] * angles[small_angles]) / 48
)
return quaternions[..., 1:] / sin_half_angles_over_angles
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment