Commit d4216dd3 authored by Kishore Venkateshan's avatar Kishore Venkateshan Committed by Facebook GitHub Bot
Browse files

2/N batchify render kernel

Summary:
# Problem
In CT / State Encoding, we expect a scenario where we would like to render a batch of topologies where each of them would have different number of vertices and triangles. Currently the only way to support this with DRTK is to iterate over the batch in a for loop for each topology and render it.
In a series of diffs we would like to solve this issue by making drtk consume a batch of triangles as opposed to just 1 set of triangles. However, we would like to achieve this behavior without affecting the most common single topology case by a lot.

# How do we pass in multiple topologies in a single batch?
We will provide a TopologyBatch structure in xrcia/lib/graphics/structures where we will provide functionality to create a Batch x MaxTriangles x 3 and Batch x MaxVertices x 3.
Padded vertices will be 0s and padded triangles will have MaxVertices - 1 as their value. But these will discarded as degenerate in rasterization / rendering.

# In this diff
- Extend render kernel and render backward kernel to support a batch dimension as default.
- `render` will now unsqueeze the batch dimension when using a single topo
- We access the vertex indices of triangles by walking an additional `batch stride * n` in the triangles data pointer.
- Add an extra condition to check to see if the triangles are degenerate; this happens when padding the batch.
- We show that the we don't cause too much overhead in GPU by introducing these 3 extra operations (Same profiling as in D68423813)

Reviewed By: podgorskiy

Differential Revision: D68423409

fbshipit-source-id: e1007b9844658ef6e1bb2267b6a94804f3b6d13b
parent e0274716
......@@ -18,6 +18,21 @@ def render(
vi: th.Tensor,
index_img: th.Tensor,
) -> Tuple[th.Tensor, th.Tensor]:
"""
Render depth and barycentric coordinates from a mesh.
Args:
v: [N, V, 3] tensor of vertex positions.
vi: [N, F, 3] or [F, 3] tensor of triangle indices.
Returns:
depth_img: [N, H, W] tensor of depth values.
bary_img: [N, H, W, 3] tensor of barycentric coordinates.
"""
if vi.ndim == 2:
vi = vi[None].expand(v.shape[0], -1, -1)
depth_img, bary_img = th.ops.render_ext.render(v, vi, index_img)
return depth_img, bary_img
......
......@@ -36,8 +36,9 @@ __global__ void render_kernel(
const index_t v_sV = v.strides[1];
const index_t v_sC = v.strides[2];
const index_t vi_sV = vi.strides[0];
const index_t vi_sF = vi.strides[1];
const index_t vi_sN = vi.strides[0];
const index_t vi_sV = vi.strides[1];
const index_t vi_sF = vi.strides[2];
const index_t index_img_sN = index_img.strides[0];
const index_t index_img_sH = index_img.strides[1];
......@@ -64,7 +65,7 @@ __global__ void render_kernel(
depth_img.data + depth_img_sN * n + depth_img_sH * h + depth_img_sW * w;
if (tr_index != -1) {
const int32_t* __restrict vi_ptr = vi.data + tr_index * vi_sV;
const int32_t* __restrict vi_ptr = vi.data + n * vi_sN + tr_index * vi_sV;
const int32_t vi_0 = vi_ptr[0 * vi_sF];
const int32_t vi_1 = vi_ptr[1 * vi_sF];
const int32_t vi_2 = vi_ptr[2 * vi_sF];
......@@ -136,8 +137,9 @@ __global__ void render_backward_kernel(
const index_t v_sV = v.strides[1];
const index_t v_sC = v.strides[2];
const index_t vi_sV = vi.strides[0];
const index_t vi_sF = vi.strides[1];
const index_t vi_sN = vi.strides[0];
const index_t vi_sV = vi.strides[1];
const index_t vi_sF = vi.strides[2];
const index_t index_img_sN = index_img.strides[0];
const index_t index_img_sH = index_img.strides[1];
......@@ -170,7 +172,7 @@ __global__ void render_backward_kernel(
scalar_t* __restrict grad_v_ptr = grad_v.data + grad_v_sN * n;
if (tr_index != -1) {
const int32_t* __restrict vi_ptr = vi.data + tr_index * vi_sV;
const int32_t* __restrict vi_ptr = vi.data + n * vi_sN + tr_index * vi_sV;
const int32_t vi_0 = vi_ptr[0 * vi_sF];
const int32_t vi_1 = vi_ptr[1 * vi_sF];
const int32_t vi_2 = vi_ptr[2 * vi_sF];
......@@ -304,8 +306,8 @@ render_cuda(const torch::Tensor& v, const torch::Tensor& vi, const torch::Tensor
index_img.layout() == torch::kStrided,
"render(): expected all inputs to have torch.strided layout");
TORCH_CHECK(
(v.dim() == 3) && (vi.dim() == 2) && (index_img.dim() == 3),
"render(): expected v.ndim == 3, vi.ndim == 2, index_img.ndim == 3, "
(v.dim() == 3) && (vi.dim() == 3) && (index_img.dim() == 3),
"render(): expected v.ndim == 3, vi.ndim == 3, index_img.ndim == 3, "
"but got v with sizes ",
v.sizes(),
" and vi with sizes ",
......@@ -320,12 +322,19 @@ render_cuda(const torch::Tensor& v, const torch::Tensor& vi, const torch::Tensor
" and index_img with sizes ",
index_img.sizes());
TORCH_CHECK(
v.size(2) == 3 && vi.size(1) == 3,
"render(): expected third dim of v to be of size 3, and second dim of vi to be of size 3, but got ",
vi.size(0) == v.size(0),
"rasterize(): expected first dim of vi to match first dim of v but got ",
v.size(0),
" in first dim of v, and ",
vi.size(0),
" in the first dim of vi");
TORCH_CHECK(
v.size(2) == 3 && vi.size(2) == 3,
"render(): expected third dim of v to be of size 3, and third dim of vi to be of size 3, but got ",
v.size(2),
" in the third dim of v, and ",
vi.size(1),
" in the second dim of vi");
vi.size(2),
" in the third dim of vi");
const at::cuda::OptionalCUDAGuard device_guard(device_of(v));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment