Commit bdb1bfdb authored by Kishore Venkateshan's avatar Kishore Venkateshan Committed by Facebook GitHub Bot
Browse files

3/N Batchify Interpolate Kernel

Summary:
# Problem
In CT / State Encoding, we expect a scenario where we would like to render a batch of topologies where each of them would have different number of vertices and triangles. Currently the only way to support this with DRTK is to iterate over the batch in a for loop for each topology and render it.
In a series of diffs we would like to solve this issue by making drtk consume a batch of triangles as opposed to just 1 set of triangles. However, we would like to achieve this behavior without affecting the most common single topology case by a lot.

# How do we pass in multiple topologies in a single batch?
We will provide a TopologyBatch structure in xrcia/lib/graphics/structures where we will provide functionality to create a Batch x MaxTriangles x 3 and Batch x MaxVertices x 3.
Padded vertices will be 0s and padded triangles will have MaxVertices - 1 as their value. But these will discarded as degenerate in rasterization / rendering.

# In this diff
- Extend `interpolate` kernel and `interpolate_backward` kernel to support a batch dimension as default.
- `interpolate` will now unsqueeze the batch dimension when using a single topo
- We access the vertex indices of triangles by walking an additional `batch stride * n` in the triangles data pointer.
- Add an extra condition to check to see if the triangles are degenerate; this happens when padding the batch.
- We show that the we don't cause too much overhead in GPU by introducing these 3 extra operations (Same profiling as in D68529076)

Reviewed By: podgorskiy

Differential Revision: D68400728

fbshipit-source-id: d13dbde5cc379789132953c05f6f9289748d67c7
parent d4216dd3
......@@ -28,7 +28,7 @@ def interpolate(
vert_attributes (th.Tensor): vertex attribute tensor
N x V x C
vi (th.Tensor): face vertex index list tensor
V x 3
F x 3 or N x F x 3
index_img (th.Tensor): index image tensor
N x H x W
bary_img (th.Tensor): 3D barycentric coordinate image tensor
......@@ -42,6 +42,9 @@ def interpolate(
For all other pixels, which had index ``-1`` in ``index_img``, the returned tensor will have non-zero
values which should be ignored.
"""
if vi.ndim == 2:
vi = vi[None].expand(vert_attributes.shape[0], -1, -1)
return th.ops.interpolate_ext.interpolate(vert_attributes, vi, index_img, bary_img)
......
......@@ -30,8 +30,9 @@ __global__ void interpolate_kernel(
const index_t vert_attributes_sV = vert_attributes.strides[1];
const index_t vert_attributes_sC = vert_attributes.strides[2];
const index_t vi_sV = vi.strides[0];
const index_t vi_sF = vi.strides[1];
const index_t vi_sN = vi.strides[0];
const index_t vi_sV = vi.strides[1];
const index_t vi_sF = vi.strides[2];
const index_t index_img_sN = index_img.strides[0];
const index_t index_img_sH = index_img.strides[1];
......@@ -56,7 +57,7 @@ __global__ void interpolate_kernel(
scalar_t* __restrict out_ptr = out_img.data + out_img_sN * n + out_img_sH * h + out_img_sW * w;
if (tr_index != -1) {
const int32_t* __restrict vi_ptr = vi.data + tr_index * vi_sV;
const int32_t* __restrict vi_ptr = vi.data + n * vi_sN + tr_index * vi_sV;
const int32_t vi_0 = vi_ptr[0 * vi_sF];
const int32_t vi_1 = vi_ptr[1 * vi_sF];
const int32_t vi_2 = vi_ptr[2 * vi_sF];
......@@ -111,8 +112,9 @@ __global__ void interpolate_backward_kernel(
index_t vert_attributes_grad_sV = vert_attributes_grad.strides[1];
index_t vert_attributes_grad_sC = vert_attributes_grad.strides[2];
index_t vi_sV = vi.strides[0];
index_t vi_sF = vi.strides[1];
index_t vi_sN = vi.strides[0];
index_t vi_sV = vi.strides[1];
index_t vi_sF = vi.strides[2];
index_t index_img_sN = index_img.strides[0];
index_t index_img_sH = index_img.strides[1];
......@@ -164,9 +166,10 @@ __global__ void interpolate_backward_kernel(
if (warp_is_used) {
int32_t vi_0 = -1, vi_1 = -1, vi_2 = -1;
if (thread_is_used) {
vi_0 = vi.data[tr_index * vi_sV + 0 * vi_sF];
vi_1 = vi.data[tr_index * vi_sV + 1 * vi_sF];
vi_2 = vi.data[tr_index * vi_sV + 2 * vi_sF];
const int32_t* __restrict vi_ptr = vi.data + n * vi_sN + tr_index * vi_sV;
vi_0 = vi_ptr[0 * vi_sF];
vi_1 = vi_ptr[1 * vi_sF];
vi_2 = vi_ptr[2 * vi_sF];
}
unsigned m = 0xFFFFFFFFU;
int vi_0_head = (__shfl_up_sync(m, vi_0, 1) != vi_0) || (lane == 0);
......@@ -292,9 +295,9 @@ torch::Tensor interpolate_cuda(
index_img.layout() == torch::kStrided && bary_img.layout() == torch::kStrided,
"interpolate(): expected all inputs to have torch.strided layout");
TORCH_CHECK(
(vert_attributes.dim() == 3) && (vi.dim() == 2) && (index_img.dim() == 3) &&
(vert_attributes.dim() == 3) && (vi.dim() == 3) && (index_img.dim() == 3) &&
(bary_img.dim() == 4),
"interpolate(): expected vert_attributes.ndim == 3, vi.ndim == 2, index_img.ndim == 3, bary_img.ndim == 4, "
"interpolate(): expected vert_attributes.ndim == 3, vi.ndim == 3, index_img.ndim == 3, bary_img.ndim == 4, "
"but got vert_attributes with sizes ",
vert_attributes.sizes(),
" and vi with sizes ",
......@@ -313,12 +316,19 @@ torch::Tensor interpolate_cuda(
" and bary_img with sizes ",
bary_img.sizes());
TORCH_CHECK(
vi.size(1) == 3 && bary_img.size(1) == 3,
"interpolate(): expected second dim of vi to be of size 3, and second dim of bary_img to be of size 3, but got ",
vi.size(1),
" in the second dim of vi, and ",
vi.size(2) == 3 && bary_img.size(1) == 3,
"interpolate(): expected last dim of vi to be of size 3, and second dim of bary_img to be of size 3, but got ",
vi.size(2),
" in the last dim of vi, and ",
bary_img.size(1),
" in the second dim of bary_img");
TORCH_CHECK(
vi.size(0) == vert_attributes.size(0),
"interpolate(): expected vi to have same first dimension as vert_atrributes, but got ",
vi.size(0),
" in the first dim of vi, and ",
vert_attributes.size(0),
" in the first dim of vert_attributes");
TORCH_CHECK(
index_img.size(1) == bary_img.size(2) && index_img.size(2) == bary_img.size(3),
"interpolate(): expected H and W dims of index_img and bary_img to match");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment