Commit 29417d1f authored by RWL's avatar RWL Committed by Facebook GitHub Bot
Browse files

NaN (divide by zero) fix for issue #561 and #790 (#891)

Summary:
https://github.com/facebookresearch/pytorch3d/issues/561
https://github.com/facebookresearch/pytorch3d/issues/790
Divide by zero fix (NaN fix).  When perspective_correct=True, BarycentricPerspectiveCorrectionForward and BarycentricPerspectiveCorrectionBackward in ../csrc/utils/geometry_utils.cuh are called.  The denominator (denom) values should not be allowed to go to zero. I'm able to resolve this issue locally with this PR and submit it for the team's review.

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/891

Reviewed By: patricklabatut

Differential Revision: D31829695

Pulled By: bottler

fbshipit-source-id: a3517b8362f6e60d48c35731258d8ce261b1d912
parent 57b9c729
...@@ -177,7 +177,7 @@ __device__ inline float3 BarycentricPerspectiveCorrectionForward( ...@@ -177,7 +177,7 @@ __device__ inline float3 BarycentricPerspectiveCorrectionForward(
const float w0_top = bary.x * z1 * z2; const float w0_top = bary.x * z1 * z2;
const float w1_top = z0 * bary.y * z2; const float w1_top = z0 * bary.y * z2;
const float w2_top = z0 * z1 * bary.z; const float w2_top = z0 * z1 * bary.z;
const float denom = w0_top + w1_top + w2_top; const float denom = fmaxf(w0_top + w1_top + w2_top, kEpsilon);
const float w0 = w0_top / denom; const float w0 = w0_top / denom;
const float w1 = w1_top / denom; const float w1 = w1_top / denom;
const float w2 = w2_top / denom; const float w2 = w2_top / denom;
...@@ -208,7 +208,7 @@ BarycentricPerspectiveCorrectionBackward( ...@@ -208,7 +208,7 @@ BarycentricPerspectiveCorrectionBackward(
const float w0_top = bary.x * z1 * z2; const float w0_top = bary.x * z1 * z2;
const float w1_top = z0 * bary.y * z2; const float w1_top = z0 * bary.y * z2;
const float w2_top = z0 * z1 * bary.z; const float w2_top = z0 * z1 * bary.z;
const float denom = w0_top + w1_top + w2_top; const float denom = fmaxf(w0_top + w1_top + w2_top, kEpsilon);
// Now do backward pass // Now do backward pass
const float grad_denom_top = const float grad_denom_top =
......
...@@ -198,7 +198,7 @@ inline vec3<T> BarycentricPerspectiveCorrectionForward( ...@@ -198,7 +198,7 @@ inline vec3<T> BarycentricPerspectiveCorrectionForward(
const T w0_top = bary.x * z1 * z2; const T w0_top = bary.x * z1 * z2;
const T w1_top = bary.y * z0 * z2; const T w1_top = bary.y * z0 * z2;
const T w2_top = bary.z * z0 * z1; const T w2_top = bary.z * z0 * z1;
const T denom = w0_top + w1_top + w2_top; const T denom = std::max<T>(w0_top + w1_top + w2_top, kEpsilon);
const T w0 = w0_top / denom; const T w0 = w0_top / denom;
const T w1 = w1_top / denom; const T w1 = w1_top / denom;
const T w2 = w2_top / denom; const T w2 = w2_top / denom;
...@@ -229,7 +229,7 @@ inline std::tuple<vec3<T>, T, T, T> BarycentricPerspectiveCorrectionBackward( ...@@ -229,7 +229,7 @@ inline std::tuple<vec3<T>, T, T, T> BarycentricPerspectiveCorrectionBackward(
const T w0_top = bary.x * z1 * z2; const T w0_top = bary.x * z1 * z2;
const T w1_top = bary.y * z0 * z2; const T w1_top = bary.y * z0 * z2;
const T w2_top = bary.z * z0 * z1; const T w2_top = bary.z * z0 * z1;
const T denom = w0_top + w1_top + w2_top; const T denom = std::max<T>(w0_top + w1_top + w2_top, kEpsilon);
// Now do backward pass // Now do backward pass
const T grad_denom_top = const T grad_denom_top =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment