Commit 6c3fe952 authored by Jeremy Reizenstein's avatar Jeremy Reizenstein Committed by Facebook GitHub Bot
Browse files

PLY TexturesVertex loading

Summary:
Include TexturesVertex colors when loading and saving Meshes to PLY files.

A couple of other improvements to the internals of ply_io, including using `None` instead of empty tensors for some missing data.

Reviewed By: gkioxari

Differential Revision: D27765260

fbshipit-source-id: b9857dc777c244b9d7d6643b608596d31435ecda
parent 097b0ef2
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
import torch import torch
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
from pytorch3d.renderer import TexturesVertex
from pytorch3d.structures import Meshes, Pointclouds from pytorch3d.structures import Meshes, Pointclouds
from .pluggable_formats import ( from .pluggable_formats import (
...@@ -66,7 +67,7 @@ class _PlyElementType: ...@@ -66,7 +67,7 @@ class _PlyElementType:
def __init__(self, name: str, count: int): def __init__(self, name: str, count: int):
self.name = name self.name = name
self.count = count self.count = count
self.properties = [] self.properties: List[_Property] = []
def add_property( def add_property(
self, name: str, data_type: str, list_size_type: Optional[str] = None self, name: str, data_type: str, list_size_type: Optional[str] = None
...@@ -142,7 +143,7 @@ class _PlyHeader: ...@@ -142,7 +143,7 @@ class _PlyHeader:
if f.readline() not in [b"ply\n", b"ply\r\n", "ply\n"]: if f.readline() not in [b"ply\n", b"ply\r\n", "ply\n"]:
raise ValueError("Invalid file header.") raise ValueError("Invalid file header.")
seen_format = False seen_format = False
self.elements = [] self.elements: List[_PlyElementType] = []
self.obj_info = [] self.obj_info = []
while True: while True:
line = f.readline() line = f.readline()
...@@ -891,8 +892,8 @@ def _get_verts( ...@@ -891,8 +892,8 @@ def _get_verts(
def _load_ply( def _load_ply(
f, *, path_manager: PathManager, return_vertex_colors: bool = False f, *, path_manager: PathManager
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
""" """
Load the data from a .ply file. Load the data from a .ply file.
...@@ -903,12 +904,11 @@ def _load_ply( ...@@ -903,12 +904,11 @@ def _load_ply(
ply format, then a text stream is not supported. ply format, then a text stream is not supported.
It is easiest to use a binary stream in all cases. It is easiest to use a binary stream in all cases.
path_manager: PathManager for loading if f is a str. path_manager: PathManager for loading if f is a str.
return_vertex_colors: whether to return vertex colors.
Returns: Returns:
verts: FloatTensor of shape (V, 3). verts: FloatTensor of shape (V, 3).
faces: None or LongTensor of vertex indices, shape (F, 3). faces: None or LongTensor of vertex indices, shape (F, 3).
vertex_colors: None or FloatTensor of shape (V, 3), only if requested vertex_colors: None or FloatTensor of shape (V, 3).
""" """
header, elements = _load_ply_raw(f, path_manager=path_manager) header, elements = _load_ply_raw(f, path_manager=path_manager)
...@@ -950,16 +950,17 @@ def _load_ply( ...@@ -950,16 +950,17 @@ def _load_ply(
if faces is not None: if faces is not None:
_check_faces_indices(faces, max_index=verts.shape[0]) _check_faces_indices(faces, max_index=verts.shape[0])
if return_vertex_colors: return verts, faces, vertex_colors
return verts, faces, vertex_colors
return verts, faces, None
def load_ply( def load_ply(
f, *, path_manager: Optional[PathManager] = None f, *, path_manager: Optional[PathManager] = None
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Load the data from a .ply file. Load the verts and faces from a .ply file.
Note that the preferred way to load data from such a file
is to use the IO.load_mesh and IO.load_pointcloud functions,
which can read more of the data.
Example .ply file format: Example .ply file format:
...@@ -1016,8 +1017,8 @@ def _save_ply( ...@@ -1016,8 +1017,8 @@ def _save_ply(
*, *,
verts: torch.Tensor, verts: torch.Tensor,
faces: Optional[torch.LongTensor], faces: Optional[torch.LongTensor],
verts_normals: torch.Tensor, verts_normals: Optional[torch.Tensor],
verts_colors: torch.Tensor, verts_colors: Optional[torch.Tensor],
ascii: bool, ascii: bool,
decimal_places: Optional[int] = None, decimal_places: Optional[int] = None,
) -> None: ) -> None:
...@@ -1029,16 +1030,16 @@ def _save_ply( ...@@ -1029,16 +1030,16 @@ def _save_ply(
verts: FloatTensor of shape (V, 3) giving vertex coordinates. verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shape (F, 3) giving faces. faces: LongTensor of shape (F, 3) giving faces.
verts_normals: FloatTensor of shape (V, 3) giving vertex normals. verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
verts_colors: FloatTensor of shape (V, 3) giving vertex colors.
ascii: (bool) whether to use the ascii ply format. ascii: (bool) whether to use the ascii ply format.
decimal_places: Number of decimal places for saving if ascii=True. decimal_places: Number of decimal places for saving if ascii=True.
""" """
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3) assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
if faces is not None: assert faces is None or not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3) assert verts_normals is None or (
assert not len(verts_normals) or (
verts_normals.dim() == 2 and verts_normals.size(1) == 3 verts_normals.dim() == 2 and verts_normals.size(1) == 3
) )
assert not len(verts_colors) or ( assert verts_colors is None or (
verts_colors.dim() == 2 and verts_colors.size(1) == 3 verts_colors.dim() == 2 and verts_colors.size(1) == 3
) )
...@@ -1052,11 +1053,11 @@ def _save_ply( ...@@ -1052,11 +1053,11 @@ def _save_ply(
f.write(b"property float x\n") f.write(b"property float x\n")
f.write(b"property float y\n") f.write(b"property float y\n")
f.write(b"property float z\n") f.write(b"property float z\n")
if verts_normals.numel() > 0: if verts_normals is not None:
f.write(b"property float nx\n") f.write(b"property float nx\n")
f.write(b"property float ny\n") f.write(b"property float ny\n")
f.write(b"property float nz\n") f.write(b"property float nz\n")
if verts_colors.numel() > 0: if verts_colors is not None:
f.write(b"property float red\n") f.write(b"property float red\n")
f.write(b"property float green\n") f.write(b"property float green\n")
f.write(b"property float blue\n") f.write(b"property float blue\n")
...@@ -1069,7 +1070,13 @@ def _save_ply( ...@@ -1069,7 +1070,13 @@ def _save_ply(
warnings.warn("Empty 'verts' provided") warnings.warn("Empty 'verts' provided")
return return
vert_data = torch.cat((verts, verts_normals, verts_colors), dim=1).detach().numpy() verts_tensors = [verts]
if verts_normals is not None:
verts_tensors.append(verts_normals)
if verts_colors is not None:
verts_tensors.append(verts_colors)
vert_data = torch.cat(verts_tensors, dim=1).detach().cpu().numpy()
if ascii: if ascii:
if decimal_places is None: if decimal_places is None:
float_str = "%f" float_str = "%f"
...@@ -1085,7 +1092,7 @@ def _save_ply( ...@@ -1085,7 +1092,7 @@ def _save_ply(
vert_data.tofile(f) vert_data.tofile(f)
if faces is not None: if faces is not None:
faces_array = faces.detach().numpy() faces_array = faces.detach().cpu().numpy()
_check_faces_indices(faces, max_index=verts.shape[0]) _check_faces_indices(faces, max_index=verts.shape[0])
...@@ -1125,12 +1132,6 @@ def save_ply( ...@@ -1125,12 +1132,6 @@ def save_ply(
""" """
verts_normals = (
torch.tensor([], dtype=torch.float32, device=verts.device)
if verts_normals is None
else verts_normals
)
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3): if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)." message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
raise ValueError(message) raise ValueError(message)
...@@ -1143,16 +1144,18 @@ def save_ply( ...@@ -1143,16 +1144,18 @@ def save_ply(
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)." message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
raise ValueError(message) raise ValueError(message)
if len(verts_normals) and not ( if (
verts_normals.dim() == 2 verts_normals is not None
and verts_normals.size(1) == 3 and len(verts_normals)
and verts_normals.size(0) == verts.size(0) and not (
verts_normals.dim() == 2
and verts_normals.size(1) == 3
and verts_normals.size(0) == verts.size(0)
)
): ):
message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)." message = "Argument 'verts_normals' should either be empty or of shape (num_verts, 3)."
raise ValueError(message) raise ValueError(message)
verts_colors = torch.FloatTensor([])
if path_manager is None: if path_manager is None:
path_manager = PathManager() path_manager = PathManager()
with _open_file(f, path_manager, "wb") as f: with _open_file(f, path_manager, "wb") as f:
...@@ -1161,7 +1164,7 @@ def save_ply( ...@@ -1161,7 +1164,7 @@ def save_ply(
verts=verts, verts=verts,
faces=faces, faces=faces,
verts_normals=verts_normals, verts_normals=verts_normals,
verts_colors=verts_colors, verts_colors=None,
ascii=ascii, ascii=ascii,
decimal_places=decimal_places, decimal_places=decimal_places,
) )
...@@ -1182,8 +1185,19 @@ class MeshPlyFormat(MeshFormatInterpreter): ...@@ -1182,8 +1185,19 @@ class MeshPlyFormat(MeshFormatInterpreter):
if not endswith(path, self.known_suffixes): if not endswith(path, self.known_suffixes):
return None return None
verts, faces = load_ply(f=path, path_manager=path_manager) verts, faces, verts_colors = _load_ply(f=path, path_manager=path_manager)
mesh = Meshes(verts=[verts.to(device)], faces=[faces.to(device)]) if faces is None:
faces = torch.zeros(0, 3, dtype=torch.int64)
textures = None
if include_textures and verts_colors is not None:
textures = TexturesVertex([verts_colors.to(device)])
mesh = Meshes(
verts=[verts.to(device)],
faces=[faces.to(device)],
textures=textures,
)
return mesh return mesh
def save( def save(
...@@ -1201,14 +1215,30 @@ class MeshPlyFormat(MeshFormatInterpreter): ...@@ -1201,14 +1215,30 @@ class MeshPlyFormat(MeshFormatInterpreter):
# TODO: normals are not saved. We only want to save them if they already exist. # TODO: normals are not saved. We only want to save them if they already exist.
verts = data.verts_list()[0] verts = data.verts_list()[0]
faces = data.faces_list()[0] faces = data.faces_list()[0]
save_ply(
f=path, if isinstance(data.textures, TexturesVertex):
verts=verts, mesh_verts_colors = data.textures.verts_features_list()[0]
faces=faces, n_colors = mesh_verts_colors.shape[1]
ascii=binary is False, if n_colors == 3:
decimal_places=decimal_places, verts_colors = mesh_verts_colors
path_manager=path_manager, else:
) warnings.warn(
f"Texture will not be saved as it has {n_colors} colors, not 3."
)
verts_colors = None
else:
verts_colors = None
with _open_file(path, path_manager, "wb") as f:
_save_ply(
f=f,
verts=verts,
faces=faces,
verts_colors=verts_colors,
verts_normals=None,
ascii=binary is False,
decimal_places=decimal_places,
)
return True return True
...@@ -1226,14 +1256,12 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter): ...@@ -1226,14 +1256,12 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
if not endswith(path, self.known_suffixes): if not endswith(path, self.known_suffixes):
return None return None
verts, faces, features = _load_ply( verts, faces, features = _load_ply(f=path, path_manager=path_manager)
f=path, path_manager=path_manager, return_vertex_colors=True
)
verts = verts.to(device) verts = verts.to(device)
if features is None: if features is not None:
pointcloud = Pointclouds(points=[verts]) features = [features.to(device)]
else:
pointcloud = Pointclouds(points=[verts], features=[features.to(device)]) pointcloud = Pointclouds(points=[verts], features=features)
return pointcloud return pointcloud
def save( def save(
...@@ -1249,13 +1277,14 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter): ...@@ -1249,13 +1277,14 @@ class PointcloudPlyFormat(PointcloudFormatInterpreter):
return False return False
points = data.points_list()[0] points = data.points_list()[0]
features = data.features_list()[0] features = data.features_packed()
with _open_file(path, path_manager, "wb") as f: with _open_file(path, path_manager, "wb") as f:
_save_ply( _save_ply(
f=f, f=f,
verts=points, verts=points,
verts_colors=features, verts_colors=features,
verts_normals=torch.FloatTensor([]), verts_normals=None,
faces=None, faces=None,
ascii=binary is False, ascii=binary is False,
decimal_places=decimal_places, decimal_places=decimal_places,
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import itertools
import struct import struct
import unittest import unittest
from io import BytesIO, StringIO from io import BytesIO, StringIO
...@@ -12,7 +13,8 @@ from common_testing import TestCaseMixin ...@@ -12,7 +13,8 @@ from common_testing import TestCaseMixin
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
from pytorch3d.io import IO from pytorch3d.io import IO
from pytorch3d.io.ply_io import load_ply, save_ply from pytorch3d.io.ply_io import load_ply, save_ply
from pytorch3d.structures import Pointclouds from pytorch3d.renderer.mesh import TexturesVertex
from pytorch3d.structures import Meshes, Pointclouds
from pytorch3d.utils import torus from pytorch3d.utils import torus
...@@ -189,6 +191,57 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase): ...@@ -189,6 +191,57 @@ class TestMeshPlyIO(TestCaseMixin, unittest.TestCase):
): ):
io.load_mesh(f3.name) io.load_mesh(f3.name)
def test_save_too_many_colors(self):
verts = torch.tensor(
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
)
faces = torch.tensor([[0, 1, 2], [0, 2, 3]])
vert_colors = torch.rand((4, 7))
texture_with_seven_colors = TexturesVertex(verts_features=[vert_colors])
mesh = Meshes(
verts=[verts],
faces=[faces],
textures=texture_with_seven_colors,
)
io = IO()
msg = "Texture will not be saved as it has 7 colors, not 3."
with NamedTemporaryFile(mode="w", suffix=".ply") as f:
with self.assertWarnsRegex(UserWarning, msg):
io.save_mesh(mesh.cuda(), f.name)
def test_save_load_meshes(self):
verts = torch.tensor(
[[0, 0, 0], [0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=torch.float32
)
faces = torch.tensor([[0, 1, 2], [0, 2, 3]])
vert_colors = torch.rand_like(verts)
texture = TexturesVertex(verts_features=[vert_colors])
for do_textures in itertools.product([True, False]):
mesh = Meshes(
verts=[verts],
faces=[faces],
textures=texture if do_textures else None,
)
device = torch.device("cuda:0")
io = IO()
with NamedTemporaryFile(mode="w", suffix=".ply") as f:
io.save_mesh(mesh.cuda(), f.name)
f.flush()
mesh2 = io.load_mesh(f.name, device=device)
self.assertEqual(mesh2.device, device)
mesh2 = mesh2.cpu()
self.assertClose(mesh2.verts_padded(), mesh.verts_padded())
self.assertClose(mesh2.faces_padded(), mesh.faces_padded())
if do_textures:
self.assertIsInstance(mesh2.textures, TexturesVertex)
self.assertClose(mesh2.textures.verts_features_list()[0], vert_colors)
else:
self.assertIsNone(mesh2.textures)
def test_save_ply_invalid_shapes(self): def test_save_ply_invalid_shapes(self):
# Invalid vertices shape # Invalid vertices shape
with self.assertRaises(ValueError) as error: with self.assertRaises(ValueError) as error:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment