Commit b19fe1de authored by Christoph Lassner's avatar Christoph Lassner Committed by Facebook GitHub Bot
Browse files

pulsar integration.

Summary:
This diff integrates the pulsar renderer source code into PyTorch3D as an alternative backend for the PyTorch3D point renderer. This diff is the first of a series of three diffs to complete that migration and focuses on the packaging and integration of the source code.

For more information about the pulsar backend, see the release notes and the paper (https://arxiv.org/abs/2004.07484). For information on how to use the backend, see the point cloud rendering notebook and the examples in the folder `docs/examples`.

Tasks addressed in the following diffs:
* Add the PyTorch3D interface,
* Add notebook examples and documentation (or adapt the existing ones to feature both interfaces).

Reviewed By: nikhilaravi

Differential Revision: D23947736

fbshipit-source-id: a5e77b53e6750334db22aefa89b4c079cda1b443
parent d5650323
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include "./renderer.forward.device.h"
namespace pulsar {
namespace Renderer {
template void forward<ISONDEVICE>(
Renderer* self,
const float* vert_pos,
const float* vert_col,
const float* vert_rad,
const CamInfo& cam,
const float& gamma,
float percent_allowed_difference,
const uint& max_n_hits,
const float* bg_col_d,
const float* opacity_d,
const size_t& num_balls,
const uint& mode,
cudaStream_t stream);
} // namespace Renderer
} // namespace pulsar
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_GET_SCREEN_AREA_DEVICE_H_
#define PULSAR_NATIVE_INCLUDE_RENDERER_GET_SCREEN_AREA_DEVICE_H_
#include "../global.h"
#include "./camera.device.h"
#include "./commands.h"
#include "./math.h"
namespace pulsar {
namespace Renderer {
/**
* Find the closest enclosing screen area rectangle in pixels that encloses a
* ball.
*
* The method returns the two x and the two y values of the boundaries. They
* are not ordered yet and you need to find min and max for the left/right and
* lower/upper boundary.
*
* The return values are floats and need to be rounded appropriately.
*/
INLINE DEVICE bool get_screen_area(
const float3& ball_center_cam,
const float3& ray_center_norm,
const float& vert_rad,
const CamInfo& cam,
const uint& idx,
/* Out variables. */
float* x_1,
float* x_2,
float* y_1,
float* y_2) {
float cos_alpha = dot(cam.sensor_dir_z, ray_center_norm);
float2 o__c_, alpha, theta;
if (cos_alpha < EPS) {
PULSAR_LOG_DEV(
PULSAR_LOG_CALC_SIGNATURE,
"signature %d|ball not visible. cos_alpha: %.9f.\n",
idx,
cos_alpha);
// No intersection, ball won't be visible.
return false;
}
// Multiply the direction vector with the camera rotation matrix
// to have the optical axis being the canonical z vector (0, 0, 1).
// TODO: optimize.
const float3 ball_center_cam_rot = rotate(
ball_center_cam,
cam.pixel_dir_x / length(cam.pixel_dir_x),
cam.pixel_dir_y / length(cam.pixel_dir_y),
cam.sensor_dir_z);
PULSAR_LOG_DEV(
PULSAR_LOG_CALC_SIGNATURE,
"signature %d|ball_center_cam_rot: %f, %f, %f.\n",
idx,
ball_center_cam.x,
ball_center_cam.y,
ball_center_cam.z);
const float pixel_size_norm_fac = FRCP(2.f * cam.half_pixel_size);
const float optical_offset_x =
(static_cast<float>(cam.aperture_width) - 1.f) * .5f;
const float optical_offset_y =
(static_cast<float>(cam.aperture_height) - 1.f) * .5f;
if (cam.orthogonal_projection) {
*x_1 =
FMA(ball_center_cam_rot.x - vert_rad,
pixel_size_norm_fac,
optical_offset_x);
*x_2 =
FMA(ball_center_cam_rot.x + vert_rad,
pixel_size_norm_fac,
optical_offset_x);
*y_1 =
FMA(ball_center_cam_rot.y - vert_rad,
pixel_size_norm_fac,
optical_offset_y);
*y_2 =
FMA(ball_center_cam_rot.y + vert_rad,
pixel_size_norm_fac,
optical_offset_y);
return true;
} else {
o__c_.x = FMAX(
FSQRT(
ball_center_cam_rot.x * ball_center_cam_rot.x +
ball_center_cam_rot.z * ball_center_cam_rot.z),
FEPS);
o__c_.y = FMAX(
FSQRT(
ball_center_cam_rot.y * ball_center_cam_rot.y +
ball_center_cam_rot.z * ball_center_cam_rot.z),
FEPS);
PULSAR_LOG_DEV(
PULSAR_LOG_CALC_SIGNATURE,
"signature %d|o__c_: %f, %f.\n",
idx,
o__c_.x,
o__c_.y);
alpha.x = sign_dir(ball_center_cam_rot.x) *
acos(FMIN(FMAX(ball_center_cam_rot.z / o__c_.x, -1.f), 1.f));
alpha.y = -sign_dir(ball_center_cam_rot.y) *
acos(FMIN(FMAX(ball_center_cam_rot.z / o__c_.y, -1.f), 1.f));
theta.x = asin(FMIN(FMAX(vert_rad / o__c_.x, -1.f), 1.f));
theta.y = asin(FMIN(FMAX(vert_rad / o__c_.y, -1.f), 1.f));
PULSAR_LOG_DEV(
PULSAR_LOG_CALC_SIGNATURE,
"signature %d|alpha.x: %f, alpha.y: %f, theta.x: %f, theta.y: %f.\n",
idx,
alpha.x,
alpha.y,
theta.x,
theta.y);
*x_1 = tan(alpha.x - theta.x) * cam.focal_length;
*x_2 = tan(alpha.x + theta.x) * cam.focal_length;
*y_1 = tan(alpha.y - theta.y) * cam.focal_length;
*y_2 = tan(alpha.y + theta.y) * cam.focal_length;
PULSAR_LOG_DEV(
PULSAR_LOG_CALC_SIGNATURE,
"signature %d|in sensor plane: x_1: %f, x_2: %f, y_1: %f, y_2: %f.\n",
idx,
*x_1,
*x_2,
*y_1,
*y_2);
*x_1 = FMA(*x_1, pixel_size_norm_fac, optical_offset_x);
*x_2 = FMA(*x_2, pixel_size_norm_fac, optical_offset_x);
*y_1 = FMA(*y_1, -pixel_size_norm_fac, optical_offset_y);
*y_2 = FMA(*y_2, -pixel_size_norm_fac, optical_offset_y);
return true;
}
};
} // namespace Renderer
} // namespace pulsar
#endif
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_H_
#define PULSAR_NATIVE_INCLUDE_RENDERER_H_
#include <algorithm>
#include "../global.h"
#include "./camera.h"
namespace pulsar {
namespace Renderer {
//! Remember to order struct members from larger size to smaller size
//! to avoid padding (for more info, see for example here:
//! http://www.catb.org/esr/structure-packing/).
/**
* This is the information that's needed to do a fast screen point
* intersection with one of the balls.
*
* Aim to keep this below 8 bytes (256 bytes per cache-line / 32 threads in a
* warp = 8 bytes per thread).
*/
struct IntersectInfo {
ushort2 min; /** minimum x, y in pixel coordinates. */
ushort2 max; /** maximum x, y in pixel coordinates. */
};
static_assert(
sizeof(IntersectInfo) == 8,
"The compiled size of `IntersectInfo` is wrong.");
/**
* Reduction operation to find the limits of multiple IntersectInfo objects.
*/
struct IntersectInfoMinMax {
IHD IntersectInfo
operator()(const IntersectInfo& a, const IntersectInfo& b) const {
// Treat the special case of an invalid intersect info object or one for
// a ball out of bounds.
if (b.max.x == MAX_USHORT && b.min.x == MAX_USHORT &&
b.max.y == MAX_USHORT && b.min.y == MAX_USHORT) {
return a;
}
if (a.max.x == MAX_USHORT && a.min.x == MAX_USHORT &&
a.max.y == MAX_USHORT && a.min.y == MAX_USHORT) {
return b;
}
IntersectInfo result;
result.min.x = std::min<ushort>(a.min.x, b.min.x);
result.min.y = std::min<ushort>(a.min.y, b.min.y);
result.max.x = std::max<ushort>(a.max.x, b.max.x);
result.max.y = std::max<ushort>(a.max.y, b.max.y);
return result;
}
};
/**
* All information that's needed to draw a ball.
*
* It's necessary to keep this information in float (not half) format,
* because the loss in accuracy would be too high and lead to artifacts.
*/
struct DrawInfo {
float3 ray_center_norm; /** Ray to the ball center, normalized. */
/** Ball color.
*
* This might be the full color in the case of n_channels <= 3. Otherwise,
* a pointer to the original 'color' data is stored in the following union.
*/
float first_color;
union {
float color[2];
float* ptr;
} color_union;
float t_center; /** Distance from the camera to the ball center. */
float radius; /** Ball radius. */
};
static_assert(
sizeof(DrawInfo) == 8 * 4,
"The compiled size of `DrawInfo` is wrong.");
/**
* An object to collect all associated data with the renderer.
*
* The `_d` suffixed pointers point to memory 'on-device', potentially on the
* GPU. All other variables are expected to point to CPU memory.
*/
struct Renderer {
/** Dummy initializer to make sure all pointers are set to NULL to
* be safe for the device-specific 'construct' and 'destruct' methods.
*/
inline Renderer() {
max_num_balls = 0;
result_d = NULL;
min_depth_d = NULL;
min_depth_sorted_d = NULL;
ii_d = NULL;
ii_sorted_d = NULL;
ids_d = NULL;
ids_sorted_d = NULL;
workspace_d = NULL;
di_d = NULL;
di_sorted_d = NULL;
region_flags_d = NULL;
num_selected_d = NULL;
forw_info_d = NULL;
grad_pos_d = NULL;
grad_col_d = NULL;
grad_rad_d = NULL;
grad_cam_d = NULL;
grad_opy_d = NULL;
grad_cam_buf_d = NULL;
n_grad_contributions_d = NULL;
};
/** The camera for this renderer. In world-coordinates. */
CamInfo cam;
/**
* The maximum amount of balls the renderer can handle. Resources are
* pre-allocated to account for this size. Less than this amount of balls
* can be rendered, but not more.
*/
int max_num_balls;
/** The result buffer. */
float* result_d;
/** Closest possible intersection depth per sphere w.r.t. the camera. */
float* min_depth_d;
/** Closest possible intersection depth per sphere, ordered ascending. */
float* min_depth_sorted_d;
/** The intersect infos per sphere. */
IntersectInfo* ii_d;
/** The intersect infos per sphere, ordered by their closest possible
* intersection depth (asc.). */
IntersectInfo* ii_sorted_d;
/** Original sphere IDs. */
int* ids_d;
/** Original sphere IDs, ordered by their closest possible intersection depth
* (asc.). */
int* ids_sorted_d;
/** Workspace for CUB routines. */
char* workspace_d;
/** Workspace size for CUB routines. */
size_t workspace_size;
/** The draw information structures for each sphere. */
DrawInfo* di_d;
/** The draw information structures sorted by closest possible intersection
* depth (asc.). */
DrawInfo* di_sorted_d;
/** Region association buffer. */
char* region_flags_d;
/** Num spheres in the current region. */
size_t* num_selected_d;
/** Pointer to information from the forward pass. */
float* forw_info_d;
/** Struct containing information about the min max pixels that contain
* rendered information in the image. */
IntersectInfo* min_max_pixels_d;
/** Gradients w.r.t. position. */
float3* grad_pos_d;
/** Gradients w.r.t. color. */
float* grad_col_d;
/** Gradients w.r.t. radius. */
float* grad_rad_d;
/** Gradients w.r.t. camera parameters. */
float* grad_cam_d;
/** Gradients w.r.t. opacity. */
float* grad_opy_d;
/** Camera gradient information by sphere.
*
* Here, every sphere's contribution to the camera gradients is stored. It is
* aggregated and written to grad_cam_d in a separate step. This avoids write
* conflicts when processing the spheres.
*/
CamGradInfo* grad_cam_buf_d;
/** Total of all gradient contributions for this image. */
int* n_grad_contributions_d;
/** The number of spheres to track for backpropagation. */
int n_track;
};
inline bool operator==(const Renderer& a, const Renderer& b) {
return a.cam == b.cam && a.max_num_balls == b.max_num_balls;
}
/**
* Construct a renderer.
*/
template <bool DEV>
void construct(
Renderer* self,
const size_t& max_num_balls,
const int& width,
const int& height,
const bool& orthogonal_projection,
const bool& right_handed_system,
const float& background_normalization_depth,
const uint& n_channels,
const uint& n_track);
/**
* Destruct the renderer and free the associated memory.
*/
template <bool DEV>
void destruct(Renderer* self);
/**
* Create a selection of points inside a rectangle.
*
* This write boolen values into `region_flags_d', which can
* for example be used by a CUB function to extract the selection.
*/
template <bool DEV>
GLOBAL void create_selector(
IntersectInfo const* const RESTRICT ii_sorted_d,
const uint num_balls,
const int min_x,
const int max_x,
const int min_y,
const int max_y,
/* Out variables. */
char* RESTRICT region_flags_d);
/**
* Calculate a signature for a ball.
*
* Populate the `ids_d`, `ii_d`, `di_d` and `min_depth_d` fields of the
* renderer. For spheres not visible in the image, sets the id field to -1,
* min_depth_d to MAX_FLOAT and the ii_d.min.x fields to MAX_USHORT.
*/
template <bool DEV>
GLOBAL void calc_signature(
Renderer renderer,
float3 const* const RESTRICT vert_poss,
float const* const RESTRICT vert_cols,
float const* const RESTRICT vert_rads,
const uint num_balls);
/**
* The block size for rendering.
*
* This should be as large as possible, but is limited due to the amount
* of variables we use and the memory required per thread.
*/
#define RENDER_BLOCK_SIZE 16
/**
* The buffer size of spheres to be loaded and analyzed for relevance.
*
* This must be at least RENDER_BLOCK_SIZE * RENDER_BLOCK_SIZE so that
* for every iteration through the loading loop every thread could add a
* 'hit' to the buffer.
*/
#define RENDER_BUFFER_SIZE RENDER_BLOCK_SIZE* RENDER_BLOCK_SIZE * 2
/**
* The threshold after which the spheres that are in the render buffer
* are rendered and the buffer is flushed.
*
* Must be less than RENDER_BUFFER_SIZE.
*/
#define RENDER_BUFFER_LOAD_THRESH 16 * 4
/**
* The render function.
*
* Assumptions:
* * the focal length is appropriately chosen,
* * ray_dir_norm.z is > EPS.
* * to be completed...
*/
template <bool DEV>
GLOBAL void render(
size_t const* const RESTRICT
num_balls, /** Number of balls relevant for this pass. */
IntersectInfo const* const RESTRICT ii_d, /** Intersect information. */
DrawInfo const* const RESTRICT di_d, /** Draw information. */
float const* const RESTRICT min_depth_d, /** Minimum depth per sphere. */
int const* const RESTRICT id_d, /** IDs. */
float const* const RESTRICT op_d, /** Opacity. */
const CamInfo cam_norm, /** Camera normalized with all vectors to be in the
* camera coordinate system.
*/
const float gamma, /** Transparency parameter. **/
const float percent_allowed_difference, /** Maximum allowed
error in color. */
const uint max_n_hits,
const float* bg_col_d,
const uint mode,
const int x_min,
const int y_min,
const int x_step,
const int y_step,
// Out variables.
float* const RESTRICT result_d, /** The result image. */
float* const RESTRICT forw_info_d, /** Additional information needed for the
grad computation. */
// Infrastructure.
const int n_track /** The number of spheres to track. */
);
/**
* Makes sure to paint background information.
*
* This is required as a separate post-processing step because certain
* pixels may not be processed during the forward pass if there is no
* possibility for a sphere to be present at their location.
*/
template <bool DEV>
GLOBAL void fill_bg(
Renderer renderer,
const CamInfo norm,
float const* const bg_col_d,
const float gamma,
const uint mode);
/**
* Rendering forward pass.
*
* Takes a renderer and sphere data as inputs and creates a rendering.
*/
template <bool DEV>
void forward(
Renderer* self,
const float* vert_pos,
const float* vert_col,
const float* vert_rad,
const CamInfo& cam,
const float& gamma,
float percent_allowed_difference,
const uint& max_n_hits,
const float* bg_col_d,
const float* opacity_d,
const size_t& num_balls,
const uint& mode,
cudaStream_t stream);
/**
* Normalize the camera gradients by the number of spheres that contributed.
*/
template <bool DEV>
GLOBAL void norm_cam_gradients(Renderer renderer);
/**
* Normalize the sphere gradients.
*
* We're assuming that the samples originate from a Monte Carlo
* sampling process and normalize by number and sphere area.
*/
template <bool DEV>
GLOBAL void norm_sphere_gradients(Renderer renderer, const int num_balls);
#define GRAD_BLOCK_SIZE 16
/** Calculate the gradients.
*/
template <bool DEV>
GLOBAL void calc_gradients(
const CamInfo cam, /** Camera in world coordinates. */
float const* const RESTRICT grad_im, /** The gradient image. */
const float
gamma, /** The transparency parameter used in the forward pass. */
float3 const* const RESTRICT vert_poss, /** Vertex position vector. */
float const* const RESTRICT vert_cols, /** Vertex color vector. */
float const* const RESTRICT vert_rads, /** Vertex radius vector. */
float const* const RESTRICT opacity, /** Vertex opacity. */
const uint num_balls, /** Number of balls. */
float const* const RESTRICT result_d, /** Result image. */
float const* const RESTRICT forw_info_d, /** Forward pass info. */
DrawInfo const* const RESTRICT di_d, /** Draw information. */
IntersectInfo const* const RESTRICT ii_d, /** Intersect information. */
// Mode switches.
const bool calc_grad_pos,
const bool calc_grad_col,
const bool calc_grad_rad,
const bool calc_grad_cam,
const bool calc_grad_opy,
// Out variables.
float* const RESTRICT grad_rad_d, /** Radius gradients. */
float* const RESTRICT grad_col_d, /** Color gradients. */
float3* const RESTRICT grad_pos_d, /** Position gradients. */
CamGradInfo* const RESTRICT grad_cam_buf_d, /** Camera gradient buffer. */
float* const RESTRICT grad_opy_d, /** Opacity gradient buffer. */
int* const RESTRICT
grad_contributed_d, /** Gradient contribution counter. */
// Infrastructure.
const int n_track,
const uint offs_x = 0,
const uint offs_y = 0);
/**
* A full backward pass.
*
* Creates the gradients for the given gradient_image and the spheres.
*/
template <bool DEV>
void backward(
Renderer* self,
const float* grad_im,
const float* image,
const float* forw_info,
const float* vert_pos,
const float* vert_col,
const float* vert_rad,
const CamInfo& cam,
const float& gamma,
float percent_allowed_difference,
const uint& max_n_hits,
const float* vert_opy,
const size_t& num_balls,
const uint& mode,
const bool& dif_pos,
const bool& dif_col,
const bool& dif_rad,
const bool& dif_cam,
const bool& dif_opy,
cudaStream_t stream);
/**
* A debug backward pass.
*
* This is a function to debug the gradient calculation. It calculates the
* gradients for exactly one pixel (set with pos_x and pos_y) without averaging.
*
* *Uses only the first sphere for camera gradient calculation!*
*/
template <bool DEV>
void backward_dbg(
Renderer* self,
const float* grad_im,
const float* image,
const float* forw_info,
const float* vert_pos,
const float* vert_col,
const float* vert_rad,
const CamInfo& cam,
const float& gamma,
float percent_allowed_difference,
const uint& max_n_hits,
const float* vert_opy,
const size_t& num_balls,
const uint& mode,
const bool& dif_pos,
const bool& dif_col,
const bool& dif_rad,
const bool& dif_cam,
const bool& dif_opy,
const uint& pos_x,
const uint& pos_y,
cudaStream_t stream);
template <bool DEV>
void nn(
const float* ref_ptr,
const float* tar_ptr,
const uint& k,
const uint& d,
const uint& n,
float* dist_ptr,
int32_t* inds_ptr,
cudaStream_t stream);
} // namespace Renderer
} // namespace pulsar
#endif
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_NORM_CAM_GRADIENTS_DEVICE_H_
#define PULSAR_NATIVE_INCLUDE_RENDERER_NORM_CAM_GRADIENTS_DEVICE_H_
#include "../global.h"
#include "./camera.device.h"
#include "./commands.h"
#include "./math.h"
#include "./renderer.h"
namespace pulsar {
namespace Renderer {
/**
* Normalize the camera gradients by the number of spheres that contributed.
*/
template <bool DEV>
GLOBAL void norm_cam_gradients(Renderer renderer) {
GET_PARALLEL_IDX_1D(idx, 1);
CamGradInfo* cgi = reinterpret_cast<CamGradInfo*>(renderer.grad_cam_d);
*cgi = *cgi * FRCP(static_cast<float>(*renderer.n_grad_contributions_d));
END_PARALLEL_NORET();
};
} // namespace Renderer
} // namespace pulsar
#endif
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include "./renderer.norm_cam_gradients.device.h"
namespace pulsar {
namespace Renderer {
template GLOBAL void norm_cam_gradients<ISONDEVICE>(Renderer renderer);
} // namespace Renderer
} // namespace pulsar
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_NORM_SPHERE_GRADIENTS_H_
#define PULSAR_NATIVE_INCLUDE_RENDERER_NORM_SPHERE_GRADIENTS_H_
#include "../global.h"
#include "./commands.h"
#include "./math.h"
#include "./renderer.h"
namespace pulsar {
namespace Renderer {
/**
* Normalize the sphere gradients.
*
* We're assuming that the samples originate from a Monte Carlo
* sampling process and normalize by number and sphere area.
*/
template <bool DEV>
GLOBAL void norm_sphere_gradients(Renderer renderer, const int num_balls) {
GET_PARALLEL_IDX_1D(idx, num_balls);
float norm_fac = 0.f;
IntersectInfo ii;
if (renderer.ids_sorted_d[idx] > 0) {
ii = renderer.ii_d[idx];
// Normalize the sphere gradients as averages.
// This avoids the case that there are small spheres in a scene with still
// un-converged colors whereas the big spheres already converged, just
// because their integrated learning rate is 'higher'.
norm_fac = FRCP(static_cast<float>(renderer.ids_sorted_d[idx]));
}
PULSAR_LOG_DEV_NODE(
PULSAR_LOG_NORMALIZE,
"ids_sorted_d[idx]: %d, norm_fac: %.9f.\n",
renderer.ids_sorted_d[idx],
norm_fac);
renderer.grad_rad_d[idx] *= norm_fac;
for (uint c_idx = 0; c_idx < renderer.cam.n_channels; ++c_idx) {
renderer.grad_col_d[idx * renderer.cam.n_channels + c_idx] *= norm_fac;
}
renderer.grad_pos_d[idx] *= norm_fac;
renderer.grad_opy_d[idx] *= norm_fac;
if (renderer.ids_sorted_d[idx] > 0) {
// For the camera, we need to be more correct and have the gradients
// be proportional to the area they cover in the image.
// This leads to a formulation very much like in monte carlo integration:
norm_fac = FRCP(static_cast<float>(renderer.ids_sorted_d[idx])) *
(static_cast<float>(ii.max.x) - static_cast<float>(ii.min.x)) *
(static_cast<float>(ii.max.y) - static_cast<float>(ii.min.y)) *
1e-3f; // for better numerics.
}
renderer.grad_cam_buf_d[idx].cam_pos *= norm_fac;
renderer.grad_cam_buf_d[idx].pixel_0_0_center *= norm_fac;
renderer.grad_cam_buf_d[idx].pixel_dir_x *= norm_fac;
renderer.grad_cam_buf_d[idx].pixel_dir_y *= norm_fac;
// The sphere only contributes to the camera gradients if it is
// large enough in screen space.
if (renderer.ids_sorted_d[idx] > 0 && ii.max.x >= ii.min.x + 3 &&
ii.max.y >= ii.min.y + 3)
renderer.ids_sorted_d[idx] = 1;
END_PARALLEL_NORET();
};
} // namespace Renderer
} // namespace pulsar
#endif
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include "./renderer.norm_sphere_gradients.device.h"
namespace pulsar {
namespace Renderer {
template GLOBAL void norm_sphere_gradients<ISONDEVICE>(
Renderer renderer,
const int num_balls);
} // namespace Renderer
} // namespace pulsar
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_RENDER_DEVICE_H_
#define PULSAR_NATIVE_INCLUDE_RENDERER_RENDER_DEVICE_H_
#include "../global.h"
#include "./camera.device.h"
#include "./commands.h"
#include "./math.h"
#include "./renderer.h"
#include "./closest_sphere_tracker.device.h"
#include "./renderer.draw.device.h"
namespace pulsar {
namespace Renderer {
template <bool DEV>
GLOBAL void render(
size_t const* const RESTRICT
num_balls, /** Number of balls relevant for this pass. */
IntersectInfo const* const RESTRICT ii_d, /** Intersect information. */
DrawInfo const* const RESTRICT di_d, /** Draw information. */
float const* const RESTRICT min_depth_d, /** Minimum depth per sphere. */
int const* const RESTRICT ids_d, /** IDs. */
float const* const RESTRICT op_d, /** Opacity. */
const CamInfo cam_norm, /** Camera normalized with all vectors to be in the
* camera coordinate system.
*/
const float gamma, /** Transparency parameter. **/
const float percent_allowed_difference, /** Maximum allowed
error in color. */
const uint max_n_hits,
const float* bg_col,
const uint mode,
const int x_min,
const int y_min,
const int x_step,
const int y_step,
// Out variables.
float* const RESTRICT result_d, /** The result image. */
float* const RESTRICT forw_info_d, /** Additional information needed for the
grad computation. */
const int n_track /** The number of spheres to track for backprop. */
) {
// Do not early stop threads in this block here. They can all contribute to
// the scanning process, we just have to prevent from writing their result.
GET_PARALLEL_IDS_2D(offs_x, offs_y, x_step, y_step);
// Variable declarations and const initializations.
const float ln_pad_over_1minuspad =
FLN(percent_allowed_difference / (1.f - percent_allowed_difference));
/** A facility to track the closest spheres to the camera
(in preparation for gradient calculation). */
ClosestSphereTracker tracker(n_track);
const uint coord_x = x_min + offs_x; /** Ray coordinate x. */
const uint coord_y = y_min + offs_y; /** Ray coordinate y. */
float3 ray_dir_norm; /** Ray cast through the pixel, normalized. */
float2 projected_ray; /** Ray intersection with the sensor. */
if (cam_norm.orthogonal_projection) {
ray_dir_norm = cam_norm.sensor_dir_z;
projected_ray.x = static_cast<float>(coord_x);
projected_ray.y = static_cast<float>(coord_y);
} else {
ray_dir_norm = normalize(
cam_norm.pixel_0_0_center + coord_x * cam_norm.pixel_dir_x +
coord_y * cam_norm.pixel_dir_y);
// This is a reasonable assumption for normal focal lengths and image sizes.
PASSERT(FABS(ray_dir_norm.z) > FEPS);
projected_ray.x = ray_dir_norm.x / ray_dir_norm.z * cam_norm.focal_length;
projected_ray.y = ray_dir_norm.y / ray_dir_norm.z * cam_norm.focal_length;
}
PULSAR_LOG_DEV_PIX(
PULSAR_LOG_RENDER_PIX,
"render|ray_dir_norm: %.9f, %.9f, %.9f. projected_ray: %.9f, %.9f.\n",
ray_dir_norm.x,
ray_dir_norm.y,
ray_dir_norm.z,
projected_ray.x,
projected_ray.y);
// Set up shared infrastructure.
/** This entire thread block. */
cg::thread_block thread_block = cg::this_thread_block();
/** The collaborators within a warp. */
cg::coalesced_group thread_warp = cg::coalesced_threads();
/** The number of loaded balls in the load buffer di_l. */
SHARED uint n_loaded;
/** Draw information buffer. */
SHARED DrawInfo di_l[RENDER_BUFFER_SIZE];
/** The original sphere id of each loaded sphere. */
SHARED uint sphere_id_l[RENDER_BUFFER_SIZE];
/** The number of pixels in this block that are done. */
SHARED int n_pixels_done;
/** Whether loading of balls is completed. */
SHARED bool loading_done;
/** The number of balls loaded overall (just for statistics). */
SHARED int n_balls_loaded;
/** The area this thread block covers. */
SHARED IntersectInfo block_area;
if (thread_block.thread_rank() == 0) {
// Initialize the shared variables.
n_loaded = 0;
block_area.min.x = static_cast<ushort>(coord_x);
block_area.max.x = static_cast<ushort>(IMIN(
coord_x + blockDim.x, cam_norm.film_border_left + cam_norm.film_width));
block_area.min.y = static_cast<ushort>(coord_y);
block_area.max.y = static_cast<ushort>(IMIN(
coord_y + blockDim.y, cam_norm.film_border_top + cam_norm.film_height));
n_pixels_done = 0;
loading_done = false;
n_balls_loaded = 0;
}
PULSAR_LOG_DEV_PIX(
PULSAR_LOG_RENDER_PIX,
"render|block_area.min: %d, %d. block_area.max: %d, %d.\n",
block_area.min.x,
block_area.min.y,
block_area.max.x,
block_area.max.y);
// Initialization of the pixel with the background color.
/**
* The result of this very pixel.
* the offset calculation might overflow if this thread is out of
* bounds of the film. However, in this case result is not
* accessed, so this is fine.
*/
float* result = result_d +
(coord_y - cam_norm.film_border_top) * cam_norm.film_width *
cam_norm.n_channels +
(coord_x - cam_norm.film_border_left) * cam_norm.n_channels;
if (coord_x >= cam_norm.film_border_left &&
coord_x < cam_norm.film_border_left + cam_norm.film_width &&
coord_y >= cam_norm.film_border_top &&
coord_y < cam_norm.film_border_top + cam_norm.film_height) {
// Initialize the result.
if (mode == 0u) {
for (uint c_id = 0; c_id < cam_norm.n_channels; ++c_id)
result[c_id] = bg_col[c_id];
} else {
result[0] = 0.f;
}
}
/** Normalization denominator. */
float sm_d = 1.f;
/** Normalization tracker for stable softmax. The maximum observed value. */
float sm_m = cam_norm.background_normalization_depth / gamma;
/** Whether this pixel has had all information needed for drawing. */
bool done =
(coord_x < cam_norm.film_border_left ||
coord_x >= cam_norm.film_border_left + cam_norm.film_width ||
coord_y < cam_norm.film_border_top ||
coord_y >= cam_norm.film_border_top + cam_norm.film_height);
/** The depth threshold for a new point to have at least
* `percent_allowed_difference` influence on the result color. All points that
* are further away than this are ignored.
*/
float depth_threshold = done ? -1.f : MAX_FLOAT;
/** The closest intersection possible of a ball that was hit by this pixel
* ray. */
float max_closest_possible_intersection_hit = -1.f;
bool hit; /** Whether a sphere was hit. */
float intersection_depth; /** The intersection_depth for a sphere at this
pixel. */
float closest_possible_intersection; /** The closest possible intersection
for this sphere. */
float max_closest_possible_intersection;
// Sync up threads so that everyone is similarly initialized.
thread_block.sync();
//! Coalesced loading and intersection analysis of balls.
for (uint ball_idx = thread_block.thread_rank();
ball_idx < iDivCeil(static_cast<uint>(*num_balls), thread_block.size()) *
thread_block.size() &&
!loading_done && n_pixels_done < thread_block.size();
ball_idx += thread_block.size()) {
if (ball_idx < static_cast<uint>(*num_balls)) { // Account for overflow.
const IntersectInfo& ii = ii_d[ball_idx];
hit = (ii.min.x <= block_area.max.x) && (ii.max.x > block_area.min.x) &&
(ii.min.y <= block_area.max.y) && (ii.max.y > block_area.min.y);
if (hit) {
uint write_idx = ATOMICADD_B(&n_loaded, 1u);
di_l[write_idx] = di_d[ball_idx];
sphere_id_l[write_idx] = static_cast<uint>(ids_d[ball_idx]);
PULSAR_LOG_DEV_PIXB(
PULSAR_LOG_RENDER_PIX,
"render|found intersection with sphere %u.\n",
sphere_id_l[write_idx]);
}
if (ii.min.x == MAX_USHORT)
// This is an invalid sphere (out of image). These spheres have
// maximum depth. Since we ordered the spheres by earliest possible
// intersection depth we re certain that there will no other sphere
// that is relevant after this one.
loading_done = true;
}
// Reset n_pixels_done.
n_pixels_done = 0;
thread_block.sync(); // Make sure n_loaded is updated.
if (n_loaded > RENDER_BUFFER_LOAD_THRESH) {
// The load buffer is full enough. Draw.
if (thread_block.thread_rank() == 0)
n_balls_loaded += n_loaded;
max_closest_possible_intersection = 0.f;
// This excludes threads outside of the image boundary. Also, it reduces
// block artifacts.
if (!done) {
for (uint draw_idx = 0; draw_idx < n_loaded; ++draw_idx) {
intersection_depth = 0.f;
if (cam_norm.orthogonal_projection) {
// The closest possible intersection is the distance to the camera
// plane.
closest_possible_intersection = min_depth_d[sphere_id_l[draw_idx]];
} else {
closest_possible_intersection =
di_l[draw_idx].t_center - di_l[draw_idx].radius;
}
PULSAR_LOG_DEV_PIX(
PULSAR_LOG_RENDER_PIX,
"render|drawing sphere %u (depth: %f, "
"closest possible intersection: %f).\n",
sphere_id_l[draw_idx],
di_l[draw_idx].t_center,
closest_possible_intersection);
hit = draw(
di_l[draw_idx], // Sphere to draw.
op_d == NULL ? 1.f : op_d[sphere_id_l[draw_idx]], // Opacity.
cam_norm, // Cam.
gamma, // Gamma.
ray_dir_norm, // Ray direction.
projected_ray, // Ray intersection with the image.
// Mode switches.
true, // Draw.
false,
false,
false,
false,
false, // No gradients.
// Position info.
coord_x,
coord_y,
sphere_id_l[draw_idx],
// Optional in variables.
NULL, // intersect information.
NULL, // ray_dir.
NULL, // norm_ray_dir.
NULL, // grad_pix.
&ln_pad_over_1minuspad,
// in/out variables
&sm_d,
&sm_m,
result,
// Optional out.
&depth_threshold,
&intersection_depth,
NULL,
NULL,
NULL,
NULL,
NULL // gradients.
);
if (hit) {
max_closest_possible_intersection_hit = FMAX(
max_closest_possible_intersection_hit,
closest_possible_intersection);
tracker.track(
sphere_id_l[draw_idx], intersection_depth, coord_x, coord_y);
}
max_closest_possible_intersection = FMAX(
max_closest_possible_intersection, closest_possible_intersection);
}
PULSAR_LOG_DEV_PIX(
PULSAR_LOG_RENDER_PIX,
"render|max_closest_possible_intersection: %f, "
"depth_threshold: %f.\n",
max_closest_possible_intersection,
depth_threshold);
}
done = done ||
(percent_allowed_difference > 0.f &&
max_closest_possible_intersection > depth_threshold) ||
tracker.get_n_hits() >= max_n_hits;
uint warp_done = thread_warp.ballot(done);
if (thread_warp.thread_rank() == 0)
ATOMICADD_B(&n_pixels_done, POPC(warp_done));
// This sync is necessary to keep n_loaded until all threads are done with
// painting.
thread_block.sync();
n_loaded = 0;
}
thread_block.sync();
}
if (thread_block.thread_rank() == 0)
n_balls_loaded += n_loaded;
PULSAR_LOG_DEV_PIX(
PULSAR_LOG_RENDER_PIX,
"render|loaded %d balls in total.\n",
n_balls_loaded);
if (!done) {
for (uint draw_idx = 0; draw_idx < n_loaded; ++draw_idx) {
intersection_depth = 0.f;
if (cam_norm.orthogonal_projection) {
// The closest possible intersection is the distance to the camera
// plane.
closest_possible_intersection = min_depth_d[sphere_id_l[draw_idx]];
} else {
closest_possible_intersection =
di_l[draw_idx].t_center - di_l[draw_idx].radius;
}
PULSAR_LOG_DEV_PIX(
PULSAR_LOG_RENDER_PIX,
"render|drawing sphere %u (depth: %f, "
"closest possible intersection: %f).\n",
sphere_id_l[draw_idx],
di_l[draw_idx].t_center,
closest_possible_intersection);
hit = draw(
di_l[draw_idx], // Sphere to draw.
op_d == NULL ? 1.f : op_d[sphere_id_l[draw_idx]], // Opacity.
cam_norm, // Cam.
gamma, // Gamma.
ray_dir_norm, // Ray direction.
projected_ray, // Ray intersection with the image.
// Mode switches.
true, // Draw.
false,
false,
false,
false,
false, // No gradients.
// Logging info.
coord_x,
coord_y,
sphere_id_l[draw_idx],
// Optional in variables.
NULL, // intersect information.
NULL, // ray_dir.
NULL, // norm_ray_dir.
NULL, // grad_pix.
&ln_pad_over_1minuspad,
// in/out variables
&sm_d,
&sm_m,
result,
// Optional out.
&depth_threshold,
&intersection_depth,
NULL,
NULL,
NULL,
NULL,
NULL // gradients.
);
if (hit) {
max_closest_possible_intersection_hit = FMAX(
max_closest_possible_intersection_hit,
closest_possible_intersection);
tracker.track(
sphere_id_l[draw_idx], intersection_depth, coord_x, coord_y);
}
}
}
if (coord_x < cam_norm.film_border_left ||
coord_y < cam_norm.film_border_top ||
coord_x >= cam_norm.film_border_left + cam_norm.film_width ||
coord_y >= cam_norm.film_border_top + cam_norm.film_height) {
RETURN_PARALLEL();
}
if (mode == 1u) {
// The subtractions, for example coord_y - cam_norm.film_border_left, are
// safe even though both components are uints. We checked their relation
// just above.
result_d
[(coord_y - cam_norm.film_border_top) * cam_norm.film_width *
cam_norm.n_channels +
(coord_x - cam_norm.film_border_left) * cam_norm.n_channels] =
static_cast<float>(tracker.get_n_hits());
} else {
float sm_d_normfac = FRCP(FMAX(sm_d, FEPS));
for (uint c_id = 0; c_id < cam_norm.n_channels; ++c_id)
result[c_id] *= sm_d_normfac;
int write_loc = (coord_y - cam_norm.film_border_top) * cam_norm.film_width *
(3 + 2 * n_track) +
(coord_x - cam_norm.film_border_left) * (3 + 2 * n_track);
forw_info_d[write_loc] = sm_m;
forw_info_d[write_loc + 1] = sm_d;
forw_info_d[write_loc + 2] = max_closest_possible_intersection_hit;
PULSAR_LOG_DEV_PIX(
PULSAR_LOG_RENDER_PIX,
"render|writing the %d most important ball infos.\n",
IMIN(n_track, tracker.get_n_hits()));
for (int i = 0; i < n_track; ++i) {
int sphere_id = tracker.get_closest_sphere_id(i);
IASF(sphere_id, forw_info_d[write_loc + 3 + i * 2]);
forw_info_d[write_loc + 3 + i * 2 + 1] =
tracker.get_closest_sphere_depth(i) == MAX_FLOAT
? -1.f
: tracker.get_closest_sphere_depth(i);
PULSAR_LOG_DEV_PIX(
PULSAR_LOG_RENDER_PIX,
"render|writing %d most important: id: %d, normalized depth: %f.\n",
i,
tracker.get_closest_sphere_id(i),
tracker.get_closest_sphere_depth(i));
}
}
END_PARALLEL_2D();
}
} // namespace Renderer
} // namespace pulsar
#endif
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#ifndef PULSAR_NATIVE_INCLUDE_RENDERER_RENDER_INSTANTIATE_H_
#define PULSAR_NATIVE_INCLUDE_RENDERER_RENDER_INSTANTIATE_H_
#include "./renderer.render.device.h"
namespace pulsar {
namespace Renderer {
template GLOBAL void render<ISONDEVICE>(
size_t const* const RESTRICT
num_balls, /** Number of balls relevant for this pass. */
IntersectInfo const* const RESTRICT ii_d, /** Intersect information. */
DrawInfo const* const RESTRICT di_d, /** Draw information. */
float const* const RESTRICT min_depth_d, /** Minimum depth per sphere. */
int const* const RESTRICT id_d, /** IDs. */
float const* const RESTRICT op_d, /** Opacity. */
const CamInfo cam_norm, /** Camera normalized with all vectors to be in the
* camera coordinate system.
*/
const float gamma, /** Transparency parameter. **/
const float percent_allowed_difference, /** Maximum allowed
error in color. */
const uint max_n_hits,
const float* bg_col_d,
const uint mode,
const int x_min,
const int y_min,
const int x_step,
const int y_step,
// Out variables.
float* const RESTRICT result_d, /** The result image. */
float* const RESTRICT forw_info_d, /** Additional information needed for the
grad computation. */
const int n_track /** The number of spheres to track for backprop. */
);
}
} // namespace pulsar
#endif
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#ifndef PULSAR_LOGGING_H_
#define PULSAR_LOGGING_H_
// #define PULSAR_LOGGING_ENABLED
/**
* Enable detailed per-operation timings.
*
* This timing scheme is not appropriate to measure batched calculations.
* Use `PULSAR_TIMINGS_BATCHED_ENABLED` for that.
*/
// #define PULSAR_TIMINGS_ENABLED
/**
* Time batched operations.
*/
// #define PULSAR_TIMINGS_BATCHED_ENABLED
#if defined(PULSAR_TIMINGS_BATCHED_ENABLED) && defined(PULSAR_TIMINGS_ENABLED)
#pragma message("Pulsar|batched and unbatched timings enabled. This will not")
#pragma message("Pulsar|create meaningful results.")
#endif
#ifdef PULSAR_LOGGING_ENABLED
// Control logging.
// 0: INFO, 1: WARNING, 2: ERROR, 3: FATAL (Abort after logging).
#define CAFFE2_LOG_THRESHOLD 0
#define PULSAR_LOG_INIT false
#define PULSAR_LOG_FORWARD false
#define PULSAR_LOG_CALC_SIGNATURE false
#define PULSAR_LOG_RENDER false
#define PULSAR_LOG_RENDER_PIX false
#define PULSAR_LOG_RENDER_PIX_X 428
#define PULSAR_LOG_RENDER_PIX_Y 669
#define PULSAR_LOG_RENDER_PIX_ALL false
#define PULSAR_LOG_TRACKER_PIX false
#define PULSAR_LOG_TRACKER_PIX_X 428
#define PULSAR_LOG_TRACKER_PIX_Y 669
#define PULSAR_LOG_TRACKER_PIX_ALL false
#define PULSAR_LOG_DRAW_PIX false
#define PULSAR_LOG_DRAW_PIX_X 428
#define PULSAR_LOG_DRAW_PIX_Y 669
#define PULSAR_LOG_DRAW_PIX_ALL false
#define PULSAR_LOG_BACKWARD false
#define PULSAR_LOG_GRAD false
#define PULSAR_LOG_GRAD_X 509
#define PULSAR_LOG_GRAD_Y 489
#define PULSAR_LOG_GRAD_ALL false
#define PULSAR_LOG_NORMALIZE false
#define PULSAR_LOG_NORMALIZE_X 0
#define PULSAR_LOG_NORMALIZE_ALL false
#define PULSAR_LOG_DEV(ID, ...) \
if ((ID)) { \
printf(__VA_ARGS__); \
}
#define PULSAR_LOG_DEV_APIX(ID, MSG, ...) \
if ((ID) && (film_coord_x == (ID##_X) && film_coord_y == (ID##_Y)) || \
ID##_ALL) { \
printf( \
"%u %u (ap %u %u)|" MSG, \
film_coord_x, \
film_coord_y, \
ap_coord_x, \
ap_coord_y, \
__VA_ARGS__); \
}
#define PULSAR_LOG_DEV_PIX(ID, MSG, ...) \
if ((ID) && (coord_x == (ID##_X) && coord_y == (ID##_Y)) || ID##_ALL) { \
printf("%u %u|" MSG, coord_x, coord_y, __VA_ARGS__); \
}
#ifdef __CUDACC__
#define PULSAR_LOG_DEV_PIXB(ID, MSG, ...) \
if ((ID) && static_cast<int>(block_area.min.x) <= (ID##_X) && \
static_cast<int>(block_area.max.x) > (ID##_X) && \
static_cast<int>(block_area.min.y) <= (ID##_Y) && \
static_cast<int>(block_area.max.y) > (ID##_Y)) { \
printf("%u %u|" MSG, coord_x, coord_y, __VA_ARGS__); \
}
#else
#define PULSAR_LOG_DEV_PIXB(ID, MSG, ...) \
if ((ID) && coord_x == (ID##_X) && coord_y == (ID##_Y)) { \
printf("%u %u|" MSG, coord_x, coord_y, __VA_ARGS__); \
}
#endif
#define PULSAR_LOG_DEV_NODE(ID, MSG, ...) \
if ((ID) && idx == (ID##_X) || (ID##_ALL)) { \
printf("%u|" MSG, idx, __VA_ARGS__); \
}
#else
#define CAFFE2_LOG_THRESHOLD 2
#define PULSAR_LOG_RENDER false
#define PULSAR_LOG_INIT false
#define PULSAR_LOG_FORWARD false
#define PULSAR_LOG_BACKWARD false
#define PULSAR_LOG_TRACKER_PIX false
#define PULSAR_LOG_DEV(...)
#define PULSAR_LOG_DEV_APIX(...)
#define PULSAR_LOG_DEV_PIX(...)
#define PULSAR_LOG_DEV_PIXB(...)
#define PULSAR_LOG_DEV_NODE(...)
#endif
#endif
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include "./camera.h"
#include "../include/math.h"
namespace pulsar {
namespace pytorch {
CamInfo cam_info_from_params(
const torch::Tensor& cam_pos,
const torch::Tensor& pixel_0_0_center,
const torch::Tensor& pixel_vec_x,
const torch::Tensor& pixel_vec_y,
const torch::Tensor& principal_point_offset,
const float& focal_length,
const uint& width,
const uint& height,
const float& min_dist,
const float& max_dist,
const bool& right_handed) {
CamInfo res;
fill_cam_vecs(
cam_pos.detach().cpu(),
pixel_0_0_center.detach().cpu(),
pixel_vec_x.detach().cpu(),
pixel_vec_y.detach().cpu(),
principal_point_offset.detach().cpu(),
right_handed,
&res);
res.half_pixel_size = 0.5f * length(res.pixel_dir_x);
if (length(res.pixel_dir_y) * 0.5f - res.half_pixel_size > EPS) {
throw std::runtime_error("Pixel sizes must agree in x and y direction!");
}
res.focal_length = focal_length;
res.aperture_width =
width + 2u * static_cast<uint>(abs(res.principal_point_offset_x));
res.aperture_height =
height + 2u * static_cast<uint>(abs(res.principal_point_offset_y));
res.pixel_0_0_center -=
res.pixel_dir_x * static_cast<float>(abs(res.principal_point_offset_x));
res.pixel_0_0_center -=
res.pixel_dir_y * static_cast<float>(abs(res.principal_point_offset_y));
res.film_width = width;
res.film_height = height;
res.film_border_left =
static_cast<uint>(std::max(0, 2 * res.principal_point_offset_x));
res.film_border_top =
static_cast<uint>(std::max(0, 2 * res.principal_point_offset_y));
LOG_IF(INFO, PULSAR_LOG_INIT)
<< "Aperture width, height: " << res.aperture_width << ", "
<< res.aperture_height;
LOG_IF(INFO, PULSAR_LOG_INIT)
<< "Film width, height: " << res.film_width << ", " << res.film_height;
LOG_IF(INFO, PULSAR_LOG_INIT)
<< "Film border left, top: " << res.film_border_left << ", "
<< res.film_border_top;
res.min_dist = min_dist;
res.max_dist = max_dist;
res.norm_fac = 1.f / (max_dist - min_dist);
return res;
};
} // namespace pytorch
} // namespace pulsar
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#ifndef PULSAR_NATIVE_CAMERA_H_
#define PULSAR_NATIVE_CAMERA_H_
#include <tuple>
#include "../global.h"
#include "../include/camera.h"
namespace pulsar {
namespace pytorch {
inline void fill_cam_vecs(
const torch::Tensor& pos_vec,
const torch::Tensor& pixel_0_0_center,
const torch::Tensor& pixel_dir_x,
const torch::Tensor& pixel_dir_y,
const torch::Tensor& principal_point_offset,
const bool& right_handed,
CamInfo* res) {
res->eye.x = pos_vec.data_ptr<float>()[0];
res->eye.y = pos_vec.data_ptr<float>()[1];
res->eye.z = pos_vec.data_ptr<float>()[2];
res->pixel_0_0_center.x = pixel_0_0_center.data_ptr<float>()[0];
res->pixel_0_0_center.y = pixel_0_0_center.data_ptr<float>()[1];
res->pixel_0_0_center.z = pixel_0_0_center.data_ptr<float>()[2];
res->pixel_dir_x.x = pixel_dir_x.data_ptr<float>()[0];
res->pixel_dir_x.y = pixel_dir_x.data_ptr<float>()[1];
res->pixel_dir_x.z = pixel_dir_x.data_ptr<float>()[2];
res->pixel_dir_y.x = pixel_dir_y.data_ptr<float>()[0];
res->pixel_dir_y.y = pixel_dir_y.data_ptr<float>()[1];
res->pixel_dir_y.z = pixel_dir_y.data_ptr<float>()[2];
auto sensor_dir_z = pixel_dir_y.cross(pixel_dir_x);
sensor_dir_z /= sensor_dir_z.norm();
if (right_handed) {
sensor_dir_z *= -1.f;
}
res->sensor_dir_z.x = sensor_dir_z.data_ptr<float>()[0];
res->sensor_dir_z.y = sensor_dir_z.data_ptr<float>()[1];
res->sensor_dir_z.z = sensor_dir_z.data_ptr<float>()[2];
res->principal_point_offset_x = principal_point_offset.data_ptr<int32_t>()[0];
res->principal_point_offset_y = principal_point_offset.data_ptr<int32_t>()[1];
}
CamInfo cam_info_from_params(
const torch::Tensor& cam_pos,
const torch::Tensor& pixel_0_0_center,
const torch::Tensor& pixel_vec_x,
const torch::Tensor& pixel_vec_y,
const torch::Tensor& principal_point_offset,
const float& focal_length,
const uint& width,
const uint& height,
const float& min_dist,
const float& max_dist,
const bool& right_handed);
} // namespace pytorch
} // namespace pulsar
#endif
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include "./renderer.h"
#include "../include/commands.h"
#include "./camera.h"
#include "./util.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
namespace PRE = ::pulsar::Renderer;
namespace pulsar {
namespace pytorch {
Renderer::Renderer(
const unsigned int& width,
const unsigned int& height,
const unsigned int& max_n_balls,
const bool& orthogonal_projection,
const bool& right_handed_system,
const float& background_normalization_depth,
const uint& n_channels,
const uint& n_track) {
LOG_IF(INFO, PULSAR_LOG_INIT) << "Initializing renderer.";
THArgCheck(width > 0, 1, "image width must be > 0!");
THArgCheck(height > 0, 2, "image height must be > 0!");
THArgCheck(max_n_balls > 0, 3, "max_n_balls must be > 0!");
THArgCheck(
background_normalization_depth > 0.f &&
background_normalization_depth < 1.f,
5,
"background_normalization_depth must be in ]0., 1.[");
THArgCheck(n_channels > 0, 6, "n_channels must be > 0");
THArgCheck(
n_track > 0 && n_track <= MAX_GRAD_SPHERES,
7,
("n_track must be > 0 and <" + std::to_string(MAX_GRAD_SPHERES) +
". Is " + std::to_string(n_track) + ".")
.c_str());
LOG_IF(INFO, PULSAR_LOG_INIT)
<< "Image width: " << width << ", height: " << height;
this->renderer_vec.emplace_back();
this->device_type = c10::DeviceType::CPU;
this->device_index = -1;
PRE::construct<false>(
this->renderer_vec.data(),
max_n_balls,
width,
height,
orthogonal_projection,
right_handed_system,
background_normalization_depth,
n_channels,
n_track);
this->device_tracker = torch::zeros(1);
};
Renderer::~Renderer() {
if (this->device_type == c10::DeviceType::CUDA) {
at::cuda::CUDAGuard device_guard(this->device_tracker.device());
for (auto nrend : this->renderer_vec) {
PRE::destruct<true>(&nrend);
}
} else {
for (auto nrend : this->renderer_vec) {
PRE::destruct<false>(&nrend);
}
}
}
bool Renderer::operator==(const Renderer& rhs) const {
LOG_IF(INFO, PULSAR_LOG_INIT) << "Equality check.";
bool renderer_agrees = (this->renderer_vec[0] == rhs.renderer_vec[0]);
LOG_IF(INFO, PULSAR_LOG_INIT) << " Renderer agrees: " << renderer_agrees;
bool device_agrees =
(this->device_tracker.device() == rhs.device_tracker.device());
LOG_IF(INFO, PULSAR_LOG_INIT) << " Device agrees: " << device_agrees;
return (renderer_agrees && device_agrees);
};
void Renderer::ensure_on_device(torch::Device device, bool /*non_blocking*/) {
THArgCheck(
device.type() == c10::DeviceType::CUDA ||
device.type() == c10::DeviceType::CPU,
1,
"Only CPU and CUDA device types are supported.");
if (device.type() != this->device_type ||
device.index() != this->device_index) {
LOG_IF(INFO, PULSAR_LOG_INIT)
<< "Transferring render buffers between devices.";
int prev_active;
cudaGetDevice(&prev_active);
if (this->device_type == c10::DeviceType::CUDA) {
LOG_IF(INFO, PULSAR_LOG_INIT) << " Destructing on CUDA.";
cudaSetDevice(this->device_index);
for (auto& nrend : this->renderer_vec) {
PRE::destruct<true>(&nrend);
}
} else {
LOG_IF(INFO, PULSAR_LOG_INIT) << " Destructing on CPU.";
for (auto& nrend : this->renderer_vec) {
PRE::destruct<false>(&nrend);
}
}
if (device.type() == c10::DeviceType::CUDA) {
LOG_IF(INFO, PULSAR_LOG_INIT) << " Constructing on CUDA.";
cudaSetDevice(device.index());
for (auto& nrend : this->renderer_vec) {
PRE::construct<true>(
&nrend,
this->renderer_vec[0].max_num_balls,
this->renderer_vec[0].cam.film_width,
this->renderer_vec[0].cam.film_height,
this->renderer_vec[0].cam.orthogonal_projection,
this->renderer_vec[0].cam.right_handed,
this->renderer_vec[0].cam.background_normalization_depth,
this->renderer_vec[0].cam.n_channels,
this->n_track());
}
} else {
LOG_IF(INFO, PULSAR_LOG_INIT) << " Constructing on CPU.";
for (auto& nrend : this->renderer_vec) {
PRE::construct<false>(
&nrend,
this->renderer_vec[0].max_num_balls,
this->renderer_vec[0].cam.film_width,
this->renderer_vec[0].cam.film_height,
this->renderer_vec[0].cam.orthogonal_projection,
this->renderer_vec[0].cam.right_handed,
this->renderer_vec[0].cam.background_normalization_depth,
this->renderer_vec[0].cam.n_channels,
this->n_track());
}
}
cudaSetDevice(prev_active);
this->device_type = device.type();
this->device_index = device.index();
}
};
void Renderer::ensure_n_renderers_gte(const size_t& batch_size) {
if (this->renderer_vec.size() < batch_size) {
ptrdiff_t diff = batch_size - this->renderer_vec.size();
LOG_IF(INFO, PULSAR_LOG_INIT)
<< "Increasing render buffers by " << diff
<< " to account for batch size " << batch_size;
for (ptrdiff_t i = 0; i < diff; ++i) {
this->renderer_vec.emplace_back();
if (this->device_type == c10::DeviceType::CUDA) {
PRE::construct<true>(
&this->renderer_vec[this->renderer_vec.size() - 1],
this->max_num_balls(),
this->width(),
this->height(),
this->renderer_vec[0].cam.orthogonal_projection,
this->renderer_vec[0].cam.right_handed,
this->renderer_vec[0].cam.background_normalization_depth,
this->renderer_vec[0].cam.n_channels,
this->n_track());
} else {
PRE::construct<false>(
&this->renderer_vec[this->renderer_vec.size() - 1],
this->max_num_balls(),
this->width(),
this->height(),
this->renderer_vec[0].cam.orthogonal_projection,
this->renderer_vec[0].cam.right_handed,
this->renderer_vec[0].cam.background_normalization_depth,
this->renderer_vec[0].cam.n_channels,
this->n_track());
}
}
}
}
std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
const torch::Tensor& vert_pos,
const torch::Tensor& vert_col,
const torch::Tensor& vert_radii,
const torch::Tensor& cam_pos,
const torch::Tensor& pixel_0_0_center,
const torch::Tensor& pixel_vec_x,
const torch::Tensor& pixel_vec_y,
const torch::Tensor& focal_length,
const torch::Tensor& principal_point_offsets,
const float& gamma,
const float& max_depth,
float& min_depth,
const c10::optional<torch::Tensor>& bg_col,
const c10::optional<torch::Tensor>& opacity,
const float& percent_allowed_difference,
const uint& max_n_hits,
const uint& mode) {
LOG_IF(INFO, PULSAR_LOG_FORWARD || PULSAR_LOG_BACKWARD) << "Arg check.";
size_t batch_size = 1;
size_t n_points;
bool batch_processing = false;
if (vert_pos.ndimension() == 3) {
// Check all parameters adhere batch size.
batch_processing = true;
batch_size = vert_pos.size(0);
THArgCheck(
vert_col.ndimension() == 3 && vert_col.size(0) == batch_size,
2,
"vert_col needs to have batch size.");
THArgCheck(
vert_radii.ndimension() == 2 && vert_radii.size(0) == batch_size,
3,
"vert_radii must be specified per batch.");
THArgCheck(
cam_pos.ndimension() == 2 && cam_pos.size(0) == batch_size,
4,
"cam_pos must be specified per batch and have the correct batch size.");
THArgCheck(
pixel_0_0_center.ndimension() == 2 &&
pixel_0_0_center.size(0) == batch_size,
5,
"pixel_0_0_center must be specified per batch.");
THArgCheck(
pixel_vec_x.ndimension() == 2 && pixel_vec_x.size(0) == batch_size,
6,
"pixel_vec_x must be specified per batch.");
THArgCheck(
pixel_vec_y.ndimension() == 2 && pixel_vec_y.size(0) == batch_size,
7,
"pixel_vec_y must be specified per batch.");
THArgCheck(
focal_length.ndimension() == 1 && focal_length.size(0) == batch_size,
8,
"focal_length must be specified per batch.");
THArgCheck(
principal_point_offsets.ndimension() == 2 &&
principal_point_offsets.size(0) == batch_size,
9,
"principal_point_offsets must be specified per batch.");
if (opacity.has_value()) {
THArgCheck(
opacity.value().ndimension() == 2 &&
opacity.value().size(0) == batch_size,
13,
"Opacity needs to be specified batch-wise.");
}
// Check all parameters are for a matching number of points.
n_points = vert_pos.size(1);
THArgCheck(
vert_col.size(1) == n_points,
2,
("The number of points for vertex positions (" +
std::to_string(n_points) + ") and vertex colors (" +
std::to_string(vert_col.size(1)) + ") doesn't agree.")
.c_str());
THArgCheck(
vert_radii.size(1) == n_points,
3,
("The number of points for vertex positions (" +
std::to_string(n_points) + ") and vertex radii (" +
std::to_string(vert_col.size(1)) + ") doesn't agree.")
.c_str());
if (opacity.has_value()) {
THArgCheck(
opacity.value().size(1) == n_points,
13,
"Opacity needs to be specified per point.");
}
// Check all parameters have the correct last dimension size.
THArgCheck(
vert_pos.size(2) == 3,
1,
("Vertex positions must be 3D (have shape " +
std::to_string(vert_pos.size(2)) + ")!")
.c_str());
THArgCheck(
vert_col.size(2) == this->renderer_vec[0].cam.n_channels,
2,
("Vertex colors must have the right number of channels (have shape " +
std::to_string(vert_col.size(2)) + ", need " +
std::to_string(this->renderer_vec[0].cam.n_channels) + ")!")
.c_str());
THArgCheck(
cam_pos.size(1) == 3,
4,
("Camera position must be 3D (has shape " +
std::to_string(cam_pos.size(1)) + ")!")
.c_str());
THArgCheck(
pixel_0_0_center.size(1) == 3,
5,
("pixel_0_0_center must be 3D (has shape " +
std::to_string(pixel_0_0_center.size(1)) + ")!")
.c_str());
THArgCheck(
pixel_vec_x.size(1) == 3,
6,
("pixel_vec_x must be 3D (has shape " +
std::to_string(pixel_vec_x.size(1)) + ")!")
.c_str());
THArgCheck(
pixel_vec_y.size(1) == 3,
7,
("pixel_vec_y must be 3D (has shape " +
std::to_string(pixel_vec_y.size(1)) + ")!")
.c_str());
THArgCheck(
principal_point_offsets.size(1) == 2,
9,
"principal_point_offsets must contain x and y offsets.");
// Ensure enough renderers are available for the batch.
ensure_n_renderers_gte(batch_size);
} else {
// Check all parameters are of correct dimension.
THArgCheck(
vert_col.ndimension() == 2, 2, "vert_col needs to have dimension 2.");
THArgCheck(
vert_radii.ndimension() == 1, 3, "vert_radii must have dimension 1.");
THArgCheck(cam_pos.ndimension() == 1, 4, "cam_pos must have dimension 1.");
THArgCheck(
pixel_0_0_center.ndimension() == 1,
5,
"pixel_0_0_center must have dimension 1.");
THArgCheck(
pixel_vec_x.ndimension() == 1, 6, "pixel_vec_x must have dimension 1.");
THArgCheck(
pixel_vec_y.ndimension() == 1, 7, "pixel_vec_y must have dimension 1.");
THArgCheck(
focal_length.ndimension() == 0,
8,
"focal_length must have dimension 0.");
THArgCheck(
principal_point_offsets.ndimension() == 1,
9,
"principal_point_offsets must have dimension 1.");
if (opacity.has_value()) {
THArgCheck(
opacity.value().ndimension() == 1,
13,
"Opacity needs to be specified per sample.");
}
// Check each.
n_points = vert_pos.size(0);
THArgCheck(
vert_col.size(0) == n_points,
2,
("The number of points for vertex positions (" +
std::to_string(n_points) + ") and vertex colors (" +
std::to_string(vert_col.size(0)) + ") doesn't agree.")
.c_str());
THArgCheck(
vert_radii.size(0) == n_points,
3,
("The number of points for vertex positions (" +
std::to_string(n_points) + ") and vertex radii (" +
std::to_string(vert_col.size(0)) + ") doesn't agree.")
.c_str());
if (opacity.has_value()) {
THArgCheck(
opacity.value().size(0) == n_points,
12,
"Opacity needs to be specified per point.");
}
// Check all parameters have the correct last dimension size.
THArgCheck(
vert_pos.size(1) == 3,
1,
("Vertex positions must be 3D (have shape " +
std::to_string(vert_pos.size(1)) + ")!")
.c_str());
THArgCheck(
vert_col.size(1) == this->renderer_vec[0].cam.n_channels,
2,
("Vertex colors must have the right number of channels (have shape " +
std::to_string(vert_col.size(1)) + ", need " +
std::to_string(this->renderer_vec[0].cam.n_channels) + ")!")
.c_str());
THArgCheck(
cam_pos.size(0) == 3,
4,
("Camera position must be 3D (has shape " +
std::to_string(cam_pos.size(0)) + ")!")
.c_str());
THArgCheck(
pixel_0_0_center.size(0) == 3,
5,
("pixel_0_0_center must be 3D (has shape " +
std::to_string(pixel_0_0_center.size(0)) + ")!")
.c_str());
THArgCheck(
pixel_vec_x.size(0) == 3,
6,
("pixel_vec_x must be 3D (has shape " +
std::to_string(pixel_vec_x.size(0)) + ")!")
.c_str());
THArgCheck(
pixel_vec_y.size(0) == 3,
7,
("pixel_vec_y must be 3D (has shape " +
std::to_string(pixel_vec_y.size(0)) + ")!")
.c_str());
THArgCheck(
principal_point_offsets.size(0) == 2,
9,
"principal_point_offsets must have x and y component.");
}
// Check device placement.
auto dev = torch::device_of(vert_pos).value();
THArgCheck(
dev.type() == this->device_type && dev.index() == this->device_index,
1,
("Vertex positions must be stored on device " +
c10::DeviceTypeName(this->device_type) + ", index " +
std::to_string(this->device_index) + "! Are stored on " +
c10::DeviceTypeName(dev.type()) + ", index " +
std::to_string(dev.index()) + ".")
.c_str());
dev = torch::device_of(vert_col).value();
THArgCheck(
dev.type() == this->device_type && dev.index() == this->device_index,
2,
("Vertex colors must be stored on device " +
c10::DeviceTypeName(this->device_type) + ", index " +
std::to_string(this->device_index) + "! Are stored on " +
c10::DeviceTypeName(dev.type()) + ", index " +
std::to_string(dev.index()) + ".")
.c_str());
dev = torch::device_of(vert_radii).value();
THArgCheck(
dev.type() == this->device_type && dev.index() == this->device_index,
3,
("Vertex radii must be stored on device " +
c10::DeviceTypeName(this->device_type) + ", index " +
std::to_string(this->device_index) + "! Are stored on " +
c10::DeviceTypeName(dev.type()) + ", index " +
std::to_string(dev.index()) + ".")
.c_str());
dev = torch::device_of(cam_pos).value();
THArgCheck(
dev.type() == this->device_type && dev.index() == this->device_index,
4,
("Camera position must be stored on device " +
c10::DeviceTypeName(this->device_type) + ", index " +
std::to_string(this->device_index) + "! Are stored on " +
c10::DeviceTypeName(dev.type()) + ", index " +
std::to_string(dev.index()) + ".")
.c_str());
dev = torch::device_of(pixel_0_0_center).value();
THArgCheck(
dev.type() == this->device_type && dev.index() == this->device_index,
5,
("pixel_0_0_center must be stored on device " +
c10::DeviceTypeName(this->device_type) + ", index " +
std::to_string(this->device_index) + "! Are stored on " +
c10::DeviceTypeName(dev.type()) + ", index " +
std::to_string(dev.index()) + ".")
.c_str());
dev = torch::device_of(pixel_vec_x).value();
THArgCheck(
dev.type() == this->device_type && dev.index() == this->device_index,
6,
("pixel_vec_x must be stored on device " +
c10::DeviceTypeName(this->device_type) + ", index " +
std::to_string(this->device_index) + "! Are stored on " +
c10::DeviceTypeName(dev.type()) + ", index " +
std::to_string(dev.index()) + ".")
.c_str());
dev = torch::device_of(pixel_vec_y).value();
THArgCheck(
dev.type() == this->device_type && dev.index() == this->device_index,
7,
("pixel_vec_y must be stored on device " +
c10::DeviceTypeName(this->device_type) + ", index " +
std::to_string(this->device_index) + "! Are stored on " +
c10::DeviceTypeName(dev.type()) + ", index " +
std::to_string(dev.index()) + ".")
.c_str());
dev = torch::device_of(principal_point_offsets).value();
THArgCheck(
dev.type() == this->device_type && dev.index() == this->device_index,
9,
("principal_point_offsets must be stored on device " +
c10::DeviceTypeName(this->device_type) + ", index " +
std::to_string(this->device_index) + "! Are stored on " +
c10::DeviceTypeName(dev.type()) + ", index " +
std::to_string(dev.index()) + ".")
.c_str());
if (opacity.has_value()) {
dev = torch::device_of(opacity.value()).value();
THArgCheck(
dev.type() == this->device_type && dev.index() == this->device_index,
13,
("opacity must be stored on device " +
c10::DeviceTypeName(this->device_type) + ", index " +
std::to_string(this->device_index) + "! Is stored on " +
c10::DeviceTypeName(dev.type()) + ", index " +
std::to_string(dev.index()) + ".")
.c_str());
}
// Type checks.
THArgCheck(
vert_pos.scalar_type() == c10::kFloat, 1, "pulsar requires float types.");
THArgCheck(
vert_col.scalar_type() == c10::kFloat, 2, "pulsar requires float types.");
THArgCheck(
vert_radii.scalar_type() == c10::kFloat,
3,
"pulsar requires float types.");
THArgCheck(
cam_pos.scalar_type() == c10::kFloat, 4, "pulsar requires float types.");
THArgCheck(
pixel_0_0_center.scalar_type() == c10::kFloat,
5,
"pulsar requires float types.");
THArgCheck(
pixel_vec_x.scalar_type() == c10::kFloat,
6,
"pulsar requires float types.");
THArgCheck(
pixel_vec_y.scalar_type() == c10::kFloat,
7,
"pulsar requires float types.");
THArgCheck(
focal_length.scalar_type() == c10::kFloat,
8,
"pulsar requires float types.");
THArgCheck(
// Unfortunately, the PyTorch interface is inconsistent for
// Int32: in Python, there exists an explicit int32 type, in
// C++ this is currently `c10::kInt`.
principal_point_offsets.scalar_type() == c10::kInt,
9,
"principal_point_offsets must be provided as int32.");
if (opacity.has_value()) {
THArgCheck(
opacity.value().scalar_type() == c10::kFloat,
13,
"opacity must be a float type.");
}
// Content checks.
THArgCheck(
(vert_radii > FEPS).all().item<bool>(),
3,
("Vertex radii must be > FEPS (min is " +
std::to_string(vert_radii.min().item<float>()) + ").")
.c_str());
if (this->orthogonal()) {
THArgCheck(
(focal_length == 0.f).all().item<bool>(),
8,
("for an orthogonal projection focal length must be zero (abs max: " +
std::to_string(focal_length.abs().max().item<float>()) + ").")
.c_str());
} else {
THArgCheck(
(focal_length > FEPS).all().item<bool>(),
8,
("for a perspective projection focal length must be > FEPS (min " +
std::to_string(focal_length.min().item<float>()) + ").")
.c_str());
}
THArgCheck(
gamma <= 1.f && gamma >= 1E-5f,
10,
("gamma must be in [1E-5, 1] (" + std::to_string(gamma) + ").").c_str());
if (min_depth == 0.f) {
min_depth = focal_length.max().item<float>() + 2.f * FEPS;
}
THArgCheck(
min_depth > focal_length.max().item<float>(),
12,
("min_depth must be > focal_length (" + std::to_string(min_depth) +
" vs. " + std::to_string(focal_length.max().item<float>()) + ").")
.c_str());
THArgCheck(
max_depth > min_depth + FEPS,
11,
("max_depth must be > min_depth + FEPS (" + std::to_string(max_depth) +
" vs. " + std::to_string(min_depth + FEPS) + ").")
.c_str());
THArgCheck(
percent_allowed_difference >= 0.f && percent_allowed_difference < 1.f,
14,
("percent_allowed_difference must be in [0., 1.[ (" +
std::to_string(percent_allowed_difference) + ").")
.c_str());
THArgCheck(max_n_hits > 0, 14, "max_n_hits must be > 0!");
THArgCheck(mode < 2, 15, "mode must be in {0, 1}.");
torch::Tensor real_bg_col;
if (bg_col.has_value()) {
THArgCheck(
bg_col.value().device().type() == this->device_type &&
bg_col.value().device().index() == this->device_index,
13,
"bg_col must be stored on the renderer device!");
THArgCheck(
bg_col.value().ndimension() == 1 &&
bg_col.value().size(0) == renderer_vec[0].cam.n_channels,
13,
"bg_col must have the same number of channels as the image,).");
real_bg_col = bg_col.value();
} else {
real_bg_col = torch::ones(
{renderer_vec[0].cam.n_channels},
c10::Device(this->device_type, this->device_index))
.to(c10::kFloat);
}
if (opacity.has_value()) {
THArgCheck(
(opacity.value() >= 0.f).all().item<bool>(),
13,
"opacity must be >= 0.");
THArgCheck(
(opacity.value() <= 1.f).all().item<bool>(),
13,
"opacity must be <= 1.");
}
LOG_IF(INFO, PULSAR_LOG_FORWARD || PULSAR_LOG_BACKWARD)
<< " batch_size: " << batch_size;
LOG_IF(INFO, PULSAR_LOG_FORWARD || PULSAR_LOG_BACKWARD)
<< " n_points: " << n_points;
LOG_IF(INFO, PULSAR_LOG_FORWARD || PULSAR_LOG_BACKWARD)
<< " batch_processing: " << batch_processing;
return std::tuple<size_t, size_t, bool, torch::Tensor>(
batch_size, n_points, batch_processing, real_bg_col);
}
std::tuple<torch::Tensor, torch::Tensor> Renderer::forward(
const torch::Tensor& vert_pos,
const torch::Tensor& vert_col,
const torch::Tensor& vert_radii,
const torch::Tensor& cam_pos,
const torch::Tensor& pixel_0_0_center,
const torch::Tensor& pixel_vec_x,
const torch::Tensor& pixel_vec_y,
const torch::Tensor& focal_length,
const torch::Tensor& principal_point_offsets,
const float& gamma,
const float& max_depth,
float min_depth,
const c10::optional<torch::Tensor>& bg_col,
const c10::optional<torch::Tensor>& opacity,
const float& percent_allowed_difference,
const uint& max_n_hits,
const uint& mode) {
// Parameter checks.
this->ensure_on_device(this->device_tracker.device());
size_t batch_size;
size_t n_points;
bool batch_processing;
torch::Tensor real_bg_col;
std::tie(batch_size, n_points, batch_processing, real_bg_col) =
this->arg_check(
vert_pos,
vert_col,
vert_radii,
cam_pos,
pixel_0_0_center,
pixel_vec_x,
pixel_vec_y,
focal_length,
principal_point_offsets,
gamma,
max_depth,
min_depth,
bg_col,
opacity,
percent_allowed_difference,
max_n_hits,
mode);
LOG_IF(INFO, PULSAR_LOG_FORWARD) << "Extracting camera objects...";
// Create the camera information.
std::vector<CamInfo> cam_infos(batch_size);
if (batch_processing) {
for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) {
cam_infos[batch_i] = cam_info_from_params(
cam_pos[batch_i],
pixel_0_0_center[batch_i],
pixel_vec_x[batch_i],
pixel_vec_y[batch_i],
principal_point_offsets[batch_i],
focal_length[batch_i].item<float>(),
this->renderer_vec[0].cam.film_width,
this->renderer_vec[0].cam.film_height,
min_depth,
max_depth,
this->renderer_vec[0].cam.right_handed);
}
} else {
cam_infos[0] = cam_info_from_params(
cam_pos,
pixel_0_0_center,
pixel_vec_x,
pixel_vec_y,
principal_point_offsets,
focal_length.item<float>(),
this->renderer_vec[0].cam.film_width,
this->renderer_vec[0].cam.film_height,
min_depth,
max_depth,
this->renderer_vec[0].cam.right_handed);
}
LOG_IF(INFO, PULSAR_LOG_FORWARD) << "Processing...";
// Let's go!
// Contiguous version of opacity, if available. We need to create this object
// in scope to keep it alive.
torch::Tensor opacity_contiguous;
float const* opacity_ptr = nullptr;
if (opacity.has_value()) {
opacity_contiguous = opacity.value().contiguous();
opacity_ptr = opacity_contiguous.data_ptr<float>();
}
if (this->device_type == c10::DeviceType::CUDA) {
int prev_active;
cudaGetDevice(&prev_active);
cudaSetDevice(this->device_index);
#ifdef PULSAR_TIMINGS_BATCHED_ENABLED
START_TIME_CU(batch_forward);
#endif
if (batch_processing) {
for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) {
// These calls are non-blocking and just kick off the computations.
PRE::forward<true>(
&this->renderer_vec[batch_i],
vert_pos[batch_i].contiguous().data_ptr<float>(),
vert_col[batch_i].contiguous().data_ptr<float>(),
vert_radii[batch_i].contiguous().data_ptr<float>(),
cam_infos[batch_i],
gamma,
percent_allowed_difference,
max_n_hits,
real_bg_col.contiguous().data_ptr<float>(),
opacity_ptr,
n_points,
mode,
at::cuda::getCurrentCUDAStream());
}
} else {
PRE::forward<true>(
this->renderer_vec.data(),
vert_pos.contiguous().data_ptr<float>(),
vert_col.contiguous().data_ptr<float>(),
vert_radii.contiguous().data_ptr<float>(),
cam_infos[0],
gamma,
percent_allowed_difference,
max_n_hits,
real_bg_col.contiguous().data_ptr<float>(),
opacity_ptr,
n_points,
mode,
at::cuda::getCurrentCUDAStream());
}
#ifdef PULSAR_TIMINGS_BATCHED_ENABLED
STOP_TIME_CU(batch_forward);
float time_ms;
GET_TIME_CU(batch_forward, &time_ms);
std::cout << "Forward render batched time per example: "
<< time_ms / static_cast<float>(batch_size) << "ms" << std::endl;
#endif
cudaSetDevice(prev_active);
} else {
#ifdef PULSAR_TIMINGS_BATCHED_ENABLED
START_TIME(batch_forward);
#endif
if (batch_processing) {
for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) {
// These calls are non-blocking and just kick off the computations.
PRE::forward<false>(
&this->renderer_vec[batch_i],
vert_pos[batch_i].contiguous().data_ptr<float>(),
vert_col[batch_i].contiguous().data_ptr<float>(),
vert_radii[batch_i].contiguous().data_ptr<float>(),
cam_infos[batch_i],
gamma,
percent_allowed_difference,
max_n_hits,
real_bg_col.contiguous().data_ptr<float>(),
opacity_ptr,
n_points,
mode,
nullptr);
}
} else {
PRE::forward<false>(
this->renderer_vec.data(),
vert_pos.contiguous().data_ptr<float>(),
vert_col.contiguous().data_ptr<float>(),
vert_radii.contiguous().data_ptr<float>(),
cam_infos[0],
gamma,
percent_allowed_difference,
max_n_hits,
real_bg_col.contiguous().data_ptr<float>(),
opacity_ptr,
n_points,
mode,
nullptr);
}
#ifdef PULSAR_TIMINGS_BATCHED_ENABLED
STOP_TIME(batch_forward);
float time_ms;
GET_TIME(batch_forward, &time_ms);
std::cout << "Forward render batched time per example: "
<< time_ms / static_cast<float>(batch_size) << "ms" << std::endl;
#endif
}
LOG_IF(INFO, PULSAR_LOG_FORWARD) << "Extracting results...";
// Create the results.
std::vector<torch::Tensor> results(batch_size);
std::vector<torch::Tensor> forw_infos(batch_size);
for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) {
results[batch_i] = from_blob(
this->renderer_vec[batch_i].result_d,
{this->renderer_vec[0].cam.film_height,
this->renderer_vec[0].cam.film_width,
this->renderer_vec[0].cam.n_channels},
this->device_type,
this->device_index,
torch::kFloat,
this->device_type == c10::DeviceType::CUDA
? at::cuda::getCurrentCUDAStream()
: (cudaStream_t) nullptr);
if (mode == 1)
results[batch_i] = results[batch_i].slice(2, 0, 1, 1);
forw_infos[batch_i] = from_blob(
this->renderer_vec[batch_i].forw_info_d,
{this->renderer_vec[0].cam.film_height,
this->renderer_vec[0].cam.film_width,
3 + 2 * this->n_track()},
this->device_type,
this->device_index,
torch::kFloat,
this->device_type == c10::DeviceType::CUDA
? at::cuda::getCurrentCUDAStream()
: (cudaStream_t) nullptr);
}
LOG_IF(INFO, PULSAR_LOG_FORWARD) << "Forward render complete.";
if (batch_processing) {
return std::tuple<torch::Tensor, torch::Tensor>(
torch::stack(results), torch::stack(forw_infos));
} else {
return std::tuple<torch::Tensor, torch::Tensor>(results[0], forw_infos[0]);
}
};
std::tuple<
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>>
Renderer::backward(
const torch::Tensor& grad_im,
const torch::Tensor& image,
const torch::Tensor& forw_info,
const torch::Tensor& vert_pos,
const torch::Tensor& vert_col,
const torch::Tensor& vert_radii,
const torch::Tensor& cam_pos,
const torch::Tensor& pixel_0_0_center,
const torch::Tensor& pixel_vec_x,
const torch::Tensor& pixel_vec_y,
const torch::Tensor& focal_length,
const torch::Tensor& principal_point_offsets,
const float& gamma,
const float& max_depth,
float min_depth,
const c10::optional<torch::Tensor>& bg_col,
const c10::optional<torch::Tensor>& opacity,
const float& percent_allowed_difference,
const uint& max_n_hits,
const uint& mode,
const bool& dif_pos,
const bool& dif_col,
const bool& dif_rad,
const bool& dif_cam,
const bool& dif_opy,
const at::optional<std::pair<uint, uint>>& dbg_pos) {
this->ensure_on_device(this->device_tracker.device());
size_t batch_size;
size_t n_points;
bool batch_processing;
torch::Tensor real_bg_col;
std::tie(batch_size, n_points, batch_processing, real_bg_col) =
this->arg_check(
vert_pos,
vert_col,
vert_radii,
cam_pos,
pixel_0_0_center,
pixel_vec_x,
pixel_vec_y,
focal_length,
principal_point_offsets,
gamma,
max_depth,
min_depth,
bg_col,
opacity,
percent_allowed_difference,
max_n_hits,
mode);
// Additional checks for the gradient computation.
THArgCheck(
(grad_im.ndimension() == 3 + batch_processing &&
static_cast<uint>(grad_im.size(0 + batch_processing)) ==
this->height() &&
static_cast<uint>(grad_im.size(1 + batch_processing)) == this->width() &&
static_cast<uint>(grad_im.size(2 + batch_processing)) ==
this->renderer_vec[0].cam.n_channels),
1,
"The gradient image size is not correct.");
THArgCheck(
(image.ndimension() == 3 + batch_processing &&
static_cast<uint>(image.size(0 + batch_processing)) == this->height() &&
static_cast<uint>(image.size(1 + batch_processing)) == this->width() &&
static_cast<uint>(image.size(2 + batch_processing)) ==
this->renderer_vec[0].cam.n_channels),
2,
"The result image size is not correct.");
THArgCheck(
grad_im.scalar_type() == c10::kFloat,
1,
"The gradient image must be of float type.");
THArgCheck(
image.scalar_type() == c10::kFloat,
2,
"The image must be of float type.");
if (dif_opy) {
THArgCheck(opacity.has_value(), 13, "dif_opy set requires opacity values.");
}
if (batch_processing) {
THArgCheck(
grad_im.size(0) == batch_size,
1,
"Gradient image batch size must agree.");
THArgCheck(image.size(0) == batch_size, 2, "Image batch size must agree.");
THArgCheck(
forw_info.size(0) == batch_size,
3,
"forward info must have batch size.");
}
THArgCheck(
(forw_info.ndimension() == 3 + batch_processing &&
static_cast<uint>(forw_info.size(0 + batch_processing)) ==
this->height() &&
static_cast<uint>(forw_info.size(1 + batch_processing)) ==
this->width() &&
static_cast<uint>(forw_info.size(2 + batch_processing)) ==
3 + 2 * this->n_track()),
3,
"The forward info image size is not correct.");
THArgCheck(
forw_info.scalar_type() == c10::kFloat,
3,
"The forward info must be of float type.");
// Check device.
auto dev = torch::device_of(grad_im).value();
THArgCheck(
dev.type() == this->device_type && dev.index() == this->device_index,
1,
("grad_im must be stored on device " +
c10::DeviceTypeName(this->device_type) + ", index " +
std::to_string(this->device_index) + "! Are stored on " +
c10::DeviceTypeName(dev.type()) + ", index " +
std::to_string(dev.index()) + ".")
.c_str());
dev = torch::device_of(image).value();
THArgCheck(
dev.type() == this->device_type && dev.index() == this->device_index,
2,
("image must be stored on device " +
c10::DeviceTypeName(this->device_type) + ", index " +
std::to_string(this->device_index) + "! Are stored on " +
c10::DeviceTypeName(dev.type()) + ", index " +
std::to_string(dev.index()) + ".")
.c_str());
dev = torch::device_of(forw_info).value();
THArgCheck(
dev.type() == this->device_type && dev.index() == this->device_index,
3,
("forw_info must be stored on device " +
c10::DeviceTypeName(this->device_type) + ", index " +
std::to_string(this->device_index) + "! Are stored on " +
c10::DeviceTypeName(dev.type()) + ", index " +
std::to_string(dev.index()) + ".")
.c_str());
if (dbg_pos.has_value()) {
THArgCheck(
dbg_pos.value().first < this->width() &&
dbg_pos.value().second < this->height(),
23,
"The debug position must be within image bounds.");
}
// Prepare the return value.
std::tuple<
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>>
ret;
if (mode == 1 || (!dif_pos && !dif_col && !dif_rad && !dif_cam)) {
return ret;
}
// Create the camera information.
std::vector<CamInfo> cam_infos(batch_size);
if (batch_processing) {
for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) {
cam_infos[batch_i] = cam_info_from_params(
cam_pos[batch_i],
pixel_0_0_center[batch_i],
pixel_vec_x[batch_i],
pixel_vec_y[batch_i],
principal_point_offsets[batch_i],
focal_length[batch_i].item<float>(),
this->renderer_vec[0].cam.film_width,
this->renderer_vec[0].cam.film_height,
min_depth,
max_depth,
this->renderer_vec[0].cam.right_handed);
}
} else {
cam_infos[0] = cam_info_from_params(
cam_pos,
pixel_0_0_center,
pixel_vec_x,
pixel_vec_y,
principal_point_offsets,
focal_length.item<float>(),
this->renderer_vec[0].cam.film_width,
this->renderer_vec[0].cam.film_height,
min_depth,
max_depth,
this->renderer_vec[0].cam.right_handed);
}
// Let's go!
// Contiguous version of opacity, if available. We need to create this object
// in scope to keep it alive.
torch::Tensor opacity_contiguous;
float const* opacity_ptr = nullptr;
if (opacity.has_value()) {
opacity_contiguous = opacity.value().contiguous();
opacity_ptr = opacity_contiguous.data_ptr<float>();
}
if (this->device_type == c10::DeviceType::CUDA) {
int prev_active;
cudaGetDevice(&prev_active);
cudaSetDevice(this->device_index);
#ifdef PULSAR_TIMINGS_BATCHED_ENABLED
START_TIME_CU(batch_backward);
#endif
if (batch_processing) {
for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) {
// These calls are non-blocking and just kick off the computations.
if (dbg_pos.has_value()) {
PRE::backward_dbg<true>(
&this->renderer_vec[batch_i],
grad_im[batch_i].contiguous().data_ptr<float>(),
image[batch_i].contiguous().data_ptr<float>(),
forw_info[batch_i].contiguous().data_ptr<float>(),
vert_pos[batch_i].contiguous().data_ptr<float>(),
vert_col[batch_i].contiguous().data_ptr<float>(),
vert_radii[batch_i].contiguous().data_ptr<float>(),
cam_infos[batch_i],
gamma,
percent_allowed_difference,
max_n_hits,
opacity_ptr,
n_points,
mode,
dif_pos,
dif_col,
dif_rad,
dif_cam,
dif_opy,
dbg_pos.value().first,
dbg_pos.value().second,
at::cuda::getCurrentCUDAStream());
} else {
PRE::backward<true>(
&this->renderer_vec[batch_i],
grad_im[batch_i].contiguous().data_ptr<float>(),
image[batch_i].contiguous().data_ptr<float>(),
forw_info[batch_i].contiguous().data_ptr<float>(),
vert_pos[batch_i].contiguous().data_ptr<float>(),
vert_col[batch_i].contiguous().data_ptr<float>(),
vert_radii[batch_i].contiguous().data_ptr<float>(),
cam_infos[batch_i],
gamma,
percent_allowed_difference,
max_n_hits,
opacity_ptr,
n_points,
mode,
dif_pos,
dif_col,
dif_rad,
dif_cam,
dif_opy,
at::cuda::getCurrentCUDAStream());
}
}
} else {
if (dbg_pos.has_value()) {
PRE::backward_dbg<true>(
this->renderer_vec.data(),
grad_im.contiguous().data_ptr<float>(),
image.contiguous().data_ptr<float>(),
forw_info.contiguous().data_ptr<float>(),
vert_pos.contiguous().data_ptr<float>(),
vert_col.contiguous().data_ptr<float>(),
vert_radii.contiguous().data_ptr<float>(),
cam_infos[0],
gamma,
percent_allowed_difference,
max_n_hits,
opacity_ptr,
n_points,
mode,
dif_pos,
dif_col,
dif_rad,
dif_cam,
dif_opy,
dbg_pos.value().first,
dbg_pos.value().second,
at::cuda::getCurrentCUDAStream());
} else {
PRE::backward<true>(
this->renderer_vec.data(),
grad_im.contiguous().data_ptr<float>(),
image.contiguous().data_ptr<float>(),
forw_info.contiguous().data_ptr<float>(),
vert_pos.contiguous().data_ptr<float>(),
vert_col.contiguous().data_ptr<float>(),
vert_radii.contiguous().data_ptr<float>(),
cam_infos[0],
gamma,
percent_allowed_difference,
max_n_hits,
opacity_ptr,
n_points,
mode,
dif_pos,
dif_col,
dif_rad,
dif_cam,
dif_opy,
at::cuda::getCurrentCUDAStream());
}
}
cudaSetDevice(prev_active);
#ifdef PULSAR_TIMINGS_BATCHED_ENABLED
STOP_TIME_CU(batch_backward);
float time_ms;
GET_TIME_CU(batch_backward, &time_ms);
std::cout << "Backward render batched time per example: "
<< time_ms / static_cast<float>(batch_size) << "ms" << std::endl;
#endif
} else {
#ifdef PULSAR_TIMINGS_BATCHED_ENABLED
START_TIME(batch_backward);
#endif
if (batch_processing) {
for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) {
// These calls are non-blocking and just kick off the computations.
if (dbg_pos.has_value()) {
PRE::backward_dbg<false>(
&this->renderer_vec[batch_i],
grad_im[batch_i].contiguous().data_ptr<float>(),
image[batch_i].contiguous().data_ptr<float>(),
forw_info[batch_i].contiguous().data_ptr<float>(),
vert_pos[batch_i].contiguous().data_ptr<float>(),
vert_col[batch_i].contiguous().data_ptr<float>(),
vert_radii[batch_i].contiguous().data_ptr<float>(),
cam_infos[batch_i],
gamma,
percent_allowed_difference,
max_n_hits,
opacity_ptr,
n_points,
mode,
dif_pos,
dif_col,
dif_rad,
dif_cam,
dif_opy,
dbg_pos.value().first,
dbg_pos.value().second,
nullptr);
} else {
PRE::backward<false>(
&this->renderer_vec[batch_i],
grad_im[batch_i].contiguous().data_ptr<float>(),
image[batch_i].contiguous().data_ptr<float>(),
forw_info[batch_i].contiguous().data_ptr<float>(),
vert_pos[batch_i].contiguous().data_ptr<float>(),
vert_col[batch_i].contiguous().data_ptr<float>(),
vert_radii[batch_i].contiguous().data_ptr<float>(),
cam_infos[batch_i],
gamma,
percent_allowed_difference,
max_n_hits,
opacity_ptr,
n_points,
mode,
dif_pos,
dif_col,
dif_rad,
dif_cam,
dif_opy,
nullptr);
}
}
} else {
if (dbg_pos.has_value()) {
PRE::backward_dbg<false>(
this->renderer_vec.data(),
grad_im.contiguous().data_ptr<float>(),
image.contiguous().data_ptr<float>(),
forw_info.contiguous().data_ptr<float>(),
vert_pos.contiguous().data_ptr<float>(),
vert_col.contiguous().data_ptr<float>(),
vert_radii.contiguous().data_ptr<float>(),
cam_infos[0],
gamma,
percent_allowed_difference,
max_n_hits,
opacity_ptr,
n_points,
mode,
dif_pos,
dif_col,
dif_rad,
dif_cam,
dif_opy,
dbg_pos.value().first,
dbg_pos.value().second,
nullptr);
} else {
PRE::backward<false>(
this->renderer_vec.data(),
grad_im.contiguous().data_ptr<float>(),
image.contiguous().data_ptr<float>(),
forw_info.contiguous().data_ptr<float>(),
vert_pos.contiguous().data_ptr<float>(),
vert_col.contiguous().data_ptr<float>(),
vert_radii.contiguous().data_ptr<float>(),
cam_infos[0],
gamma,
percent_allowed_difference,
max_n_hits,
opacity_ptr,
n_points,
mode,
dif_pos,
dif_col,
dif_rad,
dif_cam,
dif_opy,
nullptr);
}
}
#ifdef PULSAR_TIMINGS_BATCHED_ENABLED
STOP_TIME(batch_backward);
float time_ms;
GET_TIME(batch_backward, &time_ms);
std::cout << "Backward render batched time per example: "
<< time_ms / static_cast<float>(batch_size) << "ms" << std::endl;
#endif
}
if (dif_pos) {
if (batch_processing) {
std::vector<torch::Tensor> results(batch_size);
for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) {
results[batch_i] = from_blob(
reinterpret_cast<float*>(this->renderer_vec[batch_i].grad_pos_d),
{static_cast<ptrdiff_t>(n_points), 3},
this->device_type,
this->device_index,
torch::kFloat,
this->device_type == c10::DeviceType::CUDA
? at::cuda::getCurrentCUDAStream()
: (cudaStream_t) nullptr);
}
std::get<0>(ret) = torch::stack(results);
} else {
std::get<0>(ret) = from_blob(
reinterpret_cast<float*>(this->renderer_vec[0].grad_pos_d),
{static_cast<ptrdiff_t>(n_points), 3},
this->device_type,
this->device_index,
torch::kFloat,
this->device_type == c10::DeviceType::CUDA
? at::cuda::getCurrentCUDAStream()
: (cudaStream_t) nullptr);
}
}
if (dif_col) {
if (batch_processing) {
std::vector<torch::Tensor> results(batch_size);
for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) {
results[batch_i] = from_blob(
reinterpret_cast<float*>(this->renderer_vec[batch_i].grad_col_d),
{static_cast<ptrdiff_t>(n_points),
this->renderer_vec[0].cam.n_channels},
this->device_type,
this->device_index,
torch::kFloat,
this->device_type == c10::DeviceType::CUDA
? at::cuda::getCurrentCUDAStream()
: (cudaStream_t) nullptr);
}
std::get<1>(ret) = torch::stack(results);
} else {
std::get<1>(ret) = from_blob(
reinterpret_cast<float*>(this->renderer_vec[0].grad_col_d),
{static_cast<ptrdiff_t>(n_points),
this->renderer_vec[0].cam.n_channels},
this->device_type,
this->device_index,
torch::kFloat,
this->device_type == c10::DeviceType::CUDA
? at::cuda::getCurrentCUDAStream()
: (cudaStream_t) nullptr);
}
}
if (dif_rad) {
if (batch_processing) {
std::vector<torch::Tensor> results(batch_size);
for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) {
results[batch_i] = from_blob(
reinterpret_cast<float*>(this->renderer_vec[batch_i].grad_rad_d),
{static_cast<ptrdiff_t>(n_points)},
this->device_type,
this->device_index,
torch::kFloat,
this->device_type == c10::DeviceType::CUDA
? at::cuda::getCurrentCUDAStream()
: (cudaStream_t) nullptr);
}
std::get<2>(ret) = torch::stack(results);
} else {
std::get<2>(ret) = from_blob(
reinterpret_cast<float*>(this->renderer_vec[0].grad_rad_d),
{static_cast<ptrdiff_t>(n_points)},
this->device_type,
this->device_index,
torch::kFloat,
this->device_type == c10::DeviceType::CUDA
? at::cuda::getCurrentCUDAStream()
: (cudaStream_t) nullptr);
}
}
if (dif_cam) {
if (batch_processing) {
std::vector<torch::Tensor> res_p1(batch_size);
std::vector<torch::Tensor> res_p2(batch_size);
std::vector<torch::Tensor> res_p3(batch_size);
std::vector<torch::Tensor> res_p4(batch_size);
for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) {
res_p1[batch_i] = from_blob(
reinterpret_cast<float*>(this->renderer_vec[batch_i].grad_cam_d),
{3},
this->device_type,
this->device_index,
torch::kFloat,
this->device_type == c10::DeviceType::CUDA
? at::cuda::getCurrentCUDAStream()
: (cudaStream_t) nullptr);
res_p2[batch_i] = from_blob(
reinterpret_cast<float*>(
this->renderer_vec[batch_i].grad_cam_d + 3),
{3},
this->device_type,
this->device_index,
torch::kFloat,
this->device_type == c10::DeviceType::CUDA
? at::cuda::getCurrentCUDAStream()
: (cudaStream_t) nullptr);
res_p3[batch_i] = from_blob(
reinterpret_cast<float*>(
this->renderer_vec[batch_i].grad_cam_d + 6),
{3},
this->device_type,
this->device_index,
torch::kFloat,
this->device_type == c10::DeviceType::CUDA
? at::cuda::getCurrentCUDAStream()
: (cudaStream_t) nullptr);
res_p4[batch_i] = from_blob(
reinterpret_cast<float*>(
this->renderer_vec[batch_i].grad_cam_d + 9),
{3},
this->device_type,
this->device_index,
torch::kFloat,
this->device_type == c10::DeviceType::CUDA
? at::cuda::getCurrentCUDAStream()
: (cudaStream_t) nullptr);
}
std::get<3>(ret) = torch::stack(res_p1);
std::get<4>(ret) = torch::stack(res_p2);
std::get<5>(ret) = torch::stack(res_p3);
std::get<6>(ret) = torch::stack(res_p4);
} else {
std::get<3>(ret) = from_blob(
reinterpret_cast<float*>(this->renderer_vec[0].grad_cam_d),
{3},
this->device_type,
this->device_index,
torch::kFloat,
this->device_type == c10::DeviceType::CUDA
? at::cuda::getCurrentCUDAStream()
: (cudaStream_t) nullptr);
std::get<4>(ret) = from_blob(
reinterpret_cast<float*>(this->renderer_vec[0].grad_cam_d + 3),
{3},
this->device_type,
this->device_index,
torch::kFloat,
this->device_type == c10::DeviceType::CUDA
? at::cuda::getCurrentCUDAStream()
: (cudaStream_t) nullptr);
std::get<5>(ret) = from_blob(
reinterpret_cast<float*>(this->renderer_vec[0].grad_cam_d + 6),
{3},
this->device_type,
this->device_index,
torch::kFloat,
this->device_type == c10::DeviceType::CUDA
? at::cuda::getCurrentCUDAStream()
: (cudaStream_t) nullptr);
std::get<6>(ret) = from_blob(
reinterpret_cast<float*>(this->renderer_vec[0].grad_cam_d + 9),
{3},
this->device_type,
this->device_index,
torch::kFloat,
this->device_type == c10::DeviceType::CUDA
? at::cuda::getCurrentCUDAStream()
: (cudaStream_t) nullptr);
}
}
if (dif_opy) {
if (batch_processing) {
std::vector<torch::Tensor> results(batch_size);
for (size_t batch_i = 0; batch_i < batch_size; ++batch_i) {
results[batch_i] = from_blob(
reinterpret_cast<float*>(this->renderer_vec[batch_i].grad_opy_d),
{static_cast<ptrdiff_t>(n_points)},
this->device_type,
this->device_index,
torch::kFloat,
this->device_type == c10::DeviceType::CUDA
? at::cuda::getCurrentCUDAStream()
: (cudaStream_t) nullptr);
}
std::get<7>(ret) = torch::stack(results);
} else {
std::get<7>(ret) = from_blob(
reinterpret_cast<float*>(this->renderer_vec[0].grad_opy_d),
{static_cast<ptrdiff_t>(n_points)},
this->device_type,
this->device_index,
torch::kFloat,
this->device_type == c10::DeviceType::CUDA
? at::cuda::getCurrentCUDAStream()
: (cudaStream_t) nullptr);
}
}
return ret;
};
} // namespace pytorch
} // namespace pulsar
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#ifndef PULSAR_NATIVE_PYTORCH_RENDERER_H_
#define PULSAR_NATIVE_PYTORCH_RENDERER_H_
#include "../global.h"
#include "../include/renderer.h"
namespace pulsar {
namespace pytorch {
struct Renderer {
public:
/**
* Pytorch Pulsar differentiable rendering module.
*/
explicit Renderer(
const unsigned int& width,
const unsigned int& height,
const uint& max_n_balls,
const bool& orthogonal_projection,
const bool& right_handed_system,
const float& background_normalization_depth,
const uint& n_channels,
const uint& n_track);
~Renderer();
std::tuple<torch::Tensor, torch::Tensor> forward(
const torch::Tensor& vert_pos,
const torch::Tensor& vert_col,
const torch::Tensor& vert_radii,
const torch::Tensor& cam_pos,
const torch::Tensor& pixel_0_0_center,
const torch::Tensor& pixel_vec_x,
const torch::Tensor& pixel_vec_y,
const torch::Tensor& focal_length,
const torch::Tensor& principal_point_offsets,
const float& gamma,
const float& max_depth,
float min_depth,
const c10::optional<torch::Tensor>& bg_col,
const c10::optional<torch::Tensor>& opacity,
const float& percent_allowed_difference,
const uint& max_n_hits,
const uint& mode);
std::tuple<
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>,
at::optional<torch::Tensor>>
backward(
const torch::Tensor& grad_im,
const torch::Tensor& image,
const torch::Tensor& forw_info,
const torch::Tensor& vert_pos,
const torch::Tensor& vert_col,
const torch::Tensor& vert_radii,
const torch::Tensor& cam_pos,
const torch::Tensor& pixel_0_0_center,
const torch::Tensor& pixel_vec_x,
const torch::Tensor& pixel_vec_y,
const torch::Tensor& focal_length,
const torch::Tensor& principal_point_offsets,
const float& gamma,
const float& max_depth,
float min_depth,
const c10::optional<torch::Tensor>& bg_col,
const c10::optional<torch::Tensor>& opacity,
const float& percent_allowed_difference,
const uint& max_n_hits,
const uint& mode,
const bool& dif_pos,
const bool& dif_col,
const bool& dif_rad,
const bool& dif_cam,
const bool& dif_opy,
const at::optional<std::pair<uint, uint>>& dbg_pos);
// Infrastructure.
/**
* Ensure that the renderer is placed on this device.
* Is nearly a no-op if the device is correct.
*/
void ensure_on_device(torch::Device device, bool non_blocking = false);
/**
* Ensure that at least n renderers are available.
*/
void ensure_n_renderers_gte(const size_t& batch_size);
/**
* Check the parameters.
*/
std::tuple<size_t, size_t, bool, torch::Tensor> arg_check(
const torch::Tensor& vert_pos,
const torch::Tensor& vert_col,
const torch::Tensor& vert_radii,
const torch::Tensor& cam_pos,
const torch::Tensor& pixel_0_0_center,
const torch::Tensor& pixel_vec_x,
const torch::Tensor& pixel_vec_y,
const torch::Tensor& focal_length,
const torch::Tensor& principal_point_offsets,
const float& gamma,
const float& max_depth,
float& min_depth,
const c10::optional<torch::Tensor>& bg_col,
const c10::optional<torch::Tensor>& opacity,
const float& percent_allowed_difference,
const uint& max_n_hits,
const uint& mode);
bool operator==(const Renderer& rhs) const;
inline friend std::ostream& operator<<(
std::ostream& stream,
const Renderer& self) {
stream << "pulsar::Renderer[";
// Device info.
stream << self.device_type;
if (self.device_index != -1)
stream << ", ID " << self.device_index;
stream << "]";
return stream;
}
inline uint width() const {
return this->renderer_vec[0].cam.film_width;
}
inline uint height() const {
return this->renderer_vec[0].cam.film_height;
}
inline int max_num_balls() const {
return this->renderer_vec[0].max_num_balls;
}
inline bool orthogonal() const {
return this->renderer_vec[0].cam.orthogonal_projection;
}
inline bool right_handed() const {
return this->renderer_vec[0].cam.right_handed;
}
inline uint n_track() const {
return static_cast<uint>(this->renderer_vec[0].n_track);
}
/** A tensor that is registered as a buffer with this Module to track its
* device placement. Unfortunately, pytorch doesn't offer tracking Module
* device placement in a better way as of now.
*/
torch::Tensor device_tracker;
protected:
/** The device type for this renderer. */
c10::DeviceType device_type;
/** The device index for this renderer. */
c10::DeviceIndex device_index;
/** Pointer to the underlying pulsar renderers. */
std::vector<pulsar::Renderer::Renderer> renderer_vec;
};
} // namespace pytorch
} // namespace pulsar
#endif
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime_api.h>
#include <torch/extension.h>
#include "./tensor_util.h"
namespace pulsar {
namespace pytorch {
torch::Tensor sphere_ids_from_result_info_nograd(
const torch::Tensor& forw_info) {
torch::Tensor result = torch::zeros(
{forw_info.size(0),
forw_info.size(1),
forw_info.size(2),
(forw_info.size(3) - 3) / 2},
torch::TensorOptions().device(forw_info.device()).dtype(torch::kInt32));
// Get the relevant slice, contiguous.
torch::Tensor tmp =
forw_info
.slice(
/*dim=*/3, /*start=*/3, /*end=*/forw_info.size(3), /*step=*/2)
.contiguous();
if (forw_info.device().type() == c10::DeviceType::CUDA) {
cudaMemcpyAsync(
result.data_ptr(),
tmp.data_ptr(),
sizeof(uint32_t) * tmp.size(0) * tmp.size(1) * tmp.size(2) *
tmp.size(3),
cudaMemcpyDeviceToDevice,
at::cuda::getCurrentCUDAStream());
} else {
memcpy(
result.data_ptr(),
tmp.data_ptr(),
sizeof(uint32_t) * tmp.size(0) * tmp.size(1) * tmp.size(2) *
tmp.size(3));
}
// `tmp` is freed after this, the memory might get reallocated. However,
// only kernels in the same stream should ever be able to write to this
// memory, which are executed only after the memcpy is complete. That's
// why we can just continue.
return result;
}
} // namespace pytorch
} // namespace pulsar
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#ifndef PULSAR_NATIVE_PYTORCH_TENSOR_UTIL_H_
#define PULSAR_NATIVE_PYTORCH_TENSOR_UTIL_H_
#include <ATen/ATen.h>
namespace pulsar {
namespace pytorch {
torch::Tensor sphere_ids_from_result_info_nograd(
const torch::Tensor& forw_info);
}
} // namespace pulsar
#endif
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <cuda_runtime_api.h>
namespace pulsar {
namespace pytorch {
void cudaDevToDev(
void* trg,
const void* src,
const int& size,
const cudaStream_t& stream) {
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToDevice, stream);
}
void cudaDevToHost(
void* trg,
const void* src,
const int& size,
const cudaStream_t& stream) {
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToHost, stream);
}
} // namespace pytorch
} // namespace pulsar
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#ifndef PULSAR_NATIVE_PYTORCH_UTIL_H_
#define PULSAR_NATIVE_PYTORCH_UTIL_H_
#include <ATen/ATen.h>
#include "../global.h"
namespace pulsar {
namespace pytorch {
void cudaDevToDev(
void* trg,
const void* src,
const int& size,
const cudaStream_t& stream);
void cudaDevToHost(
void* trg,
const void* src,
const int& size,
const cudaStream_t& stream);
/**
* This method takes a memory pointer and wraps it into a pytorch tensor.
*
* This is preferred over `torch::from_blob`, since that requires a CUDA
* managed pointer. However, working with these for high performance
* operations is slower. Most of the rendering operations should stay
* local to the respective GPU anyways, so unmanaged pointers are
* preferred.
*/
template <typename T>
torch::Tensor from_blob(
const T* ptr,
const torch::IntArrayRef& shape,
const c10::DeviceType& device_type,
const c10::DeviceIndex& device_index,
const torch::Dtype& dtype,
const cudaStream_t& stream) {
torch::Tensor ret = torch::zeros(
shape, torch::device({device_type, device_index}).dtype(dtype));
const int num_elements =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>{});
if (device_type == c10::DeviceType::CUDA) {
cudaDevToDev(
ret.data_ptr(),
static_cast<const void*>(ptr),
sizeof(T) * num_elements,
stream);
// TODO: check for synchronization.
} else {
memcpy(ret.data_ptr(), ptr, sizeof(T) * num_elements);
}
return ret;
};
} // namespace pytorch
} // namespace pulsar
#endif
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include "./global.h"
#include "./logging.h"
/**
* A compilation unit to provide warnings about the code and avoid
* repeated messages.
*/
#ifdef PULSAR_ASSERTIONS
#pragma message("WARNING: assertions are enabled in Pulsar.")
#endif
#ifdef PULSAR_LOGGING_ENABLED
#pragma message("WARNING: logging is enabled in Pulsar.")
#endif
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .renderer import Renderer # noqa: F401
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment