Commit 0c02ae90 authored by Roman Shapovalov's avatar Roman Shapovalov Committed by Facebook GitHub Bot
Browse files

Adding utility methods to TensorProperties

Summary:
Context: in the code we are releasing with CO3D dataset, we use  `cuda()` on TensorProperties like Pointclouds and Cameras where we recursively move batch to a GPU. It would be good to push it to a release so we don’t need to depend on the nightly build.

Additionally, I aligned the logic of `.to("cuda")` without device index to the one of `torch.Tensor` where the current device is populated to index. It should not affect any actual use cases but some tests had to be changed.

Reviewed By: bottler

Differential Revision: D29659529

fbshipit-source-id: abe58aeaca14bacc68da3e6cf5ae07df3353e3ce
parent fa44a055
...@@ -15,7 +15,8 @@ Device = Union[str, torch.device] ...@@ -15,7 +15,8 @@ Device = Union[str, torch.device]
def make_device(device: Device) -> torch.device: def make_device(device: Device) -> torch.device:
""" """
Makes an actual torch.device object from the device specified as Makes an actual torch.device object from the device specified as
either a string or torch.device object. either a string or torch.device object. If the device is `cuda` without
a specific index, the index of the current device is assigned.
Args: Args:
device: Device (as str or torch.device) device: Device (as str or torch.device)
...@@ -23,7 +24,12 @@ def make_device(device: Device) -> torch.device: ...@@ -23,7 +24,12 @@ def make_device(device: Device) -> torch.device:
Returns: Returns:
A matching torch.device object A matching torch.device object
""" """
return torch.device(device) if isinstance(device, str) else device device = torch.device(device) if isinstance(device, str) else device
if device.type == "cuda" and device.index is None: # pyre-ignore[16]
# If cuda but with no index, then the current cuda device is indicated.
# In that case, we fix to that device
device = torch.device(f"cuda:{torch.cuda.current_device()}")
return device
def get_device(x, device: Optional[Device] = None) -> torch.device: def get_device(x, device: Optional[Device] = None) -> torch.device:
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
import copy import copy
import inspect import inspect
import warnings import warnings
from typing import Any, Union from typing import Any, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -174,6 +174,12 @@ class TensorProperties(nn.Module): ...@@ -174,6 +174,12 @@ class TensorProperties(nn.Module):
setattr(self, k, v.to(device_)) setattr(self, k, v.to(device_))
return self return self
def cpu(self) -> "TensorProperties":
return self.to("cpu")
def cuda(self, device: Optional[int] = None) -> "TensorProperties":
return self.to(f"cuda:{device}" if device is not None else "cuda")
def clone(self, other) -> "TensorProperties": def clone(self, other) -> "TensorProperties":
""" """
Update the tensor properties of other with the cloned properties of self. Update the tensor properties of other with the cloned properties of self.
......
...@@ -709,9 +709,9 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): ...@@ -709,9 +709,9 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
self.assertEqual(cpu_device, mesh.device) self.assertEqual(cpu_device, mesh.device)
self.assertIs(mesh, converted_mesh) self.assertIs(mesh, converted_mesh)
cuda_device = torch.device("cuda") cuda_device = torch.device("cuda:0")
converted_mesh = mesh.to("cuda") converted_mesh = mesh.to("cuda:0")
self.assertEqual(cuda_device, converted_mesh.device) self.assertEqual(cuda_device, converted_mesh.device)
self.assertEqual(cpu_device, mesh.device) self.assertEqual(cpu_device, mesh.device)
self.assertIsNot(mesh, converted_mesh) self.assertIsNot(mesh, converted_mesh)
......
...@@ -39,7 +39,17 @@ class TestTensorProperties(TestCaseMixin, unittest.TestCase): ...@@ -39,7 +39,17 @@ class TestTensorProperties(TestCaseMixin, unittest.TestCase):
example = TensorPropertiesTestClass(x=10.0, y=(100.0, 200.0)) example = TensorPropertiesTestClass(x=10.0, y=(100.0, 200.0))
device = torch.device("cuda:0") device = torch.device("cuda:0")
new_example = example.to(device=device) new_example = example.to(device=device)
self.assertTrue(new_example.device == device) self.assertEqual(new_example.device, device)
example_cpu = example.cpu()
self.assertEqual(example_cpu.device, torch.device("cpu"))
example_gpu = example.cuda()
self.assertEqual(example_gpu.device.type, "cuda")
self.assertIsNotNone(example_gpu.device.index)
example_gpu1 = example.cuda(1)
self.assertEqual(example_gpu1.device, torch.device("cuda:1"))
def test_clone(self): def test_clone(self):
# Check clone method # Check clone method
......
...@@ -22,7 +22,7 @@ from pytorch3d.structures.meshes import Meshes ...@@ -22,7 +22,7 @@ from pytorch3d.structures.meshes import Meshes
class TestShader(TestCaseMixin, unittest.TestCase): class TestShader(TestCaseMixin, unittest.TestCase):
def test_to(self): def test_to(self):
cpu_device = torch.device("cpu") cpu_device = torch.device("cpu")
cuda_device = torch.device("cuda") cuda_device = torch.device("cuda:0")
R, T = look_at_view_transform() R, T = look_at_view_transform()
......
...@@ -50,9 +50,9 @@ class TestTransform(TestCaseMixin, unittest.TestCase): ...@@ -50,9 +50,9 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
self.assertEqual(torch.float32, t.dtype) self.assertEqual(torch.float32, t.dtype)
self.assertIsNot(t, cpu_t) self.assertIsNot(t, cpu_t)
cuda_device = torch.device("cuda") cuda_device = torch.device("cuda:0")
cuda_t = t.to("cuda") cuda_t = t.to("cuda:0")
self.assertEqual(cuda_device, cuda_t.device) self.assertEqual(cuda_device, cuda_t.device)
self.assertEqual(cpu_device, t.device) self.assertEqual(cpu_device, t.device)
self.assertEqual(torch.float32, cuda_t.dtype) self.assertEqual(torch.float32, cuda_t.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment