Commit 15c72be4 authored by Nikhila Ravi's avatar Nikhila Ravi Committed by Facebook Github Bot
Browse files

Fix coordinate system conventions in renderer

Summary:
## Updates

- Defined the world and camera coordinates according to this figure. The world coordinates are defined as having +Y up, +X left and +Z in.

{F230888499}

- Removed all flipping from blending functions.
- Updated the rasterizer to return images with +Y up and +X left.
- Updated all the mesh rasterizer tests
    - The expected values are now defined in terms of the default +Y up, +X left
    - Added tests where the triangles in the meshes are non symmetrical so that it is clear which direction +X and +Y are

## Questions:
- Should we have **scene settings** instead of raster settings?
    - To be more correct we should be [z clipping in the rasterizer based on the far/near clipping planes](https://github.com/ShichenLiu/SoftRas/blob/master/soft_renderer/cuda/soft_rasterize_cuda_kernel.cu#L400) - these values are also required in the blending functions so should we make these scene level parameters and have a scene settings tuple which is available to the rasterizer and shader?

Reviewed By: gkioxari

Differential Revision: D20208604

fbshipit-source-id: 55787301b1bffa0afa9618f0a0886cc681da51f3
parent 767d68a3
...@@ -34,19 +34,22 @@ The differentiable renderer API is experimental and subject to change!. ...@@ -34,19 +34,22 @@ The differentiable renderer API is experimental and subject to change!.
### Coordinate transformation conventions ### Coordinate transformation conventions
Rendering requires transformations between several different coordinate frames: world space, view/camera space, NDC space and screen space. At each step it is important to know where the camera is located, how the x,y,z axes are aligned and the possible range of values. The following figure outlines the conventions used PyTorch3d. Rendering requires transformations between several different coordinate frames: world space, view/camera space, NDC space and screen space. At each step it is important to know where the camera is located, how the +X, +Y, +Z axes are aligned and the possible range of values. The following figure outlines the conventions used PyTorch3d.
<img src="assets/transformations_overview.png" width="1000"> <img src="assets/transformations_overview.png" width="1000">
For example, given a teapot mesh, the world coordinate frame, camera coordiante frame and image are show in the figure below. Note that the world and camera coordinate frames have the +z direction pointing in to the page.
<img src="assets/world_camera_image.png" width="1000">
--- ---
**NOTE: PyTorch3d vs OpenGL** **NOTE: PyTorch3d vs OpenGL**
While we tried to emulate several aspects of OpenGL, the NDC coordinate system in PyTorch3d is **right-handed** compared with a **left-handed** NDC coordinate system in OpenGL (the projection matrix switches the handedness). While we tried to emulate several aspects of OpenGL, there are differences in the coordinate frame conventions.
- The default world coordinate frame in PyTorch3D has +Z pointing in to the screen whereas in OpenGL, +Z is pointing out of the screen. Both are right handed.
In OpenGL, the camera at the origin is looking along `-z` axis in camera space, but it is looking along the `+z` axis in NDC space. - The NDC coordinate system in PyTorch3d is **right-handed** compared with a **left-handed** NDC coordinate system in OpenGL (the projection matrix switches the handedness).
<img align="center" src="assets/opengl_coordframes.png" width="300"> <img align="center" src="assets/opengl_coordframes.png" width="300">
...@@ -60,7 +63,7 @@ A renderer in PyTorch3d is composed of a **rasterizer** and a **shader**. Create ...@@ -60,7 +63,7 @@ A renderer in PyTorch3d is composed of a **rasterizer** and a **shader**. Create
from pytorch3d.renderer import ( from pytorch3d.renderer import (
OpenGLPerspectiveCameras, look_at_view_transform, OpenGLPerspectiveCameras, look_at_view_transform,
RasterizationSettings, BlendParams, RasterizationSettings, BlendParams,
MeshRenderer, MeshRasterizer, PhongShader MeshRenderer, MeshRasterizer, HardPhongShader
) )
# Initialize an OpenGL perspective camera. # Initialize an OpenGL perspective camera.
...@@ -81,7 +84,7 @@ raster_settings = RasterizationSettings( ...@@ -81,7 +84,7 @@ raster_settings = RasterizationSettings(
# PhongShader, passing in the device on which to initialize the default parameters # PhongShader, passing in the device on which to initialize the default parameters
renderer = MeshRenderer( renderer = MeshRenderer(
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
shader=PhongShader(device=device, cameras=cameras) shader=HardPhongShader(device=device, cameras=cameras)
) )
``` ```
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
...@@ -189,12 +189,12 @@ __global__ void RasterizeMeshesNaiveCudaKernel( ...@@ -189,12 +189,12 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
const float* face_verts, const float* face_verts,
const int64_t* mesh_to_face_first_idx, const int64_t* mesh_to_face_first_idx,
const int64_t* num_faces_per_mesh, const int64_t* num_faces_per_mesh,
float blur_radius, const float blur_radius,
bool perspective_correct, const bool perspective_correct,
int N, const int N,
int H, const int H,
int W, const int W,
int K, const int K,
int64_t* face_idxs, int64_t* face_idxs,
float* zbuf, float* zbuf,
float* pix_dists, float* pix_dists,
...@@ -207,8 +207,10 @@ __global__ void RasterizeMeshesNaiveCudaKernel( ...@@ -207,8 +207,10 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
// Convert linear index to 3D index // Convert linear index to 3D index
const int n = i / (H * W); // batch index. const int n = i / (H * W); // batch index.
const int pix_idx = i % (H * W); const int pix_idx = i % (H * W);
const int yi = pix_idx / H;
const int xi = pix_idx % W; // Determine ordering based on axis convention.
const int yi = H - 1 - pix_idx / W;
const int xi = W - 1 - pix_idx % W;
// screen coordinates to ndc coordiantes of pixel. // screen coordinates to ndc coordiantes of pixel.
const float xf = PixToNdc(xi, W); const float xf = PixToNdc(xi, W);
...@@ -254,7 +256,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel( ...@@ -254,7 +256,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
// TODO: make sorting an option as only top k is needed, not sorted values. // TODO: make sorting an option as only top k is needed, not sorted values.
BubbleSort(q, q_size); BubbleSort(q, q_size);
int idx = n * H * W * K + yi * H * K + xi * K; int idx = n * H * W * K + pix_idx * K;
for (int k = 0; k < q_size; ++k) { for (int k = 0; k < q_size; ++k) {
face_idxs[idx + k] = q[k].idx; face_idxs[idx + k] = q[k].idx;
zbuf[idx + k] = q[k].z; zbuf[idx + k] = q[k].z;
...@@ -274,7 +276,7 @@ RasterizeMeshesNaiveCuda( ...@@ -274,7 +276,7 @@ RasterizeMeshesNaiveCuda(
const int image_size, const int image_size,
const float blur_radius, const float blur_radius,
const int num_closest, const int num_closest,
bool perspective_correct) { const bool perspective_correct) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 || if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) { face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)"); AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
...@@ -331,12 +333,12 @@ RasterizeMeshesNaiveCuda( ...@@ -331,12 +333,12 @@ RasterizeMeshesNaiveCuda(
__global__ void RasterizeMeshesBackwardCudaKernel( __global__ void RasterizeMeshesBackwardCudaKernel(
const float* face_verts, // (F, 3, 3) const float* face_verts, // (F, 3, 3)
const int64_t* pix_to_face, // (N, H, W, K) const int64_t* pix_to_face, // (N, H, W, K)
bool perspective_correct, const bool perspective_correct,
int N, const int N,
int F, const int F,
int H, const int H,
int W, const int W,
int K, const int K,
const float* grad_zbuf, // (N, H, W, K) const float* grad_zbuf, // (N, H, W, K)
const float* grad_bary, // (N, H, W, K, 3) const float* grad_bary, // (N, H, W, K, 3)
const float* grad_dists, // (N, H, W, K) const float* grad_dists, // (N, H, W, K)
...@@ -351,8 +353,11 @@ __global__ void RasterizeMeshesBackwardCudaKernel( ...@@ -351,8 +353,11 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
// Convert linear index to 3D index // Convert linear index to 3D index
const int n = t_i / (H * W); // batch index. const int n = t_i / (H * W); // batch index.
const int pix_idx = t_i % (H * W); const int pix_idx = t_i % (H * W);
const int yi = pix_idx / H;
const int xi = pix_idx % W; // Determine ordering based on axis convention.
const int yi = H - 1 - pix_idx / W;
const int xi = W - 1 - pix_idx % W;
const float xf = PixToNdc(xi, W); const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H); const float yf = PixToNdc(yi, H);
const float2 pxy = make_float2(xf, yf); const float2 pxy = make_float2(xf, yf);
...@@ -360,8 +365,8 @@ __global__ void RasterizeMeshesBackwardCudaKernel( ...@@ -360,8 +365,8 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
// Loop over all the faces for this pixel. // Loop over all the faces for this pixel.
for (int k = 0; k < K; k++) { for (int k = 0; k < K; k++) {
// Index into (N, H, W, K, :) grad tensors // Index into (N, H, W, K, :) grad tensors
const int i = // pixel index + top k index
n * H * W * K + yi * H * K + xi * K + k; // pixel index + face index int i = n * H * W * K + pix_idx * K + k;
const int f = pix_to_face[i]; const int f = pix_to_face[i];
if (f < 0) { if (f < 0) {
...@@ -451,7 +456,7 @@ torch::Tensor RasterizeMeshesBackwardCuda( ...@@ -451,7 +456,7 @@ torch::Tensor RasterizeMeshesBackwardCuda(
const torch::Tensor& grad_zbuf, // (N, H, W, K) const torch::Tensor& grad_zbuf, // (N, H, W, K)
const torch::Tensor& grad_bary, // (N, H, W, K, 3) const torch::Tensor& grad_bary, // (N, H, W, K, 3)
const torch::Tensor& grad_dists, // (N, H, W, K) const torch::Tensor& grad_dists, // (N, H, W, K)
bool perspective_correct) { const bool perspective_correct) {
const int F = face_verts.size(0); const int F = face_verts.size(0);
const int N = pix_to_face.size(0); const int N = pix_to_face.size(0);
const int H = pix_to_face.size(1); const int H = pix_to_face.size(1);
...@@ -509,6 +514,7 @@ __global__ void RasterizeMeshesCoarseCudaKernel( ...@@ -509,6 +514,7 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
// Have each block handle a chunk of faces // Have each block handle a chunk of faces
const int chunks_per_batch = 1 + (F - 1) / chunk_size; const int chunks_per_batch = 1 + (F - 1) / chunk_size;
const int num_chunks = N * chunks_per_batch; const int num_chunks = N * chunks_per_batch;
for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) { for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
const int batch_idx = chunk / chunks_per_batch; // batch index const int batch_idx = chunk / chunks_per_batch; // batch index
const int chunk_idx = chunk % chunks_per_batch; const int chunk_idx = chunk % chunks_per_batch;
...@@ -551,17 +557,21 @@ __global__ void RasterizeMeshesCoarseCudaKernel( ...@@ -551,17 +557,21 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
// Y coordinate of the top and bottom of the bin. // Y coordinate of the top and bottom of the bin.
// PixToNdc gives the location of the center of each pixel, so we // PixToNdc gives the location of the center of each pixel, so we
// need to add/subtract a half pixel to get the true extent of the bin. // need to add/subtract a half pixel to get the true extent of the bin.
const float bin_y_min = PixToNdc(by * bin_size, H) - half_pix; // Reverse ordering of Y axis so that +Y is upwards in the image.
const float bin_y_max = PixToNdc((by + 1) * bin_size - 1, H) + half_pix; const int yidx = num_bins - by;
float bin_y_max = PixToNdc(yidx * bin_size - 1, H) + half_pix;
float bin_y_min = PixToNdc((yidx - 1) * bin_size, H) - half_pix;
const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax); const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax);
for (int bx = 0; bx < num_bins; ++bx) { for (int bx = 0; bx < num_bins; ++bx) {
// X coordinate of the left and right of the bin. // X coordinate of the left and right of the bin.
const float bin_x_min = PixToNdc(bx * bin_size, W) - half_pix; // Reverse ordering of x axis so that +X is left.
const float bin_x_max = const int xidx = num_bins - bx;
PixToNdc((bx + 1) * bin_size - 1, W) + half_pix; float bin_x_max = PixToNdc(xidx * bin_size - 1, W) + half_pix;
const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax); float bin_x_min = PixToNdc((xidx - 1) * bin_size, W) - half_pix;
const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
if (y_overlap && x_overlap) { if (y_overlap && x_overlap) {
binmask.set(by, bx, f); binmask.set(by, bx, f);
} }
...@@ -654,7 +664,6 @@ torch::Tensor RasterizeMeshesCoarseCuda( ...@@ -654,7 +664,6 @@ torch::Tensor RasterizeMeshesCoarseCuda(
// **************************************************************************** // ****************************************************************************
// * FINE RASTERIZATION * // * FINE RASTERIZATION *
// **************************************************************************** // ****************************************************************************
__global__ void RasterizeMeshesFineCudaKernel( __global__ void RasterizeMeshesFineCudaKernel(
const float* face_verts, // (F, 3, 3) const float* face_verts, // (F, 3, 3)
const int32_t* bin_faces, // (N, B, B, T) const int32_t* bin_faces, // (N, B, B, T)
...@@ -695,8 +704,14 @@ __global__ void RasterizeMeshesFineCudaKernel( ...@@ -695,8 +704,14 @@ __global__ void RasterizeMeshesFineCudaKernel(
if (yi >= H || xi >= W) if (yi >= H || xi >= W)
continue; continue;
const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H); // Reverse ordering of the X and Y axis so that
// in the image +Y is pointing up and +X is pointing left.
const int yidx = H - 1 - yi;
const int xidx = W - 1 - xi;
const float xf = PixToNdc(xidx, W);
const float yf = PixToNdc(yidx, H);
const float2 pxy = make_float2(xf, yf); const float2 pxy = make_float2(xf, yf);
// This part looks like the naive rasterization kernel, except we use // This part looks like the naive rasterization kernel, except we use
...@@ -751,7 +766,7 @@ RasterizeMeshesFineCuda( ...@@ -751,7 +766,7 @@ RasterizeMeshesFineCuda(
const float blur_radius, const float blur_radius,
const int bin_size, const int bin_size,
const int faces_per_pixel, const int faces_per_pixel,
bool perspective_correct) { const bool perspective_correct) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 || if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) { face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)"); AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
......
...@@ -14,10 +14,10 @@ RasterizeMeshesNaiveCpu( ...@@ -14,10 +14,10 @@ RasterizeMeshesNaiveCpu(
const torch::Tensor& face_verts, const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx, const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh, const torch::Tensor& num_faces_per_mesh,
int image_size, const int image_size,
float blur_radius, const float blur_radius,
int faces_per_pixel, const int faces_per_pixel,
bool perspective_correct); const bool perspective_correct);
#ifdef WITH_CUDA #ifdef WITH_CUDA
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
...@@ -25,10 +25,10 @@ RasterizeMeshesNaiveCuda( ...@@ -25,10 +25,10 @@ RasterizeMeshesNaiveCuda(
const at::Tensor& face_verts, const at::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx, const at::Tensor& mesh_to_face_first_idx,
const at::Tensor& num_faces_per_mesh, const at::Tensor& num_faces_per_mesh,
int image_size, const int image_size,
float blur_radius, const float blur_radius,
int num_closest, const int num_closest,
bool perspective_correct); const bool perspective_correct);
#endif #endif
// Forward pass for rasterizing a batch of meshes. // Forward pass for rasterizing a batch of meshes.
// //
...@@ -77,10 +77,10 @@ RasterizeMeshesNaive( ...@@ -77,10 +77,10 @@ RasterizeMeshesNaive(
const torch::Tensor& face_verts, const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx, const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh, const torch::Tensor& num_faces_per_mesh,
int image_size, const int image_size,
float blur_radius, const float blur_radius,
int faces_per_pixel, const int faces_per_pixel,
bool perspective_correct) { const bool perspective_correct) {
// TODO: Better type checking. // TODO: Better type checking.
if (face_verts.type().is_cuda()) { if (face_verts.type().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
...@@ -117,7 +117,7 @@ torch::Tensor RasterizeMeshesBackwardCpu( ...@@ -117,7 +117,7 @@ torch::Tensor RasterizeMeshesBackwardCpu(
const torch::Tensor& grad_bary, const torch::Tensor& grad_bary,
const torch::Tensor& grad_zbuf, const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_dists, const torch::Tensor& grad_dists,
bool perspective_correct); const bool perspective_correct);
#ifdef WITH_CUDA #ifdef WITH_CUDA
torch::Tensor RasterizeMeshesBackwardCuda( torch::Tensor RasterizeMeshesBackwardCuda(
...@@ -126,7 +126,7 @@ torch::Tensor RasterizeMeshesBackwardCuda( ...@@ -126,7 +126,7 @@ torch::Tensor RasterizeMeshesBackwardCuda(
const torch::Tensor& grad_bary, const torch::Tensor& grad_bary,
const torch::Tensor& grad_zbuf, const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_dists, const torch::Tensor& grad_dists,
bool perspective_correct); const bool perspective_correct);
#endif #endif
// Args: // Args:
...@@ -159,7 +159,7 @@ torch::Tensor RasterizeMeshesBackward( ...@@ -159,7 +159,7 @@ torch::Tensor RasterizeMeshesBackward(
const torch::Tensor& grad_zbuf, const torch::Tensor& grad_zbuf,
const torch::Tensor& grad_bary, const torch::Tensor& grad_bary,
const torch::Tensor& grad_dists, const torch::Tensor& grad_dists,
bool perspective_correct) { const bool perspective_correct) {
if (face_verts.type().is_cuda()) { if (face_verts.type().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
return RasterizeMeshesBackwardCuda( return RasterizeMeshesBackwardCuda(
...@@ -191,20 +191,20 @@ torch::Tensor RasterizeMeshesCoarseCpu( ...@@ -191,20 +191,20 @@ torch::Tensor RasterizeMeshesCoarseCpu(
const torch::Tensor& face_verts, const torch::Tensor& face_verts,
const at::Tensor& mesh_to_face_first_idx, const at::Tensor& mesh_to_face_first_idx,
const at::Tensor& num_faces_per_mesh, const at::Tensor& num_faces_per_mesh,
int image_size, const int image_size,
float blur_radius, const float blur_radius,
int bin_size, const int bin_size,
int max_faces_per_bin); const int max_faces_per_bin);
#ifdef WITH_CUDA #ifdef WITH_CUDA
torch::Tensor RasterizeMeshesCoarseCuda( torch::Tensor RasterizeMeshesCoarseCuda(
const torch::Tensor& face_verts, const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx, const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh, const torch::Tensor& num_faces_per_mesh,
int image_size, const int image_size,
float blur_radius, const float blur_radius,
int bin_size, const int bin_size,
int max_faces_per_bin); const int max_faces_per_bin);
#endif #endif
// Args: // Args:
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for // face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
...@@ -232,10 +232,10 @@ torch::Tensor RasterizeMeshesCoarse( ...@@ -232,10 +232,10 @@ torch::Tensor RasterizeMeshesCoarse(
const torch::Tensor& face_verts, const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx, const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh, const torch::Tensor& num_faces_per_mesh,
int image_size, const int image_size,
float blur_radius, const float blur_radius,
int bin_size, const int bin_size,
int max_faces_per_bin) { const int max_faces_per_bin) {
if (face_verts.type().is_cuda()) { if (face_verts.type().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
return RasterizeMeshesCoarseCuda( return RasterizeMeshesCoarseCuda(
...@@ -270,11 +270,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> ...@@ -270,11 +270,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesFineCuda( RasterizeMeshesFineCuda(
const torch::Tensor& face_verts, const torch::Tensor& face_verts,
const torch::Tensor& bin_faces, const torch::Tensor& bin_faces,
int image_size, const int image_size,
float blur_radius, const float blur_radius,
int bin_size, const int bin_size,
int faces_per_pixel, const int faces_per_pixel,
bool perspective_correct); const bool perspective_correct);
#endif #endif
// Args: // Args:
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for // face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
...@@ -317,11 +317,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> ...@@ -317,11 +317,11 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>
RasterizeMeshesFine( RasterizeMeshesFine(
const torch::Tensor& face_verts, const torch::Tensor& face_verts,
const torch::Tensor& bin_faces, const torch::Tensor& bin_faces,
int image_size, const int image_size,
float blur_radius, const float blur_radius,
int bin_size, const int bin_size,
int faces_per_pixel, const int faces_per_pixel,
bool perspective_correct) { const bool perspective_correct) {
if (face_verts.type().is_cuda()) { if (face_verts.type().is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
return RasterizeMeshesFineCuda( return RasterizeMeshesFineCuda(
...@@ -373,6 +373,7 @@ RasterizeMeshesFine( ...@@ -373,6 +373,7 @@ RasterizeMeshesFine(
// this function instead returns screen-space // this function instead returns screen-space
// barycentric coordinates for each pixel. // barycentric coordinates for each pixel.
// //
//
// Returns: // Returns:
// A 4 element tuple of: // A 4 element tuple of:
// pix_to_face: int64 tensor of shape (N, H, W, K) giving the face index of // pix_to_face: int64 tensor of shape (N, H, W, K) giving the face index of
...@@ -394,12 +395,12 @@ RasterizeMeshes( ...@@ -394,12 +395,12 @@ RasterizeMeshes(
const torch::Tensor& face_verts, const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx, const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh, const torch::Tensor& num_faces_per_mesh,
int image_size, const int image_size,
float blur_radius, const float blur_radius,
int faces_per_pixel, const int faces_per_pixel,
int bin_size, const int bin_size,
int max_faces_per_bin, const int max_faces_per_bin,
bool perspective_correct) { const bool perspective_correct) {
if (bin_size > 0 && max_faces_per_bin > 0) { if (bin_size > 0 && max_faces_per_bin > 0) {
// Use coarse-to-fine rasterization // Use coarse-to-fine rasterization
auto bin_faces = RasterizeMeshesCoarse( auto bin_faces = RasterizeMeshesCoarse(
......
...@@ -105,9 +105,9 @@ RasterizeMeshesNaiveCpu( ...@@ -105,9 +105,9 @@ RasterizeMeshesNaiveCpu(
const torch::Tensor& mesh_to_face_first_idx, const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh, const torch::Tensor& num_faces_per_mesh,
int image_size, int image_size,
float blur_radius, const float blur_radius,
int faces_per_pixel, const int faces_per_pixel,
bool perspective_correct) { const bool perspective_correct) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 || if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) { face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)"); AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
...@@ -153,12 +153,19 @@ RasterizeMeshesNaiveCpu( ...@@ -153,12 +153,19 @@ RasterizeMeshesNaiveCpu(
// Iterate through the horizontal lines of the image from top to bottom. // Iterate through the horizontal lines of the image from top to bottom.
for (int yi = 0; yi < H; ++yi) { for (int yi = 0; yi < H; ++yi) {
// Reverse the order of yi so that +Y is pointing upwards in the image.
const int yidx = H - 1 - yi;
// Y coordinate of the top of the pixel. // Y coordinate of the top of the pixel.
const float yf = PixToNdc(yi, H); const float yf = PixToNdc(yidx, H);
// Iterate through pixels on this horizontal line, left to right. // Iterate through pixels on this horizontal line, left to right.
for (int xi = 0; xi < W; ++xi) { for (int xi = 0; xi < W; ++xi) {
// Reverse the order of xi so that +X is pointing to the left in the
// image.
const int xidx = W - 1 - xi;
// X coordinate of the left of the pixel. // X coordinate of the left of the pixel.
const float xf = PixToNdc(xi, W); const float xf = PixToNdc(xidx, W);
// Use a priority queue to hold values: // Use a priority queue to hold values:
// (z, idx, r, bary.x, bary.y. bary.z) // (z, idx, r, bary.x, bary.y. bary.z)
std::priority_queue<std::tuple<float, int, float, float, float, float>> std::priority_queue<std::tuple<float, int, float, float, float, float>>
...@@ -250,7 +257,7 @@ torch::Tensor RasterizeMeshesBackwardCpu( ...@@ -250,7 +257,7 @@ torch::Tensor RasterizeMeshesBackwardCpu(
const torch::Tensor& grad_zbuf, // (N, H, W, K) const torch::Tensor& grad_zbuf, // (N, H, W, K)
const torch::Tensor& grad_bary, // (N, H, W, K, 3) const torch::Tensor& grad_bary, // (N, H, W, K, 3)
const torch::Tensor& grad_dists, // (N, H, W, K) const torch::Tensor& grad_dists, // (N, H, W, K)
bool perspective_correct) { const bool perspective_correct) {
const int F = face_verts.size(0); const int F = face_verts.size(0);
const int N = pix_to_face.size(0); const int N = pix_to_face.size(0);
const int H = pix_to_face.size(1); const int H = pix_to_face.size(1);
...@@ -267,12 +274,19 @@ torch::Tensor RasterizeMeshesBackwardCpu( ...@@ -267,12 +274,19 @@ torch::Tensor RasterizeMeshesBackwardCpu(
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
// Iterate through the horizontal lines of the image from top to bottom. // Iterate through the horizontal lines of the image from top to bottom.
for (int y = 0; y < H; ++y) { for (int y = 0; y < H; ++y) {
// Reverse the order of yi so that +Y is pointing upwards in the image.
const int yidx = H - 1 - y;
// Y coordinate of the top of the pixel. // Y coordinate of the top of the pixel.
const float yf = PixToNdc(y, H); const float yf = PixToNdc(yidx, H);
// Iterate through pixels on this horizontal line, left to right. // Iterate through pixels on this horizontal line, left to right.
for (int x = 0; x < W; ++x) { for (int x = 0; x < W; ++x) {
// Reverse the order of xi so that +X is pointing to the left in the
// image.
const int xidx = W - 1 - x;
// X coordinate of the left of the pixel. // X coordinate of the left of the pixel.
const float xf = PixToNdc(x, W); const float xf = PixToNdc(xidx, W);
const vec2<float> pxy(xf, yf); const vec2<float> pxy(xf, yf);
// Iterate through the faces that hit this pixel. // Iterate through the faces that hit this pixel.
...@@ -376,10 +390,10 @@ torch::Tensor RasterizeMeshesCoarseCpu( ...@@ -376,10 +390,10 @@ torch::Tensor RasterizeMeshesCoarseCpu(
const torch::Tensor& face_verts, const torch::Tensor& face_verts,
const torch::Tensor& mesh_to_face_first_idx, const torch::Tensor& mesh_to_face_first_idx,
const torch::Tensor& num_faces_per_mesh, const torch::Tensor& num_faces_per_mesh,
int image_size, const int image_size,
float blur_radius, const float blur_radius,
int bin_size, const int bin_size,
int max_faces_per_bin) { const int max_faces_per_bin) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 || if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) { face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)"); AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
...@@ -387,6 +401,7 @@ torch::Tensor RasterizeMeshesCoarseCpu( ...@@ -387,6 +401,7 @@ torch::Tensor RasterizeMeshesCoarseCpu(
if (num_faces_per_mesh.ndimension() != 1) { if (num_faces_per_mesh.ndimension() != 1) {
AT_ERROR("num_faces_per_mesh can only have one dimension"); AT_ERROR("num_faces_per_mesh can only have one dimension");
} }
const int N = num_faces_per_mesh.size(0); // batch size. const int N = num_faces_per_mesh.size(0); // batch size.
const int M = max_faces_per_bin; const int M = max_faces_per_bin;
...@@ -415,13 +430,13 @@ torch::Tensor RasterizeMeshesCoarseCpu( ...@@ -415,13 +430,13 @@ torch::Tensor RasterizeMeshesCoarseCpu(
const int face_stop_idx = const int face_stop_idx =
(face_start_idx + num_faces_per_mesh[n].item().to<int32_t>()); (face_start_idx + num_faces_per_mesh[n].item().to<int32_t>());
float bin_y_min = -1.0f; float bin_y_max = 1.0f;
float bin_y_max = bin_y_min + bin_width; float bin_y_min = bin_y_max - bin_width;
// Iterate through the horizontal bins from top to bottom. // Iterate through the horizontal bins from top to bottom.
for (int by = 0; by < BH; ++by) { for (int by = 0; by < BH; ++by) {
float bin_x_min = -1.0f; float bin_x_max = 1.0f;
float bin_x_max = bin_x_min + bin_width; float bin_x_min = bin_x_max - bin_width;
// Iterate through bins on this horizontal line, left to right. // Iterate through bins on this horizontal line, left to right.
for (int bx = 0; bx < BW; ++bx) { for (int bx = 0; bx < BW; ++bx) {
...@@ -458,13 +473,13 @@ torch::Tensor RasterizeMeshesCoarseCpu( ...@@ -458,13 +473,13 @@ torch::Tensor RasterizeMeshesCoarseCpu(
} }
} }
// Shift the bin to the right for the next loop iteration. // Shift the bin down for the next loop iteration.
bin_x_min = bin_x_max; bin_x_max = bin_x_min;
bin_x_max = bin_x_min + bin_width; bin_x_min = bin_x_min - bin_width;
} }
// Shift the bin down for the next loop iteration. // Shift the bin left for the next loop iteration.
bin_y_min = bin_y_max; bin_y_max = bin_y_min;
bin_y_max = bin_y_min + bin_width; bin_y_min = bin_y_min - bin_width;
} }
} }
return bin_faces; return bin_faces;
......
...@@ -38,7 +38,7 @@ def hard_rgb_blend(colors, fragments) -> torch.Tensor: ...@@ -38,7 +38,7 @@ def hard_rgb_blend(colors, fragments) -> torch.Tensor:
device = fragments.pix_to_face.device device = fragments.pix_to_face.device
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=device) pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=device)
pixel_colors[..., :3] = colors[..., 0, :] pixel_colors[..., :3] = colors[..., 0, :]
return torch.flip(pixel_colors, [1]) return pixel_colors
def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor: def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
...@@ -80,7 +80,7 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor: ...@@ -80,7 +80,7 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
alpha = torch.prod((1.0 - prob), dim=-1) alpha = torch.prod((1.0 - prob), dim=-1)
pixel_colors[..., :3] = colors[..., 0, :] # Hard assign for RGB pixel_colors[..., :3] = colors[..., 0, :] # Hard assign for RGB
pixel_colors[..., 3] = 1.0 - alpha pixel_colors[..., 3] = 1.0 - alpha
return torch.flip(pixel_colors, [1]) return pixel_colors
def softmax_rgb_blend( def softmax_rgb_blend(
...@@ -125,7 +125,7 @@ def softmax_rgb_blend( ...@@ -125,7 +125,7 @@ def softmax_rgb_blend(
N, H, W, K = fragments.pix_to_face.shape N, H, W, K = fragments.pix_to_face.shape
device = fragments.pix_to_face.device device = fragments.pix_to_face.device
pix_colors = torch.ones( pixel_colors = torch.ones(
(N, H, W, 4), dtype=colors.dtype, device=colors.device (N, H, W, 4), dtype=colors.dtype, device=colors.device
) )
background = blend_params.background_color background = blend_params.background_color
...@@ -166,7 +166,7 @@ def softmax_rgb_blend( ...@@ -166,7 +166,7 @@ def softmax_rgb_blend(
# Sum: weights * textures + background color # Sum: weights * textures + background color
weighted_colors = (weights[..., None] * colors).sum(dim=-2) weighted_colors = (weights[..., None] * colors).sum(dim=-2)
weighted_background = (delta / denom) * background weighted_background = (delta / denom) * background
pix_colors[..., :3] = weighted_colors + weighted_background pixel_colors[..., :3] = weighted_colors + weighted_background
pix_colors[..., 3] = 1.0 - alpha pixel_colors[..., 3] = 1.0 - alpha
return torch.flip(pix_colors, [1]) return pixel_colors
...@@ -944,7 +944,7 @@ def camera_position_from_spherical_angles( ...@@ -944,7 +944,7 @@ def camera_position_from_spherical_angles(
azim = math.pi / 180.0 * azim azim = math.pi / 180.0 * azim
x = dist * torch.cos(elev) * torch.sin(azim) x = dist * torch.cos(elev) * torch.sin(azim)
y = dist * torch.sin(elev) y = dist * torch.sin(elev)
z = -dist * torch.cos(elev) * torch.cos(azim) z = dist * torch.cos(elev) * torch.cos(azim)
camera_position = torch.stack([x, y, z], dim=1) camera_position = torch.stack([x, y, z], dim=1)
if camera_position.dim() == 0: if camera_position.dim() == 0:
camera_position = camera_position.view(1, -1) # add batch dim. camera_position = camera_position.view(1, -1) # add batch dim.
......
...@@ -208,6 +208,11 @@ class _RasterizeFaceVerts(torch.autograd.Function): ...@@ -208,6 +208,11 @@ class _RasterizeFaceVerts(torch.autograd.Function):
return grads return grads
def pix_to_ndc(i, S):
# NDC x-offset + (i * pixel_width + half_pixel_width)
return -1 + (2 * i + 1.0) / S
def rasterize_meshes_python( def rasterize_meshes_python(
meshes, meshes,
image_size: int = 256, image_size: int = 256,
...@@ -249,10 +254,6 @@ def rasterize_meshes_python( ...@@ -249,10 +254,6 @@ def rasterize_meshes_python(
(N, H, W, K), fill_value=-1, dtype=torch.float32, device=device (N, H, W, K), fill_value=-1, dtype=torch.float32, device=device
) )
# NDC is from [-1, 1]. Get pixel size using specified image size.
pixel_width = 2.0 / W
pixel_height = 2.0 / H
# Calculate all face bounding boxes. # Calculate all face bounding boxes.
x_mins = torch.min(faces_verts[:, :, 0], dim=1, keepdim=True).values x_mins = torch.min(faces_verts[:, :, 0], dim=1, keepdim=True).values
x_maxs = torch.max(faces_verts[:, :, 0], dim=1, keepdim=True).values x_maxs = torch.max(faces_verts[:, :, 0], dim=1, keepdim=True).values
...@@ -269,14 +270,20 @@ def rasterize_meshes_python( ...@@ -269,14 +270,20 @@ def rasterize_meshes_python(
for n in range(N): for n in range(N):
face_start_idx = mesh_to_face_first_idx[n] face_start_idx = mesh_to_face_first_idx[n]
face_stop_idx = face_start_idx + num_faces_per_mesh[n] face_stop_idx = face_start_idx + num_faces_per_mesh[n]
# Y coordinate of the top of the image.
yf = -1.0 + 0.5 * pixel_height
# Iterate through the horizontal lines of the image from top to bottom. # Iterate through the horizontal lines of the image from top to bottom.
for yi in range(H): for yi in range(H):
# X coordinate of the left of the image. # Y coordinate of one end of the image. Reverse the ordering
xf = -1.0 + 0.5 * pixel_width # of yi so that +Y is pointing up in the image.
yfix = H - 1 - yi
yf = pix_to_ndc(yfix, H)
# Iterate through pixels on this horizontal line, left to right. # Iterate through pixels on this horizontal line, left to right.
for xi in range(W): for xi in range(W):
# X coordinate of one end of the image. Reverse the ordering
# of xi so that +X is pointing to the left in the image.
xfix = W - 1 - xi
xf = pix_to_ndc(xfix, H)
top_k_points = [] top_k_points = []
# Check whether each face in the mesh affects this pixel. # Check whether each face in the mesh affects this pixel.
...@@ -347,12 +354,6 @@ def rasterize_meshes_python( ...@@ -347,12 +354,6 @@ def rasterize_meshes_python(
bary_coords[n, yi, xi, k, 2] = bary[2] bary_coords[n, yi, xi, k, 2] = bary[2]
pix_dists[n, yi, xi, k] = dist pix_dists[n, yi, xi, k] = dist
# Move to the next horizontal pixel
xf += pixel_width
# Move to the next vertical pixel
yf += pixel_height
return face_idxs, zbuf, bary_coords, pix_dists return face_idxs, zbuf, bary_coords, pix_dists
......
...@@ -53,7 +53,7 @@ class MeshRenderer(nn.Module): ...@@ -53,7 +53,7 @@ class MeshRenderer(nn.Module):
if raster_settings.blur_radius > 0.0: if raster_settings.blur_radius > 0.0:
# TODO: potentially move barycentric clipping to the rasterizer # TODO: potentially move barycentric clipping to the rasterizer
# if no downstream functions requires unclipped values. # if no downstream functions requires unclipped values.
# This will avoid unnecssary re-interpolation of the z buffer. # This will avoid unnecssary re-interpolation of the z buffer.
clipped_bary_coords = _clip_barycentric_coordinates( clipped_bary_coords = _clip_barycentric_coordinates(
fragments.bary_coords fragments.bary_coords
) )
...@@ -67,4 +67,5 @@ class MeshRenderer(nn.Module): ...@@ -67,4 +67,5 @@ class MeshRenderer(nn.Module):
pix_to_face=fragments.pix_to_face, pix_to_face=fragments.pix_to_face,
) )
images = self.shader(fragments, meshes_world, **kwargs) images = self.shader(fragments, meshes_world, **kwargs)
return images return images
tests/data/test_silhouette.png

8.9 KB | W: | H:

tests/data/test_silhouette.png

8.84 KB | W: | H:

tests/data/test_silhouette.png
tests/data/test_silhouette.png
tests/data/test_silhouette.png
tests/data/test_silhouette.png
  • 2-up
  • Swipe
  • Onion skin
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment