Commit dbf06b50 authored by facebook-github-bot's avatar facebook-github-bot
Browse files

Initial commit

fbshipit-source-id: ad58e416e3ceeca85fae0583308968d04e78fe0d
parents
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch
def mesh_laplacian_smoothing(meshes, method: str = "uniform"):
r"""
Computes the laplacian smoothing objective for a batch of meshes.
This function supports three variants of Laplacian smoothing,
namely with uniform weights("uniform"), with cotangent weights ("cot"),
and cotangent cuvature ("cotcurv").For more details read [1, 2].
Args:
meshes: Meshes object with a batch of meshes.
method: str specifying the method for the laplacian.
Returns:
loss: Average laplacian smoothing loss across the batch.
Returns 0 if meshes contains no meshes or all empty meshes.
Consider a mesh M = (V, F), with verts of shape Nx3 and faces of shape Mx3.
The Laplacian matrix L is a NxN tensor such that LV gives a tensor of vectors:
for a uniform Laplacian, LuV[i] points to the centroid of its neighboring
vertices, a cotangent Laplacian LcV[i] is known to be an approximation of
the surface normal, while the curvature variant LckV[i] scales the normals
by the discrete mean curvature. For vertex i, assume S[i] is the set of
neighboring vertices to i, a_ij and b_ij are the "outside" angles in the
two triangles connecting vertex v_i and its neighboring vertex v_j
for j in S[i], as seen in the diagram below.
.. code-block:: python
a_ij
/\
/ \
/ \
/ \
v_i /________\ v_j
\ /
\ /
\ /
\ /
\/
b_ij
The definition of the Laplacian is LV[i] = sum_j w_ij (v_j - v_i)
For the uniform variant, w_ij = 1 / |S[i]|
For the cotangent variant,
w_ij = (cot a_ij + cot b_ij) / (sum_k cot a_ik + cot b_ik)
For the cotangent curvature, w_ij = (cot a_ij + cot b_ij) / (4 A[i])
where A[i] is the sum of the areas of all triangles containing vertex v_i.
There is a nice trigonometry identity to compute cotangents. Consider a triangle
with side lengths A, B, C and angles a, b, c.
.. code-block:: python
c
/|\
/ | \
/ | \
B / H| \ A
/ | \
/ | \
/a_____|_____b\
C
Then cot a = (B^2 + C^2 - A^2) / 4 * area
We know that area = CH/2, and by the law of cosines we have
A^2 = B^2 + C^2 - 2BC cos a => B^2 + C^2 - A^2 = 2BC cos a
Putting these together, we get:
B^2 + C^2 - A^2 2BC cos a
_______________ = _________ = (B/H) cos a = cos a / sin a = cot a
4 * area 2CH
[1] Desbrun et al, "Implicit fairing of irregular meshes using diffusion
and curvature flow", SIGGRAPH 1999.
[2] Nealan et al, "Laplacian Mesh Optimization", Graphite 2006.
"""
if meshes.isempty():
return torch.tensor(
[0.0], dtype=torch.float32, device=meshes.device, requires_grad=True
)
N = len(meshes)
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
num_verts_per_mesh = meshes.num_verts_per_mesh() # (N,)
verts_packed_idx = meshes.verts_packed_to_mesh_idx() # (sum(V_n),)
weights = num_verts_per_mesh.gather(0, verts_packed_idx) # (sum(V_n),)
weights = 1.0 / weights.float()
# We don't want to backprop through the computation of the Laplacian;
# just treat it as a magic constant matrix that is used to transform
# verts into normals
with torch.no_grad():
if method == "uniform":
L = meshes.laplacian_packed()
elif method in ["cot", "cotcurv"]:
L, inv_areas = laplacian_cot(meshes)
if method == "cot":
norm_w = torch.sparse.sum(L, dim=1).to_dense().view(-1, 1)
idx = norm_w > 0
norm_w[idx] = 1.0 / norm_w[idx]
else:
norm_w = 0.25 * inv_areas
else:
raise ValueError("Method should be one of {uniform, cot, cotcurv}")
if method == "uniform":
loss = L.mm(verts_packed)
elif method == "cot":
loss = L.mm(verts_packed) * norm_w - verts_packed
elif method == "cotcurv":
loss = (L.mm(verts_packed) - verts_packed) * norm_w
loss = loss.norm(dim=1)
loss = loss * weights
return loss.sum() / N
def laplacian_cot(meshes):
"""
Returns the Laplacian matrix with cotangent weights and the inverse of the
face areas.
Args:
meshes: Meshes object with a batch of meshes.
Returns:
2-element tuple containing
- **L**: FloatTensor of shape (V,V) for the Laplacian matrix (V = sum(V_n))
Here, L[i, j] = cot a_ij + cot b_ij iff (i, j) is an edge in meshes.
See the description above for more clarity.
- **inv_areas**: FloatTensor of shape (V,) containing the inverse of sum of
face areas containing each vertex
"""
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
faces_packed = meshes.faces_packed() # (sum(F_n), 3)
# V = sum(V_n), F = sum(F_n)
V, F = verts_packed.shape[0], faces_packed.shape[0]
face_verts = verts_packed[faces_packed]
v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2]
# Side lengths of each triangle, of shape (sum(F_n),)
# A is the side opposite v1, B is opposite v2, and C is opposite v3
A = (v1 - v2).norm(dim=1)
B = (v0 - v2).norm(dim=1)
C = (v0 - v1).norm(dim=1)
# Area of each triangle (with Heron's formula); shape is (sum(F_n),)
s = 0.5 * (A + B + C)
# note that the area can be negative (close to 0) causing nans after sqrt()
# we clip it to a small positive value
area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=1e-12).sqrt()
# Compute cotangents of angles, of shape (sum(F_n), 3)
A2, B2, C2 = A * A, B * B, C * C
cota = (B2 + C2 - A2) / area
cotb = (A2 + C2 - B2) / area
cotc = (A2 + B2 - C2) / area
cot = torch.stack([cota, cotb, cotc], dim=1)
cot /= 4.0
# Construct a sparse matrix by basically doing:
# L[v1, v2] = cota
# L[v2, v0] = cotb
# L[v0, v1] = cotc
ii = faces_packed[:, [1, 2, 0]]
jj = faces_packed[:, [2, 0, 1]]
idx = torch.stack([ii, jj], dim=0).view(2, F * 3)
L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V))
# Make it symmetric; this means we are also setting
# L[v2, v1] = cota
# L[v0, v2] = cotb
# L[v1, v0] = cotc
L += L.t()
# For each vertex, compute the sum of areas for triangles containing it.
idx = faces_packed.view(-1)
inv_areas = torch.zeros(V, dtype=torch.float32, device=meshes.device)
val = torch.stack([area] * 3, dim=1).view(-1)
inv_areas.scatter_add_(0, idx, val)
idx = inv_areas > 0
inv_areas[idx] = 1.0 / inv_areas[idx]
inv_areas = inv_areas.view(-1, 1)
return L, inv_areas
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from itertools import islice
import torch
def mesh_normal_consistency(meshes):
r"""
Computes the normal consistency of each mesh in meshes.
We compute the normal consistency for each pair of neighboring faces.
If e = (v0, v1) is the connecting edge of two neighboring faces f0 and f1,
then the normal consistency between f0 and f1
.. code-block:: python
a
/\
/ \
/ f0 \
/ \
v0 /____e___\ v1
\ /
\ /
\ f1 /
\ /
\/
b
The normal consistency is
.. code-block:: python
nc(f0, f1) = 1 - cos(n0, n1)
where cos(n0, n1) = n0^n1 / ||n0|| / ||n1|| is the cosine of the angle
between the normals n0 and n1, and
n0 = (v1 - v0) x (a - v0)
n1 = - (v1 - v0) x (b - v0) = (b - v0) x (v1 - v0)
This means that if nc(f0, f1) = 0 then n0 and n1 point to the same
direction, while if nc(f0, f1) = 2 then n0 and n1 point opposite direction.
.. note::
For well-constructed meshes the assumption that only two faces share an
edge is true. This assumption could make the implementation easier and faster.
This implementation does not follow this assumption. All the faces sharing e,
which can be any in number, are discovered.
Args:
meshes: Meshes object with a batch of meshes.
Returns:
loss: Average normal consistency across the batch.
Returns 0 if meshes contains no meshes or all empty meshes.
"""
if meshes.isempty():
return torch.tensor(
[0.0], dtype=torch.float32, device=meshes.device, requires_grad=True
)
N = len(meshes)
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
faces_packed = meshes.faces_packed() # (sum(F_n), 3)
edges_packed = meshes.edges_packed() # (sum(E_n), 2)
verts_packed_to_mesh_idx = meshes.verts_packed_to_mesh_idx() # (sum(V_n),)
face_to_edge = meshes.faces_packed_to_edges_packed() # (sum(F_n), 3)
E = edges_packed.shape[0] # sum(E_n)
F = faces_packed.shape[0] # sum(F_n)
# We don't want gradients for the following operation. The goal is to
# find for each edge e all the vertices associated with e. In the example above,
# the vertices associated with e are (v0, v1, a, b), i.e. points on e (=v0, v1)
# and points connected on faces to e (=a, b).
with torch.no_grad():
edge_idx = face_to_edge.reshape(F * 3) # (3 * F,) indexes into edges
vert_idx = (
faces_packed.view(1, F, 3)
.expand(3, F, 3)
.transpose(0, 1)
.reshape(3 * F, 3)
)
edge_idx, edge_sort_idx = edge_idx.sort()
vert_idx = vert_idx[edge_sort_idx]
# In well constructed meshes each edge is shared by precisely 2 faces
# However, in many meshes, this assumption is not always satisfied.
# We want to find all faces that share an edge, a number which can
# vary and which depends on the topology.
# In particular, we find the vertices not on the edge on the shared faces.
# In the example above, we want to associate edge e with vertices a and b.
# This operation is done more efficiently in cpu with lists.
# TODO(gkioxari) find a better way to do this.
# edge_idx represents the index of the edge for each vertex. We can count
# the number of vertices which are associated with each edge.
# There can be a different number for each edge.
edge_num = edge_idx.bincount(minlength=E)
# Create pairs of vertices associated to e. We generate a list of lists:
# each list has the indices of the vertices which are opposite to one edge.
# The length of the list for each edge will vary.
vert_edge_pair_idx = split_list(
list(range(edge_idx.shape[0])), edge_num.tolist()
)
# For each list find all combinations of pairs in the list. This represents
# all pairs of vertices which are opposite to the same edge.
vert_edge_pair_idx = [
[e[i], e[j]]
for e in vert_edge_pair_idx
for i in range(len(e) - 1)
for j in range(1, len(e))
if i != j
]
vert_edge_pair_idx = torch.tensor(
vert_edge_pair_idx, device=meshes.device, dtype=torch.int64
)
v0_idx = edges_packed[edge_idx, 0]
v0 = verts_packed[v0_idx]
v1_idx = edges_packed[edge_idx, 1]
v1 = verts_packed[v1_idx]
# two of the following cross products are zeros as they are cross product
# with either (v1-v0)x(v1-v0) or (v1-v0)x(v0-v0)
n_temp0 = (v1 - v0).cross(verts_packed[vert_idx[:, 0]] - v0, dim=1)
n_temp1 = (v1 - v0).cross(verts_packed[vert_idx[:, 1]] - v0, dim=1)
n_temp2 = (v1 - v0).cross(verts_packed[vert_idx[:, 2]] - v0, dim=1)
n = n_temp0 + n_temp1 + n_temp2
n0 = n[vert_edge_pair_idx[:, 0]]
n1 = -n[vert_edge_pair_idx[:, 1]]
loss = 1 - torch.cosine_similarity(n0, n1, dim=1)
verts_packed_to_mesh_idx = verts_packed_to_mesh_idx[vert_idx[:, 0]]
verts_packed_to_mesh_idx = verts_packed_to_mesh_idx[
vert_edge_pair_idx[:, 0]
]
num_normals = verts_packed_to_mesh_idx.bincount(minlength=N)
weights = 1.0 / num_normals[verts_packed_to_mesh_idx].float()
loss = loss * weights
return loss.sum() / N
def split_list(input, length_to_split):
inputt = iter(input)
return [list(islice(inputt, elem)) for elem in length_to_split]
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .cubify import cubify
from .graph_conv import GraphConv
from .nearest_neighbor_points import nn_points_idx
from .sample_points_from_meshes import sample_points_from_meshes
from .subdivide_meshes import SubdivideMeshes
from .vert_align import vert_align
__all__ = [k for k in globals().keys() if not k.startswith("_")]
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn.functional as F
from pytorch3d.structures import Meshes
def unravel_index(idx, dims) -> torch.Tensor:
r"""
Equivalent to np.unravel_index
Args:
idx: A LongTensor whose elements are indices into the
flattened version of an array of dimensions dims.
dims: The shape of the array to be indexed.
Implemented only for dims=(N, H, W, D)
"""
if len(dims) != 4:
raise ValueError("Expects a 4-element list.")
N, H, W, D = dims
n = torch.div(idx, H * W * D)
h = torch.div(idx - n * H * W * D, W * D)
w = torch.div(idx - n * H * W * D - h * W * D, D)
d = idx - n * H * W * D - h * W * D - w * D
return torch.stack((n, h, w, d), dim=1)
def ravel_index(idx, dims) -> torch.Tensor:
"""
Computes the linear index in an array of shape dims.
It performs the reverse functionality of unravel_index
Args:
idx: A LongTensor of shape (N, 3). Each row corresponds to indices into an
array of dimensions dims.
dims: The shape of the array to be indexed.
Implemented only for dims=(H, W, D)
"""
if len(dims) != 3:
raise ValueError("Expects a 3-element list")
if idx.shape[1] != 3:
raise ValueError("Expects an index tensor of shape Nx3")
H, W, D = dims
linind = idx[:, 0] * W * D + idx[:, 1] * D + idx[:, 2]
return linind
@torch.no_grad()
def cubify(voxels, thresh, device=None) -> Meshes:
r"""
Converts a voxel to a mesh by replacing each occupied voxel with a cube
consisting of 12 faces and 8 vertices. Shared vertices are merged, and
internal faces are removed.
Args:
voxels: A FloatTensor of shape (N, D, H, W) containing occupancy probabilities.
thresh: A scalar threshold. If a voxel occupancy is larger than
thresh, the voxel is considered occupied.
Returns:
meshes: A Meshes object of the corresponding meshes.
"""
if device is None:
device = voxels.device
if len(voxels) == 0:
return Meshes(verts=[], faces=[])
N, D, H, W = voxels.size()
# vertices corresponding to a unit cube: 8x3
cube_verts = torch.tensor(
[
[0, 0, 0],
[0, 0, 1],
[0, 1, 0],
[0, 1, 1],
[1, 0, 0],
[1, 0, 1],
[1, 1, 0],
[1, 1, 1],
],
dtype=torch.int64,
device=device,
)
# faces corresponding to a unit cube: 12x3
cube_faces = torch.tensor(
[
[0, 1, 2],
[1, 3, 2], # left face: 0, 1
[2, 3, 6],
[3, 7, 6], # bottom face: 2, 3
[0, 2, 6],
[0, 6, 4], # front face: 4, 5
[0, 5, 1],
[0, 4, 5], # up face: 6, 7
[6, 7, 5],
[6, 5, 4], # right face: 8, 9
[1, 7, 3],
[1, 5, 7], # back face: 10, 11
],
dtype=torch.int64,
device=device,
)
wx = torch.tensor([0.5, 0.5], device=device).view(1, 1, 1, 1, 2)
wy = torch.tensor([0.5, 0.5], device=device).view(1, 1, 1, 2, 1)
wz = torch.tensor([0.5, 0.5], device=device).view(1, 1, 2, 1, 1)
voxelt = voxels.ge(thresh).float()
# N x 1 x D x H x W
voxelt = voxelt.view(N, 1, D, H, W)
# N x 1 x (D-1) x (H-1) x (W-1)
voxelt_x = F.conv3d(voxelt, wx).gt(0.5).float()
voxelt_y = F.conv3d(voxelt, wy).gt(0.5).float()
voxelt_z = F.conv3d(voxelt, wz).gt(0.5).float()
# 12 x N x 1 x D x H x W
faces_idx = torch.ones((cube_faces.size(0), N, 1, D, H, W), device=device)
# add left face
faces_idx[0, :, :, :, :, 1:] = 1 - voxelt_x
faces_idx[1, :, :, :, :, 1:] = 1 - voxelt_x
# add bottom face
faces_idx[2, :, :, :, :-1, :] = 1 - voxelt_y
faces_idx[3, :, :, :, :-1, :] = 1 - voxelt_y
# add front face
faces_idx[4, :, :, 1:, :, :] = 1 - voxelt_z
faces_idx[5, :, :, 1:, :, :] = 1 - voxelt_z
# add up face
faces_idx[6, :, :, :, 1:, :] = 1 - voxelt_y
faces_idx[7, :, :, :, 1:, :] = 1 - voxelt_y
# add right face
faces_idx[8, :, :, :, :, :-1] = 1 - voxelt_x
faces_idx[9, :, :, :, :, :-1] = 1 - voxelt_x
# add back face
faces_idx[10, :, :, :-1, :, :] = 1 - voxelt_z
faces_idx[11, :, :, :-1, :, :] = 1 - voxelt_z
faces_idx *= voxelt
# N x H x W x D x 12
faces_idx = faces_idx.permute(1, 2, 4, 5, 3, 0).squeeze(1)
# (NHWD) x 12
faces_idx = faces_idx.contiguous()
faces_idx = faces_idx.view(-1, cube_faces.size(0))
# boolean to linear index
# NF x 2
linind = torch.nonzero(faces_idx)
# NF x 4
nyxz = unravel_index(linind[:, 0], (N, H, W, D))
# NF x 3: faces
faces = torch.index_select(cube_faces, 0, linind[:, 1])
grid_faces = []
for d in range(cube_faces.size(1)):
# NF x 3
xyz = torch.index_select(cube_verts, 0, faces[:, d])
permute_idx = torch.tensor([1, 0, 2], device=device)
yxz = torch.index_select(xyz, 1, permute_idx)
yxz += nyxz[:, 1:]
# NF x 1
temp = ravel_index(yxz, (H + 1, W + 1, D + 1))
grid_faces.append(temp)
# NF x 3
grid_faces = torch.stack(grid_faces, dim=1)
y, x, z = torch.meshgrid(
torch.arange(H + 1), torch.arange(W + 1), torch.arange(D + 1)
)
y = y.to(device=device, dtype=torch.float32)
y = y * 2.0 / (H - 1.0) - 1.0
x = x.to(device=device, dtype=torch.float32)
x = x * 2.0 / (W - 1.0) - 1.0
z = z.to(device=device, dtype=torch.float32)
z = z * 2.0 / (D - 1.0) - 1.0
# ((H+1)(W+1)(D+1)) x 3
grid_verts = torch.stack((x, y, z), dim=3).view(-1, 3)
if len(nyxz) == 0:
verts_list = [torch.tensor([], dtype=torch.float32, device=device)] * N
faces_list = [torch.tensor([], dtype=torch.int64, device=device)] * N
return Meshes(verts=verts_list, faces=faces_list)
num_verts = grid_verts.size(0)
grid_faces += nyxz[:, 0].view(-1, 1) * num_verts
idleverts = torch.ones(num_verts * N, dtype=torch.uint8, device=device)
idleverts.scatter_(0, grid_faces.flatten(), 0)
grid_faces -= nyxz[:, 0].view(-1, 1) * num_verts
split_size = torch.bincount(nyxz[:, 0], minlength=N)
faces_list = list(torch.split(grid_faces, split_size.tolist(), 0))
idleverts = idleverts.view(N, num_verts)
idlenum = idleverts.cumsum(1)
verts_list = [
grid_verts.index_select(0, (idleverts[n] == 0).nonzero()[:, 0])
for n in range(N)
]
faces_list = [
nface - idlenum[n][nface] for n, nface in enumerate(faces_list)
]
return Meshes(verts=verts_list, faces=faces_list)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from pytorch3d import _C
class GraphConv(nn.Module):
"""A single graph convolution layer."""
def __init__(
self,
input_dim: int,
output_dim: int,
init: str = "normal",
directed: bool = False,
):
"""
Args:
input_dim: Number of input features per vertex.
output_dim: Number of output features per vertex.
init: Weight initialization method. Can be one of ['zero', 'normal'].
directed: Bool indicating if edges in the graph are directed.
"""
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.directed = directed
self.w0 = nn.Linear(input_dim, output_dim)
self.w1 = nn.Linear(input_dim, output_dim)
if init == "normal":
nn.init.normal_(self.w0.weight, mean=0, std=0.01)
nn.init.normal_(self.w1.weight, mean=0, std=0.01)
self.w0.bias.data.zero_()
self.w1.bias.data.zero_()
elif init == "zero":
self.w0.weight.data.zero_()
self.w1.weight.data.zero_()
else:
raise ValueError('Invalid GraphConv initialization "%s"' % init)
def forward(self, verts, edges):
"""
Args:
verts: FloatTensor of shape (V, input_dim) where V is the number of
vertices and input_dim is the number of input features
per vertex. input_dim has to match the input_dim specified
in __init__.
edges: LongTensor of shape (E, 2) where E is the number of edges
where each edge has the indices of the two vertices which
form the edge.
Returns:
out: FloatTensor of shape (V, output_dim) where output_dim is the
number of output features per vertex.
"""
if verts.is_cuda != edges.is_cuda:
raise ValueError(
"verts and edges tensors must be on the same device."
)
if verts.shape[0] == 0:
# empty graph.
return verts.sum() * 0.0
verts_w0 = self.w0(verts) # (V, output_dim)
verts_w1 = self.w1(verts) # (V, output_dim)
if torch.cuda.is_available() and verts.is_cuda and edges.is_cuda:
neighbor_sums = gather_scatter(verts_w1, edges, self.directed)
else:
neighbor_sums = gather_scatter_python(
verts_w1, edges, self.directed
) # (V, output_dim)
# Add neighbor features to each vertex's features.
out = verts_w0 + neighbor_sums
return out
def __repr__(self):
Din, Dout, directed = self.input_dim, self.output_dim, self.directed
return "GraphConv(%d -> %d, directed=%r)" % (Din, Dout, directed)
def gather_scatter_python(input, edges, directed: bool = False):
"""
Python implementation of gather_scatter for aggregating features of
neighbor nodes in a graph.
Given a directed graph: v0 -> v1 -> v2 the updated feature for v1 depends
on v2 in order to be consistent with Morris et al. AAAI 2019
(https://arxiv.org/abs/1810.02244). This only affects
directed graphs; for undirected graphs v1 will depend on both v0 and v2,
no matter which way the edges are physically stored.
Args:
input: Tensor of shape (num_vertices, input_dim).
edges: Tensor of edge indices of shape (num_edges, 2).
directed: bool indicating if edges are directed.
Returns:
output: Tensor of same shape as input.
"""
if not (input.dim() == 2):
raise ValueError("input can only have 2 dimensions.")
if not (edges.dim() == 2):
raise ValueError("edges can only have 2 dimensions.")
if not (edges.shape[1] == 2):
raise ValueError("edges must be of shape (num_edges, 2).")
num_vertices, input_feature_dim = input.shape
num_edges = edges.shape[0]
output = torch.zeros_like(input)
idx0 = edges[:, 0].view(num_edges, 1).expand(num_edges, input_feature_dim)
idx1 = edges[:, 1].view(num_edges, 1).expand(num_edges, input_feature_dim)
output = output.scatter_add(0, idx0, input.gather(0, idx1))
if not directed:
output = output.scatter_add(0, idx1, input.gather(0, idx0))
return output
class GatherScatter(Function):
"""
Torch autograd Function wrapper for gather_scatter C++/CUDA implementations.
"""
@staticmethod
def forward(ctx, input, edges, directed=False):
"""
Args:
ctx: Context object used to calculate gradients.
input: Tensor of shape (num_vertices, input_dim)
edges: Tensor of edge indices of shape (num_edges, 2)
directed: Bool indicating if edges are directed.
Returns:
output: Tensor of same shape as input.
"""
if not (input.dim() == 2):
raise ValueError("input can only have 2 dimensions.")
if not (edges.dim() == 2):
raise ValueError("edges can only have 2 dimensions.")
if not (edges.shape[1] == 2):
raise ValueError("edges must be of shape (num_edges, 2).")
if not (input.dtype == torch.float32):
raise ValueError("input has to be of type torch.float32.")
ctx.directed = directed
input, edges = input.contiguous(), edges.contiguous()
ctx.save_for_backward(edges)
backward = False
output = _C.gather_scatter(input, edges, directed, backward)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
grad_output = grad_output.contiguous()
edges = ctx.saved_tensors[0]
directed = ctx.directed
backward = True
grad_input = _C.gather_scatter(grad_output, edges, directed, backward)
grad_edges = None
grad_directed = None
return grad_input, grad_edges, grad_directed
gather_scatter = GatherScatter.apply
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
from pytorch3d import _C
def nn_points_idx(p1, p2, p2_normals=None) -> torch.Tensor:
"""
Compute the coordinates of nearest neighbors in pointcloud p2 to points in p1.
Args:
p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
containing P1 points of dimension D.
p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
containing P2 points of dimension D.
p2_normals: [optional] FloatTensor of shape (N, P2, D) giving
normals for p2. Default: None.
Returns:
3-element tuple containing
- **p1_nn_points**: FloatTensor of shape (N, P1, D) where
p1_neighbors[n, i] is the point in p2[n] which is
the nearest neighbor to p1[n, i].
- **p1_nn_idx**: LongTensor of shape (N, P1) giving the indices of
the neighbors.
- **p1_nn_normals**: Normal vectors for each point in p1_neighbors;
only returned if p2_normals is passed
else return [].
"""
N, P1, D = p1.shape
with torch.no_grad():
p1_nn_idx = _C.nn_points_idx(
p1.contiguous(), p2.contiguous()
) # (N, P1)
p1_nn_idx_expanded = p1_nn_idx.view(N, P1, 1).expand(N, P1, D)
p1_nn_points = p2.gather(1, p1_nn_idx_expanded)
if p2_normals is None:
p1_nn_normals = []
else:
if p2_normals.shape != p2.shape:
raise ValueError("p2_normals has incorrect shape.")
p1_nn_normals = p2_normals.gather(1, p1_nn_idx_expanded)
return p1_nn_points, p1_nn_idx, p1_nn_normals
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""
This module implements utility functions for sampling points from
batches of meshes.
"""
import sys
from typing import Tuple, Union
import torch
from pytorch3d import _C
def sample_points_from_meshes(
meshes, num_samples: int = 10000, return_normals: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Convert a batch of meshes to a pointcloud by uniformly sampling points on
the surface of the mesh with probability proportional to the face area.
Args:
meshes: A Meshes object with a batch of N meshes.
num_samples: Integer giving the number of point samples per mesh.
return_normals: If True, return normals for the sampled points.
eps: (float) used to clamp the norm of the normals to avoid dividing by 0.
Returns:
2-element tuple containing
- **samples**: FloatTensor of shape (N, num_samples, 3) giving the
coordinates of sampled points for each mesh in the batch. For empty
meshes the corresponding row in the samples array will be filled with 0.
- **normals**: FloatTensor of shape (N, num_samples, 3) giving a normal vector
to each sampled point. Only returned if return_normals is True.
For empty meshes the corresponding row in the normals array will
be filled with 0.
"""
if meshes.isempty():
raise ValueError("Meshes are empty.")
verts = meshes.verts_packed()
faces = meshes.faces_packed()
mesh_to_face = meshes.mesh_to_faces_packed_first_idx()
num_meshes = len(meshes)
num_valid_meshes = torch.sum(meshes.valid) # Non empty meshes.
# Intialize samples tensor with fill value 0 for empty meshes.
samples = torch.zeros((num_meshes, num_samples, 3), device=meshes.device)
# Only compute samples for non empty meshes
with torch.no_grad():
areas, _ = _C.face_areas_normals(
verts, faces
) # Face areas can be zero.
max_faces = meshes.num_faces_per_mesh().max().item()
areas_padded = _C.packed_to_padded_tensor(
areas, mesh_to_face[meshes.valid], max_faces
) # (N, F)
# TODO (gkioxari) Confirm multinomial bug is not present with real data.
sample_face_idxs = areas_padded.multinomial(
num_samples, replacement=True
) # (N, num_samples)
sample_face_idxs += mesh_to_face[meshes.valid].view(num_valid_meshes, 1)
# Get the vertex coordinates of the sampled faces.
face_verts = verts[faces.long()]
v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2]
# Randomly generate barycentric coords.
w0, w1, w2 = _rand_barycentric_coords(
num_valid_meshes, num_samples, verts.dtype, verts.device
)
# Use the barycentric coords to get a point on each sampled face.
a = v0[sample_face_idxs] # (N, num_samples, 3)
b = v1[sample_face_idxs]
c = v2[sample_face_idxs]
samples[meshes.valid] = (
w0[:, :, None] * a + w1[:, :, None] * b + w2[:, :, None] * c
)
if return_normals:
# Intialize normals tensor with fill value 0 for empty meshes.
# Normals for the sampled points are face normals computed from
# the vertices of the face in which the sampled point lies.
normals = torch.zeros(
(num_meshes, num_samples, 3), device=meshes.device
)
vert_normals = (v1 - v0).cross(v2 - v1, dim=1)
vert_normals = vert_normals / vert_normals.norm(
dim=1, p=2, keepdim=True
).clamp(min=sys.float_info.epsilon)
vert_normals = vert_normals[sample_face_idxs]
normals[meshes.valid] = vert_normals
return samples, normals
else:
return samples
def _rand_barycentric_coords(
size1, size2, dtype, device
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Helper function to generate random barycentric coordinates which are uniformly
distributed over a triangle.
Args:
size1, size2: The number of coordinates generated will be size1*size2.
Output tensors will each be of shape (size1, size2).
dtype: Datatype to generate.
device: A torch.device object on which the outputs will be allocated.
Returns:
w0, w1, w2: Tensors of shape (size1, size2) giving random barycentric
coordinates
"""
uv = torch.rand(2, size1, size2, dtype=dtype, device=device)
u, v = uv[0], uv[1]
u_sqrt = u.sqrt()
w0 = 1.0 - u_sqrt
w1 = u_sqrt * (1.0 - v)
w2 = u_sqrt * v
return w0, w1, w2
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn as nn
from pytorch3d.structures import Meshes
class SubdivideMeshes(nn.Module):
"""
Subdivide a triangle mesh by adding a new vertex at the center of each edge
and dividing each face into four new faces. Vectors of vertex
attributes can also be subdivided by averaging the values of the attributes
at the two vertices which form each edge. This implementation
preserves face orientation - if the vertices of a face are all ordered
counter-clockwise, then the faces in the subdivided meshes will also have
their vertices ordered counter-clockwise.
If meshes is provided as an input, the initializer performs the relatively
expensive computation of determining the new face indices. This one-time
computation can be reused for all meshes with the same face topology
but different vertex positions.
"""
def __init__(self, meshes=None):
"""
Args:
meshes: Meshes object or None. If a meshes object is provided,
the first mesh is used to compute the new faces of the
subdivided topology which can be reused for meshes with
the same input topology.
"""
super(SubdivideMeshes, self).__init__()
self.precomputed = False
self._N = -1
if meshes is not None:
# This computation is on indices, so gradients do not need to be
# tracked.
mesh = meshes[0]
with torch.no_grad():
subdivided_faces = self.subdivide_faces(mesh)
if subdivided_faces.shape[1] != 3:
raise ValueError("faces can only have three vertices")
self.register_buffer("_subdivided_faces", subdivided_faces)
self.precomputed = True
def subdivide_faces(self, meshes):
r"""
Args:
meshes: a Meshes object.
Returns:
subdivided_faces_packed: (4*sum(F_n), 3) shape LongTensor of
original and new faces.
Refer to pytorch3d.structures.meshes.py for more details on packed
representations of faces.
Each face is split into 4 faces e.g. Input face
::
v0
/\
/ \
/ \
e1 / \ e0
/ \
/ \
/ \
/______________\
v2 e2 v1
faces_packed = [[0, 1, 2]]
faces_packed_to_edges_packed = [[2, 1, 0]]
`faces_packed_to_edges_packed` is used to represent all the new
vertex indices corresponding to the mid-points of edges in the mesh.
The actual vertex coordinates will be computed in the forward function.
To get the indices of the new vertices, offset
`faces_packed_to_edges_packed` by the total number of vertices.
::
faces_packed_to_edges_packed = [[2, 1, 0]] + 3 = [[5, 4, 3]]
e.g. subdivided face
::
v0
/\
/ \
/ f0 \
v4 /______\ v3
/\ /\
/ \ f3 / \
/ f2 \ / f1 \
/______\/______\
v2 v5 v1
f0 = [0, 3, 4]
f1 = [1, 5, 3]
f2 = [2, 4, 5]
f3 = [5, 4, 3]
"""
verts_packed = meshes.verts_packed()
with torch.no_grad():
faces_packed = meshes.faces_packed()
faces_packed_to_edges_packed = meshes.faces_packed_to_edges_packed()
faces_packed_to_edges_packed += verts_packed.shape[0]
f0 = torch.stack(
[
faces_packed[:, 0],
faces_packed_to_edges_packed[:, 2],
faces_packed_to_edges_packed[:, 1],
],
dim=1,
)
f1 = torch.stack(
[
faces_packed[:, 1],
faces_packed_to_edges_packed[:, 0],
faces_packed_to_edges_packed[:, 2],
],
dim=1,
)
f2 = torch.stack(
[
faces_packed[:, 2],
faces_packed_to_edges_packed[:, 1],
faces_packed_to_edges_packed[:, 0],
],
dim=1,
)
f3 = faces_packed_to_edges_packed
subdivided_faces_packed = torch.cat(
[f0, f1, f2, f3], dim=0
) # (4*sum(F_n), 3)
return subdivided_faces_packed
def forward(self, meshes, feats=None):
"""
Subdivide a batch of meshes by adding a new vertex on each edge, and
dividing each face into four new faces. New meshes contains two types
of vertices:
1) Vertices that appear in the input meshes.
Data for these vertices are copied from the input meshes.
2) New vertices at the midpoint of each edge.
Data for these vertices is the average of the data for the two
vertices that make up the edge.
Args:
meshes: Meshes object representing a batch of meshes.
feats: Per-vertex features to be subdivided along with the verts.
Should be parallel to the packed vert representation of the
input meshes; so it should have shape (V, D) where V is the
total number of verts in the input meshes. Default: None.
Returns:
2-element tuple containing
- **new_meshes**: Meshes object of a batch of subdivided meshes.
- **new_feats**: (optional) Tensor of subdivided feats, parallel to the
(packed) vertices of the subdivided meshes. Only returned
if feats is not None.
"""
self._N = len(meshes)
if self.precomputed:
return self.subdivide_homogeneous(meshes, feats)
else:
return self.subdivide_heterogenerous(meshes, feats)
def subdivide_homogeneous(self, meshes, feats=None):
"""
Subdivide verts (and optionally features) of a batch of meshes
where each mesh has the same topology of faces. The subdivided faces
are precomputed in the initializer.
Args:
meshes: Meshes object representing a batch of meshes.
feats: Per-vertex features to be subdivided along with the verts.
Returns:
2-element tuple containing
- **new_meshes**: Meshes object of a batch of subdivided meshes.
- **new_feats**: (optional) Tensor of subdivided feats, parallel to the
(packed) vertices of the subdivided meshes. Only returned
if feats is not None.
"""
verts = meshes.verts_padded() # (N, V, D)
edges = meshes[0].edges_packed()
# The set of faces is the same across the different meshes.
new_faces = self._subdivided_faces.view(1, -1, 3).expand(
self._N, -1, -1
)
# Add one new vertex at the midpoint of each edge by taking the average
# of the vertices that form each edge.
new_verts = verts[:, edges].mean(dim=2)
new_verts = torch.cat(
[verts, new_verts], dim=1
) # (sum(V_n)+sum(E_n), 3)
new_feats = None
# Calculate features for new vertices.
if feats is not None:
if feats.dim() == 2:
# feats is in packed format, transform it from packed to
# padded, i.e. (N*V, D) to (N, V, D).
feats = feats.view(verts.size(0), verts.size(1), feats.size(1))
if feats.dim() != 3:
raise ValueError(
"features need to be of shape (N, V, D) or (N*V, D)"
)
# Take average of the features at the vertices that form each edge.
new_feats = feats[:, edges].mean(dim=2)
new_feats = torch.cat(
[feats, new_feats], dim=1
) # (sum(V_n)+sum(E_n), 3)
new_meshes = Meshes(verts=new_verts, faces=new_faces)
if feats is None:
return new_meshes
else:
return new_meshes, new_feats
def subdivide_heterogenerous(self, meshes, feats=None):
"""
Subdivide faces, verts (and optionally features) of a batch of meshes
where each mesh can have different face topologies.
Args:
meshes: Meshes object representing a batch of meshes.
feats: Per-vertex features to be subdivided along with the verts.
Returns:
2-element tuple containing
- **new_meshes**: Meshes object of a batch of subdivided meshes.
- **new_feats**: (optional) Tensor of subdivided feats, parallel to the
(packed) vertices of the subdivided meshes. Only returned
if feats is not None.
"""
# The computation of new faces is on face indices, so gradients do not
# need to be tracked.
verts = meshes.verts_packed()
with torch.no_grad():
new_faces = self.subdivide_faces(meshes)
edges = meshes.edges_packed()
face_to_mesh_idx = meshes.faces_packed_to_mesh_idx()
edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx()
num_edges_per_mesh = edge_to_mesh_idx.bincount(minlength=self._N)
num_verts_per_mesh = meshes.num_verts_per_mesh()
num_faces_per_mesh = meshes.num_faces_per_mesh()
# Add one new vertex at the midpoint of each edge.
new_verts_per_mesh = num_verts_per_mesh + num_edges_per_mesh # (N,)
new_face_to_mesh_idx = torch.cat([face_to_mesh_idx] * 4, dim=0)
# Calculate the indices needed to group the new and existing verts
# for each mesh.
verts_sort_idx = create_verts_index(
num_verts_per_mesh, num_edges_per_mesh, meshes.device
) # (sum(V_n)+sum(E_n),)
verts_ordered_idx_init = torch.zeros(
new_verts_per_mesh.sum(),
dtype=torch.int64,
device=meshes.device,
) # (sum(V_n)+sum(E_n),)
# Reassign vertex indices so that existing and new vertices for each
# mesh are sequential.
verts_ordered_idx = verts_ordered_idx_init.scatter_add(
0,
verts_sort_idx,
torch.arange(new_verts_per_mesh.sum(), device=meshes.device),
)
# Retrieve vertex indices for each face.
new_faces = verts_ordered_idx[new_faces]
# Calculate the indices needed to group the existing and new faces
# for each mesh.
face_sort_idx = create_faces_index(
num_faces_per_mesh, device=meshes.device
)
# Reorder the faces to sequentially group existing and new faces
# for each mesh.
new_faces = new_faces[face_sort_idx]
new_face_to_mesh_idx = new_face_to_mesh_idx[face_sort_idx]
new_faces_per_mesh = new_face_to_mesh_idx.bincount(
minlength=self._N
) # (sum(F_n)*4)
# Add one new vertex at the midpoint of each edge by taking the average
# of the verts that form each edge.
new_verts = verts[edges].mean(dim=1)
new_verts = torch.cat([verts, new_verts], dim=0)
# Reorder the verts to sequentially group existing and new verts for
# each mesh.
new_verts = new_verts[verts_sort_idx]
if feats is not None:
new_feats = feats[edges].mean(dim=1)
new_feats = torch.cat([feats, new_feats], dim=0)
new_feats = new_feats[verts_sort_idx]
verts_list = list(new_verts.split(new_verts_per_mesh.tolist(), 0))
faces_list = list(new_faces.split(new_faces_per_mesh.tolist(), 0))
new_verts_per_mesh_cumsum = torch.cat(
[
new_verts_per_mesh.new_full(size=(1,), fill_value=0.0),
new_verts_per_mesh.cumsum(0)[:-1],
],
dim=0,
)
faces_list = [
faces_list[n] - new_verts_per_mesh_cumsum[n] for n in range(self._N)
]
if feats is not None:
feats_list = new_feats.split(new_verts_per_mesh.tolist(), 0)
new_meshes = Meshes(verts=verts_list, faces=faces_list)
if feats is None:
return new_meshes
else:
new_feats = torch.cat(feats_list, dim=0)
return new_meshes, new_feats
def create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
"""
Helper function to group the vertex indices for each mesh. New vertices are
stacked at the end of the original verts tensor, so in order to have
sequential packing, the verts tensor needs to be reordered so that the
vertices corresponding to each mesh are grouped together.
Args:
verts_per_mesh: Tensor of shape (N,) giving the number of vertices
in each mesh in the batch where N is the batch size.
edges_per_mesh: Tensor of shape (N,) giving the number of edges
in each mesh in the batch
Returns:
verts_idx: A tensor with vert indices for each mesh ordered sequentially
by mesh index.
"""
# e.g. verts_per_mesh = (4, 5, 6)
# e.g. edges_per_mesh = (5, 7, 9)
V = verts_per_mesh.sum() # e.g. 15
E = edges_per_mesh.sum() # e.g. 21
verts_per_mesh_cumsum = verts_per_mesh.cumsum(dim=0) # (N,) e.g. (4, 9, 15)
edges_per_mesh_cumsum = edges_per_mesh.cumsum(
dim=0
) # (N,) e.g. (5, 12, 21)
v_to_e_idx = verts_per_mesh_cumsum.clone()
# vertex to edge index.
v_to_e_idx[1:] += edges_per_mesh_cumsum[
:-1
] # e.g. (4, 9, 15) + (0, 5, 12) = (4, 14, 27)
# vertex to edge offset.
v_to_e_offset = (
V - verts_per_mesh_cumsum
) # e.g. 15 - (4, 9, 15) = (11, 6, 0)
v_to_e_offset[1:] += edges_per_mesh_cumsum[
:-1
] # e.g. (11, 6, 0) + (0, 5, 12) = (11, 11, 12)
e_to_v_idx = (
verts_per_mesh_cumsum[:-1] + edges_per_mesh_cumsum[:-1]
) # (4, 9) + (5, 12) = (9, 21)
e_to_v_offset = (
verts_per_mesh_cumsum[:-1] - edges_per_mesh_cumsum[:-1] - V
) # (4, 9) - (5, 12) - 15 = (-16, -18)
# Add one new vertex per edge.
idx_diffs = torch.ones(V + E, device=device, dtype=torch.int64) # (36,)
idx_diffs[v_to_e_idx] += v_to_e_offset
idx_diffs[e_to_v_idx] += e_to_v_offset
# e.g.
# [
# 1, 1, 1, 1, 12, 1, 1, 1, 1,
# -15, 1, 1, 1, 1, 12, 1, 1, 1, 1, 1, 1,
# -17, 1, 1, 1, 1, 1, 13, 1, 1, 1, 1, 1, 1, 1
# ]
verts_idx = idx_diffs.cumsum(dim=0) - 1
# e.g.
# [
# 0, 1, 2, 3, 15, 16, 17, 18, 19, --> mesh 0
# 4, 5, 6, 7, 8, 20, 21, 22, 23, 24, 25, 26, --> mesh 1
# 9, 10, 11, 12, 13, 14, 27, 28, 29, 30, 31, 32, 33, 34, 35 --> mesh 2
# ]
# where for mesh 0, [0, 1, 2, 3] are the indices of the existing verts, and
# [15, 16, 17, 18, 19] are the indices of the new verts after subdivision.
return verts_idx
def create_faces_index(faces_per_mesh, device=None):
"""
Helper function to group the faces indices for each mesh. New faces are
stacked at the end of the original faces tensor, so in order to have
sequential packing, the faces tensor needs to be reordered to that faces
corresponding to each mesh are grouped together.
Args:
faces_per_mesh: Tensor of shape (N,) giving the number of faces
in each mesh in the batch where N is the batch size.
Returns:
faces_idx: A tensor with face indices for each mesh ordered sequentially
by mesh index.
"""
# e.g. faces_per_mesh = [2, 5, 3]
F = faces_per_mesh.sum() # e.g. 10
faces_per_mesh_cumsum = faces_per_mesh.cumsum(dim=0) # (N,) e.g. (2, 7, 10)
switch1_idx = faces_per_mesh_cumsum.clone()
switch1_idx[1:] += (
3 * faces_per_mesh_cumsum[:-1]
) # e.g. (2, 7, 10) + (0, 6, 21) = (2, 13, 31)
switch2_idx = 2 * faces_per_mesh_cumsum # e.g. (4, 14, 20)
switch2_idx[1:] += (
2 * faces_per_mesh_cumsum[:-1]
) # e.g. (4, 14, 20) + (0, 4, 14) = (4, 18, 34)
switch3_idx = 3 * faces_per_mesh_cumsum # e.g. (6, 21, 30)
switch3_idx[1:] += faces_per_mesh_cumsum[
:-1
] # e.g. (6, 21, 30) + (0, 2, 7) = (6, 23, 37)
switch4_idx = 4 * faces_per_mesh_cumsum[:-1] # e.g. (8, 28)
switch123_offset = F - faces_per_mesh # e.g. (8, 5, 7)
idx_diffs = torch.ones(4 * F, device=device, dtype=torch.int64)
idx_diffs[switch1_idx] += switch123_offset
idx_diffs[switch2_idx] += switch123_offset
idx_diffs[switch3_idx] += switch123_offset
idx_diffs[switch4_idx] -= 3 * F
# e.g
# [
# 1, 1, 9, 1, 9, 1, 9, 1, -> mesh 0
# -29, 1, 1, 1, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, -> mesh 1
# -29, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1 -> mesh 2
# ]
faces_idx = idx_diffs.cumsum(dim=0) - 1
# e.g.
# [
# 0, 1, 10, 11, 20, 21, 30, 31,
# 2, 3, 4, 5, 6, 12, 13, 14, 15, 16, 22, 23, 24, 25, 26, 32, 33, 34, 35, 36,
# 7, 8, 9, 17, 18, 19, 27, 28, 29, 37, 38, 39
# ]
# where for mesh 0, [0, 1] are the indices of the existing faces, and
# [10, 11, 20, 21, 30, 31] are the indices of the new faces after subdivision.
return faces_idx
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn.functional as F
def vert_align(
feats,
verts,
return_packed: bool = False,
interp_mode: str = "bilinear",
padding_mode: str = "zeros",
align_corners: bool = True,
) -> torch.Tensor:
"""
Sample vertex features from a feature map. This operation is called
"perceptual feaure pooling" in [1] or "vert align" in [2].
[1] Wang et al, "Pixel2Mesh: Generating 3D Mesh Models from Single
RGB Images", ECCV 2018.
[2] Gkioxari et al, "Mesh R-CNN", ICCV 2019
Args:
feats: FloatTensor of shape (N, C, H, W) representing image features
from which to sample or a list of features each with potentially
different C, H or W dimensions.
verts: FloatTensor of shape (N, V, 3) or an object (e.g. Meshes) with
'verts_padded' as an attribute giving the (x, y, z) vertex positions
for which to sample. (x, y) verts should be normalized such that
(-1, -1) corresponds to top-left and (+1, +1) to bottom-right
location in the input feature map.
return_packed: (bool) Indicates whether to return packed features
interp_mode: (str) Specifies how to interpolate features.
('bilinear' or 'nearest')
padding_mode: (str) Specifies how to handle vertices outside of the
[-1, 1] range. ('zeros', 'reflection', or 'border')
align_corners (bool): Geometrically, we consider the pixels of the
input as squares rather than points.
If set to ``True``, the extrema (``-1`` and ``1``) are considered as
referring to the center points of the input's corner pixels. If set
to ``False``, they are instead considered as referring to the corner
points of the input's corner pixels, making the sampling more
resolution agnostic. Default: ``True``
Returns:
feats_sampled: FloatTensor of shape (N, V, C) giving sampled features for
each vertex. If feats is a list, we return concatentated
features in axis=2 of shape (N, V, sum(C_n)) where
C_n = feats[n].shape[1]. If return_packed = True, the
features are transformed to a packed representation
of shape (sum(V), C)
"""
if torch.is_tensor(verts):
if verts.dim() != 3:
raise ValueError("verts tensor should be 3 dimensional")
grid = verts
elif hasattr(verts, "verts_padded"):
grid = verts.verts_padded()
else:
raise ValueError(
"verts must be a tensor or have a `verts_padded` attribute"
)
grid = grid[:, None, :, :2] # (N, 1, V, 2)
if torch.is_tensor(feats):
feats = [feats]
for feat in feats:
if feat.dim() != 4:
raise ValueError("feats must have shape (N, C, H, W)")
if grid.shape[0] != feat.shape[0]:
raise ValueError("inconsistent batch dimension")
feats_sampled = []
for feat in feats:
feat_sampled = F.grid_sample(
feat,
grid,
mode=interp_mode,
padding_mode=padding_mode,
align_corners=align_corners,
) # (N, C, 1, V)
feat_sampled = feat_sampled.squeeze(dim=2).transpose(1, 2) # (N, V, C)
feats_sampled.append(feat_sampled)
feats_sampled = torch.cat(feats_sampled, dim=2) # (N, V, sum(C))
if return_packed:
# flatten the first two dimensions: (N*V, C)
feats_sampled = feats_sampled.view(-1, feats_sampled.shape[-1])
if hasattr(verts, "verts_padded_to_packed_idx"):
idx = (
verts.verts_padded_to_packed_idx()
.view(-1, 1)
.expand(-1, feats_sampled.shape[-1])
)
feats_sampled = feats_sampled.gather(0, idx) # (sum(V), C)
return feats_sampled
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .blending import (
BlendParams,
hard_rgb_blend,
sigmoid_alpha_blend,
softmax_rgb_blend,
)
from .cameras import (
OpenGLOrthographicCameras,
OpenGLPerspectiveCameras,
camera_position_from_spherical_angles,
get_world_to_view_transform,
look_at_rotation,
look_at_view_transform,
)
from .lighting import DirectionalLights, PointLights, diffuse, specular
from .materials import Materials
from .mesh import (
GouradShader,
MeshRasterizer,
MeshRenderer,
PhongShader,
RasterizationSettings,
SilhouetteShader,
TexturedPhongShader,
gourad_shading,
interpolate_face_attributes,
interpolate_texture_map,
interpolate_vertex_colors,
phong_shading,
rasterize_meshes,
)
from .utils import TensorProperties, convert_to_tensors_and_broadcast
__all__ = [k for k in globals().keys() if not k.startswith("_")]
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import numpy as np
from typing import NamedTuple
import torch
# Example functions for blending the top K colors per pixel using the outputs
# from rasterization.
# NOTE: All blending function should return an RGBA image per batch element
# Data class to store blending params with defaults
class BlendParams(NamedTuple):
sigma: float = 1e-4
gamma: float = 1e-4
background_color = (1.0, 1.0, 1.0)
def hard_rgb_blend(colors, fragments) -> torch.Tensor:
"""
Naive blending of top K faces to return an RGBA image
- **RGB** - choose color of the closest point i.e. K=0
- **A** - 1.0
Args:
colors: (N, H, W, K, 3) RGB color for each of the top K faces per pixel.
fragments: the outputs of rasterization. From this we use
- pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
of the faces (in the packed representation) which
overlap each pixel in the image. This is used to
determine the output shape.
Returns:
RGBA pixel_colors: (N, H, W, 4)
"""
N, H, W, K = fragments.pix_to_face.shape
device = fragments.pix_to_face.device
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=device)
pixel_colors[..., :3] = colors[..., 0, :]
return torch.flip(pixel_colors, [1])
def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
"""
Silhouette blending to return an RGBA image
- **RGB** - choose color of the closest point.
- **A** - blend based on the 2D distance based probability map [0].
Args:
colors: (N, H, W, K, 3) RGB color for each of the top K faces per pixel.
fragments: the outputs of rasterization. From this we use
- pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
of the faces (in the packed representation) which
overlap each pixel in the image.
- dists: FloatTensor of shape (N, H, W, K) specifying
the 2D euclidean distance from the center of each pixel
to each of the top K overlapping faces.
Returns:
RGBA pixel_colors: (N, H, W, 4)
[0] Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based
3D Reasoning', ICCV 2019
"""
N, H, W, K = fragments.pix_to_face.shape
pixel_colors = torch.ones(
(N, H, W, 4), dtype=colors.dtype, device=colors.device
)
mask = fragments.pix_to_face >= 0
# The distance is negative if a pixel is inside a face and positive outside
# the face. Therefore use -1.0 * fragments.dists to get the correct sign.
prob = torch.sigmoid(-fragments.dists / blend_params.sigma) * mask
# The cumulative product ensures that alpha will be 1 if at least 1 face
# fully covers the pixel as for that face prob will be 1.0
# TODO: investigate why torch.cumprod backwards is very slow for large
# values of K.
# Temporarily replace this with exp(sum(log))) using the fact that
# a*b = exp(log(a*b)) = exp(log(a) + log(b))
# alpha = 1.0 - torch.cumprod((1.0 - prob), dim=-1)[..., -1]
alpha = 1.0 - torch.exp(torch.log((1.0 - prob)).sum(dim=-1))
pixel_colors[..., :3] = colors[..., 0, :] # Hard assign for RGB
pixel_colors[..., 3] = alpha
pixel_colors = torch.clamp(pixel_colors, min=0, max=1.0)
return torch.flip(pixel_colors, [1])
def softmax_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
"""
RGB and alpha channel blending to return an RGBA image based on the method
proposed in [0]
- **RGB** - blend the colors based on the 2D distance based probability map and
relative z distances.
- **A** - blend based on the 2D distance based probability map.
Args:
colors: (N, H, W, K, 3) RGB color for each of the top K faces per pixel.
fragments: namedtuple with outputs of rasterization. We use properties
- pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices
of the faces (in the packed representation) which
overlap each pixel in the image.
- dists: FloatTensor of shape (N, H, W, K) specifying
the 2D euclidean distance from the center of each pixel
to each of the top K overlapping faces.
- zbuf: FloatTensor of shape (N, H, W, K) specifying
the interpolated depth from each pixel to to each of the
top K overlapping faces.
blend_params: instance of BlendParams dataclass containing properties
- sigma: float, parameter which controls the width of the sigmoid
function used to calculate the 2D distance based probability.
Sigma controls the sharpness of the edges of the shape.
- gamma: float, parameter which controls the scaling of the
exponential function used to control the opacity of the color.
- background_color: (3) element list/tuple/torch.Tensor specifying
the RGB values for the background color.
Returns:
RGBA pixel_colors: (N, H, W, 4)
[0] Shichen Liu et al, 'Soft Rasterizer: A Differentiable Renderer for
Image-based 3D Reasoning'
"""
N, H, W, K = fragments.pix_to_face.shape
device = fragments.pix_to_face.device
pix_colors = torch.ones(
(N, H, W, 4), dtype=colors.dtype, device=colors.device
)
background = blend_params.background_color
if not torch.is_tensor(background):
background = torch.tensor(
background, dtype=torch.float32, device=device
)
# Background color
delta = np.exp(1e-10 / blend_params.gamma) * 1e-10
delta = torch.tensor(delta, device=device)
# Near and far clipping planes.
# TODO: add zfar/znear as input params.
zfar = 100.0
znear = 1.0
# Mask for padded pixels.
mask = fragments.pix_to_face >= 0
# Sigmoid probability map based on the distance of the pixel to the face.
prob_map = torch.sigmoid(-fragments.dists / blend_params.sigma) * mask
# The cumulative product ensures that alpha will be 1 if at least 1 face
# fully covers the pixel as for that face prob will be 1.0
# TODO: investigate why torch.cumprod backwards is very slow for large
# values of K.
# Temporarily replace this with exp(sum(log))) using the fact that
# a*b = exp(log(a*b)) = exp(log(a) + log(b))
# alpha = 1.0 - torch.cumprod((1.0 - prob), dim=-1)[..., -1]
alpha = 1.0 - torch.exp(torch.log((1.0 - prob_map)).sum(dim=-1))
# Weights for each face. Adjust the exponential by the max z to prevent
# overflow. zbuf shape (N, H, W, K), find max over K.
# TODO: there may still be some instability in the exponent calculation.
z_inv = (zfar - fragments.zbuf) / (zfar - znear) * mask
z_inv_max = torch.max(z_inv, dim=-1).values[..., None]
weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma)
# Normalize weights.
# weights_num shape: (N, H, W, K). Sum over K and divide through by the sum.
denom = weights_num.sum(dim=-1)[..., None] + delta
weights = weights_num / denom
# Sum: weights * textures + background color
weighted_colors = (weights[..., None] * colors).sum(dim=-2)
weighted_background = (delta / denom) * background
pix_colors[..., :3] = weighted_colors + weighted_background
pix_colors[..., 3] = alpha
# Clamp colors to the range 0-1 and flip y axis.
pix_colors = torch.clamp(pix_colors, min=0, max=1.0)
return torch.flip(pix_colors, [1])
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import math
import numpy as np
from typing import Tuple
import torch
import torch.nn.functional as F
from pytorch3d.transforms import Rotate, Transform3d, Translate
from .utils import TensorProperties, convert_to_tensors_and_broadcast
# Default values for rotation and translation matrices.
r = np.expand_dims(np.eye(3), axis=0) # (1, 3, 3)
t = np.expand_dims(np.zeros(3), axis=0) # (1, 3)
class OpenGLPerspectiveCameras(TensorProperties):
"""
A class which stores a batch of parameters to generate a batch of
projection matrices using the OpenGL convention for a perspective camera.
The extrinsics of the camera (R and T matrices) can also be set in the
initializer or passed in to `get_full_projection_transform` to get
the full transformation from world -> screen.
The `transform_points` method calculates the full world -> screen transform
and then applies it to the input points.
The transforms can also be returned separately as Transform3d objects.
"""
def __init__(
self,
znear=1.0,
zfar=100.0,
aspect_ratio=1.0,
fov=60.0,
degrees: bool = True,
R=r,
T=t,
device="cpu",
):
"""
__init__(self, znear, zfar, aspect_ratio, fov, degrees, R, T, device) -> None # noqa
Args:
znear: near clipping plane of the view frustrum.
zfar: far clipping plane of the view frustrum.
aspect_ratio: ratio of screen_width/screen_height.
fov: field of view angle of the camera.
degrees: bool, set to True if fov is specified in degrees.
R: Rotation matrix of shape (N, 3, 3)
T: Translation matrix of shape (N, 3)
device: torch.device or string
"""
# The initializer formats all inputs to torch tensors and broadcasts
# all the inputs to have the same batch dimension where necessary.
super().__init__(
device=device,
znear=znear,
zfar=zfar,
aspect_ratio=aspect_ratio,
fov=fov,
R=R,
T=T,
)
# No need to convert to tensor or broadcast.
self.degrees = degrees
def get_projection_transform(self, **kwargs) -> Transform3d:
"""
Calculate the OpenGL perpective projection matrix with a symmetric
viewing frustrum. Use column major order.
Args:
**kwargs: parameters for the projection can be passed in as keyword
arguments to override the default values set in `__init__`.
Return:
P: a Transform3d object which represents a batch of projection
matrices of shape (N, 3, 3)
.. code-block:: python
f1 = -(far + near)/(far−near)
f2 = -2*far*near/(far-near)
h1 = (top + bottom)/(top - bottom)
w1 = (right + left)/(right - left)
tanhalffov = tan((fov/2))
s1 = 1/tanhalffov
s2 = 1/(tanhalffov * (aspect_ratio))
P = [
[s1, 0, w1, 0],
[0, s2, h1, 0],
[0, 0, f1, f2],
[0, 0, -1, 0],
]
"""
znear = kwargs.get("znear", self.znear) # pyre-ignore[16]
zfar = kwargs.get("zfar", self.zfar) # pyre-ignore[16]
fov = kwargs.get("fov", self.fov) # pyre-ignore[16]
aspect_ratio = kwargs.get(
"aspect_ratio", self.aspect_ratio
) # pyre-ignore[16]
degrees = kwargs.get("degrees", self.degrees)
P = torch.zeros(
(self._N, 4, 4), device=self.device, dtype=torch.float32
)
ones = torch.ones((self._N), dtype=torch.float32, device=self.device)
if degrees:
fov = (np.pi / 180) * fov
if not torch.is_tensor(fov):
fov = torch.tensor(fov, device=self.device)
tanHalfFov = torch.tan((fov / 2))
top = tanHalfFov * znear
bottom = -top
right = top * aspect_ratio
left = -right
# NOTE: In OpenGL the projection matrix changes the handedness of the
# coordinate frame. i.e the NDC space postive z direction is the
# camera space negative z direction. This is because the sign of the z
# in the projection matrix is set to -1.0.
# In pytorch3d we maintain a right handed coordinate system throughout
# so the so the z sign is 1.0.
z_sign = 1.0
P[:, 0, 0] = 2.0 * znear / (right - left)
P[:, 1, 1] = 2.0 * znear / (top - bottom)
P[:, 0, 2] = (right + left) / (right - left)
P[:, 1, 2] = (top + bottom) / (top - bottom)
P[:, 3, 2] = z_sign * ones
# NOTE: This part of the matrix is for z renormalization in OpenGL
# which maps the z to [-1, 1]. This won't work yet as the torch3d
# rasterizer ignores faces which have z < 0.
# P[:, 2, 2] = z_sign * (far + near) / (far - near)
# P[:, 2, 3] = -2.0 * far * near / (far - near)
# P[:, 3, 2] = z_sign * torch.ones((N))
# NOTE: This maps the z coordinate from [0, 1] where z = 0 if the point
# is at the near clipping plane and z = 1 when the point is at the far
# clipping plane. This replaces the OpenGL z normalization to [-1, 1]
# until rasterization is changed to clip at z = -1.
P[:, 2, 2] = z_sign * zfar / (zfar - znear)
P[:, 2, 3] = -(zfar * znear) / (zfar - znear)
# OpenGL uses column vectors so need to transpose the projection matrix
# as torch3d uses row vectors.
transform = Transform3d(device=self.device)
transform._matrix = P.transpose(1, 2).contiguous()
return transform
def clone(self):
other = OpenGLPerspectiveCameras(device=self.device)
return super().clone(other)
def get_camera_center(self, **kwargs):
"""
Return the 3D location of the camera optical center
in the world coordinates.
Args:
**kwargs: parameters for the camera extrinsics can be passed in
as keyword arguments to override the default values
set in __init__.
Setting T here will update the values set in init as this
value may be needed later on in the rendering pipeline e.g. for
lighting calculations.
Returns:
C: a batch of 3D locations of shape (N, 3) denoting
the locations of the center of each camera in the batch.
"""
w2v_trans = self.get_world_to_view_transform(**kwargs)
P = w2v_trans.inverse().get_matrix()
# the camera center is the translation component (the first 3 elements
# of the last row) of the inverted world-to-view
# transform (4x4 RT matrix)
C = P[:, 3, :3]
return C
def get_world_to_view_transform(self, **kwargs) -> Transform3d:
"""
Return the world-to-view transform.
Args:
**kwargs: parameters for the camera extrinsics can be passed in
as keyword arguments to override the default values
set in __init__.
Setting R and T here will update the values set in init as these
values may be needed later on in the rendering pipeline e.g. for
lighting calculations.
Returns:
T: a Transform3d object which represents a batch of transforms
of shape (N, 3, 3)
"""
self.R = kwargs.get("R", self.R) # pyre-ignore[16]
self.T = kwargs.get("T", self.T) # pyre-ignore[16]
world_to_view_transform = get_world_to_view_transform(
R=self.R, T=self.T
)
return world_to_view_transform
def get_full_projection_transform(self, **kwargs) -> Transform3d:
"""
Return the full world-to-screen transform composing the
world-to-view and view-to-screen transforms.
Args:
**kwargs: parameters for the projection transforms can be passed in
as keyword arguments to override the default values
set in __init__.
Setting R and T here will update the values set in init as these
values may be needed later on in the rendering pipeline e.g. for
lighting calculations.
Returns:
T: a Transform3d object which represents a batch of transforms
of shape (N, 3, 3)
"""
self.R = kwargs.get("R", self.R) # pyre-ignore[16]
self.T = kwargs.get("T", self.T) # pyre-ignore[16]
world_to_view_transform = self.get_world_to_view_transform(
R=self.R, T=self.T
)
view_to_screen_transform = self.get_projection_transform(**kwargs)
return world_to_view_transform.compose(view_to_screen_transform)
def transform_points(self, points, **kwargs) -> torch.Tensor:
"""
Transform input points from world to screen space.
Args:
points: torch tensor of shape (..., 3).
Returns
new_points: transformed points with the same shape as the input.
"""
world_to_screen_transform = self.get_full_projection_transform(**kwargs)
return world_to_screen_transform.transform_points(points)
class OpenGLOrthographicCameras(TensorProperties):
"""
A class which stores a batch of parameters to generate a batch of
transformation matrices using the OpenGL convention for orthographic camera.
"""
def __init__(
self,
znear=1.0,
zfar=100.0,
top=1.0,
bottom=-1.0,
left=-1.0,
right=1.0,
scale_xyz=((1.0, 1.0, 1.0),), # (1, 3)
R=r,
T=t,
device="cpu",
):
"""
__init__(self, znear, zfar, top, bottom, left, right, scale_xyz, R, T, device) -> None # noqa
Args:
znear: near clipping plane of the view frustrum.
zfar: far clipping plane of the view frustrum.
top: position of the top of the screen.
bottom: position of the bottom of the screen.
left: position of the left of the screen.
right: position of the right of the screen.
scale_xyz: scale factors for each axis of shape (N, 3).
R: Rotation matrix of shape (N, 3, 3).
T: Translation of shape (N, 3).
device: torch.device or string.
Only need to set left, right, top, bottom for viewing frustrums
which are non symmetric about the origin.
"""
# The initializer formats all inputs to torch tensors and broadcasts
# all the inputs to have the same batch dimension where necessary.
super().__init__(
device=device,
znear=znear,
zfar=zfar,
top=top,
bottom=bottom,
left=left,
right=right,
scale_xyz=scale_xyz,
R=R,
T=T,
)
def get_projection_transform(self, **kwargs) -> Transform3d:
"""
Calculate the OpenGL orthographic projection matrix.
Use column major order.
Args:
**kwargs: parameters for the projection can be passed in to
override the default values set in __init__.
Return:
P: a Transform3d object which represents a batch of projection
matrices of shape (N, 3, 3)
.. code-block:: python
scale_x = 2/(right - left)
scale_y = 2/(top - bottom)
scale_z = 2/(far-near)
mid_x = (right + left)/(right - left)
mix_y = (top + bottom)/(top - bottom)
mid_z = (far + near)/(far−near)
P = [
[scale_x, 0, 0, -mid_x],
[0, scale_y, 0, -mix_y],
[0, 0, -scale_z, -mid_z],
[0, 0, 0, 1],
]
"""
znear = kwargs.get("znear", self.znear) # pyre-ignore[16]
zfar = kwargs.get("zfar", self.zfar) # pyre-ignore[16]
left = kwargs.get("left", self.left) # pyre-ignore[16]
right = kwargs.get("right", self.right) # pyre-ignore[16]
top = kwargs.get("top", self.top) # pyre-ignore[16]
bottom = kwargs.get("bottom", self.bottom) # pyre-ignore[16]
scale_xyz = kwargs.get("scale_xyz", self.scale_xyz) # pyre-ignore[16]
P = torch.zeros(
(self._N, 4, 4), dtype=torch.float32, device=self.device
)
ones = torch.ones((self._N), dtype=torch.float32, device=self.device)
# NOTE: OpenGL flips handedness of coordinate system between camera
# space and NDC space so z sign is -ve. In PyTorch3d we maintain a
# right handed coordinate system throughout.
z_sign = +1.0
P[:, 0, 0] = (2.0 / (right - left)) * scale_xyz[:, 0]
P[:, 1, 1] = (2.0 / (top - bottom)) * scale_xyz[:, 1]
P[:, 0, 3] = -(right + left) / (right - left)
P[:, 1, 3] = -(top + bottom) / (top - bottom)
P[:, 3, 3] = ones
# NOTE: This maps the z coordinate to the range [0, 1] and replaces the
# the OpenGL z normalization to [-1, 1]
P[:, 2, 2] = z_sign * (1.0 / (zfar - znear)) * scale_xyz[:, 2]
P[:, 2, 3] = -znear / (zfar - znear)
# NOTE: This part of the matrix is for z renormalization in OpenGL.
# The z is mapped to the range [-1, 1] but this won't work yet in
# pytorch3d as the rasterizer ignores faces which have z < 0.
# P[:, 2, 2] = z_sign * (2.0 / (far - near)) * scale[:, 2]
# P[:, 2, 3] = -(far + near) / (far - near)
transform = Transform3d(device=self.device)
transform._matrix = P.transpose(1, 2).contiguous()
return transform
def clone(self):
other = OpenGLOrthographicCameras(device=self.device)
return super().clone(other)
def get_camera_center(self, **kwargs):
"""
Return the 3D location of the camera optical center
in the world coordinates.
Args:
**kwargs: parameters for the camera extrinsics can be passed in
as keyword arguments to override the default values
set in __init__.
Setting T here will update the values set in init as this
value may be needed later on in the rendering pipeline e.g. for
lighting calculations.
Returns:
C: a batch of 3D locations of shape (N, 3) denoting
the locations of the center of each camera in the batch.
"""
w2v_trans = self.get_world_to_view_transform(**kwargs)
P = w2v_trans.inverse().get_matrix()
# The camera center is the translation component (the first 3 elements
# of the last row) of the inverted world-to-view
# transform (4x4 RT matrix).
C = P[:, 3, :3]
return C
def get_world_to_view_transform(self, **kwargs) -> Transform3d:
"""
Return the world-to-view transform.
Args:
**kwargs: parameters for the camera extrinsics can be passed in
as keyword arguments to override the default values
set in __init__.
Setting R and T here will update the values set in init as these
values may be needed later on in the rendering pipeline e.g. for
lighting calculations.
Returns:
T: a Transform3d object which represents a batch of transforms
of shape (N, 3, 3)
"""
self.R = kwargs.get("R", self.R) # pyre-ignore[16]
self.T = kwargs.get("T", self.T) # pyre-ignore[16]
world_to_view_transform = get_world_to_view_transform(
R=self.R, T=self.T
)
return world_to_view_transform
def get_full_projection_transform(self, **kwargs) -> Transform3d:
"""
Return the full world-to-screen transform composing the
world-to-view and view-to-screen transforms.
Args:
**kwargs: parameters for the projection transforms can be passed in
as keyword arguments to override the default values
set in `__init__`.
Setting R and T here will update the values set in init as these
values may be needed later on in the rendering pipeline e.g. for
lighting calculations.
Returns:
T: a Transform3d object which represents a batch of transforms
of shape (N, 3, 3)
"""
self.R = kwargs.get("R", self.R) # pyre-ignore[16]
self.T = kwargs.get("T", self.T) # pyre-ignore[16]
world_to_view_transform = self.get_world_to_view_transform(
R=self.R, T=self.T
)
view_to_screen_transform = self.get_projection_transform(**kwargs)
return world_to_view_transform.compose(view_to_screen_transform)
def transform_points(self, points, **kwargs) -> torch.Tensor:
"""
Transform input points from world to screen space.
Args:
points: torch tensor of shape (..., 3).
Returns
new_points: transformed points with the same shape as the input.
"""
world_to_screen_transform = self.get_full_projection_transform(**kwargs)
return world_to_screen_transform.transform_points(points)
class SfMPerspectiveCameras(TensorProperties):
"""
A class which stores a batch of parameters to generate a batch of
transformation matrices using the multi-view geometry convention for
perspective camera.
"""
def __init__(
self,
focal_length=1.0,
principal_point=((0.0, 0.0),),
R=r,
T=t,
device="cpu",
):
"""
__init__(self, focal_length, principal_point, R, T, device) -> None
Args:
focal_length: Focal length of the camera in world units.
A tensor of shape (N, 1) or (N, 2) for
square and non-square pixels respectively.
principal_point: xy coordinates of the center of
the principal point of the camera in pixels.
A tensor of shape (N, 2).
R: Rotation matrix of shape (N, 3, 3)
T: Translation matrix of shape (N, 3)
device: torch.device or string
"""
# The initializer formats all inputs to torch tensors and broadcasts
# all the inputs to have the same batch dimension where necessary.
super().__init__(
device=device,
focal_length=focal_length,
principal_point=principal_point,
R=R,
T=T,
)
def get_projection_transform(self, **kwargs) -> Transform3d:
"""
Calculate the projection matrix using the
multi-view geometry convention.
Args:
**kwargs: parameters for the projection can be passed in as keyword
arguments to override the default values set in __init__.
Returns:
P: a batch of projection matrices of shape (N, 4, 4)
.. code-block:: python
fx = focal_length[:,0]
fy = focal_length[:,1]
px = principal_point[:,0]
py = principal_point[:,1]
P = [
[fx, 0, 0, px],
[0, fy, 0, py],
[0, 0, 0, 1],
[0, 0, 1, 0],
]
"""
principal_point = kwargs.get(
"principal_point", self.principal_point
) # pyre-ignore[16]
focal_length = kwargs.get(
"focal_length", self.focal_length
) # pyre-ignore[16]
P = _get_sfm_calibration_matrix(
self._N, self.device, focal_length, principal_point, False
)
transform = Transform3d(device=self.device)
transform._matrix = P.transpose(1, 2).contiguous()
return transform
def clone(self):
other = SfMPerspectiveCameras(device=self.device)
return super().clone(other)
def get_camera_center(self, **kwargs):
"""
Return the 3D location of the camera optical center
in the world coordinates.
Args:
**kwargs: parameters for the camera extrinsics can be passed in
as keyword arguments to override the default values
set in __init__.
Setting T here will update the values set in init as this
value may be needed later on in the rendering pipeline e.g. for
lighting calculations.
Returns:
C: a batch of 3D locations of shape (N, 3) denoting
the locations of the center of each camera in the batch.
"""
w2v_trans = self.get_world_to_view_transform(**kwargs)
P = w2v_trans.inverse().get_matrix()
# the camera center is the translation component (the first 3 elements
# of the last row) of the inverted world-to-view
# transform (4x4 RT matrix)
C = P[:, 3, :3]
return C
def get_world_to_view_transform(self, **kwargs) -> Transform3d:
"""
Return the world-to-view transform.
Args:
**kwargs: parameters for the camera extrinsics can be passed in
as keyword arguments to override the default values
set in __init__.
Setting R and T here will update the values set in init as these
values may be needed later on in the rendering pipeline e.g. for
lighting calculations.
Returns:
T: a Transform3d object which represents a batch of transforms
of shape (N, 3, 3)
"""
self.R = kwargs.get("R", self.R) # pyre-ignore[16]
self.T = kwargs.get("T", self.T) # pyre-ignore[16]
world_to_view_transform = get_world_to_view_transform(
R=self.R, T=self.T
)
return world_to_view_transform
def get_full_projection_transform(self, **kwargs) -> Transform3d:
"""
Return the full world-to-screen transform composing the
world-to-view and view-to-screen transforms.
Args:
**kwargs: parameters for the projection transforms can be passed in
as keyword arguments to override the default values
set in __init__.
Setting R and T here will update the values set in init as these
values may be needed later on in the rendering pipeline e.g. for
lighting calculations.
"""
self.R = kwargs.get("R", self.R) # pyre-ignore[16]
self.T = kwargs.get("T", self.T) # pyre-ignore[16]
world_to_view_transform = self.get_world_to_view_transform(
R=self.R, T=self.T
)
view_to_screen_transform = self.get_projection_transform(**kwargs)
return world_to_view_transform.compose(view_to_screen_transform)
def transform_points(self, points, **kwargs) -> torch.Tensor:
"""
Transform input points from world to screen space.
Args:
points: torch tensor of shape (..., 3).
Returns
new_points: transformed points with the same shape as the input.
"""
world_to_screen_transform = self.get_full_projection_transform(**kwargs)
return world_to_screen_transform.transform_points(points)
class SfMOrthographicCameras(TensorProperties):
"""
A class which stores a batch of parameters to generate a batch of
transformation matrices using the multi-view geometry convention for
orthographic camera.
"""
def __init__(
self,
focal_length=1.0,
principal_point=((0.0, 0.0),),
R=r,
T=t,
device="cpu",
):
"""
__init__(self, focal_length, principal_point, R, T, device) -> None
Args:
focal_length: Focal length of the camera in world units.
A tensor of shape (N, 1) or (N, 2) for
square and non-square pixels respectively.
principal_point: xy coordinates of the center of
the principal point of the camera in pixels.
A tensor of shape (N, 2).
R: Rotation matrix of shape (N, 3, 3)
T: Translation matrix of shape (N, 3)
device: torch.device or string
"""
# The initializer formats all inputs to torch tensors and broadcasts
# all the inputs to have the same batch dimension where necessary.
super().__init__(
device=device,
focal_length=focal_length,
principal_point=principal_point,
R=R,
T=T,
)
def get_projection_transform(self, **kwargs) -> Transform3d:
"""
Calculate the projection matrix using
the multi-view geometry convention.
Args:
**kwargs: parameters for the projection can be passed in as keyword
arguments to override the default values set in __init__.
Return:
P: a batch of projection matrices of shape (N, 4, 4)
.. code-block:: python
fx = focal_length[:,0]
fy = focal_length[:,1]
px = principal_point[:,0]
py = principal_point[:,1]
P = [
[fx, 0, 0, px],
[0, fy, 0, py],
[0, 0, 1, 0],
[0, 0, 0, 1],
]
"""
principal_point = kwargs.get(
"principal_point", self.principal_point
) # pyre-ignore[16]
focal_length = kwargs.get(
"focal_length", self.focal_length
) # pyre-ignore[16]
P = _get_sfm_calibration_matrix(
self._N, self.device, focal_length, principal_point, True
)
transform = Transform3d(device=self.device)
transform._matrix = P.transpose(1, 2).contiguous()
return transform
def clone(self):
other = SfMOrthographicCameras(device=self.device)
return super().clone(other)
def get_camera_center(self, **kwargs):
"""
Return the 3D location of the camera optical center
in the world coordinates.
Args:
**kwargs: parameters for the camera extrinsics can be passed in
as keyword arguments to override the default values
set in __init__.
Setting T here will update the values set in init as this
value may be needed later on in the rendering pipeline e.g. for
lighting calculations.
Returns:
C: a batch of 3D locations of shape (N, 3) denoting
the locations of the center of each camera in the batch.
"""
w2v_trans = self.get_world_to_view_transform(**kwargs)
P = w2v_trans.inverse().get_matrix()
# the camera center is the translation component (the first 3 elements
# of the last row) of the inverted world-to-view
# transform (4x4 RT matrix)
C = P[:, 3, :3]
return C
def get_world_to_view_transform(self, **kwargs) -> Transform3d:
"""
Return the world-to-view transform.
Args:
**kwargs: parameters for the camera extrinsics can be passed in
as keyword arguments to override the default values
set in __init__.
Setting R and T here will update the values set in init as these
values may be needed later on in the rendering pipeline e.g. for
lighting calculations.
Returns:
T: a Transform3d object which represents a batch of transforms
of shape (N, 3, 3)
"""
self.R = kwargs.get("R", self.R) # pyre-ignore[16]
self.T = kwargs.get("T", self.T) # pyre-ignore[16]
world_to_view_transform = get_world_to_view_transform(
R=self.R, T=self.T
)
return world_to_view_transform
def get_full_projection_transform(self, **kwargs) -> Transform3d:
"""
Return the full world-to-screen transform composing the
world-to-view and view-to-screen transforms.
Args:
**kwargs: parameters for the projection transforms can be passed in
as keyword arguments to override the default values
set in `__init__`.
Setting R and T here will update the values set in init as these
values may be needed later on in the rendering pipeline e.g. for
lighting calculations.
"""
self.R = kwargs.get("R", self.R) # pyre-ignore[16]
self.T = kwargs.get("T", self.T) # pyre-ignore[16]
world_to_view_transform = self.get_world_to_view_transform(
R=self.R, T=self.T
)
view_to_screen_transform = self.get_projection_transform(**kwargs)
return world_to_view_transform.compose(view_to_screen_transform)
def transform_points(self, points, **kwargs) -> torch.Tensor:
"""
Transform input points from world to screen space.
Args:
points: torch tensor of shape (..., 3).
Returns
new_points: transformed points with the same shape as the input.
"""
world_to_screen_transform = self.get_full_projection_transform(**kwargs)
return world_to_screen_transform.transform_points(points)
# SfMCameras helper
def _get_sfm_calibration_matrix(
N, device, focal_length, principal_point, orthographic: bool
) -> torch.Tensor:
"""
Returns a calibration matrix of a perspective/orthograpic camera.
Args:
N: Number of cameras.
focal_length: Focal length of the camera in world units.
principal_point: xy coordinates of the center of
the principal point of the camera in pixels.
The calibration matrix `K` is set up as follows:
.. code-block:: python
fx = focal_length[:,0]
fy = focal_length[:,1]
px = principal_point[:,0]
py = principal_point[:,1]
for orthographic==True:
K = [
[fx, 0, 0, px],
[0, fy, 0, py],
[0, 0, 1, 0],
[0, 0, 0, 1],
]
else:
K = [
[fx, 0, 0, px],
[0, fy, 0, py],
[0, 0, 0, 1],
[0, 0, 1, 0],
]
Returns:
A calibration matrix `K` of the SfM-conventioned camera
of shape (N, 4, 4).
"""
if not torch.is_tensor(focal_length):
focal_length = torch.tensor(focal_length, device=device)
if len(focal_length.shape) in (0, 1) or focal_length.shape[1] == 1:
fx = fy = focal_length
else:
fx, fy = focal_length.unbind(1)
if not torch.is_tensor(principal_point):
principal_point = torch.tensor(principal_point, device=device)
px, py = principal_point.unbind(1)
K = fx.new_zeros(N, 4, 4)
K[:, 0, 0] = fx
K[:, 1, 1] = fy
K[:, 0, 3] = px
K[:, 1, 3] = py
if orthographic:
K[:, 2, 2] = 1.0
K[:, 3, 3] = 1.0
else:
K[:, 3, 2] = 1.0
K[:, 2, 3] = 1.0
return K
################################################
# Helper functions for world to view transforms
################################################
def get_world_to_view_transform(R=r, T=t) -> Transform3d:
"""
This function returns a Transform3d representing the transformation
matrix to go from world space to view space by applying a rotation and
a translation.
Pytorch3d uses the same convention as Hartley & Zisserman.
I.e., for camera extrinsic parameters R (rotation) and T (translation),
we map a 3D point `X_world` in world coordinates to
a point `X_cam` in camera coordinates with:
`X_cam = X_world R + T`
Args:
R: (N, 3, 3) matrix representing the rotation.
T: (N, 3) matrix representing the translation.
Returns:
a Transform3d object which represents the composed RT transformation.
"""
# TODO: also support the case where RT is specified as one matrix
# of shape (N, 4, 4).
if T.shape[0] != R.shape[0]:
msg = "Expected R, T to have the same batch dimension; got %r, %r"
raise ValueError(msg % (R.shape[0], T.shape[0]))
if T.dim() != 2 or T.shape[1:] != (3,):
msg = "Expected T to have shape (N, 3); got %r"
raise ValueError(msg % repr(T.shape))
if R.dim() != 3 or R.shape[1:] != (3, 3):
msg = "Expected R to have shape (N, 3, 3); got %r"
raise ValueError(msg % R.shape)
# Create a Transform3d object
T = Translate(T, device=T.device)
R = Rotate(R, device=R.device)
return R.compose(T)
def camera_position_from_spherical_angles(
distance, elevation, azimuth, degrees: bool = True, device: str = "cpu"
) -> torch.Tensor:
"""
Calculate the location of the camera based on the distance away from
the target point, the elevation and azimuth angles.
Args:
distance: distance of the camera from the object.
elevation, azimuth: angles.
The inputs distance, elevation and azimuth can be one of the following
- Python scalar
- Torch scalar
- Torch tensor of shape (N) or (1)
degrees: bool, whether the angles are specified in degrees or radians.
device: str or torch.device, device for new tensors to be placed on.
The vectors are broadcast against each other so they all have shape (N, 1).
Returns:
camera_position: (N, 3) xyz location of the camera.
"""
broadcasted_args = convert_to_tensors_and_broadcast(
distance, elevation, azimuth, device=device
)
dist, elev, azim = broadcasted_args
if degrees:
elev = math.pi / 180.0 * elev
azim = math.pi / 180.0 * azim
x = dist * torch.cos(elev) * torch.sin(azim)
y = dist * torch.sin(elev)
z = -dist * torch.cos(elev) * torch.cos(azim)
camera_position = torch.stack([x, y, z], dim=1)
if camera_position.dim() == 0:
camera_position = camera_position.view(1, -1) # add batch dim.
return camera_position.view(-1, 3)
def look_at_rotation(
camera_position, at=((0, 0, 0),), up=((0, 1, 0),), device: str = "cpu"
) -> torch.Tensor:
"""
This function takes a vector 'camera_position' which specifies the location
of the camera in world coordinates and two vectors `at` and `up` which
indicate the position of the object and the up directions of the world
coordinate system respectively. The object is assumed to be centered at
the origin.
The output is a rotation matrix representing the transformation
from world coordinates -> view coordinates.
Args:
camera_position: position of the camera in world coordinates
at: position of the object in world coordinates
up: vector specifying the up direction in the world coordinate frame.
The inputs camera_position, at and up can each be a
- 3 element tuple/list
- torch tensor of shape (1, 3)
- torch tensor of shape (N, 3)
The vectors are broadcast against each other so they all have shape (N, 3).
Returns:
R: (N, 3, 3) batched rotation matrices
"""
# Format input and broadcast
broadcasted_args = convert_to_tensors_and_broadcast(
camera_position, at, up, device=device
)
camera_position, at, up = broadcasted_args
for t, n in zip([camera_position, at, up], ["camera_position", "at", "up"]):
if t.shape[-1] != 3:
msg = "Expected arg %s to have shape (N, 3); got %r"
raise ValueError(msg % (n, t.shape))
z_axis = F.normalize(at - camera_position, eps=1e-5)
x_axis = F.normalize(torch.cross(up, z_axis), eps=1e-5)
y_axis = F.normalize(torch.cross(z_axis, x_axis), eps=1e-5)
R = torch.cat(
(x_axis[:, None, :], y_axis[:, None, :], z_axis[:, None, :]), dim=1
)
return R.transpose(1, 2)
def look_at_view_transform(
dist,
elev,
azim,
degrees: bool = True,
at=((0, 0, 0),), # (1, 3)
up=((0, 1, 0),), # (1, 3)
device="cpu",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This function returns a rotation and translation matrix
to apply the 'Look At' transformation from world -> view coordinates [0].
Args:
dist: distance of the camera from the object
elev: angle in degres or radians. This is the angle between the
vector from the object to the camera, and the horizonal plane.
azim: angle in degrees or radians. The vector from the object to
the camera is projected onto a horizontal plane y = z = 0.
azim is the angle between the projected vector and a
reference vector at (1, 0, 0) on the reference plane.
dist, elem and azim can be of shape (1), (N).
degrees: boolean flag to indicate if the elevation and azimuth
angles are specified in degrees or raidans.
up: the direction of the x axis in the world coordinate system.
at: the position of the object(s) in world coordinates.
up and at can be of shape (1, 3) or (N, 3).
Returns:
2-element tuple containing
- **R**: the rotation to apply to the points to align with the camera.
- **T**: the translation to apply to the points to align with the camera.
References:
[0] https://www.scratchapixel.com
"""
broadcasted_args = convert_to_tensors_and_broadcast(
dist, elev, azim, at, up, device=device
)
dist, elev, azim, at, up = broadcasted_args
C = camera_position_from_spherical_angles(dist, elev, azim, device=device)
R = look_at_rotation(C, at, up, device=device)
T = -torch.bmm(R.transpose(1, 2), C[:, :, None])[:, :, 0]
return R, T
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn.functional as F
from .utils import TensorProperties, convert_to_tensors_and_broadcast
def diffuse(normals, color, direction) -> torch.Tensor:
"""
Calculate the diffuse component of light reflection using Lambert's
cosine law.
Args:
normals: (N, ..., 3) xyz normal vectors. Normals and points are
expected to have the same shape.
color: (1, 3) or (N, 3) RGB color of the diffuse component of the light.
direction: (x,y,z) direction of the light
Returns:
colors: (N, ..., 3), same shape as the input points.
The normals and light direction should be in the same coordinate frame
i.e. if the points have been transformed from world -> view space then
the normals and direction should also be in view space.
NOTE: to use with the packed vertices (i.e. no batch dimension) reformat the
inputs in the following way.
.. code-block:: python
Args:
normals: (P, 3)
color: (N, 3)[batch_idx, :] -> (P, 3)
direction: (N, 3)[batch_idx, :] -> (P, 3)
Returns:
colors: (P, 3)
where batch_idx is of shape (P). For meshes, batch_idx can be:
meshes.verts_packed_to_mesh_idx() or meshes.faces_packed_to_mesh_idx()
depending on whether points refers to the vertex coordinates or
average/interpolated face coordinates.
"""
# TODO: handle multiple directional lights per batch element.
# TODO: handle attentuation.
# Ensure color and location have same batch dimension as normals
normals, color, direction = convert_to_tensors_and_broadcast(
normals, color, direction, device=normals.device
)
# Reshape direction and color so they have all the arbitrary intermediate
# dimensions as normals. Assume first dim = batch dim and last dim = 3.
points_dims = normals.shape[1:-1]
expand_dims = (-1,) + (1,) * len(points_dims) + (3,)
if direction.shape != normals.shape:
direction = direction.view(expand_dims)
if color.shape != normals.shape:
color = color.view(expand_dims)
# Renormalize the normals in case they have been interpolated.
normals = F.normalize(normals, p=2, dim=-1, eps=1e-6)
direction = F.normalize(direction, p=2, dim=-1, eps=1e-6)
angle = F.relu(torch.sum(normals * direction, dim=-1))
return color * angle[..., None]
def specular(
points, normals, direction, color, camera_position, shininess
) -> torch.Tensor:
"""
Calculate the specular component of light reflection.
Args:
points: (N, ..., 3) xyz coordinates of the points.
normals: (N, ..., 3) xyz normal vectors for each point.
color: (N, 3) RGB color of the specular component of the light.
direction: (N, 3) vector direction of the light.
camera_position: (N, 3) The xyz position of the camera.
shininess: (N) The specular exponent of the material.
Returns:
colors: (N, ..., 3), same shape as the input points.
The points, normals, camera_position, and direction should be in the same
coordinate frame i.e. if the points have been transformed from
world -> view space then the normals, camera_position, and light direction
should also be in view space.
To use with a batch of packed points reindex in the following way.
.. code-block:: python::
Args:
points: (P, 3)
normals: (P, 3)
color: (N, 3)[batch_idx] -> (P, 3)
direction: (N, 3)[batch_idx] -> (P, 3)
camera_position: (N, 3)[batch_idx] -> (P, 3)
shininess: (N)[batch_idx] -> (P)
Returns:
colors: (P, 3)
where batch_idx is of shape (P). For meshes batch_idx can be:
meshes.verts_packed_to_mesh_idx() or meshes.faces_packed_to_mesh_idx().
"""
# TODO: handle multiple directional lights
# TODO: attentuate based on inverse squared distance to the light source
if points.shape != normals.shape:
msg = "Expected points and normals to have the same shape: got %r, %r"
raise ValueError(msg % (points.shape, normals.shape))
# Ensure all inputs have same batch dimension as points
matched_tensors = convert_to_tensors_and_broadcast(
points,
color,
direction,
camera_position,
shininess,
device=points.device,
)
_, color, direction, camera_position, shininess = matched_tensors
# Reshape direction and color so they have all the arbitrary intermediate
# dimensions as points. Assume first dim = batch dim and last dim = 3.
points_dims = points.shape[1:-1]
expand_dims = (-1,) + (1,) * len(points_dims)
if direction.shape != normals.shape:
direction = direction.view(expand_dims + (3,))
if color.shape != normals.shape:
color = color.view(expand_dims + (3,))
if camera_position.shape != normals.shape:
camera_position = camera_position.view(expand_dims + (3,))
if shininess.shape != normals.shape:
shininess = shininess.view(expand_dims)
# Renormalize the normals in case they have been interpolated.
normals = F.normalize(normals, p=2, dim=-1, eps=1e-6)
direction = F.normalize(direction, p=2, dim=-1, eps=1e-6)
cos_angle = torch.sum(normals * direction, dim=-1)
# No specular highlights if angle is less than 0.
mask = (cos_angle > 0).to(torch.float32)
# Calculate the specular reflection.
view_direction = camera_position - points
view_direction = F.normalize(view_direction, p=2, dim=-1, eps=1e-6)
reflect_direction = -direction + 2 * (cos_angle[..., None] * normals)
# Cosine of the angle between the reflected light ray and the viewer
alpha = F.relu(torch.sum(view_direction * reflect_direction, dim=-1)) * mask
return color * torch.pow(alpha, shininess)[..., None]
class DirectionalLights(TensorProperties):
def __init__(
self,
ambient_color=((0.5, 0.5, 0.5),),
diffuse_color=((0.3, 0.3, 0.3),),
specular_color=((0.2, 0.2, 0.2),),
direction=((0, 1, 0),),
device: str = "cpu",
):
"""
Args:
ambient_color: RGB color of the ambient component.
diffuse_color: RGB color of the diffuse component.
specular_color: RGB color of the specular component.
direction: (x, y, z) direction vector of the light.
device: torch.device on which the tensors should be located
The inputs can each be
- 3 element tuple/list or list of lists
- torch tensor of shape (1, 3)
- torch tensor of shape (N, 3)
The inputs are broadcast against each other so they all have batch
dimension N.
"""
super().__init__(
device=device,
ambient_color=ambient_color,
diffuse_color=diffuse_color,
specular_color=specular_color,
direction=direction,
)
_validate_light_properties(self)
if self.direction.shape[-1] != 3:
msg = "Expected direction to have shape (N, 3); got %r"
raise ValueError(msg % repr(self.direction.shape))
def clone(self):
other = DirectionalLights(device=self.device)
return super().clone(other)
def diffuse(self, normals, points=None) -> torch.Tensor:
# NOTE: Points is not used but is kept in the args so that the API is
# the same for directional and point lights. The call sites should not
# need to know the light type.
return diffuse(
normals=normals, color=self.diffuse_color, direction=self.direction
)
def specular(
self, normals, points, camera_position, shininess
) -> torch.Tensor:
return specular(
points=points,
normals=normals,
color=self.specular_color,
direction=self.direction,
camera_position=camera_position,
shininess=shininess,
)
class PointLights(TensorProperties):
def __init__(
self,
ambient_color=((0.5, 0.5, 0.5),),
diffuse_color=((0.3, 0.3, 0.3),),
specular_color=((0.2, 0.2, 0.2),),
location=((0, 1, 0),),
device: str = "cpu",
):
"""
Args:
ambient_color: RGB color of the ambient component
diffuse_color: RGB color of the diffuse component
specular_color: RGB color of the specular component
location: xyz position of the light.
device: torch.device on which the tensors should be located
The inputs can each be
- 3 element tuple/list or list of lists
- torch tensor of shape (1, 3)
- torch tensor of shape (N, 3)
The inputs are broadcast against each other so they all have batch
dimension N.
"""
super().__init__(
device=device,
ambient_color=ambient_color,
diffuse_color=diffuse_color,
specular_color=specular_color,
location=location,
)
_validate_light_properties(self)
if self.location.shape[-1] != 3:
msg = "Expected location to have shape (N, 3); got %r"
raise ValueError(msg % repr(self.location.shape))
def clone(self):
other = PointLights(device=self.device)
return super().clone(other)
def diffuse(self, normals, points) -> torch.Tensor:
direction = self.location - points
return diffuse(
normals=normals, color=self.diffuse_color, direction=direction
)
def specular(
self, normals, points, camera_position, shininess
) -> torch.Tensor:
direction = self.location - points
return specular(
points=points,
normals=normals,
color=self.specular_color,
direction=direction,
camera_position=camera_position,
shininess=shininess,
)
def _validate_light_properties(obj):
props = ("ambient_color", "diffuse_color", "specular_color")
for n in props:
t = getattr(obj, n)
if t.shape[-1] != 3:
msg = "Expected %s to have shape (N, 3); got %r"
raise ValueError(msg % (n, t.shape))
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
from .utils import TensorProperties
class Materials(TensorProperties):
"""
A class for storing a batch of material properties. Currently only one
material per batch element is supported.
"""
def __init__(
self,
ambient_color=((1, 1, 1),),
diffuse_color=((1, 1, 1),),
specular_color=((1, 1, 1),),
shininess=64,
device="cpu",
):
"""
Args:
ambient_color: RGB ambient reflectivity of the material
diffuse_color: RGB diffuse reflectivity of the material
specular_color: RGB specular reflectivity of the material
shininess: The specular exponent for the material. This defines
the focus of the specular highlight with a high value
resulting in a concentrated highlight. Shininess values
can range from 0-1000.
device: torch.device or string
ambient_color, diffuse_color and specular_color can be of shape
(1, 3) or (N, 3). shininess can be of shape (1) or (N).
The colors and shininess are broadcast against each other so need to
have either the same batch dimension or batch dimension = 1.
"""
super().__init__(
device=device,
diffuse_color=diffuse_color,
ambient_color=ambient_color,
specular_color=specular_color,
shininess=shininess,
)
for n in ["ambient_color", "diffuse_color", "specular_color"]:
t = getattr(self, n)
if t.shape[-1] != 3:
msg = "Expected %s to have shape (N, 3); got %r"
raise ValueError(msg % (n, t.shape))
if self.shininess.shape != torch.Size([self._N]):
msg = "shininess should have shape (N); got %r"
raise ValueError(msg % repr(self.shininess.shape))
def clone(self):
other = Materials(device=self.device)
return super().clone(other)
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .rasterize_meshes import rasterize_meshes
from .rasterizer import MeshRasterizer, RasterizationSettings
from .renderer import MeshRenderer
from .shader import (
GouradShader,
PhongShader,
SilhouetteShader,
TexturedPhongShader,
)
from .shading import gourad_shading, phong_shading
from .texturing import ( # isort: skip
interpolate_face_attributes,
interpolate_texture_map,
interpolate_vertex_colors,
)
__all__ = [k for k in globals().keys() if not k.startswith("_")]
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import numpy as np
from typing import Optional
import torch
from pytorch3d import _C
# TODO make the epsilon user configurable
kEpsilon = 1e-30
def rasterize_meshes(
meshes,
image_size: int = 256,
blur_radius: float = 0.0,
faces_per_pixel: int = 8,
bin_size: Optional[int] = None,
max_faces_per_bin: Optional[int] = None,
perspective_correct: bool = False,
):
"""
Rasterize a batch of meshes given the shape of the desired output image.
Each mesh is rasterized onto a separate image of shape
(image_size, image_size).
Args:
meshes: A Meshes object representing a batch of meshes, batch size N.
image_size: Size in pixels of the output raster image for each mesh
in the batch. Assumes square images.
blur_radius: Float distance in the range [0, 2] used to expand the face
bounding boxes for rasterization. Setting blur radius
results in blurred edges around the shape instead of a
hard boundary. Set to 0 for no blur.
faces_per_pixel (Optional): Number of faces to save per pixel, returning
the nearest faces_per_pixel points along the z-axis.
bin_size: Size of bins to use for coarse-to-fine rasterization. Setting
bin_size=0 uses naive rasterization; setting bin_size=None attempts to
set it heuristically based on the shape of the input. This should not
affect the output, but can affect the speed of the forward pass.
faces_per_bin: Only applicable when using coarse-to-fine rasterization
(bin_size > 0); this is the maxiumum number of faces allowed within each
bin. If more than this many faces actually fall into a bin, an error
will be raised. This should not affect the output values, but can affect
the memory usage in the forward pass.
perspective_correct: Whether to apply perspective correction when computing
barycentric coordinates for pixels.
Returns:
4-element tuple containing
- **pix_to_face**: LongTensor of shape
(N, image_size, image_size, faces_per_pixel)
giving the indices of the nearest faces at each pixel,
sorted in ascending z-order.
Concretely ``pix_to_face[n, y, x, k] = f`` means that
``faces_verts[f]`` is the kth closest face (in the z-direction)
to pixel (y, x). Pixels that are hit by fewer than
faces_per_pixel are padded with -1.
- **zbuf**: FloatTensor of shape (N, image_size, image_size, faces_per_pixel)
giving the NDC z-coordinates of the nearest faces at each pixel,
sorted in ascending z-order.
Concretely, if ``pix_to_face[n, y, x, k] = f`` then
``zbuf[n, y, x, k] = face_verts[f, 2]``. Pixels hit by fewer than
faces_per_pixel are padded with -1.
- **barycentric**: FloatTensor of shape
(N, image_size, image_size, faces_per_pixel, 3)
giving the barycentric coordinates in NDC units of the
nearest faces at each pixel, sorted in ascending z-order.
Concretely, if ``pix_to_face[n, y, x, k] = f`` then
``[w0, w1, w2] = barycentric[n, y, x, k]`` gives
the barycentric coords for pixel (y, x) relative to the face
defined by ``face_verts[f]``. Pixels hit by fewer than
faces_per_pixel are padded with -1.
- **pix_dists**: FloatTensor of shape
(N, image_size, image_size, faces_per_pixel)
giving the signed Euclidean distance (in NDC units) in the
x/y plane of each point closest to the pixel. Concretely if
``pix_to_face[n, y, x, k] = f`` then ``pix_dists[n, y, x, k]`` is the
squared distance between the pixel (y, x) and the face given
by vertices ``face_verts[f]``. Pixels hit with fewer than
``faces_per_pixel`` are padded with -1.
"""
verts_packed = meshes.verts_packed()
faces_packed = meshes.faces_packed()
face_verts = verts_packed[faces_packed]
mesh_to_face_first_idx = meshes.mesh_to_faces_packed_first_idx()
num_faces_per_mesh = meshes.num_faces_per_mesh()
# TODO: Choose naive vs coarse-to-fine based on mesh size and image size.
if bin_size is None:
if not verts_packed.is_cuda:
# Binned CPU rasterization is not supported.
bin_size = 0
else:
# TODO better heuristics for bin size.
if image_size <= 64:
bin_size = 8
elif image_size <= 256:
bin_size = 16
elif image_size <= 512:
bin_size = 32
elif image_size <= 1024:
bin_size = 64
if max_faces_per_bin is None:
max_faces_per_bin = int(max(10000, verts_packed.shape[0] / 5))
return _RasterizeFaceVerts.apply(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
image_size,
blur_radius,
faces_per_pixel,
bin_size,
max_faces_per_bin,
perspective_correct,
)
class _RasterizeFaceVerts(torch.autograd.Function):
"""
Torch autograd wrapper for forward and backward pass of rasterize_meshes
implemented in C++/CUDA.
Args:
face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions
for faces in all the meshes in the batch. Concretely,
face_verts[f, i] = [x, y, z] gives the coordinates for the
ith vertex of the fth face. These vertices are expected to
be in NDC coordinates in the range [-1, 1].
mesh_to_face_first_idx: LongTensor of shape (N) giving the index in
faces_verts of the first face in each mesh in
the batch.
num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
for each mesh in the batch.
image_size, blur_radius, faces_per_pixel: same as rasterize_meshes.
perspective_correct: same as rasterize_meshes.
Returns:
same as rasterize_meshes function.
"""
@staticmethod
def forward(
ctx,
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
image_size: int = 256,
blur_radius: float = 0.01,
faces_per_pixel: int = 0,
bin_size: int = 0,
max_faces_per_bin: int = 0,
perspective_correct: bool = False,
):
pix_to_face, zbuf, barycentric_coords, dists = _C.rasterize_meshes(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
image_size,
blur_radius,
faces_per_pixel,
bin_size,
max_faces_per_bin,
perspective_correct,
)
ctx.save_for_backward(face_verts, pix_to_face)
ctx.perspective_correct = perspective_correct
return pix_to_face, zbuf, barycentric_coords, dists
@staticmethod
def backward(
ctx, grad_pix_to_face, grad_zbuf, grad_barycentric_coords, grad_dists
):
grad_face_verts = None
grad_mesh_to_face_first_idx = None
grad_num_faces_per_mesh = None
grad_image_size = None
grad_radius = None
grad_faces_per_pixel = None
grad_bin_size = None
grad_max_faces_per_bin = None
grad_perspective_correct = None
face_verts, pix_to_face = ctx.saved_tensors
grad_face_verts = _C.rasterize_meshes_backward(
face_verts,
pix_to_face,
grad_zbuf,
grad_barycentric_coords,
grad_dists,
ctx.perspective_correct,
)
grads = (
grad_face_verts,
grad_mesh_to_face_first_idx,
grad_num_faces_per_mesh,
grad_image_size,
grad_radius,
grad_faces_per_pixel,
grad_bin_size,
grad_max_faces_per_bin,
grad_perspective_correct,
)
return grads
def rasterize_meshes_python(
meshes,
image_size: int = 256,
blur_radius: float = 0.0,
faces_per_pixel: int = 8,
perspective_correct: bool = False,
):
"""
Naive PyTorch implementation of mesh rasterization with the same inputs and
outputs as the rasterize_meshes function.
This function is not optimized and is implemented as a comparison for the
C++/CUDA implementations.
"""
N = len(meshes)
# Assume only square images.
# TODO(T52813608) extend support for non-square images.
H, W, = image_size, image_size
K = faces_per_pixel
device = meshes.device
verts_packed = meshes.verts_packed()
faces_packed = meshes.faces_packed()
faces_verts = verts_packed[faces_packed]
mesh_to_face_first_idx = meshes.mesh_to_faces_packed_first_idx()
num_faces_per_mesh = meshes.num_faces_per_mesh()
# Intialize output tensors.
face_idxs = torch.full(
(N, H, W, K), fill_value=-1, dtype=torch.int64, device=device
)
zbuf = torch.full(
(N, H, W, K), fill_value=-1, dtype=torch.float32, device=device
)
bary_coords = torch.full(
(N, H, W, K, 3), fill_value=-1, dtype=torch.float32, device=device
)
pix_dists = torch.full(
(N, H, W, K), fill_value=-1, dtype=torch.float32, device=device
)
# NDC is from [-1, 1]. Get pixel size using specified image size.
pixel_width = 2.0 / W
pixel_height = 2.0 / H
# Calculate all face bounding boxes.
x_mins = torch.min(faces_verts[:, :, 0], dim=1, keepdim=True).values
x_maxs = torch.max(faces_verts[:, :, 0], dim=1, keepdim=True).values
y_mins = torch.min(faces_verts[:, :, 1], dim=1, keepdim=True).values
y_maxs = torch.max(faces_verts[:, :, 1], dim=1, keepdim=True).values
# Expand by blur radius.
x_mins = x_mins - np.sqrt(blur_radius) - kEpsilon
x_maxs = x_maxs + np.sqrt(blur_radius) + kEpsilon
y_mins = y_mins - np.sqrt(blur_radius) - kEpsilon
y_maxs = y_maxs + np.sqrt(blur_radius) + kEpsilon
# Loop through meshes in the batch.
for n in range(N):
face_start_idx = mesh_to_face_first_idx[n]
face_stop_idx = face_start_idx + num_faces_per_mesh[n]
# Y coordinate of the top of the image.
yf = -1.0 + 0.5 * pixel_height
# Iterate through the horizontal lines of the image from top to bottom.
for yi in range(H):
# X coordinate of the left of the image.
xf = -1.0 + 0.5 * pixel_width
# Iterate through pixels on this horizontal line, left to right.
for xi in range(W):
top_k_points = []
# Check whether each face in the mesh affects this pixel.
for f in range(face_start_idx, face_stop_idx):
face = faces_verts[f].squeeze()
v0, v1, v2 = face.unbind(0)
face_area = edge_function(v2, v0, v1)
# Ignore faces which have zero area.
if face_area == 0.0:
continue
outside_bbox = (
xf < x_mins[f]
or xf > x_maxs[f]
or yf < y_mins[f]
or yf > y_maxs[f]
)
# Check if pixel is outside of face bbox.
if outside_bbox:
continue
# Compute barycentric coordinates and pixel z distance.
pxy = torch.tensor(
[xf, yf], dtype=torch.float32, device=device
)
bary = barycentric_coordinates(pxy, v0[:2], v1[:2], v2[:2])
if perspective_correct:
z0, z1, z2 = v0[2], v1[2], v2[2]
l0, l1, l2 = bary[0], bary[1], bary[2]
top0 = l0 * z1 * z2
top1 = z0 * l1 * z2
top2 = z0 * z1 * l2
bot = top0 + top1 + top2
bary = torch.stack([top0 / bot, top1 / bot, top2 / bot])
pz = bary[0] * v0[2] + bary[1] * v1[2] + bary[2] * v2[2]
# Check if point is behind the image.
if pz < 0:
continue
# Calculate signed 2D distance from point to face.
# Points inside the triangle have negative distance.
dist = point_triangle_distance(pxy, v0[:2], v1[:2], v2[:2])
inside = all(x > 0.0 for x in bary)
signed_dist = dist * -1.0 if inside else dist
# Add an epsilon to prevent errors when comparing distance
# to blur radius.
if not inside and dist >= blur_radius:
continue
top_k_points.append((pz, f, bary, signed_dist))
top_k_points.sort()
if len(top_k_points) > K:
top_k_points = top_k_points[:K]
# Save to output tensors.
for k, (pz, f, bary, dist) in enumerate(top_k_points):
zbuf[n, yi, xi, k] = pz
face_idxs[n, yi, xi, k] = f
bary_coords[n, yi, xi, k, 0] = bary[0]
bary_coords[n, yi, xi, k, 1] = bary[1]
bary_coords[n, yi, xi, k, 2] = bary[2]
pix_dists[n, yi, xi, k] = dist
# Move to the next horizontal pixel
xf += pixel_width
# Move to the next vertical pixel
yf += pixel_height
return face_idxs, zbuf, bary_coords, pix_dists
def edge_function(p, v0, v1):
r"""
Determines whether a point p is on the right side of a 2D line segment
given by the end points v0, v1.
Args:
p: (x, y) Coordinates of a point.
v0, v1: (x, y) Coordinates of the end points of the edge.
Returns:
area: The signed area of the parallelogram given by the vectors
.. code-block:: python
A = p - v0
B = v1 - v0
v1 ________
/\ /
A / \ /
/ \ /
v0 /______\/
B p
The area can also be interpreted as the cross product A x B.
If the sign of the area is positive, the point p is on the
right side of the edge. Negative area indicates the point is on
the left side of the edge. i.e. for an edge v1 - v0
.. code-block:: python
v1
/
/
- / +
/
/
v0
"""
return (p[0] - v0[0]) * (v1[1] - v0[1]) - (p[1] - v0[1]) * (v1[0] - v0[0])
def barycentric_coordinates(p, v0, v1, v2):
"""
Compute the barycentric coordinates of a point relative to a triangle.
Args:
p: Coordinates of a point.
v0, v1, v2: Coordinates of the triangle vertices.
Returns
bary: (w0, w1, w2) barycentric coordinates in the range [0, 1].
"""
area = edge_function(v2, v0, v1) + kEpsilon # 2 x face area.
w0 = edge_function(p, v1, v2) / area
w1 = edge_function(p, v2, v0) / area
w2 = edge_function(p, v0, v1) / area
return (w0, w1, w2)
def point_line_distance(p, v0, v1):
"""
Return minimum distance between line segment (v1 - v0) and point p.
Args:
p: Coordinates of a point.
v0, v1: Coordinates of the end points of the line segment.
Returns:
non-square distance to the boundary of the triangle.
Consider the line extending the segment - this can be parameterized as
``v0 + t (v1 - v0)``.
First find the projection of point p onto the line. It falls where
``t = [(p - v0) . (v1 - v0)] / |v1 - v0|^2``
where . is the dot product.
The parameter t is clamped from [0, 1] to handle points outside the
segment (v1 - v0).
Once the projection of the point on the segment is known, the distance from
p to the projection gives the minimum distance to the segment.
"""
if p.shape != v0.shape != v1.shape:
raise ValueError("All points must have the same number of coordinates")
v1v0 = v1 - v0
l2 = v1v0.dot(v1v0) # |v1 - v0|^2
if l2 == 0.0:
return torch.sqrt((p - v1).dot(p - v1)) # v0 == v1
t = (v1v0).dot(p - v0) / l2
t = torch.clamp(t, min=0.0, max=1.0)
p_proj = v0 + t * v1v0
delta_p = p_proj - p
return delta_p.dot(delta_p)
def point_triangle_distance(p, v0, v1, v2):
"""
Return shortest distance between a point and a triangle.
Args:
p: Coordinates of a point.
v0, v1, v2: Coordinates of the three triangle vertices.
Returns:
shortest absolute distance from the point to the triangle.
"""
if p.shape != v0.shape != v1.shape != v2.shape:
raise ValueError("All points must have the same number of coordinates")
e01_dist = point_line_distance(p, v0, v1)
e02_dist = point_line_distance(p, v0, v2)
e12_dist = point_line_distance(p, v1, v2)
edge_dists_min = torch.min(torch.min(e01_dist, e02_dist), e12_dist)
return edge_dists_min
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from dataclasses import dataclass
from typing import NamedTuple, Optional
import torch
import torch.nn as nn
from ..cameras import get_world_to_view_transform
from .rasterize_meshes import rasterize_meshes
# Class to store the outputs of mesh rasterization
class Fragments(NamedTuple):
pix_to_face: torch.Tensor
zbuf: torch.Tensor
bary_coords: torch.Tensor
dists: torch.Tensor
# Class to store the mesh rasterization params with defaults
@dataclass
class RasterizationSettings:
image_size: int = 256
blur_radius: float = 0.0
faces_per_pixel: int = 1
bin_size: Optional[int] = None
max_faces_per_bin: Optional[int] = None
perspective_correct: bool = False
class MeshRasterizer(nn.Module):
"""
This class implements methods for rasterizing a batch of heterogenous
Meshes.
"""
def __init__(self, cameras, raster_settings=None):
"""
Args:
cameras: A cameras object which has a `transform_points` method
which returns the transformed points after applying the
world-to-view and view-to-screen
transformations.
raster_settings: the parameters for rasterization. This should be a
named tuple.
All these initial settings can be overridden by passing keyword
arguments to the forward function.
"""
super().__init__()
if raster_settings is None:
raster_settings = RasterizationSettings()
self.cameras = cameras
self.raster_settings = raster_settings
def transform(self, meshes_world, **kwargs) -> torch.Tensor:
"""
Args:
meshes_world: a Meshes object representing a batch of meshes with
vertex coordinates in world space.
Returns:
meshes_screen: a Meshes object with the vertex positions in screen
space
NOTE: keeping this as a separate function for readability but it could
be moved into forward.
"""
cameras = kwargs.get("cameras", self.cameras)
verts_world = meshes_world.verts_padded()
verts_world_packed = meshes_world.verts_packed()
verts_screen = cameras.transform_points(verts_world, **kwargs)
# NOTE: Retaining view space z coordinate for now.
# TODO: Revisit whether or not to transform z coordinate to [-1, 1] or
# [0, 1] range.
view_transform = get_world_to_view_transform(R=cameras.R, T=cameras.T)
verts_view = view_transform.transform_points(verts_world)
verts_screen[..., 2] = verts_view[..., 2]
# Offset verts of input mesh to reuse cached padded/packed calculations.
pad_to_packed_idx = meshes_world.verts_padded_to_packed_idx()
verts_screen_packed = verts_screen.view(-1, 3)[pad_to_packed_idx, :]
verts_packed_offset = verts_screen_packed - verts_world_packed
return meshes_world.offset_verts(verts_packed_offset)
def forward(self, meshes_world, **kwargs) -> Fragments:
"""
Args:
meshes_world: a Meshes object representing a batch of meshes with
coordinates in world space.
Returns:
Fragments: Rasterization outputs as a named tuple.
"""
meshes_screen = self.transform(meshes_world, **kwargs)
raster_settings = kwargs.get("raster_settings", self.raster_settings)
# TODO(jcjohns): Should we try to set perspective_correct automatically
# based on the type of the camera?
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes(
meshes_screen,
image_size=raster_settings.image_size,
blur_radius=raster_settings.blur_radius,
faces_per_pixel=raster_settings.faces_per_pixel,
bin_size=raster_settings.bin_size,
max_faces_per_bin=raster_settings.max_faces_per_bin,
perspective_correct=raster_settings.perspective_correct,
)
return Fragments(
pix_to_face=pix_to_face,
zbuf=zbuf,
bary_coords=bary_coords,
dists=dists,
)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn as nn
# A renderer class should be initialized with a
# function for rasterization and a function for shading.
# The rasterizer should:
# - transform inputs from world -> screen space
# - rasterize inputs
# - return fragments
# The shader can take fragments as input along with any other properties of
# the scene and generate images.
# E.g. rasterize inputs and then shade
#
# fragments = self.rasterize(meshes)
# images = self.shader(fragments, meshes)
# return images
class MeshRenderer(nn.Module):
"""
A class for rendering a batch of heterogeneous meshes. The class should
be initialized with a rasterizer and shader class which each have a forward
function.
"""
def __init__(self, rasterizer, shader):
super().__init__()
self.rasterizer = rasterizer
self.shader = shader
def forward(self, meshes_world, **kwargs) -> torch.Tensor:
fragments = self.rasterizer(meshes_world, **kwargs)
images = self.shader(fragments, meshes_world, **kwargs)
return images
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn as nn
from ..blending import (
BlendParams,
hard_rgb_blend,
sigmoid_alpha_blend,
softmax_rgb_blend,
)
from ..cameras import OpenGLPerspectiveCameras
from ..lighting import PointLights
from ..materials import Materials
from .shading import gourad_shading, phong_shading
from .texturing import interpolate_texture_map, interpolate_vertex_colors
# A Shader should take as input fragments from the output of rasterization
# along with scene params and output images. A shader could perform operations
# such as:
# - interpolate vertex attributes for all the fragments
# - sample colors from a texture map
# - apply per pixel lighting
# - blend colors across top K faces per pixel.
class PhongShader(nn.Module):
"""
Per pixel lighting. Apply the lighting model using the interpolated coords
and normals for each pixel.
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = PhongShader(device=torch.device("cuda:0"))
"""
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
super().__init__()
self.lights = (
lights if lights is not None else PointLights(device=device)
)
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = (
cameras
if cameras is not None
else OpenGLPerspectiveCameras(device=device)
)
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
texels = interpolate_vertex_colors(fragments, meshes)
cameras = kwargs.get("cameras", self.cameras)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
colors = phong_shading(
meshes=meshes,
fragments=fragments,
texels=texels,
lights=lights,
cameras=cameras,
materials=materials,
)
images = hard_rgb_blend(colors, fragments)
return images
class GouradShader(nn.Module):
"""
Per vertex lighting. Apply the lighting model to the vertex colors and then
interpolate using the barycentric coordinates to get colors for each pixel.
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = GouradShader(device=torch.device("cuda:0"))
"""
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
super().__init__()
self.lights = (
lights if lights is not None else PointLights(device=device)
)
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = (
cameras
if cameras is not None
else OpenGLPerspectiveCameras(device=device)
)
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
pixel_colors = gourad_shading(
meshes=meshes,
fragments=fragments,
lights=lights,
cameras=cameras,
materials=materials,
)
images = hard_rgb_blend(pixel_colors, fragments)
return images
class TexturedPhongShader(nn.Module):
"""
Per pixel lighting applied to a texture map. First interpolate the vertex
uv coordinates and sample from a texture map. Then apply the lighting model
using the interpolated coords and normals for each pixel.
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = TexturedPhongShader(device=torch.device("cuda:0"))
"""
def __init__(
self,
device="cpu",
cameras=None,
lights=None,
materials=None,
blend_params=None,
):
super().__init__()
self.lights = (
lights if lights is not None else PointLights(device=device)
)
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = (
cameras
if cameras is not None
else OpenGLPerspectiveCameras(device=device)
)
self.blend_params = (
blend_params if blend_params is not None else BlendParams()
)
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
texels = interpolate_texture_map(fragments, meshes)
cameras = kwargs.get("cameras", self.cameras)
lights = kwargs.get("lights", self.lights)
materials = kwargs.get("materials", self.materials)
colors = phong_shading(
meshes=meshes,
fragments=fragments,
texels=texels,
lights=lights,
cameras=cameras,
materials=materials,
)
images = softmax_rgb_blend(colors, fragments, self.blend_params)
return images
class SilhouetteShader(nn.Module):
"""
Calculate the silhouette by blending the top K faces for each pixel based
on the 2d euclidean distance of the centre of the pixel to the mesh face.
Use this shader for generating silhouettes similar to SoftRasterizer [0].
.. note::
To be consistent with SoftRasterizer, initialize the
RasterizationSettings for the rasterizer with
`blur_radius = np.log(1. / 1e-4 - 1.) * blend_params.sigma`
[0] Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based
3D Reasoning', ICCV 2019
"""
def __init__(self, blend_params=None):
super().__init__()
self.blend_params = (
blend_params if blend_params is not None else BlendParams()
)
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
""""
Only want to render the silhouette so RGB values can be ones.
There is no need for lighting or texturing
"""
colors = torch.ones_like(fragments.bary_coords)
blend_params = kwargs.get("blend_params", self.blend_params)
images = sigmoid_alpha_blend(colors, fragments, blend_params)
return images
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Tuple
import torch
from .texturing import interpolate_face_attributes
def _apply_lighting(
points, normals, lights, cameras, materials
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
points: torch tensor of shape (N, P, 3) or (P, 3).
normals: torch tensor of shape (N, P, 3) or (P, 3)
lights: instance of the Lights class.
cameras: instance of the Cameras class.
materials: instance of the Materials class.
Returns:
ambient_color: same shape as materials.ambient_color
diffuse_color: same shape as the input points
specular_color: same shape as the input points
"""
light_diffuse = lights.diffuse(normals=normals, points=points)
light_specular = lights.specular(
normals=normals,
points=points,
camera_position=cameras.get_camera_center(),
shininess=materials.shininess,
)
ambient_color = materials.ambient_color * lights.ambient_color
diffuse_color = materials.diffuse_color * light_diffuse
specular_color = materials.specular_color * light_specular
if normals.dim() == 2 and points.dim() == 2:
# If given packed inputs remove batch dim in output.
return (
ambient_color.squeeze(),
diffuse_color.squeeze(),
specular_color.squeeze(),
)
return ambient_color, diffuse_color, specular_color
def phong_shading(
meshes, fragments, lights, cameras, materials, texels
) -> torch.Tensor:
"""
Apply per pixel shading. First interpolate the vertex normals and
vertex coordinates using the barycentric coordinates to get the position
and normal at each pixel. Then compute the illumination for each pixel.
The pixel color is obtained by multiplying the pixel textures by the ambient
and diffuse illumination and adding the specular component.
Args:
meshes: Batch of meshes
fragments: Fragments named tuple with the outputs of rasterization
lights: Lights class containing a batch of lights
cameras: Cameras class containing a batch of cameras
materials: Materials class containing a batch of material properties
texels: texture per pixel of shape (N, H, W, K, 3)
Returns:
colors: (N, H, W, K, 3)
"""
verts = meshes.verts_packed() # (V, 3)
faces = meshes.faces_packed() # (F, 3)
vertex_normals = meshes.verts_normals_packed() # (V, 3)
faces_verts = verts[faces]
faces_normals = vertex_normals[faces]
pixel_coords = interpolate_face_attributes(fragments, faces_verts)
pixel_normals = interpolate_face_attributes(fragments, faces_normals)
ambient, diffuse, specular = _apply_lighting(
pixel_coords, pixel_normals, lights, cameras, materials
)
colors = (ambient + diffuse) * texels + specular
return colors
def gourad_shading(
meshes, fragments, lights, cameras, materials
) -> torch.Tensor:
"""
Apply per vertex shading. First compute the vertex illumination by applying
ambient, diffuse and specular lighting. If vertex color is available,
combine the ambient and diffuse vertex illumination with the vertex color
and add the specular component to determine the vertex shaded color.
Then interpolate the vertex shaded colors using the barycentric coordinates
to get a color per pixel.
Args:
meshes: Batch of meshes
fragments: Fragments named tuple with the outputs of rasterization
lights: Lights class containing a batch of lights parameters
cameras: Cameras class containing a batch of cameras parameters
materials: Materials class containing a batch of material properties
Returns:
colors: (N, H, W, K, 3)
"""
faces = meshes.faces_packed() # (F, 3)
verts = meshes.verts_packed()
vertex_normals = meshes.verts_normals_packed() # (V, 3)
vertex_colors = meshes.textures.verts_rgb_packed()
vert_to_mesh_idx = meshes.verts_packed_to_mesh_idx()
# Format properties of lights and materials so they are compatible
# with the packed representation of the vertices. This transforms
# all tensor properties in the class from shape (N, ...) -> (V, ...) where
# V is the number of packed vertices. If the number of meshes in the
# batch is one then this is not necessary.
if len(meshes) > 1:
lights = lights.clone().gather_props(vert_to_mesh_idx)
cameras = cameras.clone().gather_props(vert_to_mesh_idx)
materials = materials.clone().gather_props(vert_to_mesh_idx)
# Calculate the illumination at each vertex
ambient, diffuse, specular = _apply_lighting(
verts, vertex_normals, lights, cameras, materials
)
verts_colors_shaded = vertex_colors * (ambient + diffuse) + specular
face_colors = verts_colors_shaded[faces]
colors = interpolate_face_attributes(fragments, face_colors)
return colors
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment