Commit 7e986cfb authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

Avoid torch.square

Summary: Fix axis_angle conversions where I used torch.square which doesn't work with pytorch 1.4

Reviewed By: nikhilaravi

Differential Revision: D24451546

fbshipit-source-id: ba26f7dad5fa991f0a8f7d3d09ee7151163aecf4
parent c93c4dd7
...@@ -469,7 +469,7 @@ def axis_angle_to_quaternion(axis_angle): ...@@ -469,7 +469,7 @@ def axis_angle_to_quaternion(axis_angle):
# for x small, sin(x/2) is about x/2 - (x/2)^3/6 # for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48 # so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = ( sin_half_angles_over_angles[small_angles] = (
0.5 - torch.square(angles[small_angles]) / 48 0.5 - (angles[small_angles] * angles[small_angles]) / 48
) )
quaternions = torch.cat( quaternions = torch.cat(
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
...@@ -503,7 +503,7 @@ def quaternion_to_axis_angle(quaternions): ...@@ -503,7 +503,7 @@ def quaternion_to_axis_angle(quaternions):
# for x small, sin(x/2) is about x/2 - (x/2)^3/6 # for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48 # so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = ( sin_half_angles_over_angles[small_angles] = (
0.5 - torch.square(angles[small_angles]) / 48 0.5 - (angles[small_angles] * angles[small_angles]) / 48
) )
return quaternions[..., 1:] / sin_half_angles_over_angles return quaternions[..., 1:] / sin_half_angles_over_angles
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment