"docs/vscode:/vscode.git/clone" did not exist on "9d8ec2e67e36117ac6da0c82e597d6dbf587d578"
Commit ecc1df99 authored by facebook-github-bot's avatar facebook-github-bot
Browse files

Initial commit

fbshipit-source-id: afc575e8e7d8e2796a3f77d8b1c6c4fcb999558d
parents
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <torch/torch.h>
torch::Tensor msi_forward_cuda(
const torch::Tensor& ray_o,
const torch::Tensor& ray_d,
const torch::Tensor& texture,
int64_t sub_step_count,
double min_inv_r,
double max_inv_r,
double stop_thresh);
torch::Tensor msi_backward_cuda(
const torch::Tensor& rgba_img,
const torch::Tensor& rgba_img_grad,
const torch::Tensor& ray_o,
const torch::Tensor& ray_d,
const torch::Tensor& texture,
int64_t sub_step_count,
double min_inv_r,
double max_inv_r,
double stop_thresh);
// Copyright (c) Meta Platforms, Inc. and affiliates.
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include <ATen/autocast_mode.h>
#include <torch/script.h>
#include "msi_kernel.h"
#ifndef NO_PYBIND
#include <torch/extension.h>
#endif
/*
* Renders a Multi-Sphere Image which is similar to the one described in "NeRF++: Analyzing and
* Improving Neural Radiance Fields".
* This file provides python/torch bindings and autograd function implementation.
* It main implementation is in `src/msi/msi_kernel.cu`, see functions `msi_forward_cuda` and
* `msi_backward_cuda`.
* For more details see docstring in drtk/msi.py
*/
// Dispatch function
torch::Tensor msi(
const torch::Tensor& ray_o,
const torch::Tensor& ray_d,
const torch::Tensor& texture,
int64_t sub_step_count,
double min_inv_r,
double max_inv_r,
double stop_thresh) {
static auto op =
torch::Dispatcher::singleton().findSchemaOrThrow("msi_ext::msi", "").typed<decltype(msi)>();
return op.call(ray_o, ray_d, texture, sub_step_count, min_inv_r, max_inv_r, stop_thresh);
}
// Ideally we would need to turn off autograd handling and re-dispatch, but we just call
// cuda kernels directly
class MSIFunction : public torch::autograd::Function<MSIFunction> {
public:
static torch::Tensor forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& ray_o,
const torch::Tensor& ray_d,
const torch::Tensor& texture,
int64_t sub_step_count,
double min_inv_r,
double max_inv_r,
double stop_thresh) {
ctx->set_materialize_grads(false);
std::vector<torch::Tensor> save_list;
save_list.push_back(ray_o);
save_list.push_back(ray_d);
save_list.push_back(texture);
bool requires_grad = texture.requires_grad();
ctx->saved_data["data"] =
std::make_tuple(requires_grad, sub_step_count, min_inv_r, max_inv_r, stop_thresh);
torch::Tensor rgba_img =
msi_forward_cuda(ray_o, ray_d, texture, sub_step_count, min_inv_r, max_inv_r, stop_thresh);
save_list.push_back(rgba_img);
ctx->save_for_backward(save_list);
return rgba_img;
}
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
// rgba_img
torch::autograd::tensor_list grad_outputs) {
bool requires_grad;
int64_t sub_step_count;
double stop_thresh;
double min_inv_r;
double max_inv_r;
std::tie(requires_grad, sub_step_count, min_inv_r, max_inv_r, stop_thresh) =
ctx->saved_data["data"].to<std::tuple<bool, int64_t, double, double, double>>();
torch::autograd::tensor_list grads;
if (!requires_grad) {
grads.resize(7); // 7 - number of arguments of the forward function, see comment below.
return grads;
}
const auto saved = ctx->get_saved_variables();
const auto& ray_o = saved[0];
const auto& ray_d = saved[1];
const auto& texture = saved[2];
const auto& rgba_img = saved[3];
auto rgba_img_grad = grad_outputs[0];
auto texture_grad = msi_backward_cuda(
rgba_img,
rgba_img_grad,
ray_o,
ray_d,
texture,
sub_step_count,
min_inv_r,
max_inv_r,
stop_thresh);
// The output has to be a vector of tensors, the legth of wich must match the number of
// arguments in the forward function. Even if the arhument is not a tensor and can not have a
// gradients.
// We do not compute gradints with respect to ray origin
// or direction, other inputs except `texture` are not tensors (can't have gradints).
// Thus we only provide gradient for the `texture` which is the third argument.
auto output_has_no_grad = torch::Tensor();
grads.push_back(output_has_no_grad);
grads.push_back(output_has_no_grad);
grads.push_back(texture_grad);
grads.push_back(output_has_no_grad);
grads.push_back(output_has_no_grad);
grads.push_back(output_has_no_grad);
grads.push_back(output_has_no_grad);
return grads;
}
};
torch::Tensor msi_autograd(
const torch::Tensor& ray_o,
const torch::Tensor& ray_d,
const torch::Tensor& texture,
int64_t sub_step_count,
double min_inv_r,
double max_inv_r,
double stop_thresh) {
return MSIFunction::apply(
ray_o, ray_d, texture, sub_step_count, min_inv_r, max_inv_r, stop_thresh);
}
torch::Tensor msi_autocast(
const torch::Tensor& ray_o,
const torch::Tensor& ray_d,
const torch::Tensor& texture,
int64_t sub_step_count,
double min_inv_r,
double max_inv_r,
double stop_thresh) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return msi(
at::autocast::cached_cast(torch::kFloat32, ray_o),
at::autocast::cached_cast(torch::kFloat32, ray_d),
at::autocast::cached_cast(torch::kFloat32, texture),
sub_step_count,
min_inv_r,
max_inv_r,
stop_thresh);
}
#ifndef NO_PYBIND
PYBIND11_MODULE(msi_ext, m) {}
#endif
TORCH_LIBRARY(msi_ext, m) {
m.def(
"msi(Tensor ray_o, Tensor ray_d, Tensor texture, "
"int sub_step_count, float min_inv_r, float max_inv_r, float stop_thresh) -> "
"Tensor");
}
TORCH_LIBRARY_IMPL(msi_ext, Autograd, m) {
m.impl("msi", &msi_autograd);
}
TORCH_LIBRARY_IMPL(msi_ext, Autocast, m) {
m.impl("msi", msi_autocast);
}
TORCH_LIBRARY_IMPL(msi_ext, CUDA, m) {
m.impl("msi", &msi_forward_cuda);
}
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#include <c10/cuda/CUDAGuard.h>
#include <cuda_math_helper.h>
#include <grid_utils.h>
#include <torch/types.h>
#include <limits>
#include "rasterize_kernel.h"
#include <kernel_utils.h>
using namespace math;
template <typename scalar_t, typename index_t>
__global__ void rasterize_kernel(
const index_t nthreads,
TensorInfo<scalar_t, index_t> v,
TensorInfo<int32_t, index_t> vi,
TensorInfo<int64_t, index_t> packed_index_depth_img) {
typedef typename math::TVec2<scalar_t> scalar2_t;
typedef typename math::TVec3<scalar_t> scalar3_t;
typedef typename math::TVec4<scalar_t> scalar4_t;
const index_t H = packed_index_depth_img.sizes[1];
const index_t W = packed_index_depth_img.sizes[2];
const index_t V = v.sizes[1];
const index_t n_prim = vi.sizes[0];
const index_t index_sN = packed_index_depth_img.strides[0];
const index_t index_sH = packed_index_depth_img.strides[1];
const index_t index_sW = packed_index_depth_img.strides[2];
const index_t v_sN = v.strides[0];
const index_t v_sV = v.strides[1];
const index_t v_sC = v.strides[2];
const index_t vi_sF = vi.strides[0];
const index_t vi_sI = vi.strides[1];
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
const index_t n = index / n_prim;
const index_t id = index % n_prim;
const int32_t* __restrict vi_ptr = vi.data + vi_sF * id;
const int32_t vi_0 = (int32_t)(((uint32_t)vi_ptr[vi_sI * 0]) & 0x0FFFFFFFU);
const int32_t vi_1 = vi_ptr[vi_sI * 1];
const int32_t vi_2 = vi_ptr[vi_sI * 2];
assert(vi_0 < V && vi_1 < V && vi_2 < V);
const scalar_t* __restrict v_ptr = v.data + n * v_sN;
const scalar2_t p_0 = {v_ptr[v_sV * vi_0 + v_sC * 0], v_ptr[v_sV * vi_0 + v_sC * 1]};
const scalar2_t p_1 = {v_ptr[v_sV * vi_1 + v_sC * 0], v_ptr[v_sV * vi_1 + v_sC * 1]};
const scalar2_t p_2 = {v_ptr[v_sV * vi_2 + v_sC * 0], v_ptr[v_sV * vi_2 + v_sC * 1]};
const scalar3_t p_012_z = {
v_ptr[v_sV * vi_0 + v_sC * 2],
v_ptr[v_sV * vi_1 + v_sC * 2],
v_ptr[v_sV * vi_2 + v_sC * 2]};
const scalar2_t min_p = math::min(math::min(p_0, p_1), p_2);
const scalar2_t max_p = math::max(math::max(p_0, p_1), p_2);
const bool all_z_greater_0 = math::all_greater(p_012_z, {1e-8f, 1e-8f, 1e-8f});
const bool in_canvas = math::all_less_or_eq(min_p, {(scalar_t)(W - 1), (scalar_t)(H - 1)}) &&
math::all_greater(max_p, {0.f, 0.f});
if (all_z_greater_0 && in_canvas) {
const scalar2_t v_01 = p_1 - p_0;
const scalar2_t v_02 = p_2 - p_0;
const scalar2_t v_12 = p_2 - p_1;
const scalar_t denominator = v_01.x * v_02.y - v_01.y * v_02.x;
if (denominator != 0.f) {
// Compute triangle bounds with extra border.
int min_x = max(0, int(min_p.x));
int min_y = max(0, int(min_p.y));
int max_x = min((int)W - 1, int(max_p.x) + 1);
int max_y = min((int)H - 1, int(max_p.y) + 1);
// Loop over pixels inside triangle bbox.
for (int y = min_y; y <= max_y; ++y) {
for (int x = min_x; x <= max_x; ++x) {
const scalar2_t p = {(scalar_t)x, (scalar_t)y};
const scalar2_t vp0p = p - p_0;
const scalar2_t vp1p = p - p_1;
scalar3_t bary = scalar3_t({
vp1p.y * v_12.x - vp1p.x * v_12.y,
vp0p.x * v_02.y - vp0p.y * v_02.x,
vp0p.y * v_01.x - vp0p.x * v_01.y,
});
bary *= sign(denominator);
const bool on_edge_or_inside = (bary.x >= 0.f) && (bary.y >= 0.f) && (bary.z >= 0.f);
bool on_edge_0 = bary.x == 0.f;
bool on_edge_1 = bary.y == 0.f;
bool on_edge_2 = bary.z == 0.f;
const bool is_top_left_0 = (denominator > 0)
? (v_12.y < 0.f || v_12.y == 0.0f && v_12.x > 0.f)
: (v_12.y > 0.f || v_12.y == 0.0f && v_12.x < 0.f);
const bool is_top_left_1 = (denominator > 0)
? (v_02.y > 0.f || v_02.y == 0.0f && v_02.x < 0.f)
: (v_02.y < 0.f || v_02.y == 0.0f && v_02.x > 0.f);
const bool is_top_left_2 = (denominator > 0)
? (v_01.y < 0.f || v_01.y == 0.0f && v_01.x > 0.f)
: (v_01.y > 0.f || v_01.y == 0.0f && v_01.x < 0.f);
const bool is_top_left_or_inside = on_edge_or_inside &&
!(on_edge_0 && !is_top_left_0 || on_edge_1 && !is_top_left_1 ||
on_edge_2 && !is_top_left_2);
if (is_top_left_or_inside) {
bary /= abs(denominator);
// interpolate inverse depth linearly
const scalar3_t d_inv = 1.0 / epsclamp(p_012_z);
const scalar_t depth_inverse = dot(d_inv, bary);
const scalar_t depth = 1.0f / epsclamp(depth_inverse);
const unsigned long long packed_val =
(static_cast<unsigned long long>(__float_as_uint(depth)) << 32u) |
static_cast<unsigned long long>(id);
atomicMin(
reinterpret_cast<unsigned long long*>(packed_index_depth_img.data) +
index_sN * n + index_sH * y + index_sW * x,
packed_val);
}
}
}
}
}
}
}
template <typename scalar_t>
__device__ inline void get_line(
const math::TVec2<scalar_t>& p1,
const math::TVec2<scalar_t>& p2,
scalar_t& a,
scalar_t& b,
scalar_t& c) {
a = p1.y - p2.y;
b = p2.x - p1.x;
c = p1.x * p2.y - p2.x * p1.y;
}
template <typename scalar_t>
__device__ inline bool is_point_in_segment(
const math::TVec2<scalar_t>& p1,
const math::TVec2<scalar_t>& p2,
const math::TVec2<scalar_t>& c) {
return (
(((p2.x >= c.x) && (c.x >= p1.x)) || ((p2.x <= c.x) && (c.x <= p1.x))) &&
(((p2.y >= c.y) && (c.y >= p1.y)) || ((p2.y <= c.y) && (c.y <= p1.y))));
}
template <typename scalar_t>
__device__ inline math::TVec2<scalar_t>
get_cross_point(scalar_t a1, scalar_t b1, scalar_t c1, scalar_t a2, scalar_t b2, scalar_t c2) {
scalar_t d = a1 * b2 - a2 * b1;
if (d == scalar_t(0)) {
return math::TVec2<scalar_t>{std::numeric_limits<scalar_t>().max()};
}
return math::TVec2<scalar_t>{(b1 * c2 - b2 * c1) / d, (a2 * c1 - a1 * c2) / d};
}
template <typename scalar_t>
__device__ inline math::TVec2<scalar_t> get_cross_point(
scalar_t a1,
scalar_t b1,
scalar_t c1,
const math::TVec2<scalar_t>& p1,
const math::TVec2<scalar_t>& p2) {
scalar_t a2 = 1e16;
scalar_t b2 = 1e16;
scalar_t c2 = 1e16;
get_line(p1, p2, a2, b2, c2);
scalar_t d = a1 * b2 - a2 * b1;
if (d == scalar_t(0)) {
return math::TVec2<scalar_t>{std::numeric_limits<scalar_t>().max()};
}
return math::TVec2<scalar_t>{(b1 * c2 - b2 * c1) / d, (a2 * c1 - a1 * c2) / d};
}
template <typename scalar_t>
__device__ inline bool is_crossing_dimond(
const math::TVec2<scalar_t>& p1,
const math::TVec2<scalar_t>& p2,
const math::TVec2<scalar_t>& p) {
scalar_t a0 = 1e16;
scalar_t b0 = 1e16;
scalar_t c0 = 1e16;
get_line(p1, p2, a0, b0, c0);
bool intersecting = false;
{
math::TVec2<scalar_t> s0 = {p.x, p.y - scalar_t(0.5)};
math::TVec2<scalar_t> s1 = {p.x + scalar_t(0.5), p.y};
auto c = get_cross_point(a0, b0, c0, s0, s1);
intersecting |=
is_point_in_segment<scalar_t>(s0, s1, c) && is_point_in_segment<scalar_t>(p1, p2, c);
}
{
math::TVec2<scalar_t> s0 = {p.x + scalar_t(0.5), p.y};
math::TVec2<scalar_t> s1 = {p.x, p.y + scalar_t(0.5)};
auto c = get_cross_point(a0, b0, c0, s0, s1);
intersecting |=
is_point_in_segment<scalar_t>(s0, s1, c) && is_point_in_segment<scalar_t>(p1, p2, c);
}
{
math::TVec2<scalar_t> s0 = {p.x, p.y + scalar_t(0.5)};
math::TVec2<scalar_t> s1 = {p.x - scalar_t(0.5), p.y};
auto c = get_cross_point(a0, b0, c0, s0, s1);
intersecting |=
is_point_in_segment<scalar_t>(s0, s1, c) && is_point_in_segment<scalar_t>(p1, p2, c);
}
{
math::TVec2<scalar_t> s0 = {p.x - scalar_t(0.5), p.y};
math::TVec2<scalar_t> s1 = {p.x, p.y - scalar_t(0.5)};
auto c = get_cross_point(a0, b0, c0, s0, s1);
intersecting |=
is_point_in_segment<scalar_t>(s0, s1, c) && is_point_in_segment<scalar_t>(p1, p2, c);
}
return intersecting;
}
template <typename scalar_t, typename index_t>
__global__ void rasterize_lines_kernel(
const index_t nthreads,
TensorInfo<scalar_t, index_t> v,
TensorInfo<int32_t, index_t> vi,
TensorInfo<int64_t, index_t> packed_index_depth_img) {
typedef typename math::TVec2<scalar_t> scalar2_t;
typedef typename math::TVec3<scalar_t> scalar3_t;
typedef typename math::TVec4<scalar_t> scalar4_t;
const index_t H = packed_index_depth_img.sizes[1];
const index_t W = packed_index_depth_img.sizes[2];
const index_t V = v.sizes[1];
const index_t n_prim = vi.sizes[0];
const index_t index_sN = packed_index_depth_img.strides[0];
const index_t index_sH = packed_index_depth_img.strides[1];
const index_t index_sW = packed_index_depth_img.strides[2];
const index_t v_sN = v.strides[0];
const index_t v_sV = v.strides[1];
const index_t v_sC = v.strides[2];
const index_t vi_sF = vi.strides[0];
const index_t vi_sI = vi.strides[1];
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
const index_t n = index / n_prim;
const index_t id = index % n_prim;
const int32_t* __restrict vi_ptr = vi.data + vi_sF * id;
const int32_t flag = (int32_t)((((uint32_t)vi_ptr[vi_sI * 0] & 0xF0000000U)) >> 28U);
const int32_t vi_0 = (int32_t)(((uint32_t)vi_ptr[vi_sI * 0]) & 0x0FFFFFFFU);
const int32_t vi_1 = vi_ptr[vi_sI * 1];
const int32_t vi_2 = vi_ptr[vi_sI * 2];
const bool edge_0_visible = (flag & 0b00000001) != 0;
const bool edge_1_visible = (flag & 0b00000010) != 0;
const bool edge_2_visible = (flag & 0b00000100) != 0;
assert(vi_0 < V && vi_1 < V && vi_2 < V);
const scalar_t* __restrict v_ptr = v.data + n * v_sN;
const scalar2_t p_0 = {v_ptr[v_sV * vi_0 + v_sC * 0], v_ptr[v_sV * vi_0 + v_sC * 1]};
const scalar2_t p_1 = {v_ptr[v_sV * vi_1 + v_sC * 0], v_ptr[v_sV * vi_1 + v_sC * 1]};
const scalar2_t p_2 = {v_ptr[v_sV * vi_2 + v_sC * 0], v_ptr[v_sV * vi_2 + v_sC * 1]};
const scalar3_t p_012_z = {
v_ptr[v_sV * vi_0 + v_sC * 2],
v_ptr[v_sV * vi_1 + v_sC * 2],
v_ptr[v_sV * vi_2 + v_sC * 2]};
const scalar2_t min_p = math::min(math::min(p_0, p_1), p_2);
const scalar2_t max_p = math::max(math::max(p_0, p_1), p_2);
const bool all_z_greater_0 = math::all_greater(p_012_z, {1e-8f, 1e-8f, 1e-8f});
const bool in_canvas = math::all_less_or_eq(min_p, {(scalar_t)(W - 1), (scalar_t)(H - 1)}) &&
math::all_greater(max_p, {0.f, 0.f});
if (all_z_greater_0 && in_canvas) {
const scalar2_t v_01 = p_1 - p_0;
const scalar2_t v_02 = p_2 - p_0;
const scalar2_t v_12 = p_2 - p_1;
const scalar_t denominator = v_01.x * v_02.y - v_01.y * v_02.x;
if (denominator != 0.f) {
// Compute triangle bounds with extra border.
int min_x = max(1, int(min_p.x) - 2);
int min_y = max(1, int(min_p.y) - 2);
int max_x = min((int)W - 2, int(max_p.x) + 2);
int max_y = min((int)H - 2, int(max_p.y) + 2);
// Loop over pixels inside triangle bbox.
for (int y = min_y; y <= max_y; ++y) {
for (int x = min_x; x <= max_x; ++x) {
const scalar2_t p = {(scalar_t)x, (scalar_t)y};
const scalar2_t vp0p = p - p_0;
const scalar2_t vp1p = p - p_1;
bool intersecting = false;
intersecting |= is_crossing_dimond<scalar_t>(p_0, p_1, p) && edge_0_visible;
intersecting |= is_crossing_dimond<scalar_t>(p_1, p_2, p) && edge_1_visible;
intersecting |= is_crossing_dimond<scalar_t>(p_0, p_2, p) && edge_2_visible;
scalar3_t bary = scalar3_t({
vp1p.y * v_12.x - vp1p.x * v_12.y,
vp0p.x * v_02.y - vp0p.y * v_02.x,
vp0p.y * v_01.x - vp0p.x * v_01.y,
});
bary *= sign(denominator);
const bool on_edge_or_inside = (bary.x >= 0.f) && (bary.y >= 0.f) && (bary.z >= 0.f);
bool on_edge_0 = bary.x == 0.f;
bool on_edge_1 = bary.y == 0.f;
bool on_edge_2 = bary.z == 0.f;
const bool is_top_left_0 = (denominator > 0)
? (v_12.y < 0.f || v_12.y == 0.0f && v_12.x > 0.f)
: (v_12.y > 0.f || v_12.y == 0.0f && v_12.x < 0.f);
const bool is_top_left_1 = (denominator > 0)
? (v_02.y > 0.f || v_02.y == 0.0f && v_02.x < 0.f)
: (v_02.y < 0.f || v_02.y == 0.0f && v_02.x > 0.f);
const bool is_top_left_2 = (denominator > 0)
? (v_01.y < 0.f || v_01.y == 0.0f && v_01.x > 0.f)
: (v_01.y > 0.f || v_01.y == 0.0f && v_01.x < 0.f);
const bool is_top_left_or_inside = on_edge_or_inside &&
!(on_edge_0 && !is_top_left_0 || on_edge_1 && !is_top_left_1 ||
on_edge_2 && !is_top_left_2);
if (is_top_left_or_inside || intersecting) {
bary /= abs(denominator);
bary = math::max(bary, scalar3_t{0, 0, 0});
bary = math::min(bary, scalar3_t{1, 1, 1});
bary = bary / math::sum(bary);
// interpolate inverse depth linearly
const scalar3_t d_inv = 1.0 / epsclamp(p_012_z);
const scalar_t depth_inverse = dot(d_inv, bary);
const scalar_t depth = 1.0f / epsclamp(depth_inverse);
const unsigned long long packed_val =
(static_cast<unsigned long long>(__float_as_uint(depth)) << 32u) |
(intersecting ? static_cast<unsigned long long>(id) : 0xFFFFFFFFULL);
atomicMin(
reinterpret_cast<unsigned long long*>(packed_index_depth_img.data) +
index_sN * n + index_sH * y + index_sW * x,
packed_val);
}
}
}
}
}
}
}
template <typename index_t>
__global__ void unpack_kernel(
const index_t nthreads,
TensorInfo<int64_t, index_t> packed_index_depth_img,
TensorInfo<float, index_t> depth_img,
TensorInfo<int32_t, index_t> index_img) {
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
const unsigned long long int pv =
reinterpret_cast<unsigned long long int*>(packed_index_depth_img.data)[index];
const auto depth_uint = static_cast<uint32_t>(pv >> 32);
depth_img.data[index] = depth_uint == 0xFFFFFFFF ? 0.0f : __uint_as_float(depth_uint);
reinterpret_cast<uint32_t*>(index_img.data)[index] = static_cast<uint32_t>(pv & 0xFFFFFFFF);
}
}
std::vector<torch::Tensor> rasterize_cuda(
const torch::Tensor& v,
const torch::Tensor& vi,
int64_t height,
int64_t width,
bool wireframe) {
TORCH_CHECK(v.defined() && vi.defined(), "rasterize(): expected all inputs to be defined");
auto v_opt = v.options();
auto vi_opt = vi.options();
TORCH_CHECK(
(v.device() == vi.device()) && (v.is_cuda()),
"rasterize(): expected all inputs to be on same cuda device");
TORCH_CHECK(
v.is_floating_point(),
"rasterize(): expected v to have floating point type, but v has ",
v.dtype());
TORCH_CHECK(
vi.dtype() == torch::kInt32,
"rasterize(): expected vi to have int32 type, but vi has ",
vi.dtype());
TORCH_CHECK(
v.layout() == torch::kStrided && vi.layout() == torch::kStrided,
"rasterize(): expected all inputs to have torch.strided layout");
TORCH_CHECK(
(v.dim() == 3) && (vi.dim() == 2),
"rasterize(): expected v.ndim == 3, vi.ndim == 2, "
"but got v with sizes ",
v.sizes(),
" and vi with sizes ",
vi.sizes());
TORCH_CHECK(
v.size(2) == 3 && vi.size(1) == 3,
"rasterize(): expected third dim of v to be of size 3, and second dim of vi to be of size 3, but got ",
v.size(2),
" in the third dim of v, and ",
vi.size(1),
" in the second dim of vi");
TORCH_CHECK(
v.size(1) < 0x10000000U,
"rasterize(): expected second dim of v to be less or eual to 268435456, but got ",
v.size(1));
TORCH_CHECK(
height > 0 && width > 0,
"rasterize(): both height and width have to be greater than zero, but got height: ",
height,
", and width: ",
width);
const at::cuda::OptionalCUDAGuard device_guard(device_of(v));
auto stream = at::cuda::getCurrentCUDAStream();
auto N = v.size(0);
auto T = vi.size(0);
auto H = height;
auto W = width;
const auto count_rasterize = N * T;
const auto count_unpack = N * H * W;
auto packed_index_depth_img = at::empty({N, H, W}, v.options().dtype(torch::kInt64));
auto depth_img = at::empty({N, H, W}, v.options().dtype(torch::kFloat32));
auto index_img = at::empty({N, H, W}, v.options().dtype(torch::kInt32));
cudaMemsetAsync(
packed_index_depth_img.data_ptr(),
0xFF,
N * H * W * torch::elementSize(torch::kInt64),
stream);
// rasterize
if (count_rasterize > 0) {
AT_DISPATCH_FLOATING_TYPES(v.scalar_type(), "rasterize_kernel", [&] {
if (at::native::canUse32BitIndexMath(v) && at::native::canUse32BitIndexMath(vi) &&
at::native::canUse32BitIndexMath(packed_index_depth_img)) {
typedef int index_type;
if (wireframe) {
rasterize_lines_kernel<scalar_t, index_type>
<<<GET_BLOCKS(count_rasterize, 256), 256, 0, stream>>>(
static_cast<index_type>(count_rasterize),
getTensorInfo<scalar_t, index_type>(v),
getTensorInfo<int32_t, index_type>(vi),
getTensorInfo<int64_t, index_type>(packed_index_depth_img));
} else {
rasterize_kernel<scalar_t, index_type>
<<<GET_BLOCKS(count_rasterize, 256), 256, 0, stream>>>(
static_cast<index_type>(count_rasterize),
getTensorInfo<scalar_t, index_type>(v),
getTensorInfo<int32_t, index_type>(vi),
getTensorInfo<int64_t, index_type>(packed_index_depth_img));
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
typedef int64_t index_type;
if (wireframe) {
rasterize_lines_kernel<scalar_t, index_type>
<<<GET_BLOCKS(count_rasterize, 256), 256, 0, stream>>>(
static_cast<index_type>(count_rasterize),
getTensorInfo<scalar_t, index_type>(v),
getTensorInfo<int32_t, index_type>(vi),
getTensorInfo<int64_t, index_type>(packed_index_depth_img));
} else {
rasterize_kernel<scalar_t, index_type>
<<<GET_BLOCKS(count_rasterize, 256), 256, 0, stream>>>(
static_cast<index_type>(count_rasterize),
getTensorInfo<scalar_t, index_type>(v),
getTensorInfo<int32_t, index_type>(vi),
getTensorInfo<int64_t, index_type>(packed_index_depth_img));
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
});
}
// unpack
if (count_unpack > 0) {
if (at::native::canUse32BitIndexMath(packed_index_depth_img) &&
at::native::canUse32BitIndexMath(depth_img) &&
at::native::canUse32BitIndexMath(index_img)) {
typedef int index_type;
unpack_kernel<index_type><<<GET_BLOCKS(count_rasterize, 256), 256, 0, stream>>>(
static_cast<index_type>(count_unpack),
getTensorInfo<int64_t, index_type>(packed_index_depth_img),
getTensorInfo<float, index_type>(depth_img),
getTensorInfo<int32_t, index_type>(index_img));
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
typedef int64_t index_type;
unpack_kernel<index_type><<<GET_BLOCKS(count_rasterize, 256), 256, 0, stream>>>(
static_cast<index_type>(count_unpack),
getTensorInfo<int64_t, index_type>(packed_index_depth_img),
getTensorInfo<float, index_type>(depth_img),
getTensorInfo<int32_t, index_type>(index_img));
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
return {depth_img, index_img};
}
// Copyright (c) Meta Platforms, Inc. and affiliates.
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
std::vector<torch::Tensor> rasterize_cuda(
const torch::Tensor& v,
const torch::Tensor& vi,
int64_t height,
int64_t width,
bool wireframe);
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#include <torch/script.h>
#include <ATen/autocast_mode.h>
#ifndef NO_PYBIND
#include <torch/extension.h>
#endif
#include "rasterize_kernel.h"
// Dispatch function
torch::autograd::tensor_list rasterize(
const torch::Tensor& v,
const torch::Tensor& vi,
int64_t height,
int64_t width,
bool wireframe) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("rasterize_ext::rasterize", "")
.typed<decltype(rasterize)>();
return op.call(v, vi, height, width, wireframe);
}
// Ideally we would need to turn off autograd handling and re-dispatch, but we just call
// cuda kernels directly
class RasterizeFunction : public torch::autograd::Function<RasterizeFunction> {
public:
static torch::autograd::tensor_list forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& v,
const torch::Tensor& vi,
int64_t height,
int64_t width,
bool wireframe) {
ctx->set_materialize_grads(false);
auto outputs = rasterize_cuda(v, vi, height, width, wireframe);
ctx->mark_non_differentiable(outputs);
return outputs;
}
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::tensor_list& grad_outputs) {
return {torch::Tensor(), torch::Tensor(), torch::Tensor(), torch::Tensor(), torch::Tensor()};
}
};
torch::autograd::tensor_list rasterize_autograd(
const torch::Tensor& v,
const torch::Tensor& vi,
int64_t height,
int64_t width,
bool wireframe) {
return RasterizeFunction::apply(v, vi, height, width, wireframe);
}
torch::autograd::tensor_list rasterize_autocast(
const torch::Tensor& v,
const torch::Tensor& vi,
int64_t height,
int64_t width,
bool wireframe) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return rasterize(at::autocast::cached_cast(torch::kFloat32, v), vi, height, width, wireframe);
}
#ifndef NO_PYBIND
PYBIND11_MODULE(rasterize_ext, m) {}
#endif
TORCH_LIBRARY(rasterize_ext, m) {
m.def("rasterize(Tensor v, Tensor vi, int height, int width, bool wireframe) -> Tensor[]");
}
TORCH_LIBRARY_IMPL(rasterize_ext, Autograd, m) {
m.impl("rasterize", &rasterize_autograd);
}
TORCH_LIBRARY_IMPL(rasterize_ext, Autocast, m) {
m.impl("rasterize", rasterize_autocast);
}
TORCH_LIBRARY_IMPL(rasterize_ext, CUDA, m) {
m.impl("rasterize", &rasterize_cuda);
}
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#include <c10/cuda/CUDAGuard.h>
#include <cuda_math_helper.h>
#include <torch/types.h>
#include <ATen/native/cuda/KernelUtils.cuh>
#include "render_kernel.h"
#include <kernel_utils.h>
using namespace math;
using at::native::fastAtomicAdd;
template <typename scalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(256)
__global__ void render_kernel(
const index_t nthreads,
TensorInfo<scalar_t, index_t> v,
TensorInfo<int32_t, index_t> vi,
TensorInfo<int32_t, index_t> index_img,
TensorInfo<scalar_t, index_t> depth_img,
TensorInfo<scalar_t, index_t> bary_img) {
typedef typename math::TVec2<scalar_t> scalar2_t;
typedef typename math::TVec3<scalar_t> scalar3_t;
const index_t H = bary_img.sizes[2];
const index_t W = bary_img.sizes[3];
const index_t V = v.sizes[1];
const index_t v_sN = v.strides[0];
const index_t v_sV = v.strides[1];
const index_t v_sC = v.strides[2];
const index_t vi_sV = vi.strides[0];
const index_t vi_sF = vi.strides[1];
const index_t index_img_sN = index_img.strides[0];
const index_t index_img_sH = index_img.strides[1];
const index_t index_img_sW = index_img.strides[2];
const index_t depth_img_sN = depth_img.strides[0];
const index_t depth_img_sH = depth_img.strides[1];
const index_t depth_img_sW = depth_img.strides[2];
const index_t bary_img_sN = bary_img.strides[0];
const index_t bary_img_sB = bary_img.strides[1];
const index_t bary_img_sH = bary_img.strides[2];
const index_t bary_img_sW = bary_img.strides[3];
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
const index_t w = index % W;
const index_t h = (index / W) % H;
const index_t n = index / (H * W);
const int32_t tr_index = index_img.data[n * index_img_sN + h * index_img_sH + w * index_img_sW];
scalar_t* __restrict bary_img_ptr =
bary_img.data + bary_img_sN * n + bary_img_sH * h + bary_img_sW * w;
scalar_t* __restrict depth_img_ptr =
depth_img.data + depth_img_sN * n + depth_img_sH * h + depth_img_sW * w;
if (tr_index != -1) {
const int32_t* __restrict vi_ptr = vi.data + tr_index * vi_sV;
const int32_t vi_0 = vi_ptr[0 * vi_sF];
const int32_t vi_1 = vi_ptr[1 * vi_sF];
const int32_t vi_2 = vi_ptr[2 * vi_sF];
assert(vi_0 < V && vi_1 < V && vi_2 < V);
const scalar_t* __restrict v_ptr = v.data + n * v_sN;
const scalar2_t p_0 = {v_ptr[v_sV * vi_0 + v_sC * 0], v_ptr[v_sV * vi_0 + v_sC * 1]};
const scalar2_t p_1 = {v_ptr[v_sV * vi_1 + v_sC * 0], v_ptr[v_sV * vi_1 + v_sC * 1]};
const scalar2_t p_2 = {v_ptr[v_sV * vi_2 + v_sC * 0], v_ptr[v_sV * vi_2 + v_sC * 1]};
const scalar3_t p_012_z = {
v_ptr[v_sV * vi_0 + v_sC * 2],
v_ptr[v_sV * vi_1 + v_sC * 2],
v_ptr[v_sV * vi_2 + v_sC * 2]};
const scalar2_t v_01 = p_1 - p_0;
const scalar2_t v_02 = p_2 - p_0;
const scalar_t denominator = epsclamp((v_01.x * v_02.y - v_01.y * v_02.x));
const scalar2_t vp0p = {w - p_0.x, h - p_0.y};
const scalar2_t bary_12_pre = scalar2_t{
(vp0p.x * v_02.y - vp0p.y * v_02.x),
(vp0p.y * v_01.x - vp0p.x * v_01.y),
};
const scalar2_t bary_12 = bary_12_pre / denominator;
scalar3_t bary = {scalar_t(1.0) - bary_12.x - bary_12.y, bary_12.x, bary_12.y};
const scalar3_t p_012_z_eps = epsclamp(p_012_z);
const scalar3_t d_inv = 1.0 / p_012_z_eps;
const scalar_t depth_inverse = dot(d_inv, bary);
const scalar_t depth = 1.0f / epsclamp(depth_inverse);
const scalar3_t bary_3D = d_inv * bary * depth;
bary_img_ptr[bary_img_sB * 0] = bary_3D.x;
bary_img_ptr[bary_img_sB * 1] = bary_3D.y;
bary_img_ptr[bary_img_sB * 2] = bary_3D.z;
*depth_img_ptr = depth;
} else {
bary_img_ptr[bary_img_sB * 0] = scalar_t(0);
bary_img_ptr[bary_img_sB * 1] = scalar_t(0);
bary_img_ptr[bary_img_sB * 2] = scalar_t(0);
*depth_img_ptr = scalar_t(0);
}
}
}
template <typename scalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(256)
__global__ void render_backward_kernel(
const index_t nthreads,
TensorInfo<scalar_t, index_t> v,
TensorInfo<int32_t, index_t> vi,
TensorInfo<int32_t, index_t> index_img,
TensorInfo<scalar_t, index_t> grad_depth_img,
TensorInfo<scalar_t, index_t> grad_bary_img,
TensorInfo<scalar_t, index_t> grad_v,
const index_t memory_span) {
typedef typename math::TVec2<scalar_t> scalar2_t;
typedef typename math::TVec3<scalar_t> scalar3_t;
const index_t H = grad_bary_img.sizes[2];
const index_t W = grad_bary_img.sizes[3];
const index_t V = v.sizes[1];
const index_t v_sN = v.strides[0];
const index_t v_sV = v.strides[1];
const index_t v_sC = v.strides[2];
const index_t vi_sV = vi.strides[0];
const index_t vi_sF = vi.strides[1];
const index_t index_img_sN = index_img.strides[0];
const index_t index_img_sH = index_img.strides[1];
const index_t index_img_sW = index_img.strides[2];
const index_t grad_depth_img_sN = grad_depth_img.strides[0];
const index_t grad_depth_img_sH = grad_depth_img.strides[1];
const index_t grad_depth_img_sW = grad_depth_img.strides[2];
const index_t grad_bary_img_sN = grad_bary_img.strides[0];
const index_t grad_bary_img_sB = grad_bary_img.strides[1];
const index_t grad_bary_img_sH = grad_bary_img.strides[2];
const index_t grad_bary_img_sW = grad_bary_img.strides[3];
const index_t grad_v_sN = grad_v.strides[0];
const index_t grad_v_sV = grad_v.strides[1];
const index_t grad_v_sC = grad_v.strides[2];
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
const index_t w = index % W;
const index_t h = (index / W) % H;
const index_t n = index / (H * W);
const int32_t tr_index = index_img.data[n * index_img_sN + h * index_img_sH + w * index_img_sW];
const scalar_t* __restrict grad_bary_img_ptr =
grad_bary_img.data + grad_bary_img_sN * n + grad_bary_img_sH * h + grad_bary_img_sW * w;
const scalar_t* __restrict grad_depth_img_ptr =
grad_depth_img.data + grad_depth_img_sN * n + grad_depth_img_sH * h + grad_depth_img_sW * w;
scalar_t* __restrict grad_v_ptr = grad_v.data + grad_v_sN * n;
if (tr_index != -1) {
const int32_t* __restrict vi_ptr = vi.data + tr_index * vi_sV;
const int32_t vi_0 = vi_ptr[0 * vi_sF];
const int32_t vi_1 = vi_ptr[1 * vi_sF];
const int32_t vi_2 = vi_ptr[2 * vi_sF];
assert(vi_0 < V && vi_1 < V && vi_2 < V);
const scalar_t* __restrict v_ptr = v.data + n * v_sN;
const scalar2_t p_0 = {v_ptr[v_sV * vi_0 + v_sC * 0], v_ptr[v_sV * vi_0 + v_sC * 1]};
const scalar2_t p_1 = {v_ptr[v_sV * vi_1 + v_sC * 0], v_ptr[v_sV * vi_1 + v_sC * 1]};
const scalar2_t p_2 = {v_ptr[v_sV * vi_2 + v_sC * 0], v_ptr[v_sV * vi_2 + v_sC * 1]};
const scalar3_t p_012_z = {
v_ptr[v_sV * vi_0 + v_sC * 2],
v_ptr[v_sV * vi_1 + v_sC * 2],
v_ptr[v_sV * vi_2 + v_sC * 2]};
const scalar2_t v_01 = p_1 - p_0;
const scalar2_t v_02 = p_2 - p_0;
const scalar_t _denominator = v_01.x * v_02.y - v_01.y * v_02.x;
const scalar_t denominator = epsclamp(_denominator);
const bool denominator_clamped = denominator != _denominator;
const scalar2_t vp0p = {w - p_0.x, h - p_0.y};
const scalar2_t bary_12_pre = scalar2_t{
vp0p.x * v_02.y - vp0p.y * v_02.x,
vp0p.y * v_01.x - vp0p.x * v_01.y,
};
const scalar2_t bary_12 = bary_12_pre / denominator;
scalar3_t bary = {scalar_t(1.0) - bary_12.x - bary_12.y, bary_12.x, bary_12.y};
const scalar3_t p_012_z_eps = epsclamp(p_012_z);
const bool z0_clamped = p_012_z_eps.x != p_012_z.x;
const bool z1_clamped = p_012_z_eps.y != p_012_z.y;
const bool z2_clamped = p_012_z_eps.z != p_012_z.z;
const scalar3_t d_inv = 1.0 / p_012_z_eps;
const scalar_t depth_inverse = dot(d_inv, bary);
const scalar_t depth_inverse_eps = epsclamp(depth_inverse);
const bool depth_inverse_clamped = depth_inverse_eps != depth_inverse;
const scalar_t depth = 1.0f / depth_inverse_eps;
const scalar3_t dL_bary_3D = {
grad_bary_img_ptr[grad_bary_img_sB * 0],
grad_bary_img_ptr[grad_bary_img_sB * 1],
grad_bary_img_ptr[grad_bary_img_sB * 2]};
const scalar_t dL_depth = *grad_depth_img_ptr + dot(dL_bary_3D * d_inv, bary);
const scalar_t dL_depth_inverse =
depth_inverse_clamped ? 0.f : (-dL_depth / (depth_inverse * depth_inverse));
const scalar3_t dL_d_inv = dL_bary_3D * bary * depth + dL_depth_inverse * bary;
const scalar3_t dL_p_012_z = -dL_d_inv / (p_012_z_eps * p_012_z_eps);
fastAtomicAdd(
grad_v_ptr,
grad_v_sV * vi_0 + grad_v_sC * 2,
memory_span,
z0_clamped ? 0.f : dL_p_012_z.x,
true);
fastAtomicAdd(
grad_v_ptr,
grad_v_sV * vi_1 + grad_v_sC * 2,
memory_span,
z1_clamped ? 0.f : dL_p_012_z.y,
true);
fastAtomicAdd(
grad_v_ptr,
grad_v_sV * vi_2 + grad_v_sC * 2,
memory_span,
z2_clamped ? 0.f : dL_p_012_z.z,
true);
const scalar3_t dL_bary = dL_bary_3D * d_inv * depth + dL_depth_inverse * d_inv;
const scalar2_t dL_bary_12 = {-dL_bary.x + dL_bary.y, -dL_bary.x + dL_bary.z};
const scalar2_t dL_bary_pre = dL_bary_12 / denominator;
const scalar_t dL_denominator = denominator_clamped ? 0.f : -dot(dL_bary_pre, bary_12);
const scalar2_t dL_vp0p = {
dL_bary_pre.x * v_02.y - dL_bary_pre.y * v_01.y,
-dL_bary_pre.x * v_02.x + dL_bary_pre.y * v_01.x};
const scalar2_t dL_v_02 = {
-dL_bary_pre.x * vp0p.y - dL_denominator * v_01.y,
dL_bary_pre.x * vp0p.x + dL_denominator * v_01.x};
const scalar2_t dL_v_01 = {
dL_bary_pre.y * vp0p.y + dL_denominator * v_02.y,
-dL_bary_pre.y * vp0p.x - dL_denominator * v_02.x};
const scalar2_t dL_p0 = -dL_v_02 - dL_v_01 - dL_vp0p;
const scalar2_t dL_p1 = dL_v_01;
const scalar2_t dL_p2 = dL_v_02;
fastAtomicAdd(grad_v_ptr, grad_v_sV * vi_0 + grad_v_sC * 0, memory_span, dL_p0.x, true);
fastAtomicAdd(grad_v_ptr, grad_v_sV * vi_0 + grad_v_sC * 1, memory_span, dL_p0.y, true);
fastAtomicAdd(grad_v_ptr, grad_v_sV * vi_1 + grad_v_sC * 0, memory_span, dL_p1.x, true);
fastAtomicAdd(grad_v_ptr, grad_v_sV * vi_1 + grad_v_sC * 1, memory_span, dL_p1.y, true);
fastAtomicAdd(grad_v_ptr, grad_v_sV * vi_2 + grad_v_sC * 0, memory_span, dL_p2.x, true);
fastAtomicAdd(grad_v_ptr, grad_v_sV * vi_2 + grad_v_sC * 1, memory_span, dL_p2.y, true);
}
}
}
std::vector<torch::Tensor>
render_cuda(const torch::Tensor& v, const torch::Tensor& vi, const torch::Tensor& index_img) {
TORCH_CHECK(
v.defined() && vi.defined() && index_img.defined(),
"render(): expected all inputs to be defined");
auto v_opt = v.options();
auto vi_opt = vi.options();
auto index_img_opt = index_img.options();
TORCH_CHECK(
(v.device() == vi.device()) && (v.device() == index_img.device()) && (v.is_cuda()),
"render(): expected all inputs to be on same cuda device");
TORCH_CHECK(
v.is_floating_point(),
"render(): expected v to have floating point type, but v has ",
v.dtype());
TORCH_CHECK(
vi.dtype() == torch::kInt32,
"render(): expected vi to have int32 type, but vi has ",
vi.dtype());
TORCH_CHECK(
index_img.dtype() == torch::kInt32,
"render(): expected index_img to have int32 type, but index_img has ",
index_img.dtype());
TORCH_CHECK(
v.layout() == torch::kStrided && vi.layout() == torch::kStrided &&
index_img.layout() == torch::kStrided,
"render(): expected all inputs to have torch.strided layout");
TORCH_CHECK(
(v.dim() == 3) && (vi.dim() == 2) && (index_img.dim() == 3),
"render(): expected v.ndim == 3, vi.ndim == 2, index_img.ndim == 3, "
"but got v with sizes ",
v.sizes(),
" and vi with sizes ",
vi.sizes(),
" and index_img with sizes ",
index_img.sizes());
TORCH_CHECK(
v.size(0) == index_img.size(0),
"render(): expected v and index_img to have same batch size, "
"but got v with sizes ",
v.sizes(),
" and index_img with sizes ",
index_img.sizes());
TORCH_CHECK(
v.size(2) == 3 && vi.size(1) == 3,
"render(): expected third dim of v to be of size 3, and second dim of vi to be of size 3, but got ",
v.size(2),
" in the third dim of v, and ",
vi.size(1),
" in the second dim of vi");
const at::cuda::OptionalCUDAGuard device_guard(device_of(v));
auto N = v.size(0);
auto H = index_img.size(1);
auto W = index_img.size(2);
int64_t count = N * H * W;
auto depth_img = at::empty({N, H, W}, v.options());
auto bary_img = at::empty({N, 3, H, W}, v.options());
if (count > 0) {
AT_DISPATCH_FLOATING_TYPES(v.scalar_type(), "render_kernel", [&] {
if (at::native::canUse32BitIndexMath(v) && at::native::canUse32BitIndexMath(bary_img) &&
at::native::canUse32BitIndexMath(depth_img) &&
at::native::canUse32BitIndexMath(index_img) && at::native::canUse32BitIndexMath(vi)) {
typedef int index_type;
render_kernel<scalar_t, index_type>
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<index_type>(count),
getTensorInfo<scalar_t, index_type>(v),
getTensorInfo<int32_t, index_type>(vi),
getTensorInfo<int32_t, index_type>(index_img),
getTensorInfo<scalar_t, index_type>(depth_img),
getTensorInfo<scalar_t, index_type>(bary_img));
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
typedef int64_t index_type;
render_kernel<scalar_t, index_type>
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<index_type>(count),
getTensorInfo<scalar_t, index_type>(v),
getTensorInfo<int32_t, index_type>(vi),
getTensorInfo<int32_t, index_type>(index_img),
getTensorInfo<scalar_t, index_type>(depth_img),
getTensorInfo<scalar_t, index_type>(bary_img));
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
});
}
return {depth_img, bary_img};
}
torch::Tensor render_cuda_backward(
const torch::Tensor& v,
const torch::Tensor& vi,
const torch::Tensor& index_img,
const torch::Tensor& grad_depth_img,
const torch::Tensor& grad_bary_img) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(v));
auto N = v.size(0);
auto V = v.size(1);
auto C = v.size(2);
auto H = index_img.size(1);
auto W = index_img.size(2);
int64_t count = N * H * W;
auto grad_v = at::zeros({N, V, C}, v.options());
if (count > 0) {
AT_DISPATCH_FLOATING_TYPES(v.scalar_type(), "interpolate_kernel", [&] {
if (at::native::canUse32BitIndexMath(v) && at::native::canUse32BitIndexMath(grad_bary_img) &&
at::native::canUse32BitIndexMath(grad_v) && at::native::canUse32BitIndexMath(index_img) &&
at::native::canUse32BitIndexMath(grad_depth_img) &&
at::native::canUse32BitIndexMath(vi)) {
typedef int index_type;
render_backward_kernel<scalar_t, index_type>
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<index_type>(count),
getTensorInfo<scalar_t, index_type>(v),
getTensorInfo<int32_t, index_type>(vi),
getTensorInfo<int32_t, index_type>(index_img),
getTensorInfo<scalar_t, index_type>(grad_depth_img),
getTensorInfo<scalar_t, index_type>(grad_bary_img),
getTensorInfo<scalar_t, index_type>(grad_v),
grad_v.numel());
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
typedef int64_t index_type;
render_backward_kernel<scalar_t, index_type>
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<index_type>(count),
getTensorInfo<scalar_t, index_type>(v),
getTensorInfo<int32_t, index_type>(vi),
getTensorInfo<int32_t, index_type>(index_img),
getTensorInfo<scalar_t, index_type>(grad_depth_img),
getTensorInfo<scalar_t, index_type>(grad_bary_img),
getTensorInfo<scalar_t, index_type>(grad_v),
grad_v.numel());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
});
}
return grad_v;
}
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
std::vector<torch::Tensor>
render_cuda(const torch::Tensor& v, const torch::Tensor& vi, const torch::Tensor& index_img);
torch::Tensor render_cuda_backward(
const torch::Tensor& v,
const torch::Tensor& vi,
const torch::Tensor& index_img,
const torch::Tensor& grad_depth_img,
const torch::Tensor& grad_bary_img);
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#include <torch/script.h>
#include <ATen/autocast_mode.h>
#ifndef NO_PYBIND
#include <torch/extension.h>
#endif
#include "render_kernel.h"
// Dispatch function
torch::autograd::tensor_list
render(const torch::Tensor& v, const torch::Tensor& vi, const torch::Tensor& index_img) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("render_ext::render", "")
.typed<decltype(render)>();
return op.call(v, vi, index_img);
}
// Ideally we would need to turn off autograd handling and re-dispatch, but we just call
// cuda kernels directly
class RenderFunction : public torch::autograd::Function<RenderFunction> {
public:
static torch::autograd::tensor_list forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& v,
const torch::Tensor& vi,
const torch::Tensor& index_img) {
// ctx->set_materialize_grads(false);
std::vector<torch::Tensor> save_list;
save_list.push_back(v);
save_list.push_back(vi);
save_list.push_back(index_img);
ctx->save_for_backward(save_list);
ctx->saved_data["data"] = std::make_tuple((bool)v.requires_grad());
auto outputs = render_cuda(v, vi, index_img);
return outputs;
}
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
const auto saved = ctx->get_saved_variables();
const torch::Tensor& v = saved[0];
const torch::Tensor& vi = saved[1];
const torch::Tensor& index_img = saved[2];
bool requires_grad;
std::tie(requires_grad) = ctx->saved_data["data"].to<std::tuple<bool>>();
torch::autograd::tensor_list out;
if (!requires_grad) {
out.resize(3);
return out;
}
auto grad_v = render_cuda_backward(v, vi, index_img, grad_outputs[0], grad_outputs[1]);
out.push_back(grad_v);
out.emplace_back();
out.emplace_back();
return out;
}
};
torch::autograd::tensor_list
render_autograd(const torch::Tensor& v, const torch::Tensor& vi, const torch::Tensor& index_img) {
return RenderFunction::apply(v, vi, index_img);
}
torch::autograd::tensor_list
render_autocast(const torch::Tensor& v, const torch::Tensor& vi, const torch::Tensor& index_img) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return render(at::autocast::cached_cast(torch::kFloat32, v), vi, index_img);
}
#ifndef NO_PYBIND
PYBIND11_MODULE(render_ext, m) {}
#endif
TORCH_LIBRARY(render_ext, m) {
m.def("render(Tensor v, Tensor vi, Tensor index_img) -> Tensor[]");
}
TORCH_LIBRARY_IMPL(render_ext, Autograd, m) {
m.impl("render", &render_autograd);
}
TORCH_LIBRARY_IMPL(render_ext, Autocast, m) {
m.impl("render", render_autocast);
}
TORCH_LIBRARY_IMPL(render_ext, CUDA, m) {
m.impl("render", &render_cuda);
}
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import cv2
import torch as th
import torch.nn.functional as thf
from drtk import edge_grad_estimator, interpolate, rasterize, render
def main(write_images, xy_only=False, z_only=False):
assert not (xy_only and z_only), "You need to optimize at least some axes."
v = th.tensor(
[
[10, 200, 100],
[300, 50, 100],
[400, 500, 100],
[50, 400, 200],
[400, 50, 50],
[300, 500, 200],
],
dtype=th.float32,
device="cuda",
)
vt = th.zeros(1, 6, 2, device="cuda")
vt[:, 3:6, 0] = 1
vi = th.arange(6, device="cuda").int().view(2, 3)
w = 512
h = 512
tex = th.ones(1, 3, 16, 16, device="cuda")
tex[:, :, :, 8:] = 0.5
v = th.nn.Parameter(v[None, ...])
# Render the GT (target) image.
with th.no_grad():
v_gt = v.clone()
th.cuda.manual_seed(10)
v += th.randn_like(v) * 20.0
index_img = rasterize(v_gt, vi, h, w)
_, bary_img = render(v_gt, vi, index_img)
vt_img = interpolate(vt, vi, index_img, bary_img).permute(0, 2, 3, 1)
img_gt = (
thf.grid_sample(tex, vt_img, padding_mode="border", align_corners=False)
* (index_img != -1)[:, None]
)
img = (255 * img_gt[0]).clamp(0, 255).byte().data.cpu().numpy()
cv2.imwrite("two_triangle_imgs/target.png", img.transpose(1, 2, 0)[..., ::-1])
optim = th.optim.Adam([v], lr=1e-1, betas=(0.9, 0.999))
# Optimize geometry to match target.
for it in range(2000):
index_img = rasterize(v, vi, h, w)
_, bary_img = render(v, vi, index_img)
vt_img = interpolate(vt, vi, index_img, bary_img).permute(0, 2, 3, 1)
img = (
thf.grid_sample(tex, vt_img, padding_mode="border", align_corners=False)
* (index_img != -1)[:, None]
)
img = edge_grad_estimator(
v_pix=v,
vi=vi,
bary_img=bary_img,
img=img,
index_img=index_img,
)
loss = ((img - img_gt) ** 2).mean()
optim.zero_grad()
loss.backward()
if xy_only:
v.grad[..., 2] = 0
if z_only:
v.grad[..., :2] = 0
optim.step()
if it % 20 == 0:
print(it, f"{loss.item():0.3e}")
if write_images:
img = (255 * img[0]).clamp(0, 255).byte().data.cpu().numpy()
cv2.imwrite(
f"two_triangle_imgs/{it:06d}.png", img.transpose(1, 2, 0)[..., ::-1]
)
if __name__ == "__main__":
write_images = True
if write_images and not os.path.exists("two_triangle_imgs"):
os.mkdir("two_triangle_imgs")
main(write_images)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment