Commit dbf06b50 authored by facebook-github-bot's avatar facebook-github-bot
Browse files

Initial commit

fbshipit-source-id: ad58e416e3ceeca85fae0583308968d04e78fe0d
parents
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <torch/extension.h>
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x "must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x "must be contiguous.")
#define CHECK_CONTIGUOUS_CUDA(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <thrust/tuple.h>
// Common functions and operators for float2.
__device__ inline float2 operator-(const float2& a, const float2& b) {
return make_float2(a.x - b.x, a.y - b.y);
}
__device__ inline float2 operator+(const float2& a, const float2& b) {
return make_float2(a.x + b.x, a.y + b.y);
}
__device__ inline float2 operator/(const float2& a, const float2& b) {
return make_float2(a.x / b.x, a.y / b.y);
}
__device__ inline float2 operator/(const float2& a, const float b) {
return make_float2(a.x / b, a.y / b);
}
__device__ inline float2 operator*(const float2& a, const float2& b) {
return make_float2(a.x * b.x, a.y * b.y);
}
__device__ inline float2 operator*(const float a, const float2& b) {
return make_float2(a * b.x, a * b.y);
}
__device__ inline float dot(const float2& a, const float2& b) {
return a.x * b.x + a.y * b.y;
}
// Backward pass for the dot product.
// Args:
// a, b: Coordinates of two points.
// grad_dot: Upstream gradient for the output.
//
// Returns:
// tuple of gradients for each of the input points:
// (float2 grad_a, float2 grad_b)
//
__device__ inline thrust::tuple<float2, float2>
DotBackward(const float2& a, const float2& b, const float& grad_dot) {
return thrust::make_tuple(grad_dot * b, grad_dot * a);
}
__device__ inline float sum(const float2& a) {
return a.x + a.y;
}
// Common functions and operators for float3.
__device__ inline float3 operator-(const float3& a, const float3& b) {
return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);
}
__device__ inline float3 operator+(const float3& a, const float3& b) {
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
}
__device__ inline float3 operator/(const float3& a, const float3& b) {
return make_float3(a.x / b.x, a.y / b.y, a.z / b.z);
}
__device__ inline float3 operator/(const float3& a, const float b) {
return make_float3(a.x / b, a.y / b, a.z / b);
}
__device__ inline float3 operator*(const float3& a, const float3& b) {
return make_float3(a.x * b.x, a.y * b.y, a.z * b.z);
}
__device__ inline float3 operator*(const float a, const float3& b) {
return make_float3(a * b.x, a * b.y, a * b.z);
}
__device__ inline float dot(const float3& a, const float3& b) {
return a.x * b.x + a.y * b.y + a.z * b.z;
}
__device__ inline float sum(const float3& a) {
return a.x + a.y + a.z;
}
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <float.h>
#include <math.h>
#include <torch/extension.h>
#include <cstdio>
#include "float_math.cuh"
// Set epsilon for preventing floating point errors and division by 0.
const auto kEpsilon = 1e-30;
// Determines whether a point p is on the right side of a 2D line segment
// given by the end points v0, v1.
//
// Args:
// p: vec2 Coordinates of a point.
// v0, v1: vec2 Coordinates of the end points of the edge.
//
// Returns:
// area: The signed area of the parallelogram given by the vectors
// A = p - v0
// B = v1 - v0
//
__device__ inline float
EdgeFunctionForward(const float2& p, const float2& v0, const float2& v1) {
return (p.x - v0.x) * (v1.y - v0.y) - (p.y - v0.y) * (v1.x - v0.x);
}
// Backward pass for the edge function returning partial dervivatives for each
// of the input points.
//
// Args:
// p: vec2 Coordinates of a point.
// v0, v1: vec2 Coordinates of the end points of the edge.
// grad_edge: Upstream gradient for output from edge function.
//
// Returns:
// tuple of gradients for each of the input points:
// (float2 d_edge_dp, float2 d_edge_dv0, float2 d_edge_dv1)
//
__device__ inline thrust::tuple<float2, float2, float2> EdgeFunctionBackward(
const float2& p,
const float2& v0,
const float2& v1,
const float& grad_edge) {
const float2 dedge_dp = make_float2(v1.y - v0.y, v0.x - v1.x);
const float2 dedge_dv0 = make_float2(p.y - v1.y, v1.x - p.x);
const float2 dedge_dv1 = make_float2(v0.y - p.y, p.x - v0.x);
return thrust::make_tuple(
grad_edge * dedge_dp, grad_edge * dedge_dv0, grad_edge * dedge_dv1);
}
// The forward pass for computing the barycentric coordinates of a point
// relative to a triangle.
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: Coordinates of the triangle vertices.
//
// Returns
// bary: (w0, w1, w2) barycentric coordinates in the range [0, 1].
//
__device__ inline float3 BarycentricCoordsForward(
const float2& p,
const float2& v0,
const float2& v1,
const float2& v2) {
const float area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
const float w0 = EdgeFunctionForward(p, v1, v2) / area;
const float w1 = EdgeFunctionForward(p, v2, v0) / area;
const float w2 = EdgeFunctionForward(p, v0, v1) / area;
return make_float3(w0, w1, w2);
}
// The backward pass for computing the barycentric coordinates of a point
// relative to a triangle.
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: (x, y) coordinates of the triangle vertices.
// grad_bary_upstream: vec3<T> Upstream gradient for each of the
// barycentric coordaintes [grad_w0, grad_w1, grad_w2].
//
// Returns
// tuple of gradients for each of the triangle vertices:
// (float2 grad_v0, float2 grad_v1, float2 grad_v2)
//
__device__ inline thrust::tuple<float2, float2, float2, float2>
BarycentricCoordsBackward(
const float2& p,
const float2& v0,
const float2& v1,
const float2& v2,
const float3& grad_bary_upstream) {
const float area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
const float area2 = pow(area, 2.0);
const float e0 = EdgeFunctionForward(p, v1, v2);
const float e1 = EdgeFunctionForward(p, v2, v0);
const float e2 = EdgeFunctionForward(p, v0, v1);
const float grad_w0 = grad_bary_upstream.x;
const float grad_w1 = grad_bary_upstream.y;
const float grad_w2 = grad_bary_upstream.z;
// Calculate component of the gradient from each of w0, w1 and w2.
// e.g. for w0:
// dloss/dw0_v = dl/dw0 * dw0/dw0_top * dw0_top/dv
// + dl/dw0 * dw0/dw0_bot * dw0_bot/dv
const float dw0_darea = -e0 / (area2);
const float dw0_e0 = 1 / area;
const float dloss_d_w0area = grad_w0 * dw0_darea;
const float dloss_e0 = grad_w0 * dw0_e0;
auto de0_dv = EdgeFunctionBackward(p, v1, v2, dloss_e0);
auto dw0area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w0area);
const float2 dw0_p = thrust::get<0>(de0_dv);
const float2 dw0_dv0 = thrust::get<1>(dw0area_dv);
const float2 dw0_dv1 = thrust::get<1>(de0_dv) + thrust::get<2>(dw0area_dv);
const float2 dw0_dv2 = thrust::get<2>(de0_dv) + thrust::get<0>(dw0area_dv);
const float dw1_darea = -e1 / (area2);
const float dw1_e1 = 1 / area;
const float dloss_d_w1area = grad_w1 * dw1_darea;
const float dloss_e1 = grad_w1 * dw1_e1;
auto de1_dv = EdgeFunctionBackward(p, v2, v0, dloss_e1);
auto dw1area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w1area);
const float2 dw1_p = thrust::get<0>(de1_dv);
const float2 dw1_dv0 = thrust::get<2>(de1_dv) + thrust::get<1>(dw1area_dv);
const float2 dw1_dv1 = thrust::get<2>(dw1area_dv);
const float2 dw1_dv2 = thrust::get<1>(de1_dv) + thrust::get<0>(dw1area_dv);
const float dw2_darea = -e2 / (area2);
const float dw2_e2 = 1 / area;
const float dloss_d_w2area = grad_w2 * dw2_darea;
const float dloss_e2 = grad_w2 * dw2_e2;
auto de2_dv = EdgeFunctionBackward(p, v0, v1, dloss_e2);
auto dw2area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w2area);
const float2 dw2_p = thrust::get<0>(de2_dv);
const float2 dw2_dv0 = thrust::get<1>(de2_dv) + thrust::get<1>(dw2area_dv);
const float2 dw2_dv1 = thrust::get<2>(de2_dv) + thrust::get<2>(dw2area_dv);
const float2 dw2_dv2 = thrust::get<0>(dw2area_dv);
const float2 dbary_p = dw0_p + dw1_p + dw2_p;
const float2 dbary_dv0 = dw0_dv0 + dw1_dv0 + dw2_dv0;
const float2 dbary_dv1 = dw0_dv1 + dw1_dv1 + dw2_dv1;
const float2 dbary_dv2 = dw0_dv2 + dw1_dv2 + dw2_dv2;
return thrust::make_tuple(dbary_p, dbary_dv0, dbary_dv1, dbary_dv2);
}
// Forward pass for applying perspective correction to barycentric coordinates.
//
// Args:
// bary: Screen-space barycentric coordinates for a point
// z0, z1, z2: Camera-space z-coordinates of the triangle vertices
//
// Returns
// World-space barycentric coordinates
//
__device__ inline float3 BarycentricPerspectiveCorrectionForward(
const float3& bary,
const float z0,
const float z1,
const float z2) {
const float w0_top = bary.x * z1 * z2;
const float w1_top = z0 * bary.y * z2;
const float w2_top = z0 * z1 * bary.z;
const float denom = w0_top + w1_top + w2_top;
const float w0 = w0_top / denom;
const float w1 = w1_top / denom;
const float w2 = w2_top / denom;
return make_float3(w0, w1, w2);
}
// Backward pass for applying perspective correction to barycentric coordinates.
//
// Args:
// bary: Screen-space barycentric coordinates for a point
// z0, z1, z2: Camera-space z-coordinates of the triangle vertices
// grad_out: Upstream gradient of the loss with respect to the corrected
// barycentric coordinates.
//
// Returns a tuple of:
// grad_bary: Downstream gradient of the loss with respect to the the
// uncorrected barycentric coordinates.
// grad_z0, grad_z1, grad_z2: Downstream gradient of the loss with respect
// to the z-coordinates of the triangle verts
__device__ inline thrust::tuple<float3, float, float, float>
BarycentricPerspectiveCorrectionBackward(
const float3& bary,
const float z0,
const float z1,
const float z2,
const float3& grad_out) {
// Recompute forward pass
const float w0_top = bary.x * z1 * z2;
const float w1_top = z0 * bary.y * z2;
const float w2_top = z0 * z1 * bary.z;
const float denom = w0_top + w1_top + w2_top;
// Now do backward pass
const float grad_denom_top =
-w0_top * grad_out.x - w1_top * grad_out.y - w2_top * grad_out.z;
const float grad_denom = grad_denom_top / (denom * denom);
const float grad_w0_top = grad_denom + grad_out.x / denom;
const float grad_w1_top = grad_denom + grad_out.y / denom;
const float grad_w2_top = grad_denom + grad_out.z / denom;
const float grad_bary_x = grad_w0_top * z1 * z2;
const float grad_bary_y = grad_w1_top * z0 * z2;
const float grad_bary_z = grad_w2_top * z0 * z1;
const float3 grad_bary = make_float3(grad_bary_x, grad_bary_y, grad_bary_z);
const float grad_z0 = grad_w1_top * bary.y * z2 + grad_w2_top * bary.z * z1;
const float grad_z1 = grad_w0_top * bary.x * z2 + grad_w2_top * bary.z * z0;
const float grad_z2 = grad_w0_top * bary.x * z1 + grad_w1_top * bary.y * z0;
return thrust::make_tuple(grad_bary, grad_z0, grad_z1, grad_z2);
}
// Return minimum distance between line segment (v1 - v0) and point p.
//
// Args:
// p: Coordinates of a point.
// v0, v1: Coordinates of the end points of the line segment.
//
// Returns:
// non-square distance to the boundary of the triangle.
//
__device__ inline float
PointLineDistanceForward(const float2& p, const float2& a, const float2& b) {
const float2 ba = b - a;
float l2 = dot(ba, ba);
float t = dot(ba, p - a) / l2;
if (l2 <= kEpsilon) {
return dot(p - b, p - b);
}
t = __saturatef(t); // clamp to the interval [+0.0, 1.0]
const float2 p_proj = a + t * ba;
const float2 d = (p_proj - p);
return dot(d, d); // squared distance
}
// Backward pass for point to line distance in 2D.
//
// Args:
// p: Coordinates of a point.
// v0, v1: Coordinates of the end points of the line segment.
// grad_dist: Upstream gradient for the distance.
//
// Returns:
// tuple of gradients for each of the input points:
// (float2 grad_p, float2 grad_v0, float2 grad_v1)
//
__device__ inline thrust::tuple<float2, float2, float2>
PointLineDistanceBackward(
const float2& p,
const float2& v0,
const float2& v1,
const float& grad_dist) {
// Redo some of the forward pass calculations.
const float2 v1v0 = v1 - v0;
const float2 pv0 = p - v0;
const float t_bot = dot(v1v0, v1v0);
const float t_top = dot(v1v0, pv0);
float tt = t_top / t_bot;
tt = __saturatef(tt);
const float2 p_proj = (1.0f - tt) * v0 + tt * v1;
const float2 d = p - p_proj;
const float dist = sqrt(dot(d, d));
const float2 grad_p = -1.0f * grad_dist * 2.0f * (p_proj - p);
const float2 grad_v0 = grad_dist * (1.0f - tt) * 2.0f * (p_proj - p);
const float2 grad_v1 = grad_dist * tt * 2.0f * (p_proj - p);
return thrust::make_tuple(grad_p, grad_v0, grad_v1);
}
// The forward pass for calculating the shortest distance between a point
// and a triangle.
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: Coordinates of the three triangle vertices.
//
// Returns:
// shortest absolute distance from a point to a triangle.
//
__device__ inline float PointTriangleDistanceForward(
const float2& p,
const float2& v0,
const float2& v1,
const float2& v2) {
// Compute distance to all 3 edges of the triangle and return the min.
const float e01_dist = PointLineDistanceForward(p, v0, v1);
const float e02_dist = PointLineDistanceForward(p, v0, v2);
const float e12_dist = PointLineDistanceForward(p, v1, v2);
const float edge_dist = fminf(fminf(e01_dist, e02_dist), e12_dist);
return edge_dist;
}
// Backward pass for point triangle distance.
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: Coordinates of the three triangle vertices.
// grad_dist: Upstream gradient for the distance.
//
// Returns:
// tuple of gradients for each of the triangle vertices:
// (float2 grad_v0, float2 grad_v1, float2 grad_v2)
//
__device__ inline thrust::tuple<float2, float2, float2, float2>
PointTriangleDistanceBackward(
const float2& p,
const float2& v0,
const float2& v1,
const float2& v2,
const float& grad_dist) {
// Compute distance to all 3 edges of the triangle.
const float e01_dist = PointLineDistanceForward(p, v0, v1);
const float e02_dist = PointLineDistanceForward(p, v0, v2);
const float e12_dist = PointLineDistanceForward(p, v1, v2);
// Initialize output tensors.
float2 grad_v0 = make_float2(0.0f, 0.0f);
float2 grad_v1 = make_float2(0.0f, 0.0f);
float2 grad_v2 = make_float2(0.0f, 0.0f);
float2 grad_p = make_float2(0.0f, 0.0f);
// Find which edge is the closest and return PointLineDistanceBackward for
// that edge.
if (e01_dist <= e02_dist && e01_dist <= e12_dist) {
// Closest edge is v1 - v0.
auto grad_e01 = PointLineDistanceBackward(p, v0, v1, grad_dist);
grad_p = thrust::get<0>(grad_e01);
grad_v0 = thrust::get<1>(grad_e01);
grad_v1 = thrust::get<2>(grad_e01);
} else if (e02_dist <= e01_dist && e02_dist <= e12_dist) {
// Closest edge is v2 - v0.
auto grad_e02 = PointLineDistanceBackward(p, v0, v2, grad_dist);
grad_p = thrust::get<0>(grad_e02);
grad_v0 = thrust::get<1>(grad_e02);
grad_v2 = thrust::get<2>(grad_e02);
} else if (e12_dist <= e01_dist && e12_dist <= e02_dist) {
// Closest edge is v2 - v1.
auto grad_e12 = PointLineDistanceBackward(p, v1, v2, grad_dist);
grad_p = thrust::get<0>(grad_e12);
grad_v1 = thrust::get<1>(grad_e12);
grad_v2 = thrust::get<2>(grad_e12);
}
return thrust::make_tuple(grad_p, grad_v0, grad_v1, grad_v2);
}
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h>
#include <algorithm>
#include <type_traits>
#include "vec2.h"
#include "vec3.h"
// Set epsilon for preventing floating point errors and division by 0.
const auto kEpsilon = 1e-30;
// Determines whether a point p is on the right side of a 2D line segment
// given by the end points v0, v1.
//
// Args:
// p: vec2 Coordinates of a point.
// v0, v1: vec2 Coordinates of the end points of the edge.
//
// Returns:
// area: The signed area of the parallelogram given by the vectors
// A = p - v0
// B = v1 - v0
//
// v1 ________
// /\ /
// A / \ /
// / \ /
// v0 /______\/
// B p
//
// The area can also be interpreted as the cross product A x B.
// If the sign of the area is positive, the point p is on the
// right side of the edge. Negative area indicates the point is on
// the left side of the edge. i.e. for an edge v1 - v0:
//
// v1
// /
// /
// - / +
// /
// /
// v0
//
template <typename T>
T EdgeFunctionForward(const vec2<T>& p, const vec2<T>& v0, const vec2<T>& v1) {
const T edge = (p.x - v0.x) * (v1.y - v0.y) - (p.y - v0.y) * (v1.x - v0.x);
return edge;
}
// Backward pass for the edge function returning partial dervivatives for each
// of the input points.
//
// Args:
// p: vec2 Coordinates of a point.
// v0, v1: vec2 Coordinates of the end points of the edge.
// grad_edge: Upstream gradient for output from edge function.
//
// Returns:
// tuple of gradients for each of the input points:
// (vec2<T> d_edge_dp, vec2<T> d_edge_dv0, vec2<T> d_edge_dv1)
//
template <typename T>
inline std::tuple<vec2<T>, vec2<T>, vec2<T>> EdgeFunctionBackward(
const vec2<T>& p,
const vec2<T>& v0,
const vec2<T>& v1,
const T grad_edge) {
const vec2<T> dedge_dp(v1.y - v0.y, v0.x - v1.x);
const vec2<T> dedge_dv0(p.y - v1.y, v1.x - p.x);
const vec2<T> dedge_dv1(v0.y - p.y, p.x - v0.x);
return std::make_tuple(
grad_edge * dedge_dp, grad_edge * dedge_dv0, grad_edge * dedge_dv1);
}
// The forward pass for computing the barycentric coordinates of a point
// relative to a triangle.
// Ref:
// https://www.scratchapixel.com/lessons/3d-basic-rendering/ray-tracing-rendering-a-triangle/barycentric-coordinates
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: Coordinates of the triangle vertices.
//
// Returns
// bary: (w0, w1, w2) barycentric coordinates in the range [0, 1].
//
template <typename T>
vec3<T> BarycentricCoordinatesForward(
const vec2<T>& p,
const vec2<T>& v0,
const vec2<T>& v1,
const vec2<T>& v2) {
const T area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
const T w0 = EdgeFunctionForward(p, v1, v2) / area;
const T w1 = EdgeFunctionForward(p, v2, v0) / area;
const T w2 = EdgeFunctionForward(p, v0, v1) / area;
return vec3<T>(w0, w1, w2);
}
// The backward pass for computing the barycentric coordinates of a point
// relative to a triangle.
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: (x, y) coordinates of the triangle vertices.
// grad_bary_upstream: vec3<T> Upstream gradient for each of the
// barycentric coordaintes [grad_w0, grad_w1, grad_w2].
//
// Returns
// tuple of gradients for each of the triangle vertices:
// (vec2<T> grad_v0, vec2<T> grad_v1, vec2<T> grad_v2)
//
template <typename T>
inline std::tuple<vec2<T>, vec2<T>, vec2<T>, vec2<T>> BarycentricCoordsBackward(
const vec2<T>& p,
const vec2<T>& v0,
const vec2<T>& v1,
const vec2<T>& v2,
const vec3<T>& grad_bary_upstream) {
const T area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
const T area2 = pow(area, 2.0f);
const T area_inv = 1.0f / area;
const T e0 = EdgeFunctionForward(p, v1, v2);
const T e1 = EdgeFunctionForward(p, v2, v0);
const T e2 = EdgeFunctionForward(p, v0, v1);
const T grad_w0 = grad_bary_upstream.x;
const T grad_w1 = grad_bary_upstream.y;
const T grad_w2 = grad_bary_upstream.z;
// Calculate component of the gradient from each of w0, w1 and w2.
// e.g. for w0:
// dloss/dw0_v = dl/dw0 * dw0/dw0_top * dw0_top/dv
// + dl/dw0 * dw0/dw0_bot * dw0_bot/dv
const T dw0_darea = -e0 / (area2);
const T dw0_e0 = area_inv;
const T dloss_d_w0area = grad_w0 * dw0_darea;
const T dloss_e0 = grad_w0 * dw0_e0;
auto de0_dv = EdgeFunctionBackward(p, v1, v2, dloss_e0);
auto dw0area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w0area);
const vec2<T> dw0_p = std::get<0>(de0_dv);
const vec2<T> dw0_dv0 = std::get<1>(dw0area_dv);
const vec2<T> dw0_dv1 = std::get<1>(de0_dv) + std::get<2>(dw0area_dv);
const vec2<T> dw0_dv2 = std::get<2>(de0_dv) + std::get<0>(dw0area_dv);
const T dw1_darea = -e1 / (area2);
const T dw1_e1 = area_inv;
const T dloss_d_w1area = grad_w1 * dw1_darea;
const T dloss_e1 = grad_w1 * dw1_e1;
auto de1_dv = EdgeFunctionBackward(p, v2, v0, dloss_e1);
auto dw1area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w1area);
const vec2<T> dw1_p = std::get<0>(de1_dv);
const vec2<T> dw1_dv0 = std::get<2>(de1_dv) + std::get<1>(dw1area_dv);
const vec2<T> dw1_dv1 = std::get<2>(dw1area_dv);
const vec2<T> dw1_dv2 = std::get<1>(de1_dv) + std::get<0>(dw1area_dv);
const T dw2_darea = -e2 / (area2);
const T dw2_e2 = area_inv;
const T dloss_d_w2area = grad_w2 * dw2_darea;
const T dloss_e2 = grad_w2 * dw2_e2;
auto de2_dv = EdgeFunctionBackward(p, v0, v1, dloss_e2);
auto dw2area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w2area);
const vec2<T> dw2_p = std::get<0>(de2_dv);
const vec2<T> dw2_dv0 = std::get<1>(de2_dv) + std::get<1>(dw2area_dv);
const vec2<T> dw2_dv1 = std::get<2>(de2_dv) + std::get<2>(dw2area_dv);
const vec2<T> dw2_dv2 = std::get<0>(dw2area_dv);
const vec2<T> dbary_p = dw0_p + dw1_p + dw2_p;
const vec2<T> dbary_dv0 = dw0_dv0 + dw1_dv0 + dw2_dv0;
const vec2<T> dbary_dv1 = dw0_dv1 + dw1_dv1 + dw2_dv1;
const vec2<T> dbary_dv2 = dw0_dv2 + dw1_dv2 + dw2_dv2;
return std::make_tuple(dbary_p, dbary_dv0, dbary_dv1, dbary_dv2);
}
// Forward pass for applying perspective correction to barycentric coordinates.
//
// Args:
// bary: Screen-space barycentric coordinates for a point
// z0, z1, z2: Camera-space z-coordinates of the triangle vertices
//
// Returns
// World-space barycentric coordinates
//
template <typename T>
inline vec3<T> BarycentricPerspectiveCorrectionForward(
const vec3<T>& bary,
const T z0,
const T z1,
const T z2) {
const T w0_top = bary.x * z1 * z2;
const T w1_top = bary.y * z0 * z2;
const T w2_top = bary.z * z0 * z1;
const T denom = w0_top + w1_top + w2_top;
const T w0 = w0_top / denom;
const T w1 = w1_top / denom;
const T w2 = w2_top / denom;
return vec3<T>(w0, w1, w2);
}
// Backward pass for applying perspective correction to barycentric coordinates.
//
// Args:
// bary: Screen-space barycentric coordinates for a point
// z0, z1, z2: Camera-space z-coordinates of the triangle vertices
// grad_out: Upstream gradient of the loss with respect to the corrected
// barycentric coordinates.
//
// Returns a tuple of:
// grad_bary: Downstream gradient of the loss with respect to the the
// uncorrected barycentric coordinates.
// grad_z0, grad_z1, grad_z2: Downstream gradient of the loss with respect
// to the z-coordinates of the triangle verts
template <typename T>
inline std::tuple<vec3<T>, T, T, T> BarycentricPerspectiveCorrectionBackward(
const vec3<T>& bary,
const T z0,
const T z1,
const T z2,
const vec3<T>& grad_out) {
// Recompute forward pass
const T w0_top = bary.x * z1 * z2;
const T w1_top = bary.y * z0 * z2;
const T w2_top = bary.z * z0 * z1;
const T denom = w0_top + w1_top + w2_top;
// Now do backward pass
const T grad_denom_top =
-w0_top * grad_out.x - w1_top * grad_out.y - w2_top * grad_out.z;
const T grad_denom = grad_denom_top / (denom * denom);
const T grad_w0_top = grad_denom + grad_out.x / denom;
const T grad_w1_top = grad_denom + grad_out.y / denom;
const T grad_w2_top = grad_denom + grad_out.z / denom;
const T grad_bary_x = grad_w0_top * z1 * z2;
const T grad_bary_y = grad_w1_top * z0 * z2;
const T grad_bary_z = grad_w2_top * z0 * z1;
const vec3<T> grad_bary(grad_bary_x, grad_bary_y, grad_bary_z);
const T grad_z0 = grad_w1_top * bary.y * z2 + grad_w2_top * bary.z * z1;
const T grad_z1 = grad_w0_top * bary.x * z2 + grad_w2_top * bary.z * z0;
const T grad_z2 = grad_w0_top * bary.x * z1 + grad_w1_top * bary.y * z0;
return std::make_tuple(grad_bary, grad_z0, grad_z1, grad_z2);
}
// Calculate minimum distance between a line segment (v1 - v0) and point p.
//
// Args:
// p: Coordinates of a point.
// v0, v1: Coordinates of the end points of the line segment.
//
// Returns:
// non-square distance of the point to the line.
//
// Consider the line extending the segment - this can be parameterized as:
// v0 + t (v1 - v0).
//
// First find the projection of point p onto the line. It falls where:
// t = [(p - v0) . (v1 - v0)] / |v1 - v0|^2
// where . is the dot product.
//
// The parameter t is clamped from [0, 1] to handle points outside the
// segment (v1 - v0).
//
// Once the projection of the point on the segment is known, the distance from
// p to the projection gives the minimum distance to the segment.
//
template <typename T>
T PointLineDistanceForward(
const vec2<T>& p,
const vec2<T>& v0,
const vec2<T>& v1) {
const vec2<T> v1v0 = v1 - v0;
const T l2 = dot(v1v0, v1v0);
if (l2 <= kEpsilon) {
return sqrt(dot(p - v1, p - v1));
}
const T t = dot(v1v0, p - v0) / l2;
const T tt = std::min(std::max(t, 0.00f), 1.00f);
const vec2<T> p_proj = v0 + tt * v1v0;
return dot(p - p_proj, p - p_proj);
}
// Backward pass for point to line distance in 2D.
//
// Args:
// p: Coordinates of a point.
// v0, v1: Coordinates of the end points of the line segment.
// grad_dist: Upstream gradient for the distance.
//
// Returns:
// tuple of gradients for each of the input points:
// (vec2<T> grad_p, vec2<T> grad_v0, vec2<T> grad_v1)
//
template <typename T>
inline std::tuple<vec2<T>, vec2<T>, vec2<T>> PointLineDistanceBackward(
const vec2<T>& p,
const vec2<T>& v0,
const vec2<T>& v1,
const T& grad_dist) {
// Redo some of the forward pass calculations.
const vec2<T> v1v0 = v1 - v0;
const vec2<T> pv0 = p - v0;
const T t_bot = dot(v1v0, v1v0);
const T t_top = dot(v1v0, pv0);
const T t = t_top / t_bot;
const T tt = std::min(std::max(t, 0.00f), 1.00f);
const vec2<T> p_proj = (1.0f - tt) * v0 + tt * v1;
const vec2<T> grad_v0 = grad_dist * (1.0f - tt) * 2.0f * (p_proj - p);
const vec2<T> grad_v1 = grad_dist * tt * 2.0f * (p_proj - p);
const vec2<T> grad_p = -1.0f * grad_dist * 2.0f * (p_proj - p);
return std::make_tuple(grad_p, grad_v0, grad_v1);
}
// The forward pass for calculating the shortest distance between a point
// and a triangle.
// Ref: https://www.randygaul.net/2014/07/23/distance-point-to-line-segment/
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: Coordinates of the three triangle vertices.
//
// Returns:
// shortest absolute distance from a point to a triangle.
//
//
template <typename T>
T PointTriangleDistanceForward(
const vec2<T>& p,
const vec2<T>& v0,
const vec2<T>& v1,
const vec2<T>& v2) {
// Compute distance of point to 3 edges of the triangle and return the
// minimum value.
const T e01_dist = PointLineDistanceForward(p, v0, v1);
const T e02_dist = PointLineDistanceForward(p, v0, v2);
const T e12_dist = PointLineDistanceForward(p, v1, v2);
const T edge_dist = std::min(std::min(e01_dist, e02_dist), e12_dist);
return edge_dist;
}
// Backward pass for point triangle distance.
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: Coordinates of the three triangle vertices.
// grad_dist: Upstream gradient for the distance.
//
// Returns:
// tuple of gradients for each of the triangle vertices:
// (vec2<T> grad_v0, vec2<T> grad_v1, vec2<T> grad_v2)
//
template <typename T>
inline std::tuple<vec2<T>, vec2<T>, vec2<T>, vec2<T>>
PointTriangleDistanceBackward(
const vec2<T>& p,
const vec2<T>& v0,
const vec2<T>& v1,
const vec2<T>& v2,
const T& grad_dist) {
// Compute distance to all 3 edges of the triangle.
const T e01_dist = PointLineDistanceForward(p, v0, v1);
const T e02_dist = PointLineDistanceForward(p, v0, v2);
const T e12_dist = PointLineDistanceForward(p, v1, v2);
// Initialize output tensors.
vec2<T> grad_v0(0.0f, 0.0f);
vec2<T> grad_v1(0.0f, 0.0f);
vec2<T> grad_v2(0.0f, 0.0f);
vec2<T> grad_p(0.0f, 0.0f);
// Find which edge is the closest and return PointLineDistanceBackward for
// that edge.
if (e01_dist <= e02_dist && e01_dist <= e12_dist) {
// Closest edge is v1 - v0.
auto grad_e01 = PointLineDistanceBackward(p, v0, v1, grad_dist);
grad_p = std::get<0>(grad_e01);
grad_v0 = std::get<1>(grad_e01);
grad_v1 = std::get<2>(grad_e01);
} else if (e02_dist <= e01_dist && e02_dist <= e12_dist) {
// Closest edge is v2 - v0.
auto grad_e02 = PointLineDistanceBackward(p, v0, v2, grad_dist);
grad_p = std::get<0>(grad_e02);
grad_v0 = std::get<1>(grad_e02);
grad_v2 = std::get<2>(grad_e02);
} else if (e12_dist <= e01_dist && e12_dist <= e02_dist) {
// Closest edge is v2 - v1.
auto grad_e12 = PointLineDistanceBackward(p, v1, v2, grad_dist);
grad_p = std::get<0>(grad_e12);
grad_v1 = std::get<1>(grad_e12);
grad_v2 = std::get<2>(grad_e12);
}
return std::make_tuple(grad_p, grad_v0, grad_v1, grad_v2);
}
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <float.h>
#include <math.h>
#include <thrust/tuple.h>
#include <torch/extension.h>
#include <cstdio>
#include <tuple>
#include "float_math.cuh"
#include "geometry_utils.cuh"
#include "rasterize_points/bitmask.cuh"
#include "rasterize_points/rasterization_utils.cuh"
namespace {
// A structure for holding details about a pixel.
struct Pixel {
float z;
int64_t idx;
float dist;
float3 bary;
};
__device__ bool operator<(const Pixel& a, const Pixel& b) {
return a.z < b.z;
}
__device__ float FloatMin3(const float p1, const float p2, const float p3) {
return fminf(p1, fminf(p2, p3));
}
__device__ float FloatMax3(const float p1, const float p2, const float p3) {
return fmaxf(p1, fmaxf(p2, p3));
}
// Get the xyz coordinates of the three vertices for the face given by the
// index face_idx into face_verts.
__device__ thrust::tuple<float3, float3, float3> GetSingleFaceVerts(
const float* face_verts,
int face_idx) {
const float x0 = face_verts[face_idx * 9 + 0];
const float y0 = face_verts[face_idx * 9 + 1];
const float z0 = face_verts[face_idx * 9 + 2];
const float x1 = face_verts[face_idx * 9 + 3];
const float y1 = face_verts[face_idx * 9 + 4];
const float z1 = face_verts[face_idx * 9 + 5];
const float x2 = face_verts[face_idx * 9 + 6];
const float y2 = face_verts[face_idx * 9 + 7];
const float z2 = face_verts[face_idx * 9 + 8];
const float3 v0xyz = make_float3(x0, y0, z0);
const float3 v1xyz = make_float3(x1, y1, z1);
const float3 v2xyz = make_float3(x2, y2, z2);
return thrust::make_tuple(v0xyz, v1xyz, v2xyz);
}
// Get the min/max x/y/z values for the face given by vertices v0, v1, v2.
__device__ thrust::tuple<float2, float2, float2>
GetFaceBoundingBox(float3 v0, float3 v1, float3 v2) {
const float xmin = FloatMin3(v0.x, v1.x, v2.x);
const float ymin = FloatMin3(v0.y, v1.y, v2.y);
const float zmin = FloatMin3(v0.z, v1.z, v2.z);
const float xmax = FloatMax3(v0.x, v1.x, v2.x);
const float ymax = FloatMax3(v0.y, v1.y, v2.y);
const float zmax = FloatMax3(v0.z, v1.z, v2.z);
return thrust::make_tuple(
make_float2(xmin, xmax),
make_float2(ymin, ymax),
make_float2(zmin, zmax));
}
// Check if the point (px, py) lies outside the face bounding box face_bbox.
// Return true if the point is outside.
__device__ bool CheckPointOutsideBoundingBox(
float3 v0,
float3 v1,
float3 v2,
float blur_radius,
float2 pxy) {
const auto bbox = GetFaceBoundingBox(v0, v1, v2);
const float2 xlims = thrust::get<0>(bbox);
const float2 ylims = thrust::get<1>(bbox);
const float2 zlims = thrust::get<2>(bbox);
const float x_min = xlims.x - blur_radius;
const float y_min = ylims.x - blur_radius;
const float x_max = xlims.y + blur_radius;
const float y_max = ylims.y + blur_radius;
// Check if the current point is oustside the triangle bounding box.
return (pxy.x > x_max || pxy.x < x_min || pxy.y > y_max || pxy.y < y_min);
}
// This function checks if a pixel given by xy location pxy lies within the
// face with index face_idx in face_verts. One of the inputs is a list (q)
// which contains Pixel structs with the indices of the faces which intersect
// with this pixel sorted by closest z distance. If the point pxy lies in the
// face, the list (q) is updated and re-orderered in place. In addition
// the auxillary variables q_size, q_max_z and q_max_idx are also modified.
// This code is shared between RasterizeMeshesNaiveCudaKernel and
// RasterizeMeshesFineCudaKernel.
template <typename FaceQ>
__device__ void CheckPixelInsideFace(
const float* face_verts, // (N, P, 3)
int face_idx,
int& q_size,
float& q_max_z,
int& q_max_idx,
FaceQ& q,
float blur_radius,
float2 pxy, // Coordinates of the pixel
int K,
bool perspective_correct) {
const auto v012 = GetSingleFaceVerts(face_verts, face_idx);
const float3 v0 = thrust::get<0>(v012);
const float3 v1 = thrust::get<1>(v012);
const float3 v2 = thrust::get<2>(v012);
// Only need xy for barycentric coordinates and distance calculations.
const float2 v0xy = make_float2(v0.x, v0.y);
const float2 v1xy = make_float2(v1.x, v1.y);
const float2 v2xy = make_float2(v2.x, v2.y);
// Perform checks and skip if:
// 1. the face is behind the camera
// 2. the face has very small face area
// 3. the pixel is outside the face bbox
const float zmax = FloatMax3(v0.z, v1.z, v2.z);
const bool outside_bbox = CheckPointOutsideBoundingBox(
v0, v1, v2, sqrt(blur_radius), pxy); // use sqrt of blur for bbox
const float face_area = EdgeFunctionForward(v0xy, v1xy, v2xy);
const bool zero_face_area =
(face_area <= kEpsilon && face_area >= -1.0f * kEpsilon);
if (zmax < 0 || outside_bbox || zero_face_area) {
return;
}
// Calculate barycentric coords and euclidean dist to triangle.
const float3 p_bary0 = BarycentricCoordsForward(pxy, v0xy, v1xy, v2xy);
const float3 p_bary = !perspective_correct
? p_bary0
: BarycentricPerspectiveCorrectionForward(p_bary0, v0.z, v1.z, v2.z);
const float pz = p_bary.x * v0.z + p_bary.y * v1.z + p_bary.z * v2.z;
if (pz < 0) {
return; // Face is behind the image plane.
}
// Get abs squared distance
const float dist = PointTriangleDistanceForward(pxy, v0xy, v1xy, v2xy);
// Use the bary coordinates to determine if the point is inside the face.
const bool inside = p_bary.x > 0.0f && p_bary.y > 0.0f && p_bary.z > 0.0f;
const float signed_dist = inside ? -dist : dist;
// Check if pixel is outside blur region
if (!inside && dist >= blur_radius) {
return;
}
if (q_size < K) {
// Just insert it.
q[q_size] = {pz, face_idx, signed_dist, p_bary};
if (pz > q_max_z) {
q_max_z = pz;
q_max_idx = q_size;
}
q_size++;
} else if (pz < q_max_z) {
// Overwrite the old max, and find the new max.
q[q_max_idx] = {pz, face_idx, signed_dist, p_bary};
q_max_z = pz;
for (int i = 0; i < K; i++) {
if (q[i].z > q_max_z) {
q_max_z = q[i].z;
q_max_idx = i;
}
}
}
}
} // namespace
// ****************************************************************************
// * NAIVE RASTERIZATION *
// ****************************************************************************
__global__ void RasterizeMeshesNaiveCudaKernel(
const float* face_verts,
const int64_t* mesh_to_face_first_idx,
const int64_t* num_faces_per_mesh,
float blur_radius,
bool perspective_correct,
int N,
int H,
int W,
int K,
int64_t* face_idxs,
float* zbuf,
float* pix_dists,
float* bary) {
// Simple version: One thread per output pixel
int num_threads = gridDim.x * blockDim.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = tid; i < N * H * W; i += num_threads) {
// Convert linear index to 3D index
const int n = i / (H * W); // batch index.
const int pix_idx = i % (H * W);
const int yi = pix_idx / H;
const int xi = pix_idx % W;
// screen coordinates to ndc coordiantes of pixel.
const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H);
const float2 pxy = make_float2(xf, yf);
// For keeping track of the K closest points we want a data structure
// that (1) gives O(1) access to the closest point for easy comparisons,
// and (2) allows insertion of new elements. In the CPU version we use
// std::priority_queue; then (2) is O(log K). We can't use STL
// containers in CUDA; we could roll our own max heap in an array, but
// that would likely have a lot of warp divergence so we do something
// simpler instead: keep the elements in an unsorted array, but keep
// track of the max value and the index of the max value. Then (1) is
// still O(1) time, while (2) is O(K) with a clean loop. Since K <= 8
// this should be fast enough for our purposes.
Pixel q[kMaxPointsPerPixel];
int q_size = 0;
float q_max_z = -1000;
int q_max_idx = -1;
// Using the batch index of the thread get the start and stop
// indices for the faces.
const int64_t face_start_idx = mesh_to_face_first_idx[n];
const int64_t face_stop_idx = face_start_idx + num_faces_per_mesh[n];
// Loop through the faces in the mesh.
for (int f = face_start_idx; f < face_stop_idx; ++f) {
// Check if the pixel pxy is inside the face bounding box and if it is,
// update q, q_size, q_max_z and q_max_idx in place.
CheckPixelInsideFace(
face_verts,
f,
q_size,
q_max_z,
q_max_idx,
q,
blur_radius,
pxy,
K,
perspective_correct);
}
// TODO: make sorting an option as only top k is needed, not sorted values.
BubbleSort(q, q_size);
int idx = n * H * W * K + yi * H * K + xi * K;
for (int k = 0; k < q_size; ++k) {
face_idxs[idx + k] = q[k].idx;
zbuf[idx + k] = q[k].z;
pix_dists[idx + k] = q[k].dist;
bary[(idx + k) * 3 + 0] = q[k].bary.x;
bary[(idx + k) * 3 + 1] = q[k].bary.y;
bary[(idx + k) * 3 + 2] = q[k].bary.z;
}
}
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesNaiveCuda(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_faces_packed_first_idx,
const torch::Tensor& num_faces_per_mesh,
const int image_size,
const float blur_radius,
const int num_closest,
bool perspective_correct) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
}
if (num_faces_per_mesh.size(0) != mesh_to_faces_packed_first_idx.size(0)) {
AT_ERROR(
"num_faces_per_mesh must have save size first dimension as mesh_to_faces_packed_first_idx");
}
if (num_closest > kMaxPointsPerPixel) {
std::stringstream ss;
ss << "Must have points_per_pixel <= " << kMaxPointsPerPixel;
AT_ERROR(ss.str());
}
const int N = num_faces_per_mesh.size(0); // batch size.
const int H = image_size; // Assume square images.
const int W = image_size;
const int K = num_closest;
auto long_opts = face_verts.options().dtype(torch::kInt64);
auto float_opts = face_verts.options().dtype(torch::kFloat32);
torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts);
torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor bary = torch::full({N, H, W, K, 3}, -1, float_opts);
const size_t blocks = 1024;
const size_t threads = 64;
RasterizeMeshesNaiveCudaKernel<<<blocks, threads>>>(
face_verts.contiguous().data<float>(),
mesh_to_faces_packed_first_idx.contiguous().data<int64_t>(),
num_faces_per_mesh.contiguous().data<int64_t>(),
blur_radius,
perspective_correct,
N,
H,
W,
K,
face_idxs.contiguous().data<int64_t>(),
zbuf.contiguous().data<float>(),
pix_dists.contiguous().data<float>(),
bary.contiguous().data<float>());
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
}
// ****************************************************************************
// * BACKWARD PASS *
// ****************************************************************************
// TODO: benchmark parallelizing over faces_verts instead of over pixels.
__global__ void RasterizeMeshesBackwardCudaKernel(
const float* face_verts, // (F, 3, 3)
const int64_t* pix_to_face, // (N, H, W, K)
bool perspective_correct,
int N,
int F,
int H,
int W,
int K,
const float* grad_zbuf, // (N, H, W, K)
const float* grad_bary, // (N, H, W, K, 3)
const float* grad_dists, // (N, H, W, K)
float* grad_face_verts) { // (F, 3, 3)
// Parallelize over each pixel in images of
// size H * W, for each image in the batch of size N.
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int t_i = tid; t_i < N * H * W; t_i += num_threads) {
// Convert linear index to 3D index
const int n = t_i / (H * W); // batch index.
const int pix_idx = t_i % (H * W);
const int yi = pix_idx / H;
const int xi = pix_idx % W;
const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H);
const float2 pxy = make_float2(xf, yf);
// Loop over all the faces for this pixel.
for (int k = 0; k < K; k++) {
// Index into (N, H, W, K, :) grad tensors
const int i =
n * H * W * K + yi * H * K + xi * K + k; // pixel index + face index
const int f = pix_to_face[i];
if (f < 0) {
continue; // padded face.
}
// Get xyz coordinates of the three face vertices.
const auto v012 = GetSingleFaceVerts(face_verts, f);
const float3 v0 = thrust::get<0>(v012);
const float3 v1 = thrust::get<1>(v012);
const float3 v2 = thrust::get<2>(v012);
// Only neex xy for barycentric coordinate and distance calculations.
const float2 v0xy = make_float2(v0.x, v0.y);
const float2 v1xy = make_float2(v1.x, v1.y);
const float2 v2xy = make_float2(v2.x, v2.y);
// Get upstream gradients for the face.
const float grad_dist_upstream = grad_dists[i];
const float grad_zbuf_upstream = grad_zbuf[i];
const float grad_bary_upstream_w0 = grad_bary[i * 3 + 0];
const float grad_bary_upstream_w1 = grad_bary[i * 3 + 1];
const float grad_bary_upstream_w2 = grad_bary[i * 3 + 2];
const float3 grad_bary_upstream = make_float3(
grad_bary_upstream_w0, grad_bary_upstream_w1, grad_bary_upstream_w2);
const float3 bary0 = BarycentricCoordsForward(pxy, v0xy, v1xy, v2xy);
const float3 bary = !perspective_correct
? bary0
: BarycentricPerspectiveCorrectionForward(bary0, v0.z, v1.z, v2.z);
const bool inside = bary.x > 0.0f && bary.y > 0.0f && bary.z > 0.0f;
const float sign = inside ? -1.0f : 1.0f;
// TODO(T52813608) Add support for non-square images.
auto grad_dist_f = PointTriangleDistanceBackward(
pxy, v0xy, v1xy, v2xy, sign * grad_dist_upstream);
const float2 ddist_d_v0 = thrust::get<1>(grad_dist_f);
const float2 ddist_d_v1 = thrust::get<2>(grad_dist_f);
const float2 ddist_d_v2 = thrust::get<3>(grad_dist_f);
// Upstream gradient for barycentric coords from zbuf calculation:
// zbuf = bary_w0 * z0 + bary_w1 * z1 + bary_w2 * z2
// Therefore
// d_zbuf/d_bary_w0 = z0
// d_zbuf/d_bary_w1 = z1
// d_zbuf/d_bary_w2 = z2
const float3 d_zbuf_d_bary = make_float3(v0.z, v1.z, v2.z);
// Total upstream barycentric gradients are the sum of
// external upstream gradients and contribution from zbuf.
const float3 grad_bary_f_sum =
(grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bary);
float3 grad_bary0 = grad_bary_f_sum;
float dz0_persp = 0.0f, dz1_persp = 0.0f, dz2_persp = 0.0f;
if (perspective_correct) {
auto perspective_grads = BarycentricPerspectiveCorrectionBackward(
bary0, v0.z, v1.z, v2.z, grad_bary_f_sum);
grad_bary0 = thrust::get<0>(perspective_grads);
dz0_persp = thrust::get<1>(perspective_grads);
dz1_persp = thrust::get<2>(perspective_grads);
dz2_persp = thrust::get<3>(perspective_grads);
}
auto grad_bary_f =
BarycentricCoordsBackward(pxy, v0xy, v1xy, v2xy, grad_bary0);
const float2 dbary_d_v0 = thrust::get<1>(grad_bary_f);
const float2 dbary_d_v1 = thrust::get<2>(grad_bary_f);
const float2 dbary_d_v2 = thrust::get<3>(grad_bary_f);
atomicAdd(grad_face_verts + f * 9 + 0, dbary_d_v0.x + ddist_d_v0.x);
atomicAdd(grad_face_verts + f * 9 + 1, dbary_d_v0.y + ddist_d_v0.y);
atomicAdd(
grad_face_verts + f * 9 + 2, grad_zbuf_upstream * bary.x + dz0_persp);
atomicAdd(grad_face_verts + f * 9 + 3, dbary_d_v1.x + ddist_d_v1.x);
atomicAdd(grad_face_verts + f * 9 + 4, dbary_d_v1.y + ddist_d_v1.y);
atomicAdd(
grad_face_verts + f * 9 + 5, grad_zbuf_upstream * bary.y + dz1_persp);
atomicAdd(grad_face_verts + f * 9 + 6, dbary_d_v2.x + ddist_d_v2.x);
atomicAdd(grad_face_verts + f * 9 + 7, dbary_d_v2.y + ddist_d_v2.y);
atomicAdd(
grad_face_verts + f * 9 + 8, grad_zbuf_upstream * bary.z + dz2_persp);
}
}
}
torch::Tensor RasterizeMeshesBackwardCuda(
const torch::Tensor& face_verts, // (F, 3, 3)
const torch::Tensor& pix_to_face, // (N, H, W, K)
const torch::Tensor& grad_zbuf, // (N, H, W, K)
const torch::Tensor& grad_bary, // (N, H, W, K, 3)
const torch::Tensor& grad_dists, // (N, H, W, K)
bool perspective_correct) {
const int F = face_verts.size(0);
const int N = pix_to_face.size(0);
const int H = pix_to_face.size(1);
const int W = pix_to_face.size(2);
const int K = pix_to_face.size(3);
torch::Tensor grad_face_verts = torch::zeros({F, 3, 3}, face_verts.options());
const size_t blocks = 1024;
const size_t threads = 64;
RasterizeMeshesBackwardCudaKernel<<<blocks, threads>>>(
face_verts.contiguous().data<float>(),
pix_to_face.contiguous().data<int64_t>(),
perspective_correct,
N,
F,
H,
W,
K,
grad_zbuf.contiguous().data<float>(),
grad_bary.contiguous().data<float>(),
grad_dists.contiguous().data<float>(),
grad_face_verts.contiguous().data<float>());
return grad_face_verts;
}
// ****************************************************************************
// * COARSE RASTERIZATION *
// ****************************************************************************
__global__ void RasterizeMeshesCoarseCudaKernel(
const float* face_verts,
const int64_t* mesh_to_face_first_idx,
const int64_t* num_faces_per_mesh,
const float blur_radius,
const int N,
const int F,
const int H,
const int W,
const int bin_size,
const int chunk_size,
const int max_faces_per_bin,
int* faces_per_bin,
int* bin_faces) {
extern __shared__ char sbuf[];
const int M = max_faces_per_bin;
const int num_bins = 1 + (W - 1) / bin_size; // Integer divide round up
const float half_pix = 1.0f / W; // Size of half a pixel in NDC units
// This is a boolean array of shape (num_bins, num_bins, chunk_size)
// stored in shared memory that will track whether each point in the chunk
// falls into each bin of the image.
BitMask binmask((unsigned int*)sbuf, num_bins, num_bins, chunk_size);
// Have each block handle a chunk of faces
const int chunks_per_batch = 1 + (F - 1) / chunk_size;
const int num_chunks = N * chunks_per_batch;
for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
const int batch_idx = chunk / chunks_per_batch; // batch index
const int chunk_idx = chunk % chunks_per_batch;
const int face_start_idx = chunk_idx * chunk_size;
binmask.block_clear();
const int64_t mesh_face_start_idx = mesh_to_face_first_idx[batch_idx];
const int64_t mesh_face_stop_idx =
mesh_face_start_idx + num_faces_per_mesh[batch_idx];
// Have each thread handle a different face within the chunk
for (int f = threadIdx.x; f < chunk_size; f += blockDim.x) {
const int f_idx = face_start_idx + f;
// Check if face index corresponds to the mesh in the batch given by
// batch_idx
if (f_idx >= mesh_face_stop_idx || f_idx < mesh_face_start_idx) {
continue;
}
// Get xyz coordinates of the three face vertices.
const auto v012 = GetSingleFaceVerts(face_verts, f_idx);
const float3 v0 = thrust::get<0>(v012);
const float3 v1 = thrust::get<1>(v012);
const float3 v2 = thrust::get<2>(v012);
// Compute screen-space bbox for the triangle expanded by blur.
float xmin = FloatMin3(v0.x, v1.x, v2.x) - sqrt(blur_radius);
float ymin = FloatMin3(v0.y, v1.y, v2.y) - sqrt(blur_radius);
float xmax = FloatMax3(v0.x, v1.x, v2.x) + sqrt(blur_radius);
float ymax = FloatMax3(v0.y, v1.y, v2.y) + sqrt(blur_radius);
float zmax = FloatMax3(v0.z, v1.z, v2.z);
if (zmax < 0) {
continue; // Face is behind the camera.
}
// Brute-force search over all bins; TODO(T54294966) something smarter.
for (int by = 0; by < num_bins; ++by) {
// Y coordinate of the top and bottom of the bin.
// PixToNdc gives the location of the center of each pixel, so we
// need to add/subtract a half pixel to get the true extent of the bin.
const float bin_y_min = PixToNdc(by * bin_size, H) - half_pix;
const float bin_y_max = PixToNdc((by + 1) * bin_size - 1, H) + half_pix;
const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax);
for (int bx = 0; bx < num_bins; ++bx) {
// X coordinate of the left and right of the bin.
const float bin_x_min = PixToNdc(bx * bin_size, W) - half_pix;
const float bin_x_max =
PixToNdc((bx + 1) * bin_size - 1, W) + half_pix;
const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
if (y_overlap && x_overlap) {
binmask.set(by, bx, f);
}
}
}
}
__syncthreads();
// Now we have processed every face in the current chunk. We need to
// count the number of faces in each bin so we can write the indices
// out to global memory. We have each thread handle a different bin.
for (int byx = threadIdx.x; byx < num_bins * num_bins; byx += blockDim.x) {
const int by = byx / num_bins;
const int bx = byx % num_bins;
const int count = binmask.count(by, bx);
const int faces_per_bin_idx =
batch_idx * num_bins * num_bins + by * num_bins + bx;
// This atomically increments the (global) number of faces found
// in the current bin, and gets the previous value of the counter;
// this effectively allocates space in the bin_faces array for the
// faces in the current chunk that fall into this bin.
const int start = atomicAdd(faces_per_bin + faces_per_bin_idx, count);
// Now loop over the binmask and write the active bits for this bin
// out to bin_faces.
int next_idx = batch_idx * num_bins * num_bins * M + by * num_bins * M +
bx * M + start;
for (int f = 0; f < chunk_size; ++f) {
if (binmask.get(by, bx, f)) {
// TODO(T54296346) find the correct method for handling errors in
// CUDA. Throw an error if num_faces_per_bin > max_faces_per_bin.
// Either decrease bin size or increase max_faces_per_bin
bin_faces[next_idx] = face_start_idx + f;
next_idx++;
}
}
}
__syncthreads();
}
}
torch::Tensor RasterizeMeshesCoarseCuda(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
const int image_size,
const float blur_radius,
const int bin_size,
const int max_faces_per_bin) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
}
const int W = image_size;
const int H = image_size;
const int F = face_verts.size(0);
const int N = num_faces_per_mesh.size(0);
const int num_bins = 1 + (image_size - 1) / bin_size; // Divide round up.
const int M = max_faces_per_bin;
if (num_bins >= 22) {
std::stringstream ss;
ss << "Got " << num_bins << "; that's too many!";
AT_ERROR(ss.str());
}
auto opts = face_verts.options().dtype(torch::kInt32);
torch::Tensor faces_per_bin = torch::zeros({N, num_bins, num_bins}, opts);
torch::Tensor bin_faces = torch::full({N, num_bins, num_bins, M}, -1, opts);
const int chunk_size = 512;
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
const size_t blocks = 64;
const size_t threads = 512;
RasterizeMeshesCoarseCudaKernel<<<blocks, threads, shared_size>>>(
face_verts.contiguous().data<float>(),
mesh_to_face_first_idx.contiguous().data<int64_t>(),
num_faces_per_mesh.contiguous().data<int64_t>(),
blur_radius,
N,
F,
H,
W,
bin_size,
chunk_size,
M,
faces_per_bin.contiguous().data<int32_t>(),
bin_faces.contiguous().data<int32_t>());
return bin_faces;
}
// ****************************************************************************
// * FINE RASTERIZATION *
// ****************************************************************************
__global__ void RasterizeMeshesFineCudaKernel(
const float* face_verts, // (F, 3, 3)
const int32_t* bin_faces, // (N, B, B, T)
const float blur_radius,
const int bin_size,
const bool perspective_correct,
const int N,
const int F,
const int B,
const int M,
const int H,
const int W,
const int K,
int64_t* face_idxs, // (N, S, S, K)
float* zbuf, // (N, S, S, K)
float* pix_dists, // (N, S, S, K)
float* bary // (N, S, S, K, 3)
) {
// This can be more than S^2 if S % bin_size != 0
int num_pixels = N * B * B * bin_size * bin_size;
int num_threads = gridDim.x * blockDim.x;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int pid = tid; pid < num_pixels; pid += num_threads) {
// Convert linear index into bin and pixel indices. We make the within
// block pixel ids move the fastest, so that adjacent threads will fall
// into the same bin; this should give them coalesced memory reads when
// they read from faces and bin_faces.
int i = pid;
const int n = i / (B * B * bin_size * bin_size);
i %= B * B * bin_size * bin_size;
const int by = i / (B * bin_size * bin_size);
i %= B * bin_size * bin_size;
const int bx = i / (bin_size * bin_size);
i %= bin_size * bin_size;
const int yi = i / bin_size + by * bin_size;
const int xi = i % bin_size + bx * bin_size;
if (yi >= H || xi >= W)
continue;
const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H);
const float2 pxy = make_float2(xf, yf);
// This part looks like the naive rasterization kernel, except we use
// bin_faces to only look at a subset of faces already known to fall
// in this bin. TODO abstract out this logic into some data structure
// that is shared by both kernels?
Pixel q[kMaxPointsPerPixel];
int q_size = 0;
float q_max_z = -1000;
int q_max_idx = -1;
for (int m = 0; m < M; m++) {
const int f = bin_faces[n * B * B * M + by * B * M + bx * M + m];
if (f < 0) {
continue; // bin_faces uses -1 as a sentinal value.
}
// Check if the pixel pxy is inside the face bounding box and if it is,
// update q, q_size, q_max_z and q_max_idx in place.
CheckPixelInsideFace(
face_verts,
f,
q_size,
q_max_z,
q_max_idx,
q,
blur_radius,
pxy,
K,
perspective_correct);
}
// Now we've looked at all the faces for this bin, so we can write
// output for the current pixel.
// TODO: make sorting an option as only top k is needed, not sorted values.
BubbleSort(q, q_size);
const int pix_idx = n * H * W * K + yi * H * K + xi * K;
for (int k = 0; k < q_size; k++) {
face_idxs[pix_idx + k] = q[k].idx;
zbuf[pix_idx + k] = q[k].z;
pix_dists[pix_idx + k] = q[k].dist;
bary[(pix_idx + k) * 3 + 0] = q[k].bary.x;
bary[(pix_idx + k) * 3 + 1] = q[k].bary.y;
bary[(pix_idx + k) * 3 + 2] = q[k].bary.z;
}
}
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesFineCuda(
const torch::Tensor& face_verts,
const torch::Tensor& bin_faces,
const int image_size,
const float blur_radius,
const int bin_size,
const int faces_per_pixel,
bool perspective_correct) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
}
if (bin_faces.ndimension() != 4) {
AT_ERROR("bin_faces must have 4 dimensions");
}
const int F = face_verts.size(0);
const int N = bin_faces.size(0);
const int B = bin_faces.size(1);
const int M = bin_faces.size(3);
const int K = faces_per_pixel;
const int H = image_size; // Assume square images only.
const int W = image_size;
if (K > kMaxPointsPerPixel) {
AT_ERROR("Must have num_closest <= 8");
}
auto long_opts = face_verts.options().dtype(torch::kInt64);
auto float_opts = face_verts.options().dtype(torch::kFloat32);
torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts);
torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor bary = torch::full({N, H, W, K, 3}, -1, float_opts);
const size_t blocks = 1024;
const size_t threads = 64;
RasterizeMeshesFineCudaKernel<<<blocks, threads>>>(
face_verts.contiguous().data<float>(),
bin_faces.contiguous().data<int32_t>(),
blur_radius,
bin_size,
perspective_correct,
N,
F,
B,
M,
H,
W,
K,
face_idxs.contiguous().data<int64_t>(),
zbuf.contiguous().data<float>(),
pix_dists.contiguous().data<float>(),
bary.contiguous().data<float>());
return std::make_tuple(face_idxs, zbuf, bary, pix_dists);
}
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <torch/extension.h>
#include <cstdio>
#include <tuple>
// ****************************************************************************
// * FORWARD PASS *
// ****************************************************************************
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesNaiveCpu(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int faces_per_pixel,
bool perspective_correct);
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
RasterizeMeshesNaiveCuda(
const at::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx,
const at::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int num_closest,
bool perspective_correct);
// Forward pass for rasterizing a batch of meshes.
//
// Args:
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
// faces in all the meshes in the batch. Concretely,
// face_verts[f, i] = [x, y, z] gives the coordinates for the
// ith vertex of the fth face. These vertices are expected to be
// in NDC coordinates in the range [-1, 1].
// mesh_to_face_first_idx: LongTensor of shape (N) giving the index in
// faces_verts of the first face in each mesh in
// the batch where N is the batch size.
// num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
// for each mesh in the batch.
// image_size: Size in pixels of the output image to be rasterized.
// Assume square images only.
// blur_radius: float distance in NDC coordinates uses to expand the face
// bounding boxes for the rasterization. Set to 0.0 if no blur
// is required.
// faces_per_pixel: the number of closeset faces to rasterize per pixel.
// perspective_correct: Whether to apply perspective correction when
// computing barycentric coordinates. If this is True,
// then this function returns world-space barycentric
// coordinates for each pixel; if this is False then
// this function instead returns screen-space
// barycentric coordinates for each pixel.
//
// Returns:
// A 4 element tuple of:
// pix_to_face: int64 tensor of shape (N, H, W, K) giving the face index of
// each of the closest faces to the pixel in the rasterized
// image, or -1 for pixels that are not covered by any face.
// zbuf: float32 Tensor of shape (N, H, W, K) giving the depth of each of
// the closest faces for each pixel.
// barycentric_coords: float tensor of shape (N, H, W, K, 3) giving
// barycentric coordinates of the pixel with respect to
// each of the closest faces along the z axis, padded
// with -1 for pixels hit by fewer than
// faces_per_pixel faces.
// dists: float tensor of shape (N, H, W, K) giving the euclidean distance
// in the (NDC) x/y plane between each pixel and its K closest
// faces along the z axis padded with -1 for pixels hit by fewer than
// faces_per_pixel faces.
inline std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesNaive(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int faces_per_pixel,
bool perspective_correct) {
// TODO: Better type checking.
if (face_verts.type().is_cuda()) {
return RasterizeMeshesNaiveCuda(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
image_size,
blur_radius,
faces_per_pixel,
perspective_correct);
} else {
return RasterizeMeshesNaiveCpu(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
image_size,
blur_radius,
faces_per_pixel,
perspective_correct);
}
}
// ****************************************************************************
// * BACKWARD PASS *
// ****************************************************************************
torch::Tensor RasterizeMeshesBackwardCpu(
const torch::Tensor& face_verts,
const torch::Tensor& pix_to_face,
const torch::Tensor& grad_bary,
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_dists,
bool perspective_correct);
torch::Tensor RasterizeMeshesBackwardCuda(
const torch::Tensor& face_verts,
const torch::Tensor& pix_to_face,
const torch::Tensor& grad_bary,
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_dists,
bool perspective_correct);
// Args:
// face_verts: float32 Tensor of shape (F, 3, 3) (from forward pass) giving
// (packed) vertex positions for faces in all the meshes in
// the batch.
// pix_to_face: int64 tensor of shape (N, H, W, K) giving the face index of
// each of the closest faces to the pixel in the rasterized
// image, or -1 for pixels that are not covered by any face.
// grad_zbuf: Tensor of shape (N, H, W, K) giving upstream gradients
// d(loss)/d(zbuf) of the zbuf tensor from the forward pass.
// grad_bary: Tensor of shape (N, H, W, K, 3) giving upstream gradients
// d(loss)/d(bary) of the barycentric_coords tensor returned by
// the forward pass.
// grad_dists: Tensor of shape (N, H, W, K) giving upstream gradients
// d(loss)/d(dists) of the dists tensor from the forward pass.
// perspective_correct: Whether to apply perspective correction when
// computing barycentric coordinates. If this is True,
// then this function returns world-space barycentric
// coordinates for each pixel; if this is False then
// this function instead returns screen-space
// barycentric coordinates for each pixel.
//
// Returns:
// grad_face_verts: float32 Tensor of shape (F, 3, 3) giving downstream
// gradients for the face vertices.
torch::Tensor RasterizeMeshesBackward(
const torch::Tensor& face_verts,
const torch::Tensor& pix_to_face,
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_bary,
const torch::Tensor& grad_dists,
bool perspective_correct) {
if (face_verts.type().is_cuda()) {
return RasterizeMeshesBackwardCuda(
face_verts,
pix_to_face,
grad_zbuf,
grad_bary,
grad_dists,
perspective_correct);
} else {
return RasterizeMeshesBackwardCpu(
face_verts,
pix_to_face,
grad_zbuf,
grad_bary,
grad_dists,
perspective_correct);
}
}
// ****************************************************************************
// * COARSE RASTERIZATION *
// ****************************************************************************
torch::Tensor RasterizeMeshesCoarseCpu(
const torch::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx,
const at::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int bin_size,
int max_faces_per_bin);
torch::Tensor RasterizeMeshesCoarseCuda(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int bin_size,
int max_faces_per_bin);
// Args:
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
// faces in all the meshes in the batch. Concretely,
// face_verts[f, i] = [x, y, z] gives the coordinates for the
// ith vertex of the fth face. These vertices are expected to be
// in NDC coordinates in the range [-1, 1].
// mesh_to_face_first_idx: LongTensor of shape (N) giving the index in
// faces_verts of the first face in each mesh in
// the batch where N is the batch size.
// num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
// for each mesh in the batch.
// image_size: Size in pixels of the output image to be rasterized.
// blur_radius: float distance in NDC coordinates uses to expand the face
// bounding boxes for the rasterization. Set to 0.0 if no blur
// is required.
// bin_size: Size of each bin within the image (in pixels)
// max_faces_per_bin: Maximum number of faces to count in each bin.
//
// Returns:
// bin_face_idxs: Tensor of shape (N, num_bins, num_bins, K) giving the
// indices of faces that fall into each bin.
torch::Tensor RasterizeMeshesCoarse(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int bin_size,
int max_faces_per_bin) {
if (face_verts.type().is_cuda()) {
return RasterizeMeshesCoarseCuda(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
image_size,
blur_radius,
bin_size,
max_faces_per_bin);
} else {
return RasterizeMeshesCoarseCpu(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
image_size,
blur_radius,
bin_size,
max_faces_per_bin);
}
}
// ****************************************************************************
// * FINE RASTERIZATION *
// ****************************************************************************
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesFineCuda(
const torch::Tensor& face_verts,
const torch::Tensor& bin_faces,
int image_size,
float blur_radius,
int bin_size,
int faces_per_pixel,
bool perspective_correct);
// Args:
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
// faces in all the meshes in the batch. Concretely,
// face_verts[f, i] = [x, y, z] gives the coordinates for the
// ith vertex of the fth face. These vertices are expected to be
// in NDC coordinates in the range [-1, 1].
// bin_faces: int32 Tensor of shape (N, B, B, M) giving the indices of faces
// that fall into each bin (output from coarse rasterization).
// image_size: Size in pixels of the output image to be rasterized.
// blur_radius: float distance in NDC coordinates uses to expand the face
// bounding boxes for the rasterization. Set to 0.0 if no blur
// is required.
// bin_size: Size of each bin within the image (in pixels)
// faces_per_pixel: the number of closeset faces to rasterize per pixel.
// perspective_correct: Whether to apply perspective correction when
// computing barycentric coordinates. If this is True,
// then this function returns world-space barycentric
// coordinates for each pixel; if this is False then
// this function instead returns screen-space
// barycentric coordinates for each pixel.
//
// Returns (same as rasterize_meshes):
// A 4 element tuple of:
// pix_to_face: int64 tensor of shape (N, H, W, K) giving the face index of
// each of the closest faces to the pixel in the rasterized
// image, or -1 for pixels that are not covered by any face.
// zbuf: float32 Tensor of shape (N, H, W, K) giving the depth of each of
// the closest faces for each pixel.
// barycentric_coords: float tensor of shape (N, H, W, K, 3) giving
// barycentric coordinates of the pixel with respect to
// each of the closest faces along the z axis, padded
// with -1 for pixels hit by fewer than
// faces_per_pixel faces.
// dists: float tensor of shape (N, H, W, K) giving the euclidean distance
// in the (NDC) x/y plane between each pixel and its K closest
// faces along the z axis padded with -1 for pixels hit by fewer than
// faces_per_pixel faces.
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesFine(
const torch::Tensor& face_verts,
const torch::Tensor& bin_faces,
int image_size,
float blur_radius,
int bin_size,
int faces_per_pixel,
bool perspective_correct) {
if (face_verts.type().is_cuda()) {
return RasterizeMeshesFineCuda(
face_verts,
bin_faces,
image_size,
blur_radius,
bin_size,
faces_per_pixel,
perspective_correct);
} else {
AT_ERROR("NOT IMPLEMENTED");
}
}
// ****************************************************************************
// * MAIN ENTRY POINT *
// ****************************************************************************
// This is the main entry point for the forward pass of the mesh rasterizer;
// it uses either naive or coarse-to-fine rasterization based on bin_size.
//
// Args:
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
// faces in all the meshes in the batch. Concretely,
// face_verts[f, i] = [x, y, z] gives the coordinates for the
// ith vertex of the fth face. These vertices are expected to be
// in NDC coordinates in the range [-1, 1].
// mesh_to_face_first_idx: LongTensor of shape (N) giving the index in
// faces_verts of the first face in each mesh in
// the batch where N is the batch size.
// num_faces_per_mesh: LongTensor of shape (N) giving the number of faces
// for each mesh in the batch.
// image_size: Size in pixels of the output image to be rasterized.
// blur_radius: float distance in NDC coordinates uses to expand the face
// bounding boxes for the rasterization. Set to 0.0 if no blur
// is required.
// bin_size: Bin size (in pixels) for coarse-to-fine rasterization. Setting
// bin_size=0 uses naive rasterization instead.
// max_faces_per_bin: The maximum number of faces allowed to fall into each
// bin when using coarse-to-fine rasterization.
// perspective_correct: Whether to apply perspective correction when
// computing barycentric coordinates. If this is True,
// then this function returns world-space barycentric
// coordinates for each pixel; if this is False then
// this function instead returns screen-space
// barycentric coordinates for each pixel.
//
// Returns:
// A 4 element tuple of:
// pix_to_face: int64 tensor of shape (N, H, W, K) giving the face index of
// each of the closest faces to the pixel in the rasterized
// image, or -1 for pixels that are not covered by any face.
// zbuf: float32 Tensor of shape (N, H, W, K) giving the depth of each of
// the closest faces for each pixel.
// barycentric_coords: float tensor of shape (N, H, W, K, 3) giving
// barycentric coordinates of the pixel with respect to
// each of the closest faces along the z axis, padded
// with -1 for pixels hit by fewer than
// faces_per_pixel faces.
// dists: float tensor of shape (N, H, W, K) giving the euclidean distance
// in the (NDC) x/y plane between each pixel and its K closest
// faces along the z axis padded with -1 for pixels hit by fewer than
// faces_per_pixel faces.
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshes(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int faces_per_pixel,
int bin_size,
int max_faces_per_bin,
bool perspective_correct) {
if (bin_size > 0 && max_faces_per_bin > 0) {
// Use coarse-to-fine rasterization
auto bin_faces = RasterizeMeshesCoarse(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
image_size,
blur_radius,
bin_size,
max_faces_per_bin);
return RasterizeMeshesFine(
face_verts,
bin_faces,
image_size,
blur_radius,
bin_size,
faces_per_pixel,
perspective_correct);
} else {
// Use the naive per-pixel implementation
return RasterizeMeshesNaive(
face_verts,
mesh_to_face_first_idx,
num_faces_per_mesh,
image_size,
blur_radius,
faces_per_pixel,
perspective_correct);
}
}
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include <algorithm>
#include <list>
#include <queue>
#include <tuple>
#include "geometry_utils.h"
#include "vec2.h"
#include "vec3.h"
float PixToNdc(int i, int S) {
// NDC x-offset + (i * pixel_width + half_pixel_width)
return -1 + (2 * i + 1.0f) / S;
}
// Get (x, y, z) values for vertex from (3, 3) tensor face.
template <typename Face>
auto ExtractVerts(const Face& face, const int vertex_index) {
return std::make_tuple(
face[vertex_index][0], face[vertex_index][1], face[vertex_index][2]);
}
// Compute min/max x/y for each face.
auto ComputeFaceBoundingBoxes(const torch::Tensor& face_verts) {
const int total_F = face_verts.size(0);
auto float_opts = face_verts.options().dtype(torch::kFloat32);
auto face_verts_a = face_verts.accessor<float, 3>();
torch::Tensor face_bboxes = torch::full({total_F, 6}, -2.0, float_opts);
// Loop through all the faces
for (int f = 0; f < total_F; ++f) {
const auto& face = face_verts_a[f];
float x0, x1, x2, y0, y1, y2, z0, z1, z2;
std::tie(x0, y0, z0) = ExtractVerts(face, 0);
std::tie(x1, y1, z1) = ExtractVerts(face, 1);
std::tie(x2, y2, z2) = ExtractVerts(face, 2);
const float x_min = std::min(x0, std::min(x1, x2));
const float y_min = std::min(y0, std::min(y1, y2));
const float x_max = std::max(x0, std::max(x1, x2));
const float y_max = std::max(y0, std::max(y1, y2));
const float z_min = std::min(z0, std::min(z1, z2));
const float z_max = std::max(z0, std::max(z1, z2));
face_bboxes[f][0] = x_min;
face_bboxes[f][1] = y_min;
face_bboxes[f][2] = x_max;
face_bboxes[f][3] = y_max;
face_bboxes[f][4] = z_min;
face_bboxes[f][5] = z_max;
}
return face_bboxes;
}
// Check if the point (px, py) lies inside the face bounding box face_bbox.
// Return true if the point is outside.
template <typename Face>
bool CheckPointOutsideBoundingBox(
const Face& face_bbox,
float blur_radius,
float px,
float py) {
// Read triangle bbox coordinates and expand by blur radius.
float x_min = face_bbox[0] - blur_radius;
float y_min = face_bbox[1] - blur_radius;
float x_max = face_bbox[2] + blur_radius;
float y_max = face_bbox[3] + blur_radius;
// Check if the current point is within the triangle bounding box.
return (px > x_max || px < x_min || py > y_max || py < y_min);
}
// Calculate areas of all faces. Returns a tensor of shape (total_faces, 1)
// where faces with zero area have value -1.
auto ComputeFaceAreas(const torch::Tensor& face_verts) {
const int total_F = face_verts.size(0);
auto float_opts = face_verts.options().dtype(torch::kFloat32);
auto face_verts_a = face_verts.accessor<float, 3>();
torch::Tensor face_areas = torch::full({total_F}, -1, float_opts);
// Loop through all the faces
for (int f = 0; f < total_F; ++f) {
const auto& face = face_verts_a[f];
float x0, x1, x2, y0, y1, y2, z0, z1, z2;
std::tie(x0, y0, z0) = ExtractVerts(face, 0);
std::tie(x1, y1, z1) = ExtractVerts(face, 1);
std::tie(x2, y2, z2) = ExtractVerts(face, 2);
const vec2<float> v0(x0, y0);
const vec2<float> v1(x1, y1);
const vec2<float> v2(x2, y2);
const float face_area = EdgeFunctionForward(v0, v1, v2);
face_areas[f] = face_area;
}
return face_areas;
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesNaiveCpu(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int faces_per_pixel,
bool perspective_correct) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
}
if (num_faces_per_mesh.size(0) != mesh_to_face_first_idx.size(0)) {
AT_ERROR(
"num_faces_per_mesh must have save size first dimension as mesh_to_face_first_idx");
}
const int32_t N = mesh_to_face_first_idx.size(0); // batch_size.
const int H = image_size;
const int W = image_size;
const int K = faces_per_pixel;
auto long_opts = face_verts.options().dtype(torch::kInt64);
auto float_opts = face_verts.options().dtype(torch::kFloat32);
// Initialize output tensors.
torch::Tensor face_idxs = torch::full({N, H, W, K}, -1, long_opts);
torch::Tensor zbuf = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor pix_dists = torch::full({N, H, W, K}, -1, float_opts);
torch::Tensor barycentric_coords =
torch::full({N, H, W, K, 3}, -1, float_opts);
auto face_verts_a = face_verts.accessor<float, 3>();
auto face_idxs_a = face_idxs.accessor<int64_t, 4>();
auto zbuf_a = zbuf.accessor<float, 4>();
auto pix_dists_a = pix_dists.accessor<float, 4>();
auto barycentric_coords_a = barycentric_coords.accessor<float, 5>();
auto face_bboxes = ComputeFaceBoundingBoxes(face_verts);
auto face_bboxes_a = face_bboxes.accessor<float, 2>();
auto face_areas = ComputeFaceAreas(face_verts);
auto face_areas_a = face_areas.accessor<float, 1>();
for (int n = 0; n < N; ++n) {
// Loop through each mesh in the batch.
// Get the start index of the faces in faces_packed and the num faces
// in the mesh to avoid having to loop through all the faces.
const int face_start_idx = mesh_to_face_first_idx[n].item().to<int32_t>();
const int face_stop_idx =
(face_start_idx + num_faces_per_mesh[n].item().to<int32_t>());
// Iterate through the horizontal lines of the image from top to bottom.
for (int yi = 0; yi < H; ++yi) {
// Y coordinate of the top of the pixel.
const float yf = PixToNdc(yi, H);
// Iterate through pixels on this horizontal line, left to right.
for (int xi = 0; xi < W; ++xi) {
// X coordinate of the left of the pixel.
const float xf = PixToNdc(xi, W);
// Use a priority queue to hold values:
// (z, idx, r, bary.x, bary.y. bary.z)
std::priority_queue<std::tuple<float, int, float, float, float, float>>
q;
// Loop through the faces in the mesh.
for (int f = face_start_idx; f < face_stop_idx; ++f) {
// Get coordinates of three face vertices.
const auto& face = face_verts_a[f];
float x0, x1, x2, y0, y1, y2, z0, z1, z2;
std::tie(x0, y0, z0) = ExtractVerts(face, 0);
std::tie(x1, y1, z1) = ExtractVerts(face, 1);
std::tie(x2, y2, z2) = ExtractVerts(face, 2);
const vec2<float> v0(x0, y0);
const vec2<float> v1(x1, y1);
const vec2<float> v2(x2, y2);
// Skip faces with zero area.
const float face_area = face_areas_a[f];
if (face_area <= kEpsilon && face_area >= -1.0f * kEpsilon) {
continue;
}
// Skip if point is outside the face bounding box.
const auto face_bbox = face_bboxes_a[f];
const bool outside_bbox = CheckPointOutsideBoundingBox(
face_bbox, std::sqrt(blur_radius), xf, yf);
if (outside_bbox) {
continue;
}
// Compute barycentric coordinates and use this to get the
// depth of the point on the triangle.
const vec2<float> pxy(xf, yf);
const vec3<float> bary0 =
BarycentricCoordinatesForward(pxy, v0, v1, v2);
const vec3<float> bary = !perspective_correct
? bary0
: BarycentricPerspectiveCorrectionForward(bary0, z0, z1, z2);
// Use barycentric coordinates to get the depth of the current pixel
const float pz = (bary.x * z0 + bary.y * z1 + bary.z * z2);
if (pz < 0) {
continue; // Point is behind the image plane so ignore.
}
// Compute absolute distance of the point to the triangle.
// If the point is inside the triangle then the distance
// is negative.
const float dist = PointTriangleDistanceForward(pxy, v0, v1, v2);
// Use the bary coordinates to determine if the point is
// inside the face.
const bool inside = bary.x > 0.0f && bary.y > 0.0f && bary.z > 0.0f;
const float signed_dist = inside ? -dist : dist;
// Check if pixel is outside blur region
if (!inside && dist >= blur_radius) {
continue;
}
// The current pixel lies inside the current face.
q.emplace(pz, f, signed_dist, bary.x, bary.y, bary.z);
if (static_cast<int>(q.size()) > K) {
q.pop();
}
}
while (!q.empty()) {
auto t = q.top();
q.pop();
const int i = q.size();
zbuf_a[n][yi][xi][i] = std::get<0>(t);
face_idxs_a[n][yi][xi][i] = std::get<1>(t);
pix_dists_a[n][yi][xi][i] = std::get<2>(t);
barycentric_coords_a[n][yi][xi][i][0] = std::get<3>(t);
barycentric_coords_a[n][yi][xi][i][1] = std::get<4>(t);
barycentric_coords_a[n][yi][xi][i][2] = std::get<5>(t);
}
}
}
}
return std::make_tuple(face_idxs, zbuf, barycentric_coords, pix_dists);
}
torch::Tensor RasterizeMeshesBackwardCpu(
const torch::Tensor& face_verts, // (F, 3, 3)
const torch::Tensor& pix_to_face, // (N, H, W, K)
const torch::Tensor& grad_zbuf, // (N, H, W, K)
const torch::Tensor& grad_bary, // (N, H, W, K, 3)
const torch::Tensor& grad_dists, // (N, H, W, K)
bool perspective_correct) {
const int F = face_verts.size(0);
const int N = pix_to_face.size(0);
const int H = pix_to_face.size(1);
const int W = pix_to_face.size(2);
const int K = pix_to_face.size(3);
torch::Tensor grad_face_verts = torch::zeros({F, 3, 3}, face_verts.options());
auto face_verts_a = face_verts.accessor<float, 3>();
auto pix_to_face_a = pix_to_face.accessor<int64_t, 4>();
auto grad_dists_a = grad_dists.accessor<float, 4>();
auto grad_zbuf_a = grad_zbuf.accessor<float, 4>();
auto grad_bary_a = grad_bary.accessor<float, 5>();
for (int n = 0; n < N; ++n) {
// Iterate through the horizontal lines of the image from top to bottom.
for (int y = 0; y < H; ++y) {
// Y coordinate of the top of the pixel.
const float yf = PixToNdc(y, H);
// Iterate through pixels on this horizontal line, left to right.
for (int x = 0; x < W; ++x) {
// X coordinate of the left of the pixel.
const float xf = PixToNdc(x, W);
const vec2<float> pxy(xf, yf);
// Iterate through the faces that hit this pixel.
for (int k = 0; k < K; ++k) {
// Get face index from forward pass output.
const int f = pix_to_face_a[n][y][x][k];
if (f < 0) {
continue; // padded face.
}
// Get coordinates of the three face vertices.
const auto face_verts_f = face_verts_a[f];
const float x0 = face_verts_f[0][0];
const float y0 = face_verts_f[0][1];
const float z0 = face_verts_f[0][2];
const float x1 = face_verts_f[1][0];
const float y1 = face_verts_f[1][1];
const float z1 = face_verts_f[1][2];
const float x2 = face_verts_f[2][0];
const float y2 = face_verts_f[2][1];
const float z2 = face_verts_f[2][2];
const vec2<float> v0xy(x0, y0);
const vec2<float> v1xy(x1, y1);
const vec2<float> v2xy(x2, y2);
// Get upstream gradients for the face.
const float grad_dist_upstream = grad_dists_a[n][y][x][k];
const float grad_zbuf_upstream = grad_zbuf_a[n][y][x][k];
const auto grad_bary_upstream_w012 = grad_bary_a[n][y][x][k];
const float grad_bary_upstream_w0 = grad_bary_upstream_w012[0];
const float grad_bary_upstream_w1 = grad_bary_upstream_w012[1];
const float grad_bary_upstream_w2 = grad_bary_upstream_w012[2];
const vec3<float> grad_bary_upstream(
grad_bary_upstream_w0,
grad_bary_upstream_w1,
grad_bary_upstream_w2);
const vec3<float> bary0 =
BarycentricCoordinatesForward(pxy, v0xy, v1xy, v2xy);
const vec3<float> bary = !perspective_correct
? bary0
: BarycentricPerspectiveCorrectionForward(bary0, z0, z1, z2);
// Distances inside the face are negative so get the
// correct sign to apply to the upstream gradient.
const bool inside = bary.x > 0.0f && bary.y > 0.0f && bary.z > 0.0f;
const float sign = inside ? -1.0f : 1.0f;
// TODO(T52813608) Add support for non-square images.
const auto grad_dist_f = PointTriangleDistanceBackward(
pxy, v0xy, v1xy, v2xy, sign * grad_dist_upstream);
const auto ddist_d_v0 = std::get<1>(grad_dist_f);
const auto ddist_d_v1 = std::get<2>(grad_dist_f);
const auto ddist_d_v2 = std::get<3>(grad_dist_f);
// Upstream gradient for barycentric coords from zbuf calculation:
// zbuf = bary_w0 * z0 + bary_w1 * z1 + bary_w2 * z2
// Therefore
// d_zbuf/d_bary_w0 = z0
// d_zbuf/d_bary_w1 = z1
// d_zbuf/d_bary_w2 = z2
const vec3<float> d_zbuf_d_bary(z0, z1, z2);
// Total upstream barycentric gradients are the sum of
// external upstream gradients and contribution from zbuf.
vec3<float> grad_bary_f_sum =
(grad_bary_upstream + grad_zbuf_upstream * d_zbuf_d_bary);
vec3<float> grad_bary0 = grad_bary_f_sum;
if (perspective_correct) {
auto perspective_grads = BarycentricPerspectiveCorrectionBackward(
bary0, z0, z1, z2, grad_bary_f_sum);
grad_bary0 = std::get<0>(perspective_grads);
grad_face_verts[f][0][2] += std::get<1>(perspective_grads);
grad_face_verts[f][1][2] += std::get<2>(perspective_grads);
grad_face_verts[f][2][2] += std::get<3>(perspective_grads);
}
auto grad_bary_f =
BarycentricCoordsBackward(pxy, v0xy, v1xy, v2xy, grad_bary0);
const vec2<float> dbary_d_v0 = std::get<1>(grad_bary_f);
const vec2<float> dbary_d_v1 = std::get<2>(grad_bary_f);
const vec2<float> dbary_d_v2 = std::get<3>(grad_bary_f);
// Update output gradient buffer.
grad_face_verts[f][0][0] += dbary_d_v0.x + ddist_d_v0.x;
grad_face_verts[f][0][1] += dbary_d_v0.y + ddist_d_v0.y;
grad_face_verts[f][0][2] += grad_zbuf_upstream * bary.x;
grad_face_verts[f][1][0] += dbary_d_v1.x + ddist_d_v1.x;
grad_face_verts[f][1][1] += dbary_d_v1.y + ddist_d_v1.y;
grad_face_verts[f][1][2] += grad_zbuf_upstream * bary.y;
grad_face_verts[f][2][0] += dbary_d_v2.x + ddist_d_v2.x;
grad_face_verts[f][2][1] += dbary_d_v2.y + ddist_d_v2.y;
grad_face_verts[f][2][2] += grad_zbuf_upstream * bary.z;
}
}
}
}
return grad_face_verts;
}
torch::Tensor RasterizeMeshesCoarseCpu(
const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh,
int image_size,
float blur_radius,
int bin_size,
int max_faces_per_bin) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
}
if (num_faces_per_mesh.ndimension() != 1) {
AT_ERROR("num_faces_per_mesh can only have one dimension");
}
const int N = num_faces_per_mesh.size(0); // batch size.
const int M = max_faces_per_bin;
// Assume square images. TODO(T52813608) Support non square images.
const float height = image_size;
const float width = image_size;
const int BH = 1 + (height - 1) / bin_size; // Integer division round up.
const int BW = 1 + (width - 1) / bin_size; // Integer division round up.
auto opts = face_verts.options().dtype(torch::kInt32);
torch::Tensor faces_per_bin = torch::zeros({N, BH, BW}, opts);
torch::Tensor bin_faces = torch::full({N, BH, BW, M}, -1, opts);
auto faces_per_bin_a = faces_per_bin.accessor<int32_t, 3>();
auto bin_faces_a = bin_faces.accessor<int32_t, 4>();
// Precompute all face bounding boxes.
auto face_bboxes = ComputeFaceBoundingBoxes(face_verts);
auto face_bboxes_a = face_bboxes.accessor<float, 2>();
const float pixel_width = 2.0f / image_size;
const float bin_width = pixel_width * bin_size;
// Iterate through the meshes in the batch.
for (int n = 0; n < N; ++n) {
const int face_start_idx = mesh_to_face_first_idx[n].item().to<int32_t>();
const int face_stop_idx =
(face_start_idx + num_faces_per_mesh[n].item().to<int32_t>());
float bin_y_min = -1.0f;
float bin_y_max = bin_y_min + bin_width;
// Iterate through the horizontal bins from top to bottom.
for (int by = 0; by < BH; ++by) {
float bin_x_min = -1.0f;
float bin_x_max = bin_x_min + bin_width;
// Iterate through bins on this horizontal line, left to right.
for (int bx = 0; bx < BW; ++bx) {
int32_t faces_hit = 0;
for (int32_t f = face_start_idx; f < face_stop_idx; ++f) {
// Get bounding box and expand by blur radius.
float face_x_min = face_bboxes_a[f][0] - std::sqrt(blur_radius);
float face_y_min = face_bboxes_a[f][1] - std::sqrt(blur_radius);
float face_x_max = face_bboxes_a[f][2] + std::sqrt(blur_radius);
float face_y_max = face_bboxes_a[f][3] + std::sqrt(blur_radius);
float face_z_max = face_bboxes_a[f][5];
if (face_z_max < 0) {
continue; // Face is behind the camera.
}
// Use a half-open interval so that faces exactly on the
// boundary between bins will fall into exactly one bin.
bool x_overlap =
(face_x_min <= bin_x_max) && (bin_x_min < face_x_max);
bool y_overlap =
(face_y_min <= bin_y_max) && (bin_y_min < face_y_max);
if (x_overlap && y_overlap) {
// Got too many faces for this bin, so throw an error.
if (faces_hit >= max_faces_per_bin) {
AT_ERROR("Got too many faces per bin");
}
// The current point falls in the current bin, so
// record it.
bin_faces_a[n][by][bx][faces_hit] = f;
faces_hit++;
}
}
// Shift the bin to the right for the next loop iteration.
bin_x_min = bin_x_max;
bin_x_max = bin_x_min + bin_width;
}
// Shift the bin down for the next loop iteration.
bin_y_min = bin_y_max;
bin_y_max = bin_y_min + bin_width;
}
}
return bin_faces;
}
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <type_traits>
// A fixed-sized vector with basic arithmetic operators useful for
// representing 2D coordinates.
// TODO: switch to Eigen if more functionality is needed.
template <
typename T,
typename = std::enable_if_t<
std::is_same<T, double>::value || std::is_same<T, float>::value>>
struct vec2 {
T x, y;
typedef T scalar_t;
vec2(T x, T y) : x(x), y(y) {}
};
template <typename T>
inline vec2<T> operator+(const vec2<T>& a, const vec2<T>& b) {
return vec2<T>(a.x + b.x, a.y + b.y);
}
template <typename T>
inline vec2<T> operator-(const vec2<T>& a, const vec2<T>& b) {
return vec2<T>(a.x - b.x, a.y - b.y);
}
template <typename T>
inline vec2<T> operator*(const T a, const vec2<T>& b) {
return vec2<T>(a * b.x, a * b.y);
}
template <typename T>
inline vec2<T> operator/(const vec2<T>& a, const T b) {
if (b == 0.0) {
AT_ERROR(
"denominator in vec2 division is 0"); // prevent divide by 0 errors.
}
return vec2<T>(a.x / b, a.y / b);
}
template <typename T>
inline T dot(const vec2<T>& a, const vec2<T>& b) {
return a.x * b.x + a.y * b.y;
}
template <typename T>
inline T norm(const vec2<T>& a, const vec2<T>& b) {
const vec2<T> ba = b - a;
return sqrt(dot(ba, ba));
}
template <typename T>
std::ostream& operator<<(std::ostream& os, const vec2<T>& v) {
os << "vec2(" << v.x << ", " << v.y << ")";
return os;
}
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
// A fixed-sized vector with basic arithmetic operators useful for
// representing 3D coordinates.
// TODO: switch to Eigen if more functionality is needed.
template <
typename T,
typename = std::enable_if_t<
std::is_same<T, double>::value || std::is_same<T, float>::value>>
struct vec3 {
T x, y, z;
typedef T scalar_t;
vec3(T x, T y, T z) : x(x), y(y), z(z) {}
};
template <typename T>
inline vec3<T> operator+(const vec3<T>& a, const vec3<T>& b) {
return vec3<T>(a.x + b.x, a.y + b.y, a.z + b.z);
}
template <typename T>
inline vec3<T> operator-(const vec3<T>& a, const vec3<T>& b) {
return vec3<T>(a.x - b.x, a.y - b.y, a.z - b.z);
}
template <typename T>
inline vec3<T> operator/(const vec3<T>& a, const T b) {
if (b == 0.0) {
AT_ERROR(
"denominator in vec3 division is 0"); // prevent divide by 0 errors.
}
return vec3<T>(a.x / b, a.y / b, a.z / b);
}
template <typename T>
inline vec3<T> operator*(const T a, const vec3<T>& b) {
return vec3<T>(a * b.x, a * b.y, a * b.z);
}
template <typename T>
inline vec3<T> operator*(const vec3<T>& a, const vec3<T>& b) {
return vec3<T>(a.x * b.x, a.y * b.y, a.z * b.z);
}
template <typename T>
inline T dot(const vec3<T>& a, const vec3<T>& b) {
return a.x * b.x + a.y * b.y + a.z * b.z;
}
template <typename T>
inline vec3<T> cross(const vec3<T>& a, const vec3<T>& b) {
return vec3<T>(
a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x);
}
template <typename T>
std::ostream& operator<<(std::ostream& os, const vec3<T>& v) {
os << "vec3(" << v.x << ", " << v.y << ", " << v.z << ")";
return os;
}
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#define BINMASK_H
// A BitMask represents a bool array of shape (H, W, N). We pack values into
// the bits of unsigned ints; a single unsigned int has B = 32 bits, so to hold
// all values we use H * W * (N / B) = H * W * D values. We want to store
// BitMasks in shared memory, so we assume that the memory has already been
// allocated for it elsewhere.
class BitMask {
public:
__device__ BitMask(unsigned int* data, int H, int W, int N)
: data(data), H(H), W(W), B(8 * sizeof(unsigned int)), D(N / B) {
// TODO: check if the data is null.
N = ceilf(N % 32); // take ceil incase N % 32 != 0
block_clear(); // clear the data
}
// Use all threads in the current block to clear all bits of this BitMask
__device__ void block_clear() {
for (int i = threadIdx.x; i < H * W * D; i += blockDim.x) {
data[i] = 0;
}
__syncthreads();
}
__device__ int _get_elem_idx(int y, int x, int d) {
return y * W * D + x * D + d / B;
}
__device__ int _get_bit_idx(int d) {
return d % B;
}
// Turn on a single bit (y, x, d)
__device__ void set(int y, int x, int d) {
int elem_idx = _get_elem_idx(y, x, d);
int bit_idx = _get_bit_idx(d);
const unsigned int mask = 1U << bit_idx;
atomicOr(data + elem_idx, mask);
}
// Turn off a single bit (y, x, d)
__device__ void unset(int y, int x, int d) {
int elem_idx = _get_elem_idx(y, x, d);
int bit_idx = _get_bit_idx(d);
const unsigned int mask = ~(1U << bit_idx);
atomicAnd(data + elem_idx, mask);
}
// Check whether the bit (y, x, d) is on or off
__device__ bool get(int y, int x, int d) {
int elem_idx = _get_elem_idx(y, x, d);
int bit_idx = _get_bit_idx(d);
return (data[elem_idx] >> bit_idx) & 1U;
}
// Compute the number of bits set in the row (y, x, :)
__device__ int count(int y, int x) {
int total = 0;
for (int i = 0; i < D; ++i) {
int elem_idx = y * W * D + x * D + i;
unsigned int elem = data[elem_idx];
total += __popc(elem);
}
return total;
}
private:
unsigned int* data;
int H, W, B, D;
};
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
// Given a pixel coordinate 0 <= i < S, convert it to a normalized device
// coordinate in the range [-1, 1]. We divide the NDC range into S evenly-sized
// pixels, and assume that each pixel falls in the *center* of its range.
__device__ inline float PixToNdc(int i, int S) {
// NDC x-offset + (i * pixel_width + half_pixel_width)
return -1 + (2 * i + 1.0f) / S;
}
// The maximum number of points per pixel that we can return. Since we use
// thread-local arrays to hold and sort points, the maximum size of the array
// needs to be known at compile time. There might be some fancy template magic
// we could use to make this more dynamic, but for now just fix a constant.
// TODO: is 8 enough? Would increasing have performance considerations?
const int32_t kMaxPointsPerPixel = 150;
template <typename T>
__device__ inline void BubbleSort(T* arr, int n) {
// Bubble sort. We only use it for tiny thread-local arrays (n < 8); in this
// regime we care more about warp divergence than computational complexity.
for (int i = 0; i < n - 1; ++i) {
for (int j = 0; j < n - i - 1; ++j) {
if (arr[j + 1] < arr[j]) {
T temp = arr[j];
arr[j] = arr[j + 1];
arr[j + 1] = temp;
}
}
}
}
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <math.h>
#include <torch/extension.h>
#include <cstdio>
#include <sstream>
#include <tuple>
#include "rasterize_points/bitmask.cuh"
#include "rasterize_points/rasterization_utils.cuh"
namespace {
// A little structure for holding details about a pixel.
struct Pix {
float z; // Depth of the reference point.
int32_t idx; // Index of the reference point.
float dist2; // Euclidean distance square to the reference point.
};
__device__ inline bool operator<(const Pix& a, const Pix& b) {
return a.z < b.z;
}
// This function checks if a pixel given by xy location pxy lies within the
// point with index p and batch index n. One of the inputs is a list (q)
// which contains Pixel structs with the indices of the points which intersect
// with this pixel sorted by closest z distance. If the pixel pxy lies in the
// point, the list (q) is updated and re-orderered in place. In addition
// the auxillary variables q_size, q_max_z and q_max_idx are also modified.
// This code is shared between RasterizePointsNaiveCudaKernel and
// RasterizePointsFineCudaKernel.
template <typename PointQ>
__device__ void CheckPixelInsidePoint(
const float* points, // (N, P, 3)
const int p,
int& q_size,
float& q_max_z,
int& q_max_idx,
PointQ& q,
const float radius2,
const float xf,
const float yf,
const int n,
const int P,
const int K) {
const float px = points[n * P * 3 + p * 3 + 0];
const float py = points[n * P * 3 + p * 3 + 1];
const float pz = points[n * P * 3 + p * 3 + 2];
if (pz < 0)
return; // Don't render points behind the camera
const float dx = xf - px;
const float dy = yf - py;
const float dist2 = dx * dx + dy * dy;
if (dist2 < radius2) {
if (q_size < K) {
// Just insert it
q[q_size] = {pz, p, dist2};
if (pz > q_max_z) {
q_max_z = pz;
q_max_idx = q_size;
}
q_size++;
} else if (pz < q_max_z) {
// Overwrite the old max, and find the new max
q[q_max_idx] = {pz, p, dist2};
q_max_z = pz;
for (int i = 0; i < K; i++) {
if (q[i].z > q_max_z) {
q_max_z = q[i].z;
q_max_idx = i;
}
}
}
}
}
} // namespace
// ****************************************************************************
// * NAIVE RASTERIZATION *
// ****************************************************************************
__global__ void RasterizePointsNaiveCudaKernel(
const float* points, // (N, P, 3)
const float radius,
const int N,
const int P,
const int S,
const int K,
int32_t* point_idxs, // (N, S, S, K)
float* zbuf, // (N, S, S, K)
float* pix_dists) { // (N, S, S, K)
// Simple version: One thread per output pixel
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockDim.x * blockIdx.x + threadIdx.x;
const float radius2 = radius * radius;
for (int i = tid; i < N * S * S; i += num_threads) {
// Convert linear index to 3D index
const int n = i / (S * S); // Batch index
const int pix_idx = i % (S * S);
const int yi = pix_idx / S;
const int xi = pix_idx % S;
const float xf = PixToNdc(xi, S);
const float yf = PixToNdc(yi, S);
// For keeping track of the K closest points we want a data structure
// that (1) gives O(1) access to the closest point for easy comparisons,
// and (2) allows insertion of new elements. In the CPU version we use
// std::priority_queue; then (2) is O(log K). We can't use STL
// containers in CUDA; we could roll our own max heap in an array, but
// that would likely have a lot of warp divergence so we do something
// simpler instead: keep the elements in an unsorted array, but keep
// track of the max value and the index of the max value. Then (1) is
// still O(1) time, while (2) is O(K) with a clean loop. Since K <= 8
// this should be fast enough for our purposes.
// TODO(jcjohns) Abstract this out into a standalone data structure
Pix q[kMaxPointsPerPixel];
int q_size = 0;
float q_max_z = -1000;
int q_max_idx = -1;
for (int p = 0; p < P; ++p) {
CheckPixelInsidePoint(
points, p, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, n, P, K);
}
BubbleSort(q, q_size);
int idx = n * S * S * K + yi * S * K + xi * K;
for (int k = 0; k < q_size; ++k) {
point_idxs[idx + k] = q[k].idx;
zbuf[idx + k] = q[k].z;
pix_dists[idx + k] = q[k].dist2;
}
}
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
RasterizePointsNaiveCuda(
const torch::Tensor& points,
const int image_size,
const float radius,
const int points_per_pixel) {
const int N = points.size(0);
const int P = points.size(1);
const int S = image_size;
const int K = points_per_pixel;
if (K > kMaxPointsPerPixel) {
std::stringstream ss;
ss << "Must have points_per_pixel <= " << kMaxPointsPerPixel;
AT_ERROR(ss.str());
}
auto int_opts = points.options().dtype(torch::kInt32);
auto float_opts = points.options().dtype(torch::kFloat32);
torch::Tensor point_idxs = torch::full({N, S, S, K}, -1, int_opts);
torch::Tensor zbuf = torch::full({N, S, S, K}, -1, float_opts);
torch::Tensor pix_dists = torch::full({N, S, S, K}, -1, float_opts);
const size_t blocks = 1024;
const size_t threads = 64;
RasterizePointsNaiveCudaKernel<<<blocks, threads>>>(
points.contiguous().data<float>(),
radius,
N,
P,
S,
K,
point_idxs.contiguous().data<int32_t>(),
zbuf.contiguous().data<float>(),
pix_dists.contiguous().data<float>());
return std::make_tuple(point_idxs, zbuf, pix_dists);
}
// ****************************************************************************
// * COARSE RASTERIZATION *
// ****************************************************************************
__global__ void RasterizePointsCoarseCudaKernel(
const float* points,
const float radius,
const int N,
const int P,
const int S,
const int bin_size,
const int chunk_size,
const int max_points_per_bin,
int* points_per_bin,
int* bin_points) {
extern __shared__ char sbuf[];
const int M = max_points_per_bin;
const int num_bins = 1 + (S - 1) / bin_size; // Integer divide round up
const float half_pix = 1.0f / S; // Size of half a pixel in NDC units
// This is a boolean array of shape (num_bins, num_bins, chunk_size)
// stored in shared memory that will track whether each point in the chunk
// falls into each bin of the image.
BitMask binmask((unsigned int*)sbuf, num_bins, num_bins, chunk_size);
// Have each block handle a chunk of points and build a 3D bitmask in
// shared memory to mark which points hit which bins. In this first phase,
// each thread processes one point at a time. After processing the chunk,
// one thread is assigned per bin, and the thread counts and writes the
// points for the bin out to global memory.
const int chunks_per_batch = 1 + (P - 1) / chunk_size;
const int num_chunks = N * chunks_per_batch;
for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
const int batch_idx = chunk / chunks_per_batch;
const int chunk_idx = chunk % chunks_per_batch;
const int point_start_idx = chunk_idx * chunk_size;
binmask.block_clear();
// Have each thread handle a different point within the chunk
for (int p = threadIdx.x; p < chunk_size; p += blockDim.x) {
const int p_idx = point_start_idx + p;
if (p_idx >= P)
break;
const float px = points[batch_idx * P * 3 + p_idx * 3 + 0];
const float py = points[batch_idx * P * 3 + p_idx * 3 + 1];
const float pz = points[batch_idx * P * 3 + p_idx * 3 + 2];
if (pz < 0)
continue; // Don't render points behind the camera
const float px0 = px - radius;
const float px1 = px + radius;
const float py0 = py - radius;
const float py1 = py + radius;
// Brute-force search over all bins; TODO something smarter?
// For example we could compute the exact bin where the point falls,
// then check neighboring bins. This way we wouldn't have to check
// all bins (however then we might have more warp divergence?)
for (int by = 0; by < num_bins; ++by) {
// Get y extent for the bin. PixToNdc gives us the location of
// the center of each pixel, so we need to add/subtract a half
// pixel to get the true extent of the bin.
const float by0 = PixToNdc(by * bin_size, S) - half_pix;
const float by1 = PixToNdc((by + 1) * bin_size - 1, S) + half_pix;
const bool y_overlap = (py0 <= by1) && (by0 <= py1);
if (!y_overlap) {
continue;
}
for (int bx = 0; bx < num_bins; ++bx) {
// Get x extent for the bin; again we need to adjust the
// output of PixToNdc by half a pixel.
const float bx0 = PixToNdc(bx * bin_size, S) - half_pix;
const float bx1 = PixToNdc((bx + 1) * bin_size - 1, S) + half_pix;
const bool x_overlap = (px0 <= bx1) && (bx0 <= px1);
if (x_overlap) {
binmask.set(by, bx, p);
}
}
}
}
__syncthreads();
// Now we have processed every point in the current chunk. We need to
// count the number of points in each bin so we can write the indices
// out to global memory. We have each thread handle a different bin.
for (int byx = threadIdx.x; byx < num_bins * num_bins; byx += blockDim.x) {
const int by = byx / num_bins;
const int bx = byx % num_bins;
const int count = binmask.count(by, bx);
const int points_per_bin_idx =
batch_idx * num_bins * num_bins + by * num_bins + bx;
// This atomically increments the (global) number of points found
// in the current bin, and gets the previous value of the counter;
// this effectively allocates space in the bin_points array for the
// points in the current chunk that fall into this bin.
const int start = atomicAdd(points_per_bin + points_per_bin_idx, count);
// Now loop over the binmask and write the active bits for this bin
// out to bin_points.
int next_idx = batch_idx * num_bins * num_bins * M + by * num_bins * M +
bx * M + start;
for (int p = 0; p < chunk_size; ++p) {
if (binmask.get(by, bx, p)) {
// TODO: Throw an error if next_idx >= M -- this means that
// we got more than max_points_per_bin in this bin
// TODO: check if atomicAdd is needed in line 265.
bin_points[next_idx] = point_start_idx + p;
next_idx++;
}
}
}
__syncthreads();
}
}
torch::Tensor RasterizePointsCoarseCuda(
const torch::Tensor& points,
const int image_size,
const float radius,
const int bin_size,
const int max_points_per_bin) {
const int N = points.size(0);
const int P = points.size(1);
const int num_bins = 1 + (image_size - 1) / bin_size; // divide round up
const int M = max_points_per_bin;
if (num_bins >= 22) {
// Make sure we do not use too much shared memory.
std::stringstream ss;
ss << "Got " << num_bins << "; that's too many!";
AT_ERROR(ss.str());
}
auto opts = points.options().dtype(torch::kInt32);
torch::Tensor points_per_bin = torch::zeros({N, num_bins, num_bins}, opts);
torch::Tensor bin_points = torch::full({N, num_bins, num_bins, M}, -1, opts);
const int chunk_size = 512;
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
const size_t blocks = 64;
const size_t threads = 512;
RasterizePointsCoarseCudaKernel<<<blocks, threads, shared_size>>>(
points.contiguous().data<float>(),
radius,
N,
P,
image_size,
bin_size,
chunk_size,
M,
points_per_bin.contiguous().data<int32_t>(),
bin_points.contiguous().data<int32_t>());
return bin_points;
}
// ****************************************************************************
// * FINE RASTERIZATION *
// ****************************************************************************
__global__ void RasterizePointsFineCudaKernel(
const float* points, // (N, P, 3)
const int32_t* bin_points, // (N, B, B, T)
const float radius,
const int bin_size,
const int N,
const int P,
const int B,
const int M,
const int S,
const int K,
int32_t* point_idxs, // (N, S, S, K)
float* zbuf, // (N, S, S, K)
float* pix_dists) { // (N, S, S, K)
// This can be more than S^2 if S is not dividable by bin_size.
const int num_pixels = N * B * B * bin_size * bin_size;
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const float radius2 = radius * radius;
for (int pid = tid; pid < num_pixels; pid += num_threads) {
// Convert linear index into bin and pixel indices. We make the within
// block pixel ids move the fastest, so that adjacent threads will fall
// into the same bin; this should give them coalesced memory reads when
// they read from points and bin_points.
int i = pid;
const int n = i / (B * B * bin_size * bin_size);
i %= B * B * bin_size * bin_size;
const int by = i / (B * bin_size * bin_size);
i %= B * bin_size * bin_size;
const int bx = i / (bin_size * bin_size);
i %= bin_size * bin_size;
const int yi = i / bin_size + by * bin_size;
const int xi = i % bin_size + bx * bin_size;
if (yi >= S || xi >= S)
continue;
const float xf = PixToNdc(xi, S);
const float yf = PixToNdc(yi, S);
// This part looks like the naive rasterization kernel, except we use
// bin_points to only look at a subset of points already known to fall
// in this bin. TODO abstract out this logic into some data structure
// that is shared by both kernels?
Pix q[kMaxPointsPerPixel];
int q_size = 0;
float q_max_z = -1000;
int q_max_idx = -1;
for (int m = 0; m < M; ++m) {
const int p = bin_points[n * B * B * M + by * B * M + bx * M + m];
if (p < 0) {
// bin_points uses -1 as a sentinal value
continue;
}
CheckPixelInsidePoint(
points, p, q_size, q_max_z, q_max_idx, q, radius2, xf, yf, n, P, K);
}
// Now we've looked at all the points for this bin, so we can write
// output for the current pixel.
BubbleSort(q, q_size);
const int pix_idx = n * S * S * K + yi * S * K + xi * K;
for (int k = 0; k < q_size; ++k) {
point_idxs[pix_idx + k] = q[k].idx;
zbuf[pix_idx + k] = q[k].z;
pix_dists[pix_idx + k] = q[k].dist2;
}
}
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
const torch::Tensor& points,
const torch::Tensor& bin_points,
const int image_size,
const float radius,
const int bin_size,
const int points_per_pixel) {
const int N = points.size(0);
const int P = points.size(1);
const int B = bin_points.size(1);
const int M = bin_points.size(3);
const int S = image_size;
const int K = points_per_pixel;
if (K > kMaxPointsPerPixel) {
AT_ERROR("Must have num_closest <= 8");
}
auto int_opts = points.options().dtype(torch::kInt32);
auto float_opts = points.options().dtype(torch::kFloat32);
torch::Tensor point_idxs = torch::full({N, S, S, K}, -1, int_opts);
torch::Tensor zbuf = torch::full({N, S, S, K}, -1, float_opts);
torch::Tensor pix_dists = torch::full({N, S, S, K}, -1, float_opts);
const size_t blocks = 1024;
const size_t threads = 64;
RasterizePointsFineCudaKernel<<<blocks, threads>>>(
points.contiguous().data<float>(),
bin_points.contiguous().data<int32_t>(),
radius,
bin_size,
N,
P,
B,
M,
S,
K,
point_idxs.contiguous().data<int32_t>(),
zbuf.contiguous().data<float>(),
pix_dists.contiguous().data<float>());
return std::make_tuple(point_idxs, zbuf, pix_dists);
}
// ****************************************************************************
// * BACKWARD PASS *
// ****************************************************************************
// TODO(T55115174) Add more documentation for backward kernel.
__global__ void RasterizePointsBackwardCudaKernel(
const float* points, // (N, P, 3)
const int32_t* idxs, // (N, H, W, K)
const int N,
const int P,
const int H,
const int W,
const int K,
const float* grad_zbuf, // (N, H, W, K)
const float* grad_dists, // (N, H, W, K)
float* grad_points) { // (N, P, 3)
// Parallelized over each of K points per pixel, for each pixel in images of
// size H * W, for each image in the batch of size N.
int num_threads = gridDim.x * blockDim.x;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = tid; i < N * H * W * K; i += num_threads) {
const int n = i / (H * W * K);
const int yxk = i % (H * W * K);
const int yi = yxk / (W * K);
const int xk = yxk % (W * K);
const int xi = xk / K;
// k = xk % K (We don't actually need k, but this would be it.)
const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H);
const int p = idxs[i];
if (p < 0)
continue;
const float grad_dist2 = grad_dists[i];
const int p_ind = n * P * 3 + p * 3;
const float px = points[p_ind];
const float py = points[p_ind + 1];
const float dx = px - xf;
const float dy = py - yf;
const float grad_px = 2.0f * grad_dist2 * dx;
const float grad_py = 2.0f * grad_dist2 * dy;
const float grad_pz = grad_zbuf[i];
atomicAdd(grad_points + p_ind, grad_px);
atomicAdd(grad_points + p_ind + 1, grad_py);
atomicAdd(grad_points + p_ind + 2, grad_pz);
}
}
torch::Tensor RasterizePointsBackwardCuda(
const torch::Tensor& points, // (N, P, 3)
const torch::Tensor& idxs, // (N, H, W, K)
const torch::Tensor& grad_zbuf, // (N, H, W, K)
const torch::Tensor& grad_dists) { // (N, H, W, K)
const int N = points.size(0);
const int P = points.size(1);
const int H = idxs.size(1);
const int W = idxs.size(2);
const int K = idxs.size(3);
torch::Tensor grad_points = torch::zeros({N, P, 3}, points.options());
const size_t blocks = 1024;
const size_t threads = 64;
RasterizePointsBackwardCudaKernel<<<blocks, threads>>>(
points.contiguous().data<float>(),
idxs.contiguous().data<int32_t>(),
N,
P,
H,
W,
K,
grad_zbuf.contiguous().data<float>(),
grad_dists.contiguous().data<float>(),
grad_points.contiguous().data<float>());
return grad_points;
}
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <torch/extension.h>
#include <cstdio>
#include <tuple>
// ****************************************************************************
// * NAIVE RASTERIZATION *
// ****************************************************************************
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
const torch::Tensor& points,
const int image_size,
const float radius,
const int points_per_pixel);
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
RasterizePointsNaiveCuda(
const torch::Tensor& points,
const int image_size,
const float radius,
const int points_per_pixel);
// Naive (forward) pointcloud rasterization: For each pixel, for each point,
// check whether that point hits the pixel.
//
// Args:
// points: Tensor of shape (N, P, 3) (in NDC)
// radius: Radius of each point (in NDC units)
// image_size: (S) Size of the image to return (in pixels)
// points_per_pixel: (K) The number closest of points to return for each pixel
//
// Returns:
// idxs: int32 Tensor of shape (N, S, S, K) giving the indices of the
// closest K points along the z-axis for each pixel, padded with -1 for
// pixels
// hit by fewer than K points.
// zbuf: float32 Tensor of shape (N, S, S, K) giving the depth of each
// closest point for each pixel.
// dists: float32 Tensor of shape (N, S, S, K) giving squared Euclidean
// distance in the (NDC) x/y plane between each pixel and its K closest
// points along the z axis.
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
const torch::Tensor& points,
const int image_size,
const float radius,
const int points_per_pixel) {
if (points.type().is_cuda()) {
return RasterizePointsNaiveCuda(
points, image_size, radius, points_per_pixel);
} else {
return RasterizePointsNaiveCpu(
points, image_size, radius, points_per_pixel);
}
}
// ****************************************************************************
// * COARSE RASTERIZATION *
// ****************************************************************************
torch::Tensor RasterizePointsCoarseCpu(
const torch::Tensor& points,
const int image_size,
const float radius,
const int bin_size,
const int max_points_per_bin);
torch::Tensor RasterizePointsCoarseCuda(
const torch::Tensor& points,
const int image_size,
const float radius,
const int bin_size,
const int max_points_per_bin);
// Args:
// points: Tensor of shape (N, P, 3)
// radius: Radius of points to rasterize (in NDC units)
// image_size: Size of the image to generate (in pixels)
// bin_size: Size of each bin within the image (in pixels)
//
// Returns:
// points_per_bin: Tensor of shape (N, num_bins, num_bins) giving the number
// of points that fall in each bin
// bin_points: Tensor of shape (N, num_bins, num_bins, K) giving the indices
// of points that fall into each bin.
torch::Tensor RasterizePointsCoarse(
const torch::Tensor& points,
const int image_size,
const float radius,
const int bin_size,
const int max_points_per_bin) {
if (points.type().is_cuda()) {
return RasterizePointsCoarseCuda(
points, image_size, radius, bin_size, max_points_per_bin);
} else {
return RasterizePointsCoarseCpu(
points, image_size, radius, bin_size, max_points_per_bin);
}
}
// ****************************************************************************
// * FINE RASTERIZATION *
// ****************************************************************************
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
const torch::Tensor& points,
const torch::Tensor& bin_points,
const int image_size,
const float radius,
const int bin_size,
const int points_per_pixel);
// Args:
// points: float32 Tensor of shape (N, P, 3)
// bin_points: int32 Tensor of shape (N, B, B, M) giving the indices of points
// that fall into each bin (output from coarse rasterization)
// image_size: Size of image to generate (in pixels)
// radius: Radius of points to rasterize (NDC units)
// bin_size: Size of each bin (in pixels)
// points_per_pixel: How many points to rasterize for each pixel
//
// Returns (same as rasterize_points):
// idxs: int32 Tensor of shape (N, S, S, K) giving the indices of the closest
// points_per_pixel points along the z-axis for each pixel, padded with
// -1 for pixels hit by fewer than points_per_pixel points
// zbuf: float32 Tensor of shape (N, S, S, K) giving the depth of each of each
// closest point for each pixel
// dists: float32 Tensor of shape (N, S, S, K) giving squared Euclidean
// distance in the (NDC) x/y plane between each pixel and its K closest
// points along the z axis.
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFine(
const torch::Tensor& points,
const torch::Tensor& bin_points,
const int image_size,
const float radius,
const int bin_size,
const int points_per_pixel) {
if (points.type().is_cuda()) {
return RasterizePointsFineCuda(
points, bin_points, image_size, radius, bin_size, points_per_pixel);
} else {
AT_ERROR("NOT IMPLEMENTED");
}
}
// ****************************************************************************
// * BACKWARD PASS *
// ****************************************************************************
torch::Tensor RasterizePointsBackwardCpu(
const torch::Tensor& points,
const torch::Tensor& idxs,
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_dists);
torch::Tensor RasterizePointsBackwardCuda(
const torch::Tensor& points,
const torch::Tensor& idxs,
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_dists);
// Args:
// points: float32 Tensor of shape (N, P, 3)
// idxs: int32 Tensor of shape (N, H, W, K) (from forward pass)
// grad_zbuf: float32 Tensor of shape (N, H, W, K) giving upstream gradient
// d(loss)/d(zbuf) of the distances from each pixel to its nearest
// points.
// grad_dists: Tensor of shape (N, H, W, K) giving upstream gradient
// d(loss)/d(dists) of the dists tensor returned by the forward
// pass.
//
// Returns:
// grad_points: float32 Tensor of shape (N, P, 3) giving downstream gradients
torch::Tensor RasterizePointsBackward(
const torch::Tensor& points,
const torch::Tensor& idxs,
const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_dists) {
if (points.type().is_cuda()) {
return RasterizePointsBackwardCuda(points, idxs, grad_zbuf, grad_dists);
} else {
return RasterizePointsBackwardCpu(points, idxs, grad_zbuf, grad_dists);
}
}
// ****************************************************************************
// * MAIN ENTRY POINT *
// ****************************************************************************
// This is the main entry point for the forward pass of the point rasterizer;
// it uses either naive or coarse-to-fine rasterization based on bin_size.
//
// Args:
// points: Tensor of shape (N, P, 3) (in NDC)
// radius: Radius of each point (in NDC units)
// image_size: (S) Size of the image to return (in pixels)
// points_per_pixel: (K) The number of points to return for each pixel
// bin_size: Bin size (in pixels) for coarse-to-fine rasterization. Setting
// bin_size=0 uses naive rasterization instead.
// max_points_per_bin: The maximum number of points allowed to fall into each
// bin when using coarse-to-fine rasterization.
//
// Returns:
// idxs: int32 Tensor of shape (N, S, S, K) giving the indices of the
// closest points_per_pixel points along the z-axis for each pixel,
// padded with -1 for pixels hit by fewer than points_per_pixel points
// zbuf: float32 Tensor of shape (N, S, S, K) giving the depth of each of each
// closest point for each pixel
// dists: float32 Tensor of shape (N, S, S, K) giving squared Euclidean
// distance in the (NDC) x/y plane between each pixel and its K closest
// points along the z axis.
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePoints(
const torch::Tensor& points,
const int image_size,
const float radius,
const int points_per_pixel,
const int bin_size,
const int max_points_per_bin) {
if (bin_size == 0) {
// Use the naive per-pixel implementation
return RasterizePointsNaive(points, image_size, radius, points_per_pixel);
} else {
// Use coarse-to-fine rasterization
const auto bin_points = RasterizePointsCoarse(
points, image_size, radius, bin_size, max_points_per_bin);
return RasterizePointsFine(
points, bin_points, image_size, radius, bin_size, points_per_pixel);
}
}
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include <queue>
#include <tuple>
// Given a pixel coordinate 0 <= i < S, convert it to a normalized device
// coordinate in the range [-1, 1]. The NDC range is divided into S evenly-sized
// pixels, and assume that each pixel falls in the *center* of its range.
inline float PixToNdc(const int i, const int S) {
// NDC x-offset + (i * pixel_width + half_pixel_width)
return -1 + (2 * i + 1.0f) / S;
}
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaiveCpu(
const torch::Tensor& points,
const int image_size,
const float radius,
const int points_per_pixel) {
const int N = points.size(0);
const int P = points.size(1);
const int S = image_size;
const int K = points_per_pixel;
auto int_opts = points.options().dtype(torch::kInt32);
auto float_opts = points.options().dtype(torch::kFloat32);
torch::Tensor point_idxs = torch::full({N, S, S, K}, -1, int_opts);
torch::Tensor zbuf = torch::full({N, S, S, K}, -1, float_opts);
torch::Tensor pix_dists = torch::full({N, S, S, K}, -1, float_opts);
auto points_a = points.accessor<float, 3>();
auto point_idxs_a = point_idxs.accessor<int32_t, 4>();
auto zbuf_a = zbuf.accessor<float, 4>();
auto pix_dists_a = pix_dists.accessor<float, 4>();
const float radius2 = radius * radius;
for (int n = 0; n < N; ++n) {
for (int yi = 0; yi < S; ++yi) {
float yf = PixToNdc(yi, S);
for (int xi = 0; xi < S; ++xi) {
float xf = PixToNdc(xi, S);
// Use a priority queue to hold (z, idx, r)
std::priority_queue<std::tuple<float, int, float>> q;
for (int p = 0; p < P; ++p) {
const float px = points_a[n][p][0];
const float py = points_a[n][p][1];
const float pz = points_a[n][p][2];
if (pz < 0) {
continue;
}
const float dx = px - xf;
const float dy = py - yf;
const float dist2 = dx * dx + dy * dy;
if (dist2 < radius2) {
// The current point hit the current pixel
q.emplace(pz, p, dist2);
if ((int)q.size() > K) {
q.pop();
}
}
}
// Now all the points have been seen, so pop elements off the queue
// one by one and write them into the output tensors.
while (!q.empty()) {
auto t = q.top();
q.pop();
int i = q.size();
zbuf_a[n][yi][xi][i] = std::get<0>(t);
point_idxs_a[n][yi][xi][i] = std::get<1>(t);
pix_dists_a[n][yi][xi][i] = std::get<2>(t);
}
}
}
}
return std::make_tuple(point_idxs, zbuf, pix_dists);
}
std::tuple<torch::Tensor, torch::Tensor> RasterizePointsCoarseCpu(
const torch::Tensor& points,
const int image_size,
const float radius,
const int bin_size,
const int max_points_per_bin) {
const int N = points.size(0);
const int P = points.size(1);
const int B = 1 + (image_size - 1) / bin_size; // Integer division round up
const int M = max_points_per_bin;
auto opts = points.options().dtype(torch::kInt32);
torch::Tensor points_per_bin = torch::zeros({N, B, B}, opts);
torch::Tensor bin_points = torch::full({N, B, B, M}, -1, opts);
auto points_a = points.accessor<float, 3>();
auto points_per_bin_a = points_per_bin.accessor<int32_t, 3>();
auto bin_points_a = bin_points.accessor<int32_t, 4>();
const float pixel_width = 2.0f / image_size;
const float bin_width = pixel_width * bin_size;
for (int n = 0; n < N; ++n) {
float bin_y_min = -1.0f;
float bin_y_max = bin_y_min + bin_width;
for (int by = 0; by < B; by++) {
float bin_x_min = -1.0f;
float bin_x_max = bin_x_min + bin_width;
for (int bx = 0; bx < B; bx++) {
int32_t points_hit = 0;
for (int32_t p = 0; p < P; p++) {
float px = points_a[n][p][0];
float py = points_a[n][p][1];
float pz = points_a[n][p][2];
if (pz < 0) {
continue;
}
float point_x_min = px - radius;
float point_x_max = px + radius;
float point_y_min = py - radius;
float point_y_max = py + radius;
// Use a half-open interval so that points exactly on the
// boundary between bins will fall into exactly one bin.
bool x_hit = (point_x_min <= bin_x_max) && (bin_x_min <= point_x_max);
bool y_hit = (point_y_min <= bin_y_max) && (bin_y_min <= point_y_max);
if (x_hit && y_hit) {
// Got too many points for this bin, so throw an error.
if (points_hit >= max_points_per_bin) {
AT_ERROR("Got too many points per bin");
}
// The current point falls in the current bin, so
// record it.
bin_points_a[n][by][bx][points_hit] = p;
points_hit++;
}
}
// Record the number of points found in this bin
points_per_bin_a[n][by][bx] = points_hit;
// Shift the bin to the right for the next loop iteration
bin_x_min = bin_x_max;
bin_x_max = bin_x_min + bin_width;
}
// Shift the bin down for the next loop iteration
bin_y_min = bin_y_max;
bin_y_max = bin_y_min + bin_width;
}
}
return std::make_tuple(points_per_bin, bin_points);
}
torch::Tensor RasterizePointsBackwardCpu(
const torch::Tensor& points, // (N, P, 3)
const torch::Tensor& idxs, // (N, H, W, K)
const torch::Tensor& grad_zbuf, // (N, H, W, K)
const torch::Tensor& grad_dists) { // (N, H, W, K)
const int N = points.size(0);
const int P = points.size(1);
const int H = idxs.size(1);
const int W = idxs.size(2);
const int K = idxs.size(3);
// For now only support square images.
// TODO(jcjohns): Extend to non-square images.
if (H != W) {
AT_ERROR("RasterizePointsBackwardCpu only supports square images");
}
torch::Tensor grad_points = torch::zeros({N, P, 3}, points.options());
auto points_a = points.accessor<float, 3>();
auto idxs_a = idxs.accessor<int32_t, 4>();
auto grad_dists_a = grad_dists.accessor<float, 4>();
auto grad_zbuf_a = grad_zbuf.accessor<float, 4>();
auto grad_points_a = grad_points.accessor<float, 3>();
for (int n = 0; n < N; ++n) { // Loop over images in the batch
for (int y = 0; y < H; ++y) { // Loop over rows in the image
const float yf = PixToNdc(y, H);
for (int x = 0; x < W; ++x) { // Loop over pixels in the row
const float xf = PixToNdc(x, W);
for (int k = 0; k < K; ++k) { // Loop over points for the pixel
const int p = idxs_a[n][y][x][k];
if (p < 0) {
break;
}
const float grad_dist2 = grad_dists_a[n][y][x][k];
const float px = points_a[n][p][0];
const float py = points_a[n][p][1];
const float dx = px - xf;
const float dy = py - yf;
// Remember: dists[n][y][x][k] = dx * dx + dy * dy;
const float grad_px = 2.0f * grad_dist2 * dx;
const float grad_py = 2.0f * grad_dist2 * dy;
grad_points_a[n][p][0] += grad_px;
grad_points_a[n][p][1] += grad_py;
grad_points_a[n][p][2] += grad_zbuf_a[n][y][x][k];
}
}
}
}
return grad_points;
}
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .obj_io import load_obj, save_obj
from .ply_io import load_ply, save_ply
__all__ = [k for k in globals().keys() if not k.startswith("_")]
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""This module implements utility functions for loading and saving meshes."""
import numpy as np
import os
import pathlib
import warnings
from collections import namedtuple
from typing import List
import torch
from fvcore.common.file_io import PathManager
from PIL import Image
def _read_image(file_name: str, format=None):
"""
Read an image from a file using Pillow.
Args:
file_name: image file path.
format: one of ["RGB", "BGR"]
Returns:
image: an image of shape (H, W, C).
"""
if format not in ["RGB", "BGR"]:
raise ValueError("format can only be one of [RGB, BGR]; got %s", format)
with PathManager.open(file_name, "rb") as f:
image = Image.open(f)
if format is not None:
# PIL only supports RGB. First convert to RGB and flip channels
# below for BGR.
image = image.convert("RGB")
image = np.asarray(image).astype(np.float32)
if format == "BGR":
image = image[:, :, ::-1]
return image
# Faces & Aux type returned from load_obj function.
_Faces = namedtuple("Faces", "verts_idx normals_idx textures_idx materials_idx")
_Aux = namedtuple(
"Properties", "normals verts_uvs material_colors texture_images"
)
def _format_faces_indices(faces_indices, max_index):
"""
Format indices and check for invalid values. Indices can refer to
values in one of the face properties: vertices, textures or normals.
See comments of the load_obj function for more details.
Args:
faces_indices: List of ints of indices.
max_index: Max index for the face property.
Returns:
faces_indices: List of ints of indices.
Raises:
ValueError if indices are not in a valid range.
"""
faces_indices = torch.tensor(faces_indices, dtype=torch.int64)
# Change to 0 based indexing.
faces_indices[(faces_indices > 0)] -= 1
# Negative indexing counts from the end.
faces_indices[(faces_indices < 0)] += max_index
# Check indices are valid.
if not (
torch.all(faces_indices < max_index) and torch.all(faces_indices >= 0)
):
raise ValueError("Faces have invalid indices.")
return faces_indices
def _open_file(f):
new_f = False
if isinstance(f, str):
new_f = True
f = open(f, "r")
elif isinstance(f, pathlib.Path):
new_f = True
f = f.open("r")
return f, new_f
def load_obj(f_obj):
"""
Load a mesh and textures from a .obj and .mtl file.
Currently this handles verts, faces, vertex texture uv coordinates, normals,
texture images and material reflectivity values.
Note .obj files are 1-indexed. The tensors returned from this function
are 0-indexed. OBJ spec reference: http://www.martinreddy.net/gfx/3d/OBJ.spec
Example .obj file format:
::
# this is a comment
v 1.000000 -1.000000 -1.000000
v 1.000000 -1.000000 1.000000
v -1.000000 -1.000000 1.000000
v -1.000000 -1.000000 -1.000000
v 1.000000 1.000000 -1.000000
vt 0.748573 0.750412
vt 0.749279 0.501284
vt 0.999110 0.501077
vt 0.999455 0.750380
vn 0.000000 0.000000 -1.000000
vn -1.000000 -0.000000 -0.000000
vn -0.000000 -0.000000 1.000000
f 5/2/1 1/2/1 4/3/1
f 5/1/1 4/3/1 2/4/1
The first character of the line denotes the type of input:
::
- v is a vertex
- vt is the texture coordinate of one vertex
- vn is the normal of one vertex
- f is a face
Faces are interpreted as follows:
::
5/2/1 describes the first vertex of the first triange
- 5: index of vertex [1.000000 1.000000 -1.000000]
- 2: index of texture coordinate [0.749279 0.501284]
- 1: index of normal [0.000000 0.000000 -1.000000]
If there are faces with more than 3 vertices
they are subdivided into triangles. Polygonal faces are assummed to have
vertices ordered counter-clockwise so the (right-handed) normal points
into the screen e.g. a proper rectangular face would be specified like this:
::
0_________1
| |
| |
3 ________2
The face would be split into two triangles: (0, 1, 2) and (0, 2, 3),
both of which are also oriented clockwise and have normals
pointing into the screen.
Args:
f: A file-like object (with methods read, readline, tell, and seek),
a pathlib path or a string containing a file name.
Returns:
6-element tuple containing
- **verts**: FloatTensor of shape (V, 3).
- **faces**: NamedTuple with fields:
- verts_idx: LongTensor of vertex indices, shape (F, 3).
- normals_idx: (optional) LongTensor of normal indices, shape (F, 3).
- textures_idx: (optional) LongTensor of texture indices, shape (F, 3).
This can be used to index into verts_uvs.
- materials_idx: (optional) List of indices indicating which
material the texture is derived from for each face.
If there is no material for a face, the index is -1.
This can be used to retrieve the corresponding values
in material_colors/texture_images after they have been
converted to tensors or Materials/Textures data
structures - see textures.py and materials.py for
more info.
- **aux**: NamedTuple with fields:
- normals: FloatTensor of shape (N, 3)
- verts_uvs: FloatTensor of shape (T, 2), giving the uv coordinate per
vertex. If a vertex is shared between two faces, it can have
a different uv value for each instance. Therefore it is
possible that the number of verts_uvs is greater than
num verts i.e. T > V.
vertex.
- material_colors: dict of material names and associated properties.
If a material does not have any properties it will have an
empty dict.
.. code-block:: python
{
material_name_1: {
"ambient_color": tensor of shape (1, 3),
"diffuse_color": tensor of shape (1, 3),
"specular_color": tensor of shape (1, 3),
"shininess": tensor of shape (1)
},
material_name_2: {},
...
}
- texture_images: dict of material names and texture images.
.. code-block:: python
{
material_name_1: (H, W, 3) image,
...
}
"""
data_dir = "./"
if isinstance(f_obj, (str, bytes, os.PathLike)):
data_dir = os.path.dirname(f_obj)
f_obj, new_f = _open_file(f_obj)
try:
return _load(f_obj, data_dir)
finally:
if new_f:
f_obj.close()
def _parse_face(
line,
material_idx,
faces_verts_idx,
faces_normals_idx,
faces_textures_idx,
faces_materials_idx,
):
face = line.split(" ")[1:]
face_list = [f.split("/") for f in face]
face_verts = []
face_normals = []
face_textures = []
for vert_props in face_list:
# Vertex index.
face_verts.append(int(vert_props[0]))
if len(vert_props) > 1:
if vert_props[1] != "":
# Texture index is present e.g. f 4/1/1.
face_textures.append(int(vert_props[1]))
if len(vert_props) > 2:
# Normal index present e.g. 4/1/1 or 4//1.
face_normals.append(int(vert_props[2]))
if len(vert_props) > 3:
raise ValueError(
"Face vertices can ony have 3 properties. \
Face vert %s, Line: %s"
% (str(vert_props), str(line))
)
# Triplets must be consistent for all vertices in a face e.g.
# legal statement: f 4/1/1 3/2/1 2/1/1.
# illegal statement: f 4/1/1 3//1 2//1.
if len(face_normals) > 0:
if not (len(face_verts) == len(face_normals)):
raise ValueError(
"Face %s is an illegal statement. \
Vertex properties are inconsistent. Line: %s"
% (str(face), str(line))
)
if len(face_textures) > 0:
if not (len(face_verts) == len(face_textures)):
raise ValueError(
"Face %s is an illegal statement. \
Vertex properties are inconsistent. Line: %s"
% (str(face), str(line))
)
# Subdivide faces with more than 3 vertices. See comments of the
# load_obj function for more details.
for i in range(len(face_verts) - 2):
faces_verts_idx.append(
(face_verts[0], face_verts[i + 1], face_verts[i + 2])
)
if len(face_normals) > 0:
faces_normals_idx.append(
(face_normals[0], face_normals[i + 1], face_normals[i + 2])
)
if len(face_textures) > 0:
faces_textures_idx.append(
(face_textures[0], face_textures[i + 1], face_textures[i + 2])
)
faces_materials_idx.append(material_idx)
def _load(f_obj, data_dir):
"""
Load a mesh from a file-like object. See load_obj function more details.
Any material files associated with the obj are expected to be in the
directory given by data_dir.
"""
lines = [line.strip() for line in f_obj]
verts = []
normals = []
verts_uvs = []
faces_verts_idx = []
faces_normals_idx = []
faces_textures_idx = []
material_names = []
faces_materials_idx = []
f_mtl = None
materials_idx = -1
# startswith expects each line to be a string. If the file is read in as
# bytes then first decode to strings.
if isinstance(lines[0], bytes):
lines = [l.decode("utf-8") for l in lines]
for line in lines:
if line.startswith("mtllib"):
if len(line.split()) < 2:
raise ValueError("material file name is not specified")
# NOTE: this assumes only one mtl file per .obj.
f_mtl = os.path.join(data_dir, line.split()[1])
elif len(line.split()) != 0 and line.split()[0] == "usemtl":
material_name = line.split()[1]
material_names.append(material_name)
materials_idx = len(material_names) - 1
elif line.startswith("v "):
# Line is a vertex.
vert = [float(x) for x in line.split()[1:4]]
if len(vert) != 3:
msg = "Vertex %s does not have 3 values. Line: %s"
raise ValueError(msg % (str(vert), str(line)))
verts.append(vert)
elif line.startswith("vt "):
# Line is a texture.
tx = [float(x) for x in line.split()[1:3]]
if len(tx) != 2:
raise ValueError(
"Texture %s does not have 2 values. Line: %s"
% (str(tx), str(line))
)
verts_uvs.append(tx)
elif line.startswith("vn "):
# Line is a normal.
norm = [float(x) for x in line.split()[1:4]]
if len(norm) != 3:
msg = "Normal %s does not have 3 values. Line: %s"
raise ValueError(msg % (str(norm), str(line)))
normals.append(norm)
elif line.startswith("f "):
# Line is a face.
_parse_face(
line,
materials_idx,
faces_verts_idx,
faces_normals_idx,
faces_textures_idx,
faces_materials_idx,
)
verts = torch.tensor(verts) # (V, 3)
normals = torch.tensor(normals) # (N, 3)
verts_uvs = torch.tensor(verts_uvs) # (T, 3)
faces_verts_idx = _format_faces_indices(faces_verts_idx, verts.shape[0])
# Repeat for normals and textures if present.
if len(faces_normals_idx) > 0:
faces_normals_idx = _format_faces_indices(
faces_normals_idx, normals.shape[0]
)
if len(faces_textures_idx) > 0:
faces_textures_idx = _format_faces_indices(
faces_textures_idx, verts_uvs.shape[0]
)
if len(faces_materials_idx) > 0:
faces_materials_idx = torch.tensor(
faces_materials_idx, dtype=torch.int64
)
# Load materials
material_colors, texture_images = None, None
if (len(material_names) > 0) and (f_mtl is not None):
if os.path.isfile(f_mtl):
material_colors, texture_images = load_mtl(
f_mtl, material_names, data_dir
)
else:
warnings.warn(f"Mtl file does not exist: {f_mtl}")
elif len(material_names) > 0:
warnings.warn("No mtl file provided")
faces = _Faces(
verts_idx=faces_verts_idx,
normals_idx=faces_normals_idx,
textures_idx=faces_textures_idx,
materials_idx=faces_materials_idx,
)
aux = _Aux(
normals=normals if len(normals) > 0 else None,
verts_uvs=verts_uvs if len(verts_uvs) > 0 else None,
material_colors=material_colors,
texture_images=texture_images,
)
return verts, faces, aux
def load_mtl(f_mtl, material_names: List, data_dir: str):
"""
Load texture images and material reflectivity values for ambient, diffuse
and specular light (Ka, Kd, Ks, Ns).
Args:
f_mtl: a file like object of the material information.
material_names: a list of the material names found in the .obj file.
data_dir: the directory where the material texture files are located.
Returns:
material_colors: dict of properties for each material. If a material
does not have any properties it will have an emtpy dict.
{
material_name_1: {
"ambient_color": tensor of shape (1, 3),
"diffuse_color": tensor of shape (1, 3),
"specular_color": tensor of shape (1, 3),
"shininess": tensor of shape (1)
},
material_name_2: {},
...
}
texture_images: dict of material names and texture images
{
material_name_1: (H, W, 3) image,
...
}
"""
texture_files = {}
material_colors = {}
material_properties = {}
texture_images = {}
material_name = ""
f_mtl, new_f = _open_file(f_mtl)
lines = [line.strip() for line in f_mtl]
for line in lines:
if len(line.split()) != 0:
if line.split()[0] == "newmtl":
material_name = line.split()[1]
material_colors[material_name] = {}
if line.split()[0] == "map_Kd":
# Texture map.
texture_files[material_name] = line.split()[1]
if line.split()[0] == "Kd":
# RGB diffuse reflectivity
kd = np.array(list(line.split()[1:4])).astype(np.float32)
kd = torch.from_numpy(kd)
material_colors[material_name]["diffuse_color"] = kd
if line.split()[0] == "Ka":
# RGB ambient reflectivity
ka = np.array(list(line.split()[1:4])).astype(np.float32)
ka = torch.from_numpy(ka)
material_colors[material_name]["ambient_color"] = ka
if line.split()[0] == "Ks":
# RGB specular reflectivity
ks = np.array(list(line.split()[1:4])).astype(np.float32)
ks = torch.from_numpy(ks)
material_colors[material_name]["specular_color"] = ks
if line.split()[0] == "Ns":
# Specular exponent
ns = np.array(list(line.split()[1:4])).astype(np.float32)
ns = torch.from_numpy(ns)
material_colors[material_name]["shininess"] = ns
if new_f:
f_mtl.close()
# Only keep the materials referenced in the obj.
for name in material_names:
if name in texture_files:
# Load the texture image.
filename = texture_files[name]
filename_texture = os.path.join(data_dir, filename)
if os.path.isfile(filename_texture):
image = _read_image(filename_texture, format="RGB") / 255.0
image = torch.from_numpy(image)
texture_images[name] = image
else:
msg = f"Texture file does not exist: {filename_texture}"
warnings.warn(msg)
if name in material_colors:
material_properties[name] = material_colors[name]
return material_properties, texture_images
def save_obj(f, verts, faces, decimal_places: int = None):
"""
Save a mesh to an .obj file.
Args:
f: File (or path) to which the mesh should be written.
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shape (F, 3) giving faces.
decimal_places: Number of decimal places for saving.
"""
new_f = False
if isinstance(f, str):
new_f = True
f = open(f, "w")
elif isinstance(f, pathlib.Path):
new_f = True
f = f.open("w")
try:
return _save(f, verts, faces, decimal_places)
finally:
if new_f:
f.close()
# TODO (nikhilar) Speed up this function.
def _save(f, verts, faces, decimal_places: int = None):
if verts.dim() != 2 or verts.size(1) != 3:
raise ValueError("Argument 'verts' should be of shape (num_verts, 3).")
if faces.dim() != 2 or faces.size(1) != 3:
raise ValueError("Argument 'faces' should be of shape (num_faces, 3).")
verts, faces = verts.cpu(), faces.cpu()
if decimal_places is None:
float_str = "%f"
else:
float_str = "%" + ".%df" % decimal_places
lines = ""
V, D = verts.shape
for i in range(V):
vert = [float_str % verts[i, j] for j in range(D)]
lines += "v %s\n" % " ".join(vert)
F, P = faces.shape
for i in range(F):
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
if i + 1 < F:
lines += "f %s\n" % " ".join(face)
elif i + 1 == F:
# No newline at the end of the file.
lines += "f %s" % " ".join(face)
f.write(lines)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""This module implements utility functions for loading and saving meshes."""
import numpy as np
import pathlib
import struct
import sys
import warnings
from collections import namedtuple
from typing import Optional, Tuple
import torch
_PlyTypeData = namedtuple("_PlyTypeData", "size struct_char np_type")
_PLY_TYPES = {
"char": _PlyTypeData(1, "b", np.byte),
"uchar": _PlyTypeData(1, "B", np.ubyte),
"short": _PlyTypeData(2, "h", np.short),
"ushort": _PlyTypeData(2, "H", np.ushort),
"int": _PlyTypeData(4, "i", np.int32),
"uint": _PlyTypeData(4, "I", np.uint32),
"float": _PlyTypeData(4, "f", np.float32),
"double": _PlyTypeData(8, "d", np.float64),
}
_Property = namedtuple("_Property", "name data_type list_size_type")
class _PlyElementType:
"""
Description of an element of a Ply file.
Members:
self.properties: (List[_Property]) description of all the properties.
Each one contains a name and data type.
self.count: (int) number of such elements in the file
self.name: (str) name of the element
"""
def __init__(self, name: str, count: int):
self.name = name
self.count = count
self.properties = []
def add_property(
self, name: str, data_type: str, list_size_type: Optional[str] = None
):
"""Adds a new property.
Args:
name: (str) name of the property.
data_type: (str) PLY data type.
list_size_type: (str) PLY data type of the list size, or None if not
a list.
"""
for property in self.properties:
if property.name == name:
msg = "Cannot have two properties called %s in %s."
raise ValueError(msg % (name, self.name))
self.properties.append(_Property(name, data_type, list_size_type))
def is_fixed_size(self) -> bool:
"""Return whether the Element has no list properties
Returns:
True if none of the properties are lists.
"""
for property in self.properties:
if property.list_size_type is not None:
return False
return True
def is_constant_type_fixed_size(self) -> bool:
"""Return whether the Element has all properties of the same non-list
type.
Returns:
True if none of the properties are lists and all the properties
share a type.
"""
if not self.is_fixed_size():
return False
first_type = self.properties[0].data_type
for property in self.properties:
if property.data_type != first_type:
return False
return True
def try_constant_list(self) -> bool:
"""Whether the element is just a single list, which might have a
constant size, and therefore we could try to parse quickly with numpy.
Returns:
True if the only property is a list.
"""
if len(self.properties) != 1:
return False
if self.properties[0].list_size_type is None:
return False
return True
class _PlyHeader:
def __init__(self, f):
"""
Load a header of a Ply file from a file-like object.
Members:
self.elements: (List[_PlyElementType]) element description
self.ascii: (bool) Whether in ascii format
self.big_endian: (bool) (if not ascii) whether big endian
self.obj_info: (dict) arbitrary extra data
Args:
f: file-like object.
"""
if f.readline() not in [b"ply\n", b"ply\r\n", "ply\n"]:
raise ValueError("Invalid file header.")
seen_format = False
self.elements = []
self.obj_info = {}
while True:
line = f.readline()
if isinstance(line, bytes):
line = line.decode("ascii")
line = line.strip()
if line == "end_header":
if not self.elements:
raise ValueError("No elements found.")
if not self.elements[-1].properties:
raise ValueError("Found an element with no properties.")
if not seen_format:
raise ValueError("No format line found.")
break
if not seen_format:
if line == "format ascii 1.0":
seen_format = True
self.ascii = True
continue
if line == "format binary_little_endian 1.0":
seen_format = True
self.ascii = False
self.big_endian = False
continue
if line == "format binary_big_endian 1.0":
seen_format = True
self.ascii = False
self.big_endian = True
continue
if line.startswith("format"):
raise ValueError("Invalid format line.")
if line.startswith("comment") or len(line) == 0:
continue
if line.startswith("element"):
self._parse_element(line)
continue
if line.startswith("obj_info"):
items = line.split(" ")
if len(items) != 3:
raise ValueError("Invalid line: %s" % line)
self.obj_info[items[1]] = items[2]
continue
if line.startswith("property"):
self._parse_property(line)
continue
raise ValueError("Invalid line: %s." % line)
def _parse_property(self, line: str):
"""
Decode a ply file header property line.
Args:
line: (str) the ply file's line.
"""
if not self.elements:
raise ValueError("Encountered property before any element.")
items = line.split(" ")
if len(items) not in [3, 5]:
raise ValueError("Invalid line: %s" % line)
datatype = items[1]
name = items[-1]
if datatype == "list":
datatype = items[3]
list_size_type = items[2]
if list_size_type not in _PLY_TYPES:
raise ValueError("Invalid datatype: %s" % list_size_type)
else:
list_size_type = None
if datatype not in _PLY_TYPES:
raise ValueError("Invalid datatype: %s" % datatype)
self.elements[-1].add_property(name, datatype, list_size_type)
def _parse_element(self, line: str):
"""
Decode a ply file header element line.
Args:
line: (str) the ply file's line.
"""
if self.elements and not self.elements[-1].properties:
raise ValueError("Found an element with no properties.")
items = line.split(" ")
if len(items) != 3:
raise ValueError("Invalid line: %s" % line)
try:
count = int(items[2])
except ValueError:
msg = "Number of items for %s was not a number."
raise ValueError(msg % items[1])
self.elements.append(_PlyElementType(items[1], count))
def _read_ply_fixed_size_element_ascii(f, definition: _PlyElementType):
"""
Given an element which has no lists and one type, read the
corresponding data.
Args:
f: file-like object being read.
definition: The element object which describes what we are reading.
Returns:
2D numpy array corresponding to the data. The rows are the different
values. There is one column for each property.
"""
np_type = _PLY_TYPES[definition.properties[0].data_type].np_type
data = np.loadtxt(
f, dtype=np_type, comments=None, ndmin=2, max_rows=definition.count
)
if data.shape[1] != len(definition.properties):
raise ValueError("Inconsistent data for %s." % definition.name)
if data.shape[0] != definition.count:
raise ValueError("Not enough data for %s." % definition.name)
return data
def _try_read_ply_constant_list_ascii(f, definition: _PlyElementType):
"""
If definition is an element which is a single list, attempt to read the
corresponding data assuming every value has the same length.
If the data is ragged, return None and leave f undisturbed.
Args:
f: file-like object being read.
definition: The element object which describes what we are reading.
Returns:
If every element has the same size, 2D numpy array corresponding to the
data. The rows are the different values. Otherwise None.
"""
np_type = _PLY_TYPES[definition.properties[0].data_type].np_type
start_point = f.tell()
try:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message=".* Empty input file.*", category=UserWarning
)
data = np.loadtxt(
f,
dtype=np_type,
comments=None,
ndmin=2,
max_rows=definition.count,
)
except ValueError:
f.seek(start_point)
return None
if (data.shape[1] - 1 != data[:, 0]).any():
msg = "A line of %s data did not have the specified length."
raise ValueError(msg % definition.name)
if data.shape[0] != definition.count:
raise ValueError("Not enough data for %s." % definition.name)
return data[:, 1:]
def _parse_heterogenous_property_ascii(datum, line_iter, property: _Property):
"""
Read a general data property from an ascii .ply file.
Args:
datum: list to append the single value to. That value will be a numpy
array if the property is a list property, otherwise an int or
float.
line_iter: iterator to words on the line from which we read.
property: the property object describing the property we are reading.
"""
value = next(line_iter, None)
if value is None:
raise ValueError("Too little data for an element.")
if property.list_size_type is None:
try:
if property.data_type in ["double", "float"]:
datum.append(float(value))
else:
datum.append(int(value))
except ValueError:
raise ValueError("Bad numerical data.")
else:
try:
length = int(value)
except ValueError:
raise ValueError("A list length was not a number.")
list_value = np.zeros(
length, dtype=_PLY_TYPES[property.data_type].np_type
)
for i in range(length):
inner_value = next(line_iter, None)
if inner_value is None:
raise ValueError("Too little data for an element.")
try:
list_value[i] = float(inner_value)
except ValueError:
raise ValueError("Bad numerical data.")
datum.append(list_value)
def _read_ply_element_ascii(f, definition: _PlyElementType):
"""
Decode all instances of a single element from an ascii .ply file.
Args:
f: file-like object being read.
definition: The element object which describes what we are reading.
Returns:
In simple cases where every element has the same size, 2D numpy array
corresponding to the data. The rows are the different values.
Otherwise a list of lists of values, where the outer list is
each occurence of the element, and the inner lists have one value per
property.
"""
if definition.is_constant_type_fixed_size():
return _read_ply_fixed_size_element_ascii(f, definition)
if definition.try_constant_list():
data = _try_read_ply_constant_list_ascii(f, definition)
if data is not None:
return data
# We failed to read the element as a lump, must process each line manually.
data = []
for _i in range(definition.count):
line_string = f.readline()
if line_string == "":
raise ValueError("Not enough data for %s." % definition.name)
datum = []
line_iter = iter(line_string.strip().split())
for property in definition.properties:
_parse_heterogenous_property_ascii(datum, line_iter, property)
data.append(datum)
if next(line_iter, None) is not None:
raise ValueError("Too much data for an element.")
return data
def _read_ply_fixed_size_element_binary(
f, definition: _PlyElementType, big_endian: bool
):
"""
Given an element which has no lists and one type, read the
corresponding data.
Args:
f: file-like object being read.
definition: The element object which describes what we are reading.
big_endian: (bool) whether the document is encoded as big endian.
Returns:
2D numpy array corresponding to the data. The rows are the different
values. There is one column for each property.
"""
ply_type = _PLY_TYPES[definition.properties[0].data_type]
np_type = ply_type.np_type
type_size = ply_type.size
needed_length = definition.count * len(definition.properties)
needed_bytes = needed_length * type_size
bytes_data = f.read(needed_bytes)
if len(bytes_data) != needed_bytes:
raise ValueError("Not enough data for %s." % definition.name)
data = np.frombuffer(bytes_data, dtype=np_type)
if (sys.byteorder == "big") != big_endian:
data = data.byteswap()
return data.reshape(definition.count, len(definition.properties))
def _read_ply_element_struct(f, definition: _PlyElementType, endian_str: str):
"""
Given an element which has no lists, read the corresponding data. Uses the
struct library.
Note: It looks like struct would also support lists where
type=size_type=char, but it is hard to know how much data to read in that
case.
Args:
f: file-like object being read.
definition: The element object which describes what we are reading.
endian_str: ">" or "<" according to whether the document is big or
little endian.
Returns:
2D numpy array corresponding to the data. The rows are the different
values. There is one column for each property.
"""
format = "".join(
_PLY_TYPES[property.data_type].struct_char
for property in definition.properties
)
format = endian_str + format
pattern = struct.Struct(format)
size = pattern.size
needed_bytes = size * definition.count
bytes_data = f.read(needed_bytes)
if len(bytes_data) != needed_bytes:
raise ValueError("Not enough data for %s." % definition.name)
data = [
pattern.unpack_from(bytes_data, i * size)
for i in range(definition.count)
]
return data
def _try_read_ply_constant_list_binary(
f, definition: _PlyElementType, big_endian: bool
):
"""
If definition is an element which is a single list, attempt to read the
corresponding data assuming every value has the same length.
If the data is ragged, return None and leave f undisturbed.
Args:
f: file-like object being read.
definition: The element object which describes what we are reading.
big_endian: (bool) whether the document is encoded as big endian.
Returns:
If every element has the same size, 2D numpy array corresponding to the
data. The rows are the different values. Otherwise None.
"""
property = definition.properties[0]
endian_str = ">" if big_endian else "<"
length_format = endian_str + _PLY_TYPES[property.list_size_type].struct_char
length_struct = struct.Struct(length_format)
def get_length():
bytes_data = f.read(length_struct.size)
if len(bytes_data) != length_struct.size:
raise ValueError("Not enough data for %s." % definition.name)
[length] = length_struct.unpack(bytes_data)
return length
start_point = f.tell()
length = get_length()
np_type = _PLY_TYPES[definition.properties[0].data_type].np_type
type_size = _PLY_TYPES[definition.properties[0].data_type].size
data_size = type_size * length
output = np.zeros((definition.count, length), dtype=np_type)
for i in range(definition.count):
bytes_data = f.read(data_size)
if len(bytes_data) != data_size:
raise ValueError("Not enough data for %s" % definition.name)
output[i] = np.frombuffer(bytes_data, dtype=np_type)
if i + 1 == definition.count:
break
if length != get_length():
f.seek(start_point)
return None
if (sys.byteorder == "big") != big_endian:
output = output.byteswap()
return output
def _read_ply_element_binary(
f, definition: _PlyElementType, big_endian: bool
) -> list:
"""
Decode all instances of a single element from a binary .ply file.
Args:
f: file-like object being read.
definition: The element object which describes what we are reading.
big_endian: (bool) whether the document is encoded as big endian.
Returns:
In simple cases where every element has the same size, 2D numpy array
corresponding to the data. The rows are the different values.
Otherwise a list of lists/tuples of values, where the outer list is
each occurence of the element, and the inner lists have one value per
property.
"""
endian_str = ">" if big_endian else "<"
if definition.is_constant_type_fixed_size():
return _read_ply_fixed_size_element_binary(f, definition, big_endian)
if definition.is_fixed_size():
return _read_ply_element_struct(f, definition, endian_str)
if definition.try_constant_list():
data = _try_read_ply_constant_list_binary(f, definition, big_endian)
if data is not None:
return data
# We failed to read the element as a lump, must process each line manually.
property_structs = []
for property in definition.properties:
initial_type = property.list_size_type or property.data_type
property_structs.append(
struct.Struct(endian_str + _PLY_TYPES[initial_type].struct_char)
)
data = []
for _i in range(definition.count):
datum = []
for property, property_struct in zip(
definition.properties, property_structs
):
size = property_struct.size
initial_data = f.read(size)
if len(initial_data) != size:
raise ValueError("Not enough data for %s" % definition.name)
[initial] = property_struct.unpack(initial_data)
if property.list_size_type is None:
datum.append(initial)
else:
type_size = _PLY_TYPES[property.data_type].size
needed_bytes = type_size * initial
list_data = f.read(needed_bytes)
if len(list_data) != needed_bytes:
raise ValueError("Not enough data for %s" % definition.name)
np_type = _PLY_TYPES[property.data_type].np_type
list_np = np.frombuffer(list_data, dtype=np_type)
if (sys.byteorder == "big") != big_endian:
list_np = list_np.byteswap()
datum.append(list_np)
data.append(datum)
return data
def _load_ply_raw_stream(f) -> Tuple[_PlyHeader, dict]:
"""
Implementation for _load_ply_raw which takes a stream.
Args:
f: A binary or text file-like object.
Returns:
header: A _PlyHeader object describing the metadata in the ply file.
elements: A dictionary of element names to values. If an element is regular, in
the sense of having no lists or being one uniformly-sized list, then the
value will be a 2D numpy array. If not, it is a list of the relevant
property values.
"""
header = _PlyHeader(f)
elements = {}
if header.ascii:
for element in header.elements:
elements[element.name] = _read_ply_element_ascii(f, element)
else:
big = header.big_endian
for element in header.elements:
elements[element.name] = _read_ply_element_binary(f, element, big)
end = f.read().strip()
if len(end) != 0:
raise ValueError("Extra data at end of file: " + str(end[:20]))
return header, elements
def _load_ply_raw(f) -> Tuple[_PlyHeader, dict]:
"""
Load the data from a .ply file.
Args:
f: A binary or text file-like object (with methods read, readline,
tell and seek), a pathlib path or a string containing a file name.
If the ply file is binary, a text stream is not supported.
It is recommended to use a binary stream.
Returns:
header: A _PlyHeader object describing the metadata in the ply file.
elements: A dictionary of element names to values. If an element is
regular, in the sense of having no lists or being one
uniformly-sized list, then the value will be a 2D numpy array.
If not, it is a list of the relevant property values.
"""
new_f = False
if isinstance(f, str):
new_f = True
f = open(f, "rb")
elif isinstance(f, pathlib.Path):
new_f = True
f = f.open("rb")
try:
header, elements = _load_ply_raw_stream(f)
finally:
if new_f:
f.close()
return header, elements
def load_ply(f):
"""
Load the data from a .ply file.
Example .ply file format:
ply
format ascii 1.0 { ascii/binary, format version number }
comment made by Greg Turk { comments keyword specified, like all lines }
comment this file is a cube
element vertex 8 { define "vertex" element, 8 of them in file }
property float x { vertex contains float "x" coordinate }
property float y { y coordinate is also a vertex property }
property float z { z coordinate, too }
element face 6 { there are 6 "face" elements in the file }
property list uchar int vertex_index { "vertex_indices" is a list of ints }
end_header { delimits the end of the header }
0 0 0 { start of vertex list }
0 0 1
0 1 1
0 1 0
1 0 0
1 0 1
1 1 1
1 1 0
4 0 1 2 3 { start of face list }
4 7 6 5 4
4 0 4 5 1
4 1 5 6 2
4 2 6 7 3
4 3 7 4 0
Args:
f: A binary or text file-like object (with methods read, readline,
tell and seek), a pathlib path or a string containing a file name.
If the ply file is in the binary ply format rather than the text
ply format, then a text stream is not supported.
It is easiest to use a binary stream in all cases.
Returns:
verts: FloatTensor of shape (V, 3).
faces: LongTensor of vertex indices, shape (F, 3).
"""
header, elements = _load_ply_raw(f)
vertex = elements.get("vertex", None)
if vertex is None:
raise ValueError("The ply file has no vertex element.")
face = elements.get("face", None)
if face is None:
raise ValueError("The ply file has no face element.")
if (
not isinstance(vertex, np.ndarray)
or vertex.ndim != 2
or vertex.shape[1] != 3
):
raise ValueError("Invalid vertices in file.")
verts = torch.tensor(vertex, dtype=torch.float32)
face_head = next(head for head in header.elements if head.name == "face")
if (
len(face_head.properties) != 1
or face_head.properties[0].list_size_type is None
):
raise ValueError("Unexpected form of faces data.")
# face_head.properties[0].name is usually "vertex_index" or "vertex_indices"
# but we don't need to enforce this.
if isinstance(face, np.ndarray) and face.ndim == 2:
if face.shape[1] < 3:
raise ValueError("Faces must have at least 3 vertices.")
face_arrays = [
face[:, [0, i + 1, i + 2]] for i in range(face.shape[1] - 2)
]
faces = torch.tensor(np.vstack(face_arrays), dtype=torch.int64)
else:
face_list = []
for face_item in face:
if face_item.ndim != 1:
raise ValueError("Bad face data.")
if face_item.shape[0] < 3:
raise ValueError("Faces must have at least 3 vertices.")
for i in range(face_item.shape[0] - 2):
face_list.append(
[face_item[0], face_item[i + 1], face_item[i + 2]]
)
faces = torch.tensor(face_list, dtype=torch.int64)
return verts, faces
def _save_ply(f, verts, faces, decimal_places: Optional[int]):
"""
Internal implementation for saving a mesh to a .ply file.
Args:
f: File object to which the mesh should be written.
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shape (F, 3) giving faces.
decimal_places: Number of decimal places for saving.
"""
print("ply\nformat ascii 1.0", file=f)
print(f"element vertex {verts.shape[0]}", file=f)
print("property float x", file=f)
print("property float y", file=f)
print("property float z", file=f)
print(f"element face {faces.shape[0]}", file=f)
print("property list uchar int vertex_index", file=f)
print("end_header", file=f)
if decimal_places is None:
float_str = "%f"
else:
float_str = "%" + ".%df" % decimal_places
np.savetxt(f, verts.detach().numpy(), float_str)
np.savetxt(f, faces.detach().numpy(), "3 %d %d %d")
def save_ply(f, verts, faces, decimal_places: Optional[int] = None):
"""
Save a mesh to a .ply file.
Args:
f: File (or path) to which the mesh should be written.
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shape (F, 3) giving faces.
decimal_places: Number of decimal places for saving.
"""
new_f = False
if isinstance(f, str):
new_f = True
f = open(f, "w")
elif isinstance(f, pathlib.Path):
new_f = True
f = f.open("w")
try:
_save_ply(f, verts, faces, decimal_places)
finally:
if new_f:
f.close()
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .chamfer import chamfer_distance
from .mesh_edge_loss import mesh_edge_loss
from .mesh_laplacian_smoothing import mesh_laplacian_smoothing
from .mesh_normal_consistency import mesh_normal_consistency
__all__ = [k for k in globals().keys() if not k.startswith("_")]
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn.functional as F
from pytorch3d.ops.nearest_neighbor_points import nn_points_idx
def _validate_chamfer_reduction_inputs(
batch_reduction: str, point_reduction: str
):
"""Check the requested reductions are valid.
Args:
batch_reduction: Reduction operation to apply for the loss across the
batch, can be one of ["none", "mean", "sum"].
point_reduction: Reduction operation to apply for the loss across the
points, can be one of ["none", "mean", "sum"].
"""
if batch_reduction not in ["none", "mean", "sum"]:
raise ValueError(
'batch_reduction must be one of ["none", "mean", "sum"]'
)
if point_reduction not in ["none", "mean", "sum"]:
raise ValueError(
'point_reduction must be one of ["none", "mean", "sum"]'
)
if batch_reduction == "none" and point_reduction == "none":
raise ValueError(
'batch_reduction and point_reduction cannot both be "none".'
)
def chamfer_distance(
x,
y,
x_normals=None,
y_normals=None,
weights=None,
batch_reduction: str = "mean",
point_reduction: str = "mean",
):
"""
Chamfer distance between two pointclouds x and y.
Args:
x: FloatTensor of shape (N, P1, D) representing a batch of point clouds
with P1 points in each batch element, batch size N and feature
dimension D.
y: FloatTensor of shape (N, P2, D) representing a batch of point clouds
with P2 points in each batch element, batch size N and feature
dimension D.
x_normals: Optional FloatTensor of shape (N, P1, D).
y_normals: Optional FloatTensor of shape (N, P2, D).
weights: Optional FloatTensor of shape (N,) giving weights for
batch elements for reduction operation.
batch_reduction: Reduction operation to apply for the loss across the
batch, can be one of ["none", "mean", "sum"].
point_reduction: Reduction operation to apply for the loss across the
points, can be one of ["none", "mean", "sum"].
Returns:
2-element tuple containing
- **loss**: Tensor giving the reduced distance between the pointclouds
in x and the pointclouds in y.
- **loss_normals**: Tensor giving the reduced cosine distance of normals
between pointclouds in x and pointclouds in y. Returns None if
x_normals and y_normals are None.
"""
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
N, P1, D = x.shape
P2 = y.shape[1]
if y.shape[0] != N or y.shape[2] != D:
raise ValueError("y does not have the correct shape.")
if weights is not None:
if weights.size(0) != N:
raise ValueError("weights must be of shape (N,).")
if not (weights >= 0).all():
raise ValueError("weights can not be nonnegative.")
if weights.sum() == 0.0:
weights = weights.view(N, 1)
if batch_reduction in ["mean", "sum"]:
return (
(x.sum((1, 2)) * weights).sum() * 0.0,
(x.sum((1, 2)) * weights).sum() * 0.0,
)
return (
(x.sum((1, 2)) * weights) * 0.0,
(x.sum((1, 2)) * weights) * 0.0,
)
return_normals = x_normals is not None and y_normals is not None
cham_norm_x = x.new_zeros(())
cham_norm_y = x.new_zeros(())
x_near, xidx_near, x_normals_near = nn_points_idx(x, y, y_normals)
y_near, yidx_near, y_normals_near = nn_points_idx(y, x, x_normals)
cham_x = (x - x_near).norm(dim=2, p=2) ** 2.0 # (N, P1)
cham_y = (y - y_near).norm(dim=2, p=2) ** 2.0 # (N, P2)
if weights is not None:
cham_x *= weights.view(N, 1)
cham_y *= weights.view(N, 1)
if return_normals:
cham_norm_x = 1 - torch.abs(
F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)
)
cham_norm_y = 1 - torch.abs(
F.cosine_similarity(y_normals, y_normals_near, dim=2, eps=1e-6)
)
if weights is not None:
cham_norm_x *= weights.view(N, 1)
cham_norm_y *= weights.view(N, 1)
if point_reduction != "none":
# If not 'none' then either 'sum' or 'mean'.
cham_x = cham_x.sum(1) # (N,)
cham_y = cham_y.sum(1) # (N,)
if return_normals:
cham_norm_x = cham_norm_x.sum(1) # (N,)
cham_norm_y = cham_norm_y.sum(1) # (N,)
if point_reduction == "mean":
cham_x /= P1
cham_y /= P2
if return_normals:
cham_norm_x /= P1
cham_norm_y /= P2
if batch_reduction != "none":
cham_x = cham_x.sum()
cham_y = cham_y.sum()
if return_normals:
cham_norm_x = cham_norm_x.sum()
cham_norm_y = cham_norm_y.sum()
if batch_reduction == "mean":
div = weights.sum() if weights is not None else N
cham_x /= div
cham_y /= div
if return_normals:
cham_norm_x /= div
cham_norm_y /= div
cham_dist = cham_x + cham_y
cham_normals = cham_norm_x + cham_norm_y if return_normals else None
return cham_dist, cham_normals
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
def mesh_edge_loss(meshes, target_length: float = 0.0):
"""
Computes mesh edge length regularization loss averaged across all meshes
in a batch. Each edge contributes equally to the final loss, regardless of
numbers of edges per mesh in the batch by weighting each mesh with the
inverse number of edges. For example, if mesh 3 (out of N) has only E=4
edges, then the loss for each edge in mesh 3 should be multiplied by 1/E to
contribute to the final loss.
Args:
meshes: Meshes object with a batch of meshes.
target_length: Resting value for the edge length.
Returns:
loss: Average loss across the batch. Returns 0 if meshes contains
no meshes or all empty meshes.
"""
if meshes.isempty():
return torch.tensor(
[0.0], dtype=torch.float32, device=meshes.device, requires_grad=True
)
N = len(meshes)
edges_packed = meshes.edges_packed() # (sum(E_n), 3)
verts_packed = meshes.verts_packed() # (sum(V_n), 3)
edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx() # (sum(E_n), )
num_edges_per_mesh = meshes.num_edges_per_mesh() # N
# Determine the weight for each edge based on the number of edges in the
# mesh it corresponds to.
# TODO (nikhilar) Find a faster way of computing the weights for each edge
# as this is currently a bottleneck for meshes with a large number of faces.
weights = num_edges_per_mesh.gather(0, edge_to_mesh_idx)
weights = 1.0 / weights.float()
verts_edges = verts_packed[edges_packed]
v0, v1 = verts_edges.unbind(1)
loss = ((v0 - v1).norm(dim=1, p=2) - target_length) ** 2.0
loss = loss * weights
return loss.sum() / N
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment