Commit 46cb5aaa authored by Jiali Duan's avatar Jiali Duan Committed by Facebook GitHub Bot
Browse files

Omit _check_valid_rotation_matrix by default

Summary:
According to the profiler trace D40326775, _check_valid_rotation_matrix is slow because of aten::all_close operation and _safe_det_3x3 bottlenecks. Disable the check by default unless environment variable PYTORCH3D_CHECK_ROTATION_MATRICES is set to 1.

Comparison after applying the change:
```
Profiling/Function    get_world_to_view (ms)   Transform_points(ms)    specular(ms)
before                12.751                    18.577                  21.384
after                 4.432 (34.7%)             9.248 (49.8%)           11.507 (53.8%)
```

Profiling trace:
https://pxl.cl/2h687
More details in https://docs.google.com/document/d/1kfhEQfpeQToikr5OH9ZssM39CskxWoJ2p8DO5-t6eWk/edit?usp=sharing

Reviewed By: kjchalup

Differential Revision: D40442503

fbshipit-source-id: 954b58de47de235c9d93af441643c22868b547d0
parent 8339cf26
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math import math
import os
import warnings import warnings
from typing import List, Optional, Union from typing import List, Optional, Union
...@@ -636,7 +637,10 @@ class Rotate(Transform3d): ...@@ -636,7 +637,10 @@ class Rotate(Transform3d):
msg = "R must have shape (3, 3) or (N, 3, 3); got %s" msg = "R must have shape (3, 3) or (N, 3, 3); got %s"
raise ValueError(msg % repr(R.shape)) raise ValueError(msg % repr(R.shape))
R = R.to(device=device_, dtype=dtype) R = R.to(device=device_, dtype=dtype)
_check_valid_rotation_matrix(R, tol=orthogonal_tol) if os.environ.get("PYTORCH3D_CHECK_ROTATION_MATRICES", "0") == "1":
# Note: aten::all_close in the check is computationally slow, so we
# only run the check when PYTORCH3D_CHECK_ROTATION_MATRICES is on.
_check_valid_rotation_matrix(R, tol=orthogonal_tol)
N = R.shape[0] N = R.shape[0]
mat = torch.eye(4, dtype=dtype, device=device_) mat = torch.eye(4, dtype=dtype, device=device_)
mat = mat.view(1, 4, 4).repeat(N, 1, 1) mat = mat.view(1, 4, 4).repeat(N, 1, 1)
......
...@@ -4,9 +4,10 @@ ...@@ -4,9 +4,10 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math import math
import os
import unittest import unittest
from unittest import mock
import torch import torch
from pytorch3d.transforms import random_rotations from pytorch3d.transforms import random_rotations
...@@ -191,7 +192,25 @@ class TestTransform(TestCaseMixin, unittest.TestCase): ...@@ -191,7 +192,25 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
self.assertTrue(torch.allclose(points_out, points_out_expected)) self.assertTrue(torch.allclose(points_out, points_out_expected))
self.assertTrue(torch.allclose(normals_out, normals_out_expected)) self.assertTrue(torch.allclose(normals_out, normals_out_expected))
def test_rotate(self): @mock.patch.dict(os.environ, {"PYTORCH3D_CHECK_ROTATION_MATRICES": "1"}, clear=True)
def test_rotate_check_rot_valid_on(self):
R = so3_exp_map(torch.randn((1, 3)))
t = Transform3d().rotate(R)
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
1, 3, 3
)
normals = torch.tensor(
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]
).view(1, 3, 3)
points_out = t.transform_points(points)
normals_out = t.transform_normals(normals)
points_out_expected = torch.bmm(points, R)
normals_out_expected = torch.bmm(normals, R)
self.assertTrue(torch.allclose(points_out, points_out_expected))
self.assertTrue(torch.allclose(normals_out, normals_out_expected))
@mock.patch.dict(os.environ, {"PYTORCH3D_CHECK_ROTATION_MATRICES": "0"}, clear=True)
def test_rotate_check_rot_valid_off(self):
R = so3_exp_map(torch.randn((1, 3))) R = so3_exp_map(torch.randn((1, 3)))
t = Transform3d().rotate(R) t = Transform3d().rotate(R)
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view( points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment