Commit b0810efa authored by Stanislav Pidhorskyi's avatar Stanislav Pidhorskyi Committed by Facebook GitHub Bot
Browse files

grid_scatter

Summary:
Adds `grid_scatter` op that is similar to `grid_sample` but the grid points to the destination location instead of the source.

`grid_scatter` is indeed dual to `grid_sample`. Forward of `grid_scatter` is backward of `grid_sample` and backward of  `grid_scatter` is forward of `grid_sample` (with the exception for the gradient with respect to grid) which is reflected in the reference implementation in `drtk/grid_scatter.py`.

```python
def grid_scatter(
    input: th.Tensor,
    grid: th.Tensor,
    output_height: int,
    output_width: int,
    mode: str = "bilinear",
    padding_mode: str = "border",
    align_corners: Optional[bool] = None,
) -> th.Tensor:
```

Where :
* `input` [N x C x H x W]: is the input tensor values from which will be transferred to the result.
* `grid` [N x H x W x 2]: is the grid tensor that points to the location where the values from the  input tensor should be copied to. The `W`, `H` sizes of grid should match the corresponding sizes of the `input` tensor.
*  `output_height`, `output_width`: size of the output, where output will be: [N x C x `output_height` x `output_width`]. In contrast to `grid_sample`, we can no longer rely on the sizes of the `grid` for this information.
* `mode`, `padding_mode`, `align_corners` same as for the `grid_sample`, but now for the reverse operation - splatting (or scattering).

At the moment does not support "nearest" mode, which is rarely needed. Maybe will add later.

Ideally, we would also want to support autocast mode where the `input` and output tensors are float16 while the `grid` is float32. This is not the case at the moment, but I'll add that later.

## Example usage

Let's assume that we loaded mesh into `v, vi, vt, vti`, have defined `image_width, image_height`, `cam_pos`, `cam_rot`, `focal`, `princpt`, and computed normals for the mesh `normals`. We also define a shading function, e.g.:

```lang=python
def shade(
    vn_img: th.Tensor,
    light_dir: th.Tensor,
    ambient_intensity: float = 1.0,
    direct_intensity: float = 1.0,
    shadow_img: Optional[th.Tensor] = None,
):
    ambient = (vn_img[:, 1:2] * 0.5 + 0.5) * th.as_tensor([0.45, 0.5, 0.7]).cuda()[
        None, :, None, None
    ]
    direct = (
        th.sum(vn_img.mul(thf.normalize(light_dir, dim=1)), dim=1, keepdim=True).clamp(
            min=0.0
        )
        * th.as_tensor([0.65, 0.6, 0.5]).cuda()[None, :, None, None]
    )
    if shadow_img is not None:
        direct = direct * shadow_img
    return th.pow(ambient * ambient_intensity + direct * direct_intensity, 1 / 2.2)
```

And we can render the image as:

```lang=python
v_pix = transform(v, cam_pos, cam_rot, focal, princpt)
index_img = rasterize(v_pix, vi, image_height, image_width)
_, bary_img = render(v_pix, vi, index_img)

# mask image
mask: th.Tensor = (index_img != -1)[:, None]

# compute vt image
vt_img = interpolate(vt.mul(2.0).sub(1.0)[None], vti, index_img, bary_img)

# compute normals
vn_img = interpolate(normals, vi, index_img, bary_img)

diffuse = (
    shade(vn_img, th.as_tensor([0.5, 0.5, 0.0]).cuda()[None, :, None, None]) * mask
)
```

 {F1801805545}

## Shadow mapping

We can use  `grid_scatter` to compute mesh visibility from the camera view:

```lang=python
texel_weight = grid_scatter(
    mask.float(),
    vt_img.permute(0, 2, 3, 1),
    output_width=512,
    output_height=512,
    mode="bilinear",
    padding_mode="border",
    align_corners=False,
)
threshold = 0.1  # texel_weight is proportional to how much pixel are the texel covers. We can specify a threshold of how much covered pixel area counts as visible.
visibility = (texel_weight > threshold).float()
```
 {F1801810094}

Now we can render the scene from different angle and use the visibility mask for shadows:

```lang=python
v_pix = transform(v, cam_pos_new, cam_rot_new, focal, princpt)
index_img = rasterize(v_pix, vi, image_height, image_width)
_, bary_img = render(v_pix, vi, index_img)

# mask image
mask: th.Tensor = (index_img != -1)[:, None]

# compute vt image
vt_img = interpolate(vt.mul(2.0).sub(1.0)[None], vti, index_img, bary_img)

# compute v image (for near-field)
v_img = interpolate(v, vi, index_img, bary_img)

# shadow
shadow_img = thf.grid_sample(visibility, vt_img.permute(0, 2, 3, 1), mode="bilinear", padding_mode="border", align_corners=False)

# compute normals
vn_img = interpolate(normals, vi, index_img, bary_img)

diffuse = shade(vn_img, cam_pos[:, :, None, None] - v_img, 0.05, 0.4, shadow_img) * mask
```
 {F1801811232}

## Texture projection

Let's load a test image:

```lang=python
import skimage
test_image = (
    th.as_tensor(skimage.data.coffee(), dtype=th.float32).permute(2, 0, 1)[None, ...].mul(1 / 255).contiguous().cuda()
)

test_image = thf.interpolate(test_image, scale_factor=2.0, mode="bilinear", align_corners=False)
```

{F1801814094}

We can use `grid_scatter` to project the image onto the uv space:

```lang=python
camera_image_extended = (
    th.cat([test_image, th.ones_like(test_image[:, :1])], dim=1) * mask
)

texture_weight = grid_scatter(
    camera_image_extended,
    vt_img.permute(0, 2, 3, 1),
    output_width=512,
    output_height=512,
    mode="bilinear",
    padding_mode="border",
    align_corners=False,
)

texture = texture_weight[:, :3] / texture_weight[:, 3:4].clamp(min=1e-4)
```

{F1801816367}

And if we render the scene from a different angle using the projected texture:

 {F1801817130}

Reviewed By: HapeMask

Differential Revision: D61006613

fbshipit-source-id: 98c83ba4eda531e9d73cb9e533176286dc699f63
parent 3c37fcb2
......@@ -6,6 +6,7 @@
from . import utils # noqa # noqa
from .edge_grad_estimator import edge_grad_estimator, edge_grad_estimator_ref # noqa
from .grid_scatter import grid_scatter, grid_scatter_ref # noqa
from .interpolate import interpolate, interpolate_ref # noqa
from .mipmap_grid_sample import mipmap_grid_sample, mipmap_grid_sample_ref # noqa
from .msi import msi # noqa
......
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch as th
import torch.nn.functional as thf
from drtk import grid_scatter_ext
th.ops.load_library(grid_scatter_ext.__file__)
@th.compiler.disable
def grid_scatter(
input: th.Tensor,
grid: th.Tensor,
output_height: int,
output_width: int,
mode: str = "bilinear",
padding_mode: str = "border",
align_corners: Optional[bool] = None,
) -> th.Tensor:
if mode != "bilinear" and mode != "bicubic":
raise ValueError(
"grid_scatter(): only 'bilinear' and 'bicubic' modes are supported "
"but got: '{}'".format(mode)
)
if (
padding_mode != "zeros"
and padding_mode != "border"
and padding_mode != "reflection"
):
raise ValueError(
"grid_scatter(): expected padding_mode "
"to be 'zeros', 'border', or 'reflection', "
"but got: '{}'".format(padding_mode)
)
if mode == "bilinear":
mode_enum = 0
elif mode == "nearest": # not supported
mode_enum = 1
else: # mode == 'bicubic'
mode_enum = 2
if padding_mode == "zeros":
padding_mode_enum = 0
elif padding_mode == "border":
padding_mode_enum = 1
else: # padding_mode == 'reflection'
padding_mode_enum = 2
if align_corners is None:
align_corners = False
return th.ops.grid_scatter_ext.grid_scatter_2d(
input,
grid,
output_height,
output_width,
padding_mode_enum,
mode_enum,
align_corners,
)
class GridScatterRef(th.autograd.Function):
@staticmethod
def forward(
ctx,
input: th.Tensor,
grid: th.Tensor,
output_height: int,
output_width: int,
mode: str = "bilinear",
padding_mode: str = "border",
align_corners: Optional[bool] = None,
):
with th.enable_grad():
tex = th.ones(
input.shape[0],
input.shape[1],
output_height,
output_width,
dtype=input.dtype,
device=input.device,
)
tex.requires_grad_(True)
out = thf.grid_sample(
tex,
grid,
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners,
)
out.backward(input)
ctx.save_for_backward(input, grid, out)
ctx.mode = mode
ctx.padding_mode = padding_mode
ctx.align_corners = align_corners
return tex.grad
@staticmethod
def backward(ctx, grad_output: th.Tensor):
input, grid, out = ctx.saved_tensors
grid = grid.clone().detach()
grid.requires_grad_(True)
with th.enable_grad():
input_grad = thf.grid_sample(
grad_output,
grid,
mode=ctx.mode,
padding_mode=ctx.padding_mode,
align_corners=ctx.align_corners,
)
input_grad.backward(input)
return input_grad, grid.grad, None, None, None, None, None
_grid_scatter_ref = GridScatterRef.apply
@th.compiler.disable
def grid_scatter_ref(
input: th.Tensor,
grid: th.Tensor,
output_height: int,
output_width: int,
mode: str = "bilinear",
padding_mode: str = "border",
align_corners: Optional[bool] = None,
) -> th.Tensor:
return _grid_scatter_ref(
input,
grid,
output_height,
output_width,
mode,
padding_mode,
align_corners,
)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
from torch import Tensor
def grid_scatter_2d(
input: List[Tensor],
grid: Tensor,
output_height: int,
output_width: int,
padding_mode: int,
interpolation_mode: int,
align_corners: bool,
) -> Tensor: ...
......@@ -151,6 +151,15 @@ def main(debug: bool) -> None:
extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
include_dirs=include_dir,
),
CUDAExtension(
"drtk.grid_scatter_ext",
sources=[
"src/grid_scatter/grid_scatter_module.cpp",
"src/grid_scatter/grid_scatter_kernel.cu",
],
extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
include_dirs=include_dir,
),
],
cmdclass={"build_ext": BuildExtension},
packages=["drtk", "drtk.utils"],
......
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#include <c10/cuda/CUDAGuard.h>
#include <cuda_math_helper.h>
#include <torch/types.h>
#include <grid_utils.h>
#include <kernel_utils.h>
#include <tensor_list.h>
using namespace math;
constexpr int tex_ndim = 4;
template <typename scalar_t, typename index_t>
__device__ void scatter_bilinear(
const TensorInfoCompact<scalar_t, index_t, tex_ndim>& input,
const TensorInfoCompact<scalar_t, index_t, tex_ndim>& output,
scalar_t x,
scalar_t y,
const index_t w,
const index_t h,
const index_t n,
const index_t C,
const GridSamplerPadding padding_mode,
bool align_corners,
index_t output_memory_span) {
index_t input_sN = input.strides[0];
index_t input_sC = input.strides[1];
index_t input_sH = input.strides[2];
index_t input_sW = input.strides[3];
index_t output_sN = output.strides[0];
index_t output_sC = output.strides[1];
index_t output_sH = output.strides[2];
index_t output_sW = output.strides[3];
index_t output_H = output.sizes[2];
index_t output_W = output.sizes[3];
scalar_t ix = grid_sampler_compute_source_index(x, output_W, padding_mode, align_corners);
scalar_t iy = grid_sampler_compute_source_index(y, output_H, padding_mode, align_corners);
// get NE, NW, SE, SW pixel values from (x, y)
index_t ix_nw = static_cast<index_t>(::floor(ix));
index_t iy_nw = static_cast<index_t>(::floor(iy));
index_t ix_ne = ix_nw + 1;
index_t iy_ne = iy_nw;
index_t ix_sw = ix_nw;
index_t iy_sw = iy_nw + 1;
index_t ix_se = ix_nw + 1;
index_t iy_se = iy_nw + 1;
// get surfaces to each neighbor:
scalar_t nw = (ix_se - ix) * (iy_se - iy);
scalar_t ne = (ix - ix_sw) * (iy_sw - iy);
scalar_t sw = (ix_ne - ix) * (iy - iy_ne);
scalar_t se = (ix - ix_nw) * (iy - iy_nw);
const scalar_t* input_ptr_NCHW = input.data + n * input_sN + h * input_sH + w * input_sW;
index_t NC_offset = n * output_sN;
for (index_t c = 0; c < C; ++c, NC_offset += output_sC) {
scalar_t input_value = *(input_ptr_NCHW + c * input_sC);
// calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd].
safe_add_2d(
output.data,
iy_nw,
ix_nw,
output_sH,
output_sW,
output_H,
output_W,
nw * input_value,
NC_offset,
output_memory_span);
safe_add_2d(
output.data,
iy_ne,
ix_ne,
output_sH,
output_sW,
output_H,
output_W,
ne * input_value,
NC_offset,
output_memory_span);
safe_add_2d(
output.data,
iy_sw,
ix_sw,
output_sH,
output_sW,
output_H,
output_W,
sw * input_value,
NC_offset,
output_memory_span);
safe_add_2d(
output.data,
iy_se,
ix_se,
output_sH,
output_sW,
output_H,
output_W,
se * input_value,
NC_offset,
output_memory_span);
}
}
template <typename scalar_t, typename index_t>
__device__ void scatter_bicubic(
const TensorInfoCompact<scalar_t, index_t, tex_ndim>& input,
const TensorInfoCompact<scalar_t, index_t, tex_ndim>& output,
scalar_t x,
scalar_t y,
const index_t w,
const index_t h,
const index_t n,
const index_t C,
const GridSamplerPadding padding_mode,
bool align_corners,
index_t output_memory_span) {
index_t input_sN = input.strides[0];
index_t input_sC = input.strides[1];
index_t input_sH = input.strides[2];
index_t input_sW = input.strides[3];
index_t output_sN = output.strides[0];
index_t output_sC = output.strides[1];
index_t output_sH = output.strides[2];
index_t output_sW = output.strides[3];
index_t output_H = output.sizes[2];
index_t output_W = output.sizes[3];
scalar_t ix = grid_sampler_compute_source_index(x, output_W, padding_mode, align_corners);
scalar_t iy = grid_sampler_compute_source_index(y, output_H, padding_mode, align_corners);
scalar_t ix_nw = static_cast<index_t>(::floor(ix));
scalar_t iy_nw = static_cast<index_t>(::floor(iy));
const scalar_t tx = ix - ix_nw;
const scalar_t ty = iy - iy_nw;
scalar_t x_coeffs[4];
scalar_t y_coeffs[4];
get_cubic_upsampling_coefficients<scalar_t>(x_coeffs, tx);
get_cubic_upsampling_coefficients<scalar_t>(y_coeffs, ty);
const scalar_t* input_ptr_NCHW = input.data + n * input_sN + h * input_sH + w * input_sW;
index_t NC_offset = n * output_sN;
for (index_t c = 0; c < C; ++c, NC_offset += output_sC) {
scalar_t input_value = *(input_ptr_NCHW + c * input_sC);
#pragma unroll 4
for (index_t i = 0; i < 4; ++i) {
#pragma unroll 4
for (index_t j = 0; j < 4; ++j) {
// set input gradient. See Note [Passing pointer and offset to fastAtomicAdd].
add_value_bounded<scalar_t>(
output.data,
ix_nw - 1 + i,
iy_nw - 1 + j,
output_W,
output_H,
output_sW,
output_sH,
input_value * x_coeffs[i] * y_coeffs[j],
padding_mode,
align_corners,
NC_offset,
output_memory_span);
}
}
}
}
template <typename scalar_t, typename index_t, bool grid_requires_grad, bool input_requires_grad>
__device__ TVec2<scalar_t> scatter_bilinear_backward(
const TensorInfoCompact<scalar_t, index_t, tex_ndim>& input,
const TensorInfoCompact<scalar_t, index_t, tex_ndim>& grad_input,
const TensorInfoCompact<scalar_t, index_t, tex_ndim>& grad_output,
scalar_t x,
scalar_t y,
const index_t w,
const index_t h,
const index_t n,
const index_t C,
const GridSamplerPadding padding_mode,
bool align_corners) {
index_t input_sN = input.strides[0];
index_t input_sC = input.strides[1];
index_t input_sH = input.strides[2];
index_t input_sW = input.strides[3];
index_t grad_input_sN = grad_input.strides[0];
index_t grad_input_sC = grad_input.strides[1];
index_t grad_input_sH = grad_input.strides[2];
index_t grad_input_sW = grad_input.strides[3];
index_t out_H = grad_output.sizes[2];
index_t out_W = grad_output.sizes[3];
index_t grad_output_sN = grad_output.strides[0];
index_t grad_output_sC = grad_output.strides[1];
index_t grad_output_sH = grad_output.strides[2];
index_t grad_output_sW = grad_output.strides[3];
// multipliers for gradients on ix and iy
TVec2<scalar_t> gi_mult;
scalar_t ix =
grid_sampler_compute_source_index_set_grad(x, out_W, padding_mode, align_corners, &gi_mult.x);
scalar_t iy =
grid_sampler_compute_source_index_set_grad(y, out_H, padding_mode, align_corners, &gi_mult.y);
// get NE, NW, SE, SW pixel values from (x, y)
index_t ix_nw = static_cast<index_t>(::floor(ix));
index_t iy_nw = static_cast<index_t>(::floor(iy));
index_t ix_ne = ix_nw + 1;
index_t iy_ne = iy_nw;
index_t ix_sw = ix_nw;
index_t iy_sw = iy_nw + 1;
index_t ix_se = ix_nw + 1;
index_t iy_se = iy_nw + 1;
// get surfaces to each neighbor:
scalar_t nw = (ix_se - ix) * (iy_se - iy);
scalar_t ne = (ix - ix_sw) * (iy_sw - iy);
scalar_t sw = (ix_ne - ix) * (iy - iy_ne);
scalar_t se = (ix - ix_nw) * (iy - iy_nw);
const scalar_t* input_ptr_NCHW = input.data + n * input_sN + h * input_sH + w * input_sW;
TVec2<scalar_t> gi = {scalar_t(0), scalar_t(0)};
auto grad_output_ptr_NC = grad_output.data + n * grad_output_sN;
auto grad_input_ptr_NCHW =
grad_input.data + n * grad_input_sN + h * grad_input_sH + w * grad_input_sW;
for (index_t c = 0; c < C;
++c, grad_output_ptr_NC += grad_output_sC, grad_input_ptr_NCHW += grad_input_sC) {
if (input_requires_grad) {
auto g_input = scalar_t(0.0);
if (within_bounds_2d(iy_nw, ix_nw, out_H, out_W)) {
g_input += grad_output_ptr_NC[iy_nw * grad_output_sH + ix_nw * grad_output_sW] * nw;
}
if (within_bounds_2d(iy_ne, ix_ne, out_H, out_W)) {
g_input += grad_output_ptr_NC[iy_ne * grad_output_sH + ix_ne * grad_output_sW] * ne;
}
if (within_bounds_2d(iy_sw, ix_sw, out_H, out_W)) {
g_input += grad_output_ptr_NC[iy_sw * grad_output_sH + ix_sw * grad_output_sW] * sw;
}
if (within_bounds_2d(iy_se, ix_se, out_H, out_W)) {
g_input += grad_output_ptr_NC[iy_se * grad_output_sH + ix_se * grad_output_sW] * se;
}
*grad_input_ptr_NCHW = g_input;
}
if (grid_requires_grad) {
// calculate grad_grid
scalar_t input_value = *(input_ptr_NCHW + c * input_sC);
if (within_bounds_2d(iy_nw, ix_nw, out_H, out_W)) {
scalar_t gOut = grad_output_ptr_NC[iy_nw * grad_output_sH + ix_nw * grad_output_sW];
gi.x -= input_value * (iy_se - iy) * gOut;
gi.y -= input_value * (ix_se - ix) * gOut;
}
if (within_bounds_2d(iy_ne, ix_ne, out_H, out_W)) {
scalar_t gOut = grad_output_ptr_NC[iy_ne * grad_output_sH + ix_ne * grad_output_sW];
gi.x += input_value * (iy_sw - iy) * gOut;
gi.y -= input_value * (ix - ix_sw) * gOut;
}
if (within_bounds_2d(iy_sw, ix_sw, out_H, out_W)) {
scalar_t gOut = grad_output_ptr_NC[iy_sw * grad_output_sH + ix_sw * grad_output_sW];
gi.x -= input_value * (iy - iy_ne) * gOut;
gi.y += input_value * (ix_ne - ix) * gOut;
}
if (within_bounds_2d(iy_se, ix_se, out_H, out_W)) {
scalar_t gOut = grad_output_ptr_NC[iy_se * grad_output_sH + ix_se * grad_output_sW];
gi.x += input_value * (iy - iy_nw) * gOut;
gi.y += input_value * (ix - ix_nw) * gOut;
}
}
}
return gi_mult * gi;
}
template <typename scalar_t, typename index_t, bool grid_requires_grad, bool input_requires_grad>
__device__ TVec2<scalar_t> scatter_bicubic_backward(
const TensorInfoCompact<scalar_t, index_t, tex_ndim>& input,
const TensorInfoCompact<scalar_t, index_t, tex_ndim>& grad_input,
const TensorInfoCompact<scalar_t, index_t, tex_ndim>& grad_output,
scalar_t x,
scalar_t y,
const index_t w,
const index_t h,
const index_t n,
const index_t C,
const GridSamplerPadding padding_mode,
bool align_corners) {
index_t input_sN = input.strides[0];
index_t input_sC = input.strides[1];
index_t input_sH = input.strides[2];
index_t input_sW = input.strides[3];
index_t grad_input_sN = grad_input.strides[0];
index_t grad_input_sC = grad_input.strides[1];
index_t grad_input_sH = grad_input.strides[2];
index_t grad_input_sW = grad_input.strides[3];
index_t out_H = grad_output.sizes[2];
index_t out_W = grad_output.sizes[3];
index_t grad_output_sN = grad_output.strides[0];
index_t grad_output_sC = grad_output.strides[1];
index_t grad_output_sH = grad_output.strides[2];
index_t grad_output_sW = grad_output.strides[3];
// multipliers for gradients on ix and iy
TVec2<scalar_t> gi_mult;
scalar_t ix =
grid_sampler_compute_source_index_set_grad(x, out_W, padding_mode, align_corners, &gi_mult.x);
scalar_t iy =
grid_sampler_compute_source_index_set_grad(y, out_H, padding_mode, align_corners, &gi_mult.y);
// get NE, NW, SE, SW pixel values from (x, y)
scalar_t ix_nw = ::floor(ix);
scalar_t iy_nw = ::floor(iy);
const scalar_t tx = ix - ix_nw;
const scalar_t ty = iy - iy_nw;
scalar_t x_coeffs[4];
scalar_t y_coeffs[4];
scalar_t x_coeffs_grad[4];
scalar_t y_coeffs_grad[4];
get_cubic_upsampling_coefficients<scalar_t>(x_coeffs, tx);
get_cubic_upsampling_coefficients<scalar_t>(y_coeffs, ty);
get_cubic_coefficients_grad<scalar_t>(x_coeffs_grad, tx);
get_cubic_coefficients_grad<scalar_t>(y_coeffs_grad, ty);
const scalar_t* input_ptr_NCHW = input.data + n * input_sN + h * input_sH + w * input_sW;
TVec2<scalar_t> gi = {scalar_t(0), scalar_t(0)};
auto grad_output_ptr_NC = grad_output.data + n * grad_output_sN;
auto grad_input_ptr_NCHW =
grad_input.data + n * grad_input_sN + h * grad_input_sH + w * grad_input_sW;
for (index_t c = 0; c < C;
++c, grad_output_ptr_NC += grad_output_sC, grad_input_ptr_NCHW += grad_input_sC) {
scalar_t coefficients[4];
scalar_t input_value = *(input_ptr_NCHW + c * input_sC);
#pragma unroll 4
for (index_t i = 0; i < 4; ++i) {
if (input_requires_grad) {
coefficients[i] = cubic_interp1d(
get_value_bounded<scalar_t>(
grad_output_ptr_NC,
ix_nw - 1,
iy_nw - 1 + i,
out_W,
out_H,
grad_output_sW,
grad_output_sH,
padding_mode,
align_corners),
get_value_bounded<scalar_t>(
grad_output_ptr_NC,
ix_nw + 0,
iy_nw - 1 + i,
out_W,
out_H,
grad_output_sW,
grad_output_sH,
padding_mode,
align_corners),
get_value_bounded<scalar_t>(
grad_output_ptr_NC,
ix_nw + 1,
iy_nw - 1 + i,
out_W,
out_H,
grad_output_sW,
grad_output_sH,
padding_mode,
align_corners),
get_value_bounded<scalar_t>(
grad_output_ptr_NC,
ix_nw + 2,
iy_nw - 1 + i,
out_W,
out_H,
grad_output_sW,
grad_output_sH,
padding_mode,
align_corners),
tx);
*grad_input_ptr_NCHW =
cubic_interp1d(coefficients[0], coefficients[1], coefficients[2], coefficients[3], ty);
}
if (grid_requires_grad) {
#pragma unroll 4
for (index_t j = 0; j < 4; ++j) {
// set grid gradient
scalar_t gOut = get_value_bounded<scalar_t>(
grad_output_ptr_NC,
ix_nw - 1 + i,
iy_nw - 1 + j,
out_W,
out_H,
grad_output_sW,
grad_output_sH,
padding_mode,
align_corners);
gi -= gOut * input_value *
TVec2<scalar_t>({x_coeffs_grad[i] * y_coeffs[j], y_coeffs_grad[j] * x_coeffs[i]});
}
}
}
}
return gi_mult * gi;
}
template <typename scalar_t, typename index_t, GridSamplerInterpolation interpolation_mode>
C10_LAUNCH_BOUNDS_1(256)
__global__ void grid_scatter_2d_kernel(
const index_t nthreads,
TensorInfoCompact<scalar_t, index_t, tex_ndim> input,
TensorInfoCompact<scalar_t, index_t, tex_ndim> grid,
TensorInfoCompact<scalar_t, index_t, tex_ndim> output,
const GridSamplerPadding padding_mode,
bool align_corners,
index_t output_memory_span) {
index_t C = output.sizes[1];
index_t inp_H = input.sizes[2];
index_t inp_W = input.sizes[3];
index_t grid_sN = grid.strides[0];
index_t grid_sH = grid.strides[1];
index_t grid_sW = grid.strides[2];
index_t grid_sCoor = grid.strides[3];
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
const index_t w = index % inp_W;
const index_t h = (index / inp_W) % inp_H;
const index_t n = index / (inp_H * inp_W);
const index_t grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
// get the corresponding input x, y co-ordinates from grid
scalar_t u = grid.data[grid_offset];
scalar_t v = grid.data[grid_offset + grid_sCoor];
if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
scatter_bilinear(
input, output, u, v, w, h, n, C, padding_mode, align_corners, output_memory_span);
} else if (interpolation_mode == GridSamplerInterpolation::Bicubic) {
scatter_bicubic(
input, output, u, v, w, h, n, C, padding_mode, align_corners, output_memory_span);
}
}
}
template <
typename scalar_t,
typename index_t,
GridSamplerInterpolation interpolation_mode,
bool grid_requires_grad,
bool input_requires_grad>
C10_LAUNCH_BOUNDS_1(256)
__global__ void grid_scatter_2d_backward_kernel(
const index_t nthreads,
TensorInfoCompact<scalar_t, index_t, tex_ndim> grad_output,
TensorInfoCompact<scalar_t, index_t, tex_ndim> input,
TensorInfoCompact<scalar_t, index_t, tex_ndim> grid,
TensorInfoCompact<scalar_t, index_t, tex_ndim> grad_input,
TensorInfoCompact<scalar_t, index_t, tex_ndim> grad_grid, // initialized to empty
const GridSamplerPadding padding_mode,
bool align_corners) {
index_t C = input.sizes[1];
index_t inp_H = input.sizes[2];
index_t inp_W = input.sizes[3];
index_t grid_sN = grid.strides[0];
index_t grid_sH = grid.strides[1];
index_t grid_sW = grid.strides[2];
index_t grid_sCoor = grid.strides[3];
index_t gGrid_sW = grad_grid.strides[2];
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
const index_t w = index % inp_W;
const index_t h = (index / inp_W) % inp_H;
const index_t n = index / (inp_H * inp_W);
const auto grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
// get the corresponding input x, y co-ordinates from grid
scalar_t u = grid.data[grid_offset];
scalar_t v = grid.data[grid_offset + grid_sCoor];
scalar_t* gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW;
if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
auto ggrad =
scatter_bilinear_backward<scalar_t, index_t, grid_requires_grad, input_requires_grad>(
input, grad_input, grad_output, u, v, w, h, n, C, padding_mode, align_corners);
if (grid_requires_grad) {
gGrid_ptr_NHW[0] = ggrad.x;
gGrid_ptr_NHW[1] = ggrad.y;
}
} else if (interpolation_mode == GridSamplerInterpolation::Bicubic) {
auto ggrad =
scatter_bicubic_backward<scalar_t, index_t, grid_requires_grad, input_requires_grad>(
input, grad_input, grad_output, u, v, w, h, n, C, padding_mode, align_corners);
if (grid_requires_grad) {
gGrid_ptr_NHW[0] = ggrad.x;
gGrid_ptr_NHW[1] = ggrad.y;
}
}
}
}
template <typename scalar_t, typename index_t>
__host__ void grid_scatter_2d_dispatch_interpolation_type(
const index_t nthreads,
TensorInfoCompact<scalar_t, index_t, tex_ndim> input,
TensorInfoCompact<scalar_t, index_t, tex_ndim> grid,
TensorInfoCompact<scalar_t, index_t, tex_ndim> output,
const GridSamplerPadding padding_mode,
bool align_corners,
GridSamplerInterpolation interpolation_mode,
index_t output_memory_span) {
if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
grid_scatter_2d_kernel<scalar_t, index_t, GridSamplerInterpolation::Bilinear>
<<<GET_BLOCKS(nthreads, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
nthreads, input, grid, output, padding_mode, align_corners, output_memory_span);
} else if (interpolation_mode == GridSamplerInterpolation::Bicubic) {
grid_scatter_2d_kernel<scalar_t, index_t, GridSamplerInterpolation::Bicubic>
<<<GET_BLOCKS(nthreads, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
nthreads, input, grid, output, padding_mode, align_corners, output_memory_span);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template <typename scalar_t, typename index_t, bool grid_requires_grad, bool input_requires_grad>
__host__ void grid_scatter_2d_backward_dispatch_interpolation_type(
const index_t nthreads,
TensorInfoCompact<scalar_t, index_t, tex_ndim> grad_output,
TensorInfoCompact<scalar_t, index_t, tex_ndim> input,
TensorInfoCompact<scalar_t, index_t, tex_ndim> grid,
TensorInfoCompact<scalar_t, index_t, tex_ndim> grad_input,
TensorInfoCompact<scalar_t, index_t, tex_ndim> grad_grid, // initialized to empty
const GridSamplerPadding padding_mode,
bool align_corners,
GridSamplerInterpolation interpolation_mode) {
if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
grid_scatter_2d_backward_kernel<
scalar_t,
index_t,
GridSamplerInterpolation::Bilinear,
grid_requires_grad,
input_requires_grad>
<<<GET_BLOCKS(nthreads, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
nthreads, grad_output, input, grid, grad_input, grad_grid, padding_mode, align_corners);
} else if (interpolation_mode == GridSamplerInterpolation::Bicubic) {
grid_scatter_2d_backward_kernel<
scalar_t,
index_t,
GridSamplerInterpolation::Bicubic,
grid_requires_grad,
input_requires_grad>
<<<GET_BLOCKS(nthreads, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
nthreads, grad_output, input, grid, grad_input, grad_grid, padding_mode, align_corners);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template <typename scalar_t, typename index_t>
__host__ void grid_scatter_2d_backward_dispatch_requires_grad(
const index_t nthreads,
TensorInfoCompact<scalar_t, index_t, tex_ndim> grad_output,
TensorInfoCompact<scalar_t, index_t, tex_ndim> input,
TensorInfoCompact<scalar_t, index_t, tex_ndim> grid,
TensorInfoCompact<scalar_t, index_t, tex_ndim> grad_input,
TensorInfoCompact<scalar_t, index_t, tex_ndim> grad_grid, // initialized to empty
const GridSamplerPadding padding_mode,
bool align_corners,
GridSamplerInterpolation interpolation_mode,
bool grid_requires_grad,
bool input_requires_grad) {
if (grid_requires_grad && input_requires_grad) {
grid_scatter_2d_backward_dispatch_interpolation_type<scalar_t, index_t, true, true>(
nthreads,
grad_output,
input,
grid,
grad_input,
grad_grid,
padding_mode,
align_corners,
interpolation_mode);
} else if (!grid_requires_grad && input_requires_grad) {
grid_scatter_2d_backward_dispatch_interpolation_type<scalar_t, index_t, false, true>(
nthreads,
grad_output,
input,
grid,
grad_input,
grad_grid,
padding_mode,
align_corners,
interpolation_mode);
} else if (grid_requires_grad && !input_requires_grad) {
grid_scatter_2d_backward_dispatch_interpolation_type<scalar_t, index_t, true, false>(
nthreads,
grad_output,
input,
grid,
grad_input,
grad_grid,
padding_mode,
align_corners,
interpolation_mode);
}
}
__host__ torch::Tensor grid_scatter_2d_cuda(
const torch::Tensor& input,
const torch::Tensor& grid,
int64_t output_height,
int64_t output_width,
int64_t padding_mode,
int64_t interpolation_mode,
bool align_corners) {
TORCH_CHECK(
input.defined() && grid.defined(),
"grid_scatter(): expected input and grid to not be undefined, but input is ",
input,
" and grid is ",
grid);
auto input_opt = input.options();
auto grid_opt = grid.options();
TORCH_CHECK(
output_height > 0 && output_width > 0,
"grid_scatter(): expected output_height and output_width to be greater than 0, but output_height is ",
output_height,
" and output_width is ",
output_width);
TORCH_CHECK(
input_opt.device() == grid_opt.device() && grid_opt.device().is_cuda(),
"grid_scatter(): expected input and grid to be on same CUDA device, but input is on ",
input_opt.device(),
" and grid is on ",
grid_opt.device());
TORCH_CHECK(
input.is_floating_point() && grid.is_floating_point(),
"grid_scatter(): expected input and grid to have floating point dtype, but input has ",
input_opt.dtype(),
" and grid has ",
grid_opt.dtype());
TORCH_CHECK(
input_opt.layout() == torch::kStrided && grid_opt.layout() == torch::kStrided,
"grid_scatter(): expected input and grid to have torch.strided layout, but "
"input has ",
input_opt.layout(),
" and grid has ",
grid_opt.layout());
TORCH_CHECK(
(input.dim() == 4) && input.dim() == grid.dim(),
"grid_scatter(): expected 4D input and grid with same number of "
"dimensions, but got input with sizes ",
input.sizes(),
" and grid with sizes ",
grid.sizes());
TORCH_CHECK(
input.size(0) == grid.size(0) && input.size(2) == grid.size(1) &&
input.size(3) == grid.size(2),
"grid_scatter(): expected grid and input to have same batch size, width and height"
"but got input with sizes ",
input.sizes(),
" and grid with sizes ",
grid.sizes());
TORCH_CHECK(
grid.size(-1) == input.dim() - 2,
"grid_scatter(): expected grid to have size ",
input[0].dim() - 2,
" in last dimension, but got grid with sizes ",
grid.sizes());
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
auto N = input.size(0);
auto C = input.size(1);
auto H = grid.size(1);
auto W = grid.size(2);
auto output = at::zeros({N, C, output_height, output_width}, input.options());
int64_t count = N * H * W;
if (count > 0) {
// Should be AT_DISPATCH_FLOATING_TYPES_AND_HALF, but half is broken on prod
AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "grid_scatter_2d_kernel", [&] {
if (at::native::canUse32BitIndexMath(input) && at::native::canUse32BitIndexMath(grid) &&
at::native::canUse32BitIndexMath(output)) {
typedef int index_type;
grid_scatter_2d_dispatch_interpolation_type<scalar_t, index_type>(
static_cast<index_type>(count),
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(input),
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(grid),
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(output),
static_cast<GridSamplerPadding>(padding_mode),
align_corners,
static_cast<GridSamplerInterpolation>(interpolation_mode),
static_cast<index_type>(output.numel()));
} else {
typedef int64_t index_type;
grid_scatter_2d_dispatch_interpolation_type<scalar_t, index_type>(
static_cast<index_type>(count),
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(input),
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(grid),
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(output),
static_cast<GridSamplerPadding>(padding_mode),
align_corners,
static_cast<GridSamplerInterpolation>(interpolation_mode),
static_cast<index_type>(output.numel()));
}
});
}
return output;
}
__host__ std::tuple<torch::Tensor, torch::Tensor> grid_scatter_2d_cuda_backward(
const torch::Tensor& grad_output,
const torch::Tensor& input,
const torch::Tensor& grid,
int64_t padding_mode,
int64_t interpolation_mode,
bool align_corners,
bool grid_requires_grad,
bool input_requires_grad) {
auto N = input.size(0);
auto H = grid.size(1);
auto W = grid.size(2);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
auto grad_input = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
int64_t count = N * H * W;
if (count > 0) {
// Should be AT_DISPATCH_FLOATING_TYPES_AND_HALF, but half is broken on prod
AT_DISPATCH_FLOATING_TYPES(input[0].scalar_type(), "grid_scatter_2d_backward_kernel", [&] {
if (at::native::canUse32BitIndexMath(input) && at::native::canUse32BitIndexMath(grid) &&
at::native::canUse32BitIndexMath(grad_output)) {
typedef int index_type;
grid_scatter_2d_backward_dispatch_requires_grad<scalar_t, index_type>(
static_cast<index_type>(count),
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(grad_output),
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(input),
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(grid),
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(grad_input),
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(grad_grid),
static_cast<GridSamplerPadding>(padding_mode),
align_corners,
static_cast<GridSamplerInterpolation>(interpolation_mode),
grid_requires_grad,
input_requires_grad);
} else {
typedef int64_t index_type;
grid_scatter_2d_backward_dispatch_requires_grad<scalar_t, index_type>(
static_cast<index_type>(count),
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(grad_output),
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(input),
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(grid),
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(grad_input),
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(grad_grid),
static_cast<GridSamplerPadding>(padding_mode),
align_corners,
static_cast<GridSamplerInterpolation>(interpolation_mode),
grid_requires_grad,
input_requires_grad);
}
});
}
return std::make_tuple(grad_input, grad_grid);
}
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <torch/torch.h>
torch::Tensor grid_scatter_2d_cuda(
const torch::Tensor& input,
const torch::Tensor& grid,
int64_t output_height,
int64_t output_width,
int64_t padding_mode,
int64_t interpolation_mode,
bool align_corners);
std::tuple<torch::Tensor, torch::Tensor> grid_scatter_2d_cuda_backward(
const torch::Tensor& grad_output,
const torch::Tensor& input,
const torch::Tensor& grid,
int64_t padding_mode,
int64_t interpolation_mode,
bool align_corners,
bool grid_requires_grad,
bool input_requires_grad);
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#include <torch/script.h>
#include <ATen/autocast_mode.h>
#ifndef NO_PYBIND
#include <torch/extension.h>
#endif
#include "grid_scatter_kernel.h"
// Dispatch function
torch::Tensor grid_scatter_2d(
const torch::Tensor& input,
const torch::Tensor& grid,
int64_t output_height,
int64_t output_width,
int64_t padding_mode,
int64_t interpolation_mode,
bool align_corners) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("grid_scatter_ext::grid_scatter_2d", "")
.typed<decltype(grid_scatter_2d)>();
return op.call(
input, grid, output_height, output_width, padding_mode, interpolation_mode, align_corners);
}
// Ideally we would need to turn off autograd handling and re-dispatch, but we just call
// cuda kernels directly
class GridScatter2DFunction : public torch::autograd::Function<GridScatter2DFunction> {
public:
static torch::autograd::tensor_list forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& input,
const torch::Tensor& grid,
int64_t output_height,
int64_t output_width,
int64_t padding_mode,
int64_t interpolation_mode,
bool align_corners) {
ctx->set_materialize_grads(false);
std::vector<torch::Tensor> save_list;
save_list.push_back(input);
save_list.push_back(grid);
ctx->save_for_backward(save_list);
bool grid_requires_grad = grid.requires_grad();
bool input_requires_grad = input.requires_grad();
ctx->saved_data["data"] = std::make_tuple(
grid_requires_grad, input_requires_grad, padding_mode, interpolation_mode, align_corners);
auto out = grid_scatter_2d_cuda(
input, grid, output_height, output_width, padding_mode, interpolation_mode, align_corners);
return {out};
}
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
bool grid_requires_grad;
bool input_requires_grad;
int64_t padding_mode;
int64_t interpolation_mode;
bool align_corners;
std::tie(
grid_requires_grad, input_requires_grad, padding_mode, interpolation_mode, align_corners) =
ctx->saved_data["data"].to<std::tuple<bool, bool, int64_t, int64_t, bool>>();
torch::autograd::tensor_list out;
if (!grid_requires_grad && !input_requires_grad) {
out.resize(7);
return out;
}
const auto saved = ctx->get_saved_variables();
const torch::Tensor& input = saved[0];
const torch::Tensor& grid = saved[1];
auto grad_out = grid_scatter_2d_cuda_backward(
grad_outputs[0],
input,
grid,
padding_mode,
interpolation_mode,
align_corners,
grid_requires_grad,
input_requires_grad);
out.push_back(std::get<0>(grad_out));
out.push_back(std::get<1>(grad_out));
out.emplace_back();
out.emplace_back();
out.emplace_back();
out.emplace_back();
out.emplace_back();
return out;
}
};
torch::Tensor grid_scatter_2d_autograd(
const torch::Tensor& input,
const torch::Tensor& grid,
int64_t output_height,
int64_t output_width,
int64_t padding_mode,
int64_t interpolation_mode,
bool align_corners) {
return GridScatter2DFunction::apply(
input, grid, output_height, output_width, padding_mode, interpolation_mode, align_corners)[0];
}
torch::Tensor grid_scatter_2d_autocast(
const torch::Tensor& input,
const torch::Tensor& grid,
int64_t output_height,
int64_t output_width,
int64_t padding_mode,
int64_t interpolation_mode,
bool align_corners) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return grid_scatter_2d(
at::autocast::cached_cast(torch::kFloat32, input),
at::autocast::cached_cast(torch::kFloat32, grid),
output_height,
output_width,
padding_mode,
interpolation_mode,
align_corners);
}
#ifndef NO_PYBIND
// Just so that we can import this file as a Python module to get the path and
// import the Torch ops.
PYBIND11_MODULE(grid_scatter_ext, m) {}
#endif
TORCH_LIBRARY(grid_scatter_ext, m) {
m.def(
"grid_scatter_2d(Tensor input, Tensor grid, int output_height, int output_width, int padding_mode, int interpolation_mode, bool align_corners) -> Tensor");
}
TORCH_LIBRARY_IMPL(grid_scatter_ext, Autograd, m) {
m.impl("grid_scatter_2d", &grid_scatter_2d_autograd);
}
TORCH_LIBRARY_IMPL(grid_scatter_ext, Autocast, m) {
m.impl("grid_scatter_2d", grid_scatter_2d_autocast);
}
TORCH_LIBRARY_IMPL(grid_scatter_ext, CUDA, m) {
m.impl("grid_scatter_2d", &grid_scatter_2d_cuda);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment