Commit d57daa6f authored by Patrick Labatut's avatar Patrick Labatut Committed by Facebook GitHub Bot
Browse files

Address black + isort fbsource linter warnings

Summary: Address black + isort fbsource linter warnings from D20558374 (previous diff)

Reviewed By: nikhilaravi

Differential Revision: D20558373

fbshipit-source-id: d3607de4a01fb24c0d5269634563a7914bddf1c8
parent eb512ffd
...@@ -8,6 +8,7 @@ TODO: python 3.8 when pytorch 1.4. ...@@ -8,6 +8,7 @@ TODO: python 3.8 when pytorch 1.4.
""" """
import os.path import os.path
import jinja2 import jinja2
import yaml import yaml
...@@ -45,9 +46,7 @@ def workflow_pair( ...@@ -45,9 +46,7 @@ def workflow_pair(
): ):
w = [] w = []
base_workflow_name = ( base_workflow_name = f"{prefix}binary_linux_{btype}_py{python_version}_{cu_version}"
f"{prefix}binary_linux_{btype}_py{python_version}_{cu_version}"
)
w.append( w.append(
generate_base_workflow( generate_base_workflow(
...@@ -94,9 +93,7 @@ def generate_base_workflow( ...@@ -94,9 +93,7 @@ def generate_base_workflow(
return {f"binary_linux_{btype}": d} return {f"binary_linux_{btype}": d}
def generate_upload_workflow( def generate_upload_workflow(*, base_workflow_name, btype, cu_version, filter_branch):
*, base_workflow_name, btype, cu_version, filter_branch
):
d = { d = {
"name": f"{base_workflow_name}_upload", "name": f"{base_workflow_name}_upload",
"context": "org-member", "context": "org-member",
......
...@@ -22,6 +22,7 @@ from recommonmark.states import DummyStateMachine ...@@ -22,6 +22,7 @@ from recommonmark.states import DummyStateMachine
from sphinx.builders.html import StandaloneHTMLBuilder from sphinx.builders.html import StandaloneHTMLBuilder
from sphinx.ext.autodoc import between from sphinx.ext.autodoc import between
# Monkey patch to fix recommonmark 0.4 doc reference issues. # Monkey patch to fix recommonmark 0.4 doc reference issues.
orig_run_role = DummyStateMachine.run_role orig_run_role = DummyStateMachine.run_role
...@@ -154,9 +155,7 @@ html_theme_options = {"collapse_navigation": True} ...@@ -154,9 +155,7 @@ html_theme_options = {"collapse_navigation": True}
def url_resolver(url): def url_resolver(url):
if ".html" not in url: if ".html" not in url:
url = url.replace("../", "") url = url.replace("../", "")
return ( return "https://github.com/facebookresearch/pytorch3d/blob/master/" + url
"https://github.com/facebookresearch/pytorch3d/blob/master/" + url
)
else: else:
if DEPLOY: if DEPLOY:
return "http://pytorch3d.readthedocs.io/" + url return "http://pytorch3d.readthedocs.io/" + url
...@@ -188,9 +187,7 @@ def setup(app): ...@@ -188,9 +187,7 @@ def setup(app):
# Register a sphinx.ext.autodoc.between listener to ignore everything # Register a sphinx.ext.autodoc.between listener to ignore everything
# between lines that contain the word IGNORE # between lines that contain the word IGNORE
app.connect( app.connect("autodoc-process-docstring", between("^.*IGNORE.*$", exclude=True))
"autodoc-process-docstring", between("^.*IGNORE.*$", exclude=True)
)
app.add_transform(AutoStructify) app.add_transform(AutoStructify)
return app return app
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Render a coloured point cloud\n", "# Render a colored point cloud\n",
"\n", "\n",
"This tutorial shows how to:\n", "This tutorial shows how to:\n",
"- set up a renderer \n", "- set up a renderer \n",
...@@ -84,7 +84,7 @@ ...@@ -84,7 +84,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"### Load a point cloud and corresponding colours\n", "### Load a point cloud and corresponding colors\n",
"\n", "\n",
"Load a `.ply` file and create a **Point Cloud** object. \n", "Load a `.ply` file and create a **Point Cloud** object. \n",
"\n", "\n",
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .camera_visualization import ( from .camera_visualization import get_camera_wireframe, plot_camera_scene, plot_cameras
get_camera_wireframe,
plot_camera_scene,
plot_cameras,
)
from .plot_image_grid import image_grid from .plot_image_grid import image_grid
...@@ -34,13 +34,9 @@ def image_grid( ...@@ -34,13 +34,9 @@ def image_grid(
cols = 1 cols = 1
gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {} gridspec_kw = {"wspace": 0.0, "hspace": 0.0} if fill else {}
fig, axarr = plt.subplots( fig, axarr = plt.subplots(rows, cols, gridspec_kw=gridspec_kw, figsize=(15, 9))
rows, cols, gridspec_kw=gridspec_kw, figsize=(15, 9)
)
bleed = 0 bleed = 0
fig.subplots_adjust( fig.subplots_adjust(left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed))
left=bleed, bottom=bleed, right=(1 - bleed), top=(1 - bleed)
)
for ax, im in zip(axarr.ravel(), images): for ax, im in zip(axarr.ravel(), images):
if rgb: if rgb:
......
...@@ -4,4 +4,5 @@ ...@@ -4,4 +4,5 @@
from .obj_io import load_obj, load_objs_as_meshes, save_obj from .obj_io import load_obj, load_objs_as_meshes, save_obj
from .ply_io import load_ply, save_ply from .ply_io import load_ply, save_ply
__all__ = [k for k in globals().keys() if not k.startswith("_")] __all__ = [k for k in globals().keys() if not k.startswith("_")]
...@@ -2,16 +2,16 @@ ...@@ -2,16 +2,16 @@
"""This module implements utility functions for loading and saving meshes.""" """This module implements utility functions for loading and saving meshes."""
import numpy as np
import os import os
import pathlib import pathlib
import warnings import warnings
from collections import namedtuple from collections import namedtuple
from typing import List, Optional from typing import List, Optional
import numpy as np
import torch import torch
from fvcore.common.file_io import PathManager from fvcore.common.file_io import PathManager
from PIL import Image from PIL import Image
from pytorch3d.structures import Meshes, Textures, join_meshes from pytorch3d.structures import Meshes, Textures, join_meshes
...@@ -51,9 +51,7 @@ def _read_image(file_name: str, format=None): ...@@ -51,9 +51,7 @@ def _read_image(file_name: str, format=None):
# Faces & Aux type returned from load_obj function. # Faces & Aux type returned from load_obj function.
_Faces = namedtuple("Faces", "verts_idx normals_idx textures_idx materials_idx") _Faces = namedtuple("Faces", "verts_idx normals_idx textures_idx materials_idx")
_Aux = namedtuple( _Aux = namedtuple("Properties", "normals verts_uvs material_colors texture_images")
"Properties", "normals verts_uvs material_colors texture_images"
)
def _format_faces_indices(faces_indices, max_index): def _format_faces_indices(faces_indices, max_index):
...@@ -247,9 +245,7 @@ def load_objs_as_meshes(files: list, device=None, load_textures: bool = True): ...@@ -247,9 +245,7 @@ def load_objs_as_meshes(files: list, device=None, load_textures: bool = True):
image = list(tex_maps.values())[0].to(device)[None] image = list(tex_maps.values())[0].to(device)[None]
tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=image) tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=image)
mesh = Meshes( mesh = Meshes(verts=[verts], faces=[faces.verts_idx.to(device)], textures=tex)
verts=[verts], faces=[faces.verts_idx.to(device)], textures=tex
)
mesh_list.append(mesh) mesh_list.append(mesh)
if len(mesh_list) == 1: if len(mesh_list) == 1:
return mesh_list[0] return mesh_list[0]
...@@ -308,9 +304,7 @@ def _parse_face( ...@@ -308,9 +304,7 @@ def _parse_face(
# Subdivide faces with more than 3 vertices. See comments of the # Subdivide faces with more than 3 vertices. See comments of the
# load_obj function for more details. # load_obj function for more details.
for i in range(len(face_verts) - 2): for i in range(len(face_verts) - 2):
faces_verts_idx.append( faces_verts_idx.append((face_verts[0], face_verts[i + 1], face_verts[i + 2]))
(face_verts[0], face_verts[i + 1], face_verts[i + 2])
)
if len(face_normals) > 0: if len(face_normals) > 0:
faces_normals_idx.append( faces_normals_idx.append(
(face_normals[0], face_normals[i + 1], face_normals[i + 2]) (face_normals[0], face_normals[i + 1], face_normals[i + 2])
...@@ -367,8 +361,7 @@ def _load(f_obj, data_dir, load_textures=True): ...@@ -367,8 +361,7 @@ def _load(f_obj, data_dir, load_textures=True):
tx = [float(x) for x in line.split()[1:3]] tx = [float(x) for x in line.split()[1:3]]
if len(tx) != 2: if len(tx) != 2:
raise ValueError( raise ValueError(
"Texture %s does not have 2 values. Line: %s" "Texture %s does not have 2 values. Line: %s" % (str(tx), str(line))
% (str(tx), str(line))
) )
verts_uvs.append(tx) verts_uvs.append(tx)
elif line.startswith("vn "): elif line.startswith("vn "):
...@@ -397,17 +390,13 @@ def _load(f_obj, data_dir, load_textures=True): ...@@ -397,17 +390,13 @@ def _load(f_obj, data_dir, load_textures=True):
# Repeat for normals and textures if present. # Repeat for normals and textures if present.
if len(faces_normals_idx) > 0: if len(faces_normals_idx) > 0:
faces_normals_idx = _format_faces_indices( faces_normals_idx = _format_faces_indices(faces_normals_idx, normals.shape[0])
faces_normals_idx, normals.shape[0]
)
if len(faces_textures_idx) > 0: if len(faces_textures_idx) > 0:
faces_textures_idx = _format_faces_indices( faces_textures_idx = _format_faces_indices(
faces_textures_idx, verts_uvs.shape[0] faces_textures_idx, verts_uvs.shape[0]
) )
if len(faces_materials_idx) > 0: if len(faces_materials_idx) > 0:
faces_materials_idx = torch.tensor( faces_materials_idx = torch.tensor(faces_materials_idx, dtype=torch.int64)
faces_materials_idx, dtype=torch.int64
)
# Load materials # Load materials
material_colors, texture_images = None, None material_colors, texture_images = None, None
......
...@@ -4,15 +4,17 @@ ...@@ -4,15 +4,17 @@
"""This module implements utility functions for loading and saving meshes.""" """This module implements utility functions for loading and saving meshes."""
import numpy as np
import pathlib import pathlib
import struct import struct
import sys import sys
import warnings import warnings
from collections import namedtuple from collections import namedtuple
from typing import Optional, Tuple from typing import Optional, Tuple
import numpy as np
import torch import torch
_PlyTypeData = namedtuple("_PlyTypeData", "size struct_char np_type") _PlyTypeData = namedtuple("_PlyTypeData", "size struct_char np_type")
_PLY_TYPES = { _PLY_TYPES = {
...@@ -257,11 +259,7 @@ def _try_read_ply_constant_list_ascii(f, definition: _PlyElementType): ...@@ -257,11 +259,7 @@ def _try_read_ply_constant_list_ascii(f, definition: _PlyElementType):
"ignore", message=".* Empty input file.*", category=UserWarning "ignore", message=".* Empty input file.*", category=UserWarning
) )
data = np.loadtxt( data = np.loadtxt(
f, f, dtype=np_type, comments=None, ndmin=2, max_rows=definition.count
dtype=np_type,
comments=None,
ndmin=2,
max_rows=definition.count,
) )
except ValueError: except ValueError:
f.seek(start_point) f.seek(start_point)
...@@ -301,9 +299,7 @@ def _parse_heterogenous_property_ascii(datum, line_iter, property: _Property): ...@@ -301,9 +299,7 @@ def _parse_heterogenous_property_ascii(datum, line_iter, property: _Property):
length = int(value) length = int(value)
except ValueError: except ValueError:
raise ValueError("A list length was not a number.") raise ValueError("A list length was not a number.")
list_value = np.zeros( list_value = np.zeros(length, dtype=_PLY_TYPES[property.data_type].np_type)
length, dtype=_PLY_TYPES[property.data_type].np_type
)
for i in range(length): for i in range(length):
inner_value = next(line_iter, None) inner_value = next(line_iter, None)
if inner_value is None: if inner_value is None:
...@@ -404,8 +400,7 @@ def _read_ply_element_struct(f, definition: _PlyElementType, endian_str: str): ...@@ -404,8 +400,7 @@ def _read_ply_element_struct(f, definition: _PlyElementType, endian_str: str):
values. There is one column for each property. values. There is one column for each property.
""" """
format = "".join( format = "".join(
_PLY_TYPES[property.data_type].struct_char _PLY_TYPES[property.data_type].struct_char for property in definition.properties
for property in definition.properties
) )
format = endian_str + format format = endian_str + format
pattern = struct.Struct(format) pattern = struct.Struct(format)
...@@ -414,10 +409,7 @@ def _read_ply_element_struct(f, definition: _PlyElementType, endian_str: str): ...@@ -414,10 +409,7 @@ def _read_ply_element_struct(f, definition: _PlyElementType, endian_str: str):
bytes_data = f.read(needed_bytes) bytes_data = f.read(needed_bytes)
if len(bytes_data) != needed_bytes: if len(bytes_data) != needed_bytes:
raise ValueError("Not enough data for %s." % definition.name) raise ValueError("Not enough data for %s." % definition.name)
data = [ data = [pattern.unpack_from(bytes_data, i * size) for i in range(definition.count)]
pattern.unpack_from(bytes_data, i * size)
for i in range(definition.count)
]
return data return data
...@@ -475,9 +467,7 @@ def _try_read_ply_constant_list_binary( ...@@ -475,9 +467,7 @@ def _try_read_ply_constant_list_binary(
return output return output
def _read_ply_element_binary( def _read_ply_element_binary(f, definition: _PlyElementType, big_endian: bool) -> list:
f, definition: _PlyElementType, big_endian: bool
) -> list:
""" """
Decode all instances of a single element from a binary .ply file. Decode all instances of a single element from a binary .ply file.
...@@ -515,9 +505,7 @@ def _read_ply_element_binary( ...@@ -515,9 +505,7 @@ def _read_ply_element_binary(
data = [] data = []
for _i in range(definition.count): for _i in range(definition.count):
datum = [] datum = []
for property, property_struct in zip( for property, property_struct in zip(definition.properties, property_structs):
definition.properties, property_structs
):
size = property_struct.size size = property_struct.size
initial_data = f.read(size) initial_data = f.read(size)
if len(initial_data) != size: if len(initial_data) != size:
...@@ -656,28 +644,19 @@ def load_ply(f): ...@@ -656,28 +644,19 @@ def load_ply(f):
if face is None: if face is None:
raise ValueError("The ply file has no face element.") raise ValueError("The ply file has no face element.")
if ( if not isinstance(vertex, np.ndarray) or vertex.ndim != 2 or vertex.shape[1] != 3:
not isinstance(vertex, np.ndarray)
or vertex.ndim != 2
or vertex.shape[1] != 3
):
raise ValueError("Invalid vertices in file.") raise ValueError("Invalid vertices in file.")
verts = torch.tensor(vertex, dtype=torch.float32) verts = torch.tensor(vertex, dtype=torch.float32)
face_head = next(head for head in header.elements if head.name == "face") face_head = next(head for head in header.elements if head.name == "face")
if ( if len(face_head.properties) != 1 or face_head.properties[0].list_size_type is None:
len(face_head.properties) != 1
or face_head.properties[0].list_size_type is None
):
raise ValueError("Unexpected form of faces data.") raise ValueError("Unexpected form of faces data.")
# face_head.properties[0].name is usually "vertex_index" or "vertex_indices" # face_head.properties[0].name is usually "vertex_index" or "vertex_indices"
# but we don't need to enforce this. # but we don't need to enforce this.
if isinstance(face, np.ndarray) and face.ndim == 2: if isinstance(face, np.ndarray) and face.ndim == 2:
if face.shape[1] < 3: if face.shape[1] < 3:
raise ValueError("Faces must have at least 3 vertices.") raise ValueError("Faces must have at least 3 vertices.")
face_arrays = [ face_arrays = [face[:, [0, i + 1, i + 2]] for i in range(face.shape[1] - 2)]
face[:, [0, i + 1, i + 2]] for i in range(face.shape[1] - 2)
]
faces = torch.tensor(np.vstack(face_arrays), dtype=torch.int64) faces = torch.tensor(np.vstack(face_arrays), dtype=torch.int64)
else: else:
face_list = [] face_list = []
...@@ -687,9 +666,7 @@ def load_ply(f): ...@@ -687,9 +666,7 @@ def load_ply(f):
if face_item.shape[0] < 3: if face_item.shape[0] < 3:
raise ValueError("Faces must have at least 3 vertices.") raise ValueError("Faces must have at least 3 vertices.")
for i in range(face_item.shape[0] - 2): for i in range(face_item.shape[0] - 2):
face_list.append( face_list.append([face_item[0], face_item[i + 1], face_item[i + 2]])
[face_item[0], face_item[i + 1], face_item[i + 2]]
)
faces = torch.tensor(face_list, dtype=torch.int64) faces = torch.tensor(face_list, dtype=torch.int64)
return verts, faces return verts, faces
......
...@@ -6,4 +6,5 @@ from .mesh_edge_loss import mesh_edge_loss ...@@ -6,4 +6,5 @@ from .mesh_edge_loss import mesh_edge_loss
from .mesh_laplacian_smoothing import mesh_laplacian_smoothing from .mesh_laplacian_smoothing import mesh_laplacian_smoothing
from .mesh_normal_consistency import mesh_normal_consistency from .mesh_normal_consistency import mesh_normal_consistency
__all__ = [k for k in globals().keys() if not k.startswith("_")] __all__ = [k for k in globals().keys() if not k.startswith("_")]
...@@ -2,13 +2,10 @@ ...@@ -2,13 +2,10 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from pytorch3d.ops.nearest_neighbor_points import nn_points_idx from pytorch3d.ops.nearest_neighbor_points import nn_points_idx
def _validate_chamfer_reduction_inputs( def _validate_chamfer_reduction_inputs(batch_reduction: str, point_reduction: str):
batch_reduction: str, point_reduction: str
):
"""Check the requested reductions are valid. """Check the requested reductions are valid.
Args: Args:
...@@ -18,17 +15,11 @@ def _validate_chamfer_reduction_inputs( ...@@ -18,17 +15,11 @@ def _validate_chamfer_reduction_inputs(
points, can be one of ["none", "mean", "sum"]. points, can be one of ["none", "mean", "sum"].
""" """
if batch_reduction not in ["none", "mean", "sum"]: if batch_reduction not in ["none", "mean", "sum"]:
raise ValueError( raise ValueError('batch_reduction must be one of ["none", "mean", "sum"]')
'batch_reduction must be one of ["none", "mean", "sum"]'
)
if point_reduction not in ["none", "mean", "sum"]: if point_reduction not in ["none", "mean", "sum"]:
raise ValueError( raise ValueError('point_reduction must be one of ["none", "mean", "sum"]')
'point_reduction must be one of ["none", "mean", "sum"]'
)
if batch_reduction == "none" and point_reduction == "none": if batch_reduction == "none" and point_reduction == "none":
raise ValueError( raise ValueError('batch_reduction and point_reduction cannot both be "none".')
'batch_reduction and point_reduction cannot both be "none".'
)
def chamfer_distance( def chamfer_distance(
...@@ -87,10 +78,7 @@ def chamfer_distance( ...@@ -87,10 +78,7 @@ def chamfer_distance(
(x.sum((1, 2)) * weights).sum() * 0.0, (x.sum((1, 2)) * weights).sum() * 0.0,
(x.sum((1, 2)) * weights).sum() * 0.0, (x.sum((1, 2)) * weights).sum() * 0.0,
) )
return ( return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0)
(x.sum((1, 2)) * weights) * 0.0,
(x.sum((1, 2)) * weights) * 0.0,
)
return_normals = x_normals is not None and y_normals is not None return_normals = x_normals is not None and y_normals is not None
cham_norm_x = x.new_zeros(()) cham_norm_x = x.new_zeros(())
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from itertools import islice from itertools import islice
import torch import torch
...@@ -76,10 +77,7 @@ def mesh_normal_consistency(meshes): ...@@ -76,10 +77,7 @@ def mesh_normal_consistency(meshes):
with torch.no_grad(): with torch.no_grad():
edge_idx = face_to_edge.reshape(F * 3) # (3 * F,) indexes into edges edge_idx = face_to_edge.reshape(F * 3) # (3 * F,) indexes into edges
vert_idx = ( vert_idx = (
faces_packed.view(1, F, 3) faces_packed.view(1, F, 3).expand(3, F, 3).transpose(0, 1).reshape(3 * F, 3)
.expand(3, F, 3)
.transpose(0, 1)
.reshape(3 * F, 3)
) )
edge_idx, edge_sort_idx = edge_idx.sort() edge_idx, edge_sort_idx = edge_idx.sort()
vert_idx = vert_idx[edge_sort_idx] vert_idx = vert_idx[edge_sort_idx]
...@@ -132,9 +130,7 @@ def mesh_normal_consistency(meshes): ...@@ -132,9 +130,7 @@ def mesh_normal_consistency(meshes):
loss = 1 - torch.cosine_similarity(n0, n1, dim=1) loss = 1 - torch.cosine_similarity(n0, n1, dim=1)
verts_packed_to_mesh_idx = verts_packed_to_mesh_idx[vert_idx[:, 0]] verts_packed_to_mesh_idx = verts_packed_to_mesh_idx[vert_idx[:, 0]]
verts_packed_to_mesh_idx = verts_packed_to_mesh_idx[ verts_packed_to_mesh_idx = verts_packed_to_mesh_idx[vert_edge_pair_idx[:, 0]]
vert_edge_pair_idx[:, 0]
]
num_normals = verts_packed_to_mesh_idx.bincount(minlength=N) num_normals = verts_packed_to_mesh_idx.bincount(minlength=N)
weights = 1.0 / num_normals[verts_packed_to_mesh_idx].float() weights = 1.0 / num_normals[verts_packed_to_mesh_idx].float()
......
...@@ -10,4 +10,5 @@ from .sample_points_from_meshes import sample_points_from_meshes ...@@ -10,4 +10,5 @@ from .sample_points_from_meshes import sample_points_from_meshes
from .subdivide_meshes import SubdivideMeshes from .subdivide_meshes import SubdivideMeshes
from .vert_align import vert_align from .vert_align import vert_align
__all__ = [k for k in globals().keys() if not k.startswith("_")] __all__ = [k for k in globals().keys() if not k.startswith("_")]
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from pytorch3d.structures import Meshes from pytorch3d.structures import Meshes
...@@ -200,8 +199,6 @@ def cubify(voxels, thresh, device=None) -> Meshes: ...@@ -200,8 +199,6 @@ def cubify(voxels, thresh, device=None) -> Meshes:
grid_verts.index_select(0, (idleverts[n] == 0).nonzero()[:, 0]) grid_verts.index_select(0, (idleverts[n] == 0).nonzero()[:, 0])
for n in range(N) for n in range(N)
] ]
faces_list = [ faces_list = [nface - idlenum[n][nface] for n, nface in enumerate(faces_list)]
nface - idlenum[n][nface] for n, nface in enumerate(faces_list)
]
return Meshes(verts=verts_list, faces=faces_list) return Meshes(verts=verts_list, faces=faces_list)
...@@ -3,11 +3,10 @@ ...@@ -3,11 +3,10 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from pytorch3d import _C
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
from pytorch3d import _C
class GraphConv(nn.Module): class GraphConv(nn.Module):
"""A single graph convolution layer.""" """A single graph convolution layer."""
...@@ -60,9 +59,7 @@ class GraphConv(nn.Module): ...@@ -60,9 +59,7 @@ class GraphConv(nn.Module):
number of output features per vertex. number of output features per vertex.
""" """
if verts.is_cuda != edges.is_cuda: if verts.is_cuda != edges.is_cuda:
raise ValueError( raise ValueError("verts and edges tensors must be on the same device.")
"verts and edges tensors must be on the same device."
)
if verts.shape[0] == 0: if verts.shape[0] == 0:
# empty graph. # empty graph.
return verts.new_zeros((0, self.output_dim)) * verts.sum() return verts.new_zeros((0, self.output_dim)) * verts.sum()
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch import torch
from pytorch3d import _C from pytorch3d import _C
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch import torch
from pytorch3d import _C
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
from pytorch3d import _C
class _MeshFaceAreasNormals(Function): class _MeshFaceAreasNormals(Function):
""" """
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
import torch import torch
from pytorch3d import _C from pytorch3d import _C
...@@ -31,9 +30,7 @@ def nn_points_idx(p1, p2, p2_normals=None) -> torch.Tensor: ...@@ -31,9 +30,7 @@ def nn_points_idx(p1, p2, p2_normals=None) -> torch.Tensor:
""" """
N, P1, D = p1.shape N, P1, D = p1.shape
with torch.no_grad(): with torch.no_grad():
p1_nn_idx = _C.nn_points_idx( p1_nn_idx = _C.nn_points_idx(p1.contiguous(), p2.contiguous()) # (N, P1)
p1.contiguous(), p2.contiguous()
) # (N, P1)
p1_nn_idx_expanded = p1_nn_idx.view(N, P1, 1).expand(N, P1, D) p1_nn_idx_expanded = p1_nn_idx.view(N, P1, 1).expand(N, P1, D)
p1_nn_points = p2.gather(1, p1_nn_idx_expanded) p1_nn_points = p2.gather(1, p1_nn_idx_expanded)
if p2_normals is None: if p2_normals is None:
......
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch import torch
from pytorch3d import _C
from torch.autograd import Function from torch.autograd import Function
from torch.autograd.function import once_differentiable from torch.autograd.function import once_differentiable
from pytorch3d import _C
class _PackedToPadded(Function): class _PackedToPadded(Function):
""" """
......
...@@ -7,8 +7,8 @@ batches of meshes. ...@@ -7,8 +7,8 @@ batches of meshes.
""" """
import sys import sys
from typing import Tuple, Union from typing import Tuple, Union
import torch
import torch
from pytorch3d.ops.mesh_face_areas_normals import mesh_face_areas_normals from pytorch3d.ops.mesh_face_areas_normals import mesh_face_areas_normals
from pytorch3d.ops.packed_to_padded import packed_to_padded from pytorch3d.ops.packed_to_padded import packed_to_padded
...@@ -53,9 +53,7 @@ def sample_points_from_meshes( ...@@ -53,9 +53,7 @@ def sample_points_from_meshes(
# Only compute samples for non empty meshes # Only compute samples for non empty meshes
with torch.no_grad(): with torch.no_grad():
areas, _ = mesh_face_areas_normals( areas, _ = mesh_face_areas_normals(verts, faces) # Face areas can be zero.
verts, faces
) # Face areas can be zero.
max_faces = meshes.num_faces_per_mesh().max().item() max_faces = meshes.num_faces_per_mesh().max().item()
areas_padded = packed_to_padded( areas_padded = packed_to_padded(
areas, mesh_to_face[meshes.valid], max_faces areas, mesh_to_face[meshes.valid], max_faces
...@@ -80,21 +78,17 @@ def sample_points_from_meshes( ...@@ -80,21 +78,17 @@ def sample_points_from_meshes(
a = v0[sample_face_idxs] # (N, num_samples, 3) a = v0[sample_face_idxs] # (N, num_samples, 3)
b = v1[sample_face_idxs] b = v1[sample_face_idxs]
c = v2[sample_face_idxs] c = v2[sample_face_idxs]
samples[meshes.valid] = ( samples[meshes.valid] = w0[:, :, None] * a + w1[:, :, None] * b + w2[:, :, None] * c
w0[:, :, None] * a + w1[:, :, None] * b + w2[:, :, None] * c
)
if return_normals: if return_normals:
# Intialize normals tensor with fill value 0 for empty meshes. # Intialize normals tensor with fill value 0 for empty meshes.
# Normals for the sampled points are face normals computed from # Normals for the sampled points are face normals computed from
# the vertices of the face in which the sampled point lies. # the vertices of the face in which the sampled point lies.
normals = torch.zeros( normals = torch.zeros((num_meshes, num_samples, 3), device=meshes.device)
(num_meshes, num_samples, 3), device=meshes.device
)
vert_normals = (v1 - v0).cross(v2 - v1, dim=1) vert_normals = (v1 - v0).cross(v2 - v1, dim=1)
vert_normals = vert_normals / vert_normals.norm( vert_normals = vert_normals / vert_normals.norm(dim=1, p=2, keepdim=True).clamp(
dim=1, p=2, keepdim=True min=sys.float_info.epsilon
).clamp(min=sys.float_info.epsilon) )
vert_normals = vert_normals[sample_face_idxs] vert_normals = vert_normals[sample_face_idxs]
normals[meshes.valid] = vert_normals normals[meshes.valid] = vert_normals
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from pytorch3d.structures import Meshes from pytorch3d.structures import Meshes
...@@ -193,16 +192,12 @@ class SubdivideMeshes(nn.Module): ...@@ -193,16 +192,12 @@ class SubdivideMeshes(nn.Module):
edges = meshes[0].edges_packed() edges = meshes[0].edges_packed()
# The set of faces is the same across the different meshes. # The set of faces is the same across the different meshes.
new_faces = self._subdivided_faces.view(1, -1, 3).expand( new_faces = self._subdivided_faces.view(1, -1, 3).expand(self._N, -1, -1)
self._N, -1, -1
)
# Add one new vertex at the midpoint of each edge by taking the average # Add one new vertex at the midpoint of each edge by taking the average
# of the vertices that form each edge. # of the vertices that form each edge.
new_verts = verts[:, edges].mean(dim=2) new_verts = verts[:, edges].mean(dim=2)
new_verts = torch.cat( new_verts = torch.cat([verts, new_verts], dim=1) # (sum(V_n)+sum(E_n), 3)
[verts, new_verts], dim=1
) # (sum(V_n)+sum(E_n), 3)
new_feats = None new_feats = None
# Calculate features for new vertices. # Calculate features for new vertices.
...@@ -212,15 +207,11 @@ class SubdivideMeshes(nn.Module): ...@@ -212,15 +207,11 @@ class SubdivideMeshes(nn.Module):
# padded, i.e. (N*V, D) to (N, V, D). # padded, i.e. (N*V, D) to (N, V, D).
feats = feats.view(verts.size(0), verts.size(1), feats.size(1)) feats = feats.view(verts.size(0), verts.size(1), feats.size(1))
if feats.dim() != 3: if feats.dim() != 3:
raise ValueError( raise ValueError("features need to be of shape (N, V, D) or (N*V, D)")
"features need to be of shape (N, V, D) or (N*V, D)"
)
# Take average of the features at the vertices that form each edge. # Take average of the features at the vertices that form each edge.
new_feats = feats[:, edges].mean(dim=2) new_feats = feats[:, edges].mean(dim=2)
new_feats = torch.cat( new_feats = torch.cat([feats, new_feats], dim=1) # (sum(V_n)+sum(E_n), 3)
[feats, new_feats], dim=1
) # (sum(V_n)+sum(E_n), 3)
new_meshes = Meshes(verts=new_verts, faces=new_faces) new_meshes = Meshes(verts=new_verts, faces=new_faces)
...@@ -270,9 +261,7 @@ class SubdivideMeshes(nn.Module): ...@@ -270,9 +261,7 @@ class SubdivideMeshes(nn.Module):
) # (sum(V_n)+sum(E_n),) ) # (sum(V_n)+sum(E_n),)
verts_ordered_idx_init = torch.zeros( verts_ordered_idx_init = torch.zeros(
new_verts_per_mesh.sum(), new_verts_per_mesh.sum(), dtype=torch.int64, device=meshes.device
dtype=torch.int64,
device=meshes.device,
) # (sum(V_n)+sum(E_n),) ) # (sum(V_n)+sum(E_n),)
# Reassign vertex indices so that existing and new vertices for each # Reassign vertex indices so that existing and new vertices for each
...@@ -288,9 +277,7 @@ class SubdivideMeshes(nn.Module): ...@@ -288,9 +277,7 @@ class SubdivideMeshes(nn.Module):
# Calculate the indices needed to group the existing and new faces # Calculate the indices needed to group the existing and new faces
# for each mesh. # for each mesh.
face_sort_idx = create_faces_index( face_sort_idx = create_faces_index(num_faces_per_mesh, device=meshes.device)
num_faces_per_mesh, device=meshes.device
)
# Reorder the faces to sequentially group existing and new faces # Reorder the faces to sequentially group existing and new faces
# for each mesh. # for each mesh.
...@@ -361,9 +348,7 @@ def create_verts_index(verts_per_mesh, edges_per_mesh, device=None): ...@@ -361,9 +348,7 @@ def create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
E = edges_per_mesh.sum() # e.g. 21 E = edges_per_mesh.sum() # e.g. 21
verts_per_mesh_cumsum = verts_per_mesh.cumsum(dim=0) # (N,) e.g. (4, 9, 15) verts_per_mesh_cumsum = verts_per_mesh.cumsum(dim=0) # (N,) e.g. (4, 9, 15)
edges_per_mesh_cumsum = edges_per_mesh.cumsum( edges_per_mesh_cumsum = edges_per_mesh.cumsum(dim=0) # (N,) e.g. (5, 12, 21)
dim=0
) # (N,) e.g. (5, 12, 21)
v_to_e_idx = verts_per_mesh_cumsum.clone() v_to_e_idx = verts_per_mesh_cumsum.clone()
...@@ -373,9 +358,7 @@ def create_verts_index(verts_per_mesh, edges_per_mesh, device=None): ...@@ -373,9 +358,7 @@ def create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
] # e.g. (4, 9, 15) + (0, 5, 12) = (4, 14, 27) ] # e.g. (4, 9, 15) + (0, 5, 12) = (4, 14, 27)
# vertex to edge offset. # vertex to edge offset.
v_to_e_offset = ( v_to_e_offset = V - verts_per_mesh_cumsum # e.g. 15 - (4, 9, 15) = (11, 6, 0)
V - verts_per_mesh_cumsum
) # e.g. 15 - (4, 9, 15) = (11, 6, 0)
v_to_e_offset[1:] += edges_per_mesh_cumsum[ v_to_e_offset[1:] += edges_per_mesh_cumsum[
:-1 :-1
] # e.g. (11, 6, 0) + (0, 5, 12) = (11, 11, 12) ] # e.g. (11, 6, 0) + (0, 5, 12) = (11, 11, 12)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment