Commit b0810efa authored by Stanislav Pidhorskyi's avatar Stanislav Pidhorskyi Committed by Facebook GitHub Bot
Browse files

grid_scatter

Summary:
Adds `grid_scatter` op that is similar to `grid_sample` but the grid points to the destination location instead of the source.

`grid_scatter` is indeed dual to `grid_sample`. Forward of `grid_scatter` is backward of `grid_sample` and backward of  `grid_scatter` is forward of `grid_sample` (with the exception for the gradient with respect to grid) which is reflected in the reference implementation in `drtk/grid_scatter.py`.

```python
def grid_scatter(
    input: th.Tensor,
    grid: th.Tensor,
    output_height: int,
    output_width: int,
    mode: str = "bilinear",
    padding_mode: str = "border",
    align_corners: Optional[bool] = None,
) -> th.Tensor:
```

Where :
* `input` [N x C x H x W]: is the input tensor values from which will be transferred to the result.
* `grid` [N x H x W x 2]: is the grid tensor that points to the location where the values from the  input tensor should be copied to. The `W`, `H` sizes of grid should match the corresponding sizes of the `input` tensor.
*  `output_height`, `output_width`: size of the output, where output will be: [N x C x `output_height` x `output_width`]. In contrast to `grid_sample`, we can no longer rely on the sizes of the `grid` for this information.
* `mode`, `padding_mode`, `align_corners` same as for the `grid_sample`, but now for the reverse operation - splatting (or scattering).

At the moment does not support "nearest" mode, which is rarely needed. Maybe will add later.

Ideally, we would also want to support autocast mode where the `input` and output tensors are float16 while the `grid` is float32. This is not the case at the moment, but I'll add that later.

## Example usage

Let's assume that we loaded mesh into `v, vi, vt, vti`, have defined `image_width, image_height`, `cam_pos`, `cam_rot`, `focal`, `princpt`, and computed normals for the mesh `normals`. We also define a shading function, e.g.:

```lang=python
def shade(
    vn_img: th.Tensor,
    light_dir: th.Tensor,
    ambient_intensity: float = 1.0,
    direct_intensity: float = 1.0,
    shadow_img: Optional[th.Tensor] = None,
):
    ambient = (vn_img[:, 1:2] * 0.5 + 0.5) * th.as_tensor([0.45, 0.5, 0.7]).cuda()[
        None, :, None, None
    ]
    direct = (
        th.sum(vn_img.mul(thf.normalize(light_dir, dim=1)), dim=1, keepdim=True).clamp(
            min=0.0
        )
        * th.as_tensor([0.65, 0.6, 0.5]).cuda()[None, :, None, None]
    )
    if shadow_img is not None:
        direct = direct * shadow_img
    return th.pow(ambient * ambient_intensity + direct * direct_intensity, 1 / 2.2)
```

And we can render the image as:

```lang=python
v_pix = transform(v, cam_pos, cam_rot, focal, princpt)
index_img = rasterize(v_pix, vi, image_height, image_width)
_, bary_img = render(v_pix, vi, index_img)

# mask image
mask: th.Tensor = (index_img != -1)[:, None]

# compute vt image
vt_img = interpolate(vt.mul(2.0).sub(1.0)[None], vti, index_img, bary_img)

# compute normals
vn_img = interpolate(normals, vi, index_img, bary_img)

diffuse = (
    shade(vn_img, th.as_tensor([0.5, 0.5, 0.0]).cuda()[None, :, None, None]) * mask
)
```

 {F1801805545}

## Shadow mapping

We can use  `grid_scatter` to compute mesh visibility from the camera view:

```lang=python
texel_weight = grid_scatter(
    mask.float(),
    vt_img.permute(0, 2, 3, 1),
    output_width=512,
    output_height=512,
    mode="bilinear",
    padding_mode="border",
    align_corners=False,
)
threshold = 0.1  # texel_weight is proportional to how much pixel are the texel covers. We can specify a threshold of how much covered pixel area counts as visible.
visibility = (texel_weight > threshold).float()
```
 {F1801810094}

Now we can render the scene from different angle and use the visibility mask for shadows:

```lang=python
v_pix = transform(v, cam_pos_new, cam_rot_new, focal, princpt)
index_img = rasterize(v_pix, vi, image_height, image_width)
_, bary_img = render(v_pix, vi, index_img)

# mask image
mask: th.Tensor = (index_img != -1)[:, None]

# compute vt image
vt_img = interpolate(vt.mul(2.0).sub(1.0)[None], vti, index_img, bary_img)

# compute v image (for near-field)
v_img = interpolate(v, vi, index_img, bary_img)

# shadow
shadow_img = thf.grid_sample(visibility, vt_img.permute(0, 2, 3, 1), mode="bilinear", padding_mode="border", align_corners=False)

# compute normals
vn_img = interpolate(normals, vi, index_img, bary_img)

diffuse = shade(vn_img, cam_pos[:, :, None, None] - v_img, 0.05, 0.4, shadow_img) * mask
```
 {F1801811232}

## Texture projection

Let's load a test image:

```lang=python
import skimage
test_image = (
    th.as_tensor(skimage.data.coffee(), dtype=th.float32).permute(2, 0, 1)[None, ...].mul(1 / 255).contiguous().cuda()
)

test_image = thf.interpolate(test_image, scale_factor=2.0, mode="bilinear", align_corners=False)
```

{F1801814094}

We can use `grid_scatter` to project the image onto the uv space:

```lang=python
camera_image_extended = (
    th.cat([test_image, th.ones_like(test_image[:, :1])], dim=1) * mask
)

texture_weight = grid_scatter(
    camera_image_extended,
    vt_img.permute(0, 2, 3, 1),
    output_width=512,
    output_height=512,
    mode="bilinear",
    padding_mode="border",
    align_corners=False,
)

texture = texture_weight[:, :3] / texture_weight[:, 3:4].clamp(min=1e-4)
```

{F1801816367}

And if we render the scene from a different angle using the projected texture:

 {F1801817130}

Reviewed By: HapeMask

Differential Revision: D61006613

fbshipit-source-id: 98c83ba4eda531e9d73cb9e533176286dc699f63
parent 3c37fcb2
......@@ -6,6 +6,7 @@
from . import utils # noqa # noqa
from .edge_grad_estimator import edge_grad_estimator, edge_grad_estimator_ref # noqa
from .grid_scatter import grid_scatter, grid_scatter_ref # noqa
from .interpolate import interpolate, interpolate_ref # noqa
from .mipmap_grid_sample import mipmap_grid_sample, mipmap_grid_sample_ref # noqa
from .msi import msi # noqa
......
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch as th
import torch.nn.functional as thf
from drtk import grid_scatter_ext
th.ops.load_library(grid_scatter_ext.__file__)
@th.compiler.disable
def grid_scatter(
input: th.Tensor,
grid: th.Tensor,
output_height: int,
output_width: int,
mode: str = "bilinear",
padding_mode: str = "border",
align_corners: Optional[bool] = None,
) -> th.Tensor:
if mode != "bilinear" and mode != "bicubic":
raise ValueError(
"grid_scatter(): only 'bilinear' and 'bicubic' modes are supported "
"but got: '{}'".format(mode)
)
if (
padding_mode != "zeros"
and padding_mode != "border"
and padding_mode != "reflection"
):
raise ValueError(
"grid_scatter(): expected padding_mode "
"to be 'zeros', 'border', or 'reflection', "
"but got: '{}'".format(padding_mode)
)
if mode == "bilinear":
mode_enum = 0
elif mode == "nearest": # not supported
mode_enum = 1
else: # mode == 'bicubic'
mode_enum = 2
if padding_mode == "zeros":
padding_mode_enum = 0
elif padding_mode == "border":
padding_mode_enum = 1
else: # padding_mode == 'reflection'
padding_mode_enum = 2
if align_corners is None:
align_corners = False
return th.ops.grid_scatter_ext.grid_scatter_2d(
input,
grid,
output_height,
output_width,
padding_mode_enum,
mode_enum,
align_corners,
)
class GridScatterRef(th.autograd.Function):
@staticmethod
def forward(
ctx,
input: th.Tensor,
grid: th.Tensor,
output_height: int,
output_width: int,
mode: str = "bilinear",
padding_mode: str = "border",
align_corners: Optional[bool] = None,
):
with th.enable_grad():
tex = th.ones(
input.shape[0],
input.shape[1],
output_height,
output_width,
dtype=input.dtype,
device=input.device,
)
tex.requires_grad_(True)
out = thf.grid_sample(
tex,
grid,
mode=mode,
padding_mode=padding_mode,
align_corners=align_corners,
)
out.backward(input)
ctx.save_for_backward(input, grid, out)
ctx.mode = mode
ctx.padding_mode = padding_mode
ctx.align_corners = align_corners
return tex.grad
@staticmethod
def backward(ctx, grad_output: th.Tensor):
input, grid, out = ctx.saved_tensors
grid = grid.clone().detach()
grid.requires_grad_(True)
with th.enable_grad():
input_grad = thf.grid_sample(
grad_output,
grid,
mode=ctx.mode,
padding_mode=ctx.padding_mode,
align_corners=ctx.align_corners,
)
input_grad.backward(input)
return input_grad, grid.grad, None, None, None, None, None
_grid_scatter_ref = GridScatterRef.apply
@th.compiler.disable
def grid_scatter_ref(
input: th.Tensor,
grid: th.Tensor,
output_height: int,
output_width: int,
mode: str = "bilinear",
padding_mode: str = "border",
align_corners: Optional[bool] = None,
) -> th.Tensor:
return _grid_scatter_ref(
input,
grid,
output_height,
output_width,
mode,
padding_mode,
align_corners,
)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
from torch import Tensor
def grid_scatter_2d(
input: List[Tensor],
grid: Tensor,
output_height: int,
output_width: int,
padding_mode: int,
interpolation_mode: int,
align_corners: bool,
) -> Tensor: ...
......@@ -151,6 +151,15 @@ def main(debug: bool) -> None:
extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
include_dirs=include_dir,
),
CUDAExtension(
"drtk.grid_scatter_ext",
sources=[
"src/grid_scatter/grid_scatter_module.cpp",
"src/grid_scatter/grid_scatter_kernel.cu",
],
extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
include_dirs=include_dir,
),
],
cmdclass={"build_ext": BuildExtension},
packages=["drtk", "drtk.utils"],
......
This diff is collapsed.
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <torch/torch.h>
torch::Tensor grid_scatter_2d_cuda(
const torch::Tensor& input,
const torch::Tensor& grid,
int64_t output_height,
int64_t output_width,
int64_t padding_mode,
int64_t interpolation_mode,
bool align_corners);
std::tuple<torch::Tensor, torch::Tensor> grid_scatter_2d_cuda_backward(
const torch::Tensor& grad_output,
const torch::Tensor& input,
const torch::Tensor& grid,
int64_t padding_mode,
int64_t interpolation_mode,
bool align_corners,
bool grid_requires_grad,
bool input_requires_grad);
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#include <torch/script.h>
#include <ATen/autocast_mode.h>
#ifndef NO_PYBIND
#include <torch/extension.h>
#endif
#include "grid_scatter_kernel.h"
// Dispatch function
torch::Tensor grid_scatter_2d(
const torch::Tensor& input,
const torch::Tensor& grid,
int64_t output_height,
int64_t output_width,
int64_t padding_mode,
int64_t interpolation_mode,
bool align_corners) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("grid_scatter_ext::grid_scatter_2d", "")
.typed<decltype(grid_scatter_2d)>();
return op.call(
input, grid, output_height, output_width, padding_mode, interpolation_mode, align_corners);
}
// Ideally we would need to turn off autograd handling and re-dispatch, but we just call
// cuda kernels directly
class GridScatter2DFunction : public torch::autograd::Function<GridScatter2DFunction> {
public:
static torch::autograd::tensor_list forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& input,
const torch::Tensor& grid,
int64_t output_height,
int64_t output_width,
int64_t padding_mode,
int64_t interpolation_mode,
bool align_corners) {
ctx->set_materialize_grads(false);
std::vector<torch::Tensor> save_list;
save_list.push_back(input);
save_list.push_back(grid);
ctx->save_for_backward(save_list);
bool grid_requires_grad = grid.requires_grad();
bool input_requires_grad = input.requires_grad();
ctx->saved_data["data"] = std::make_tuple(
grid_requires_grad, input_requires_grad, padding_mode, interpolation_mode, align_corners);
auto out = grid_scatter_2d_cuda(
input, grid, output_height, output_width, padding_mode, interpolation_mode, align_corners);
return {out};
}
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
bool grid_requires_grad;
bool input_requires_grad;
int64_t padding_mode;
int64_t interpolation_mode;
bool align_corners;
std::tie(
grid_requires_grad, input_requires_grad, padding_mode, interpolation_mode, align_corners) =
ctx->saved_data["data"].to<std::tuple<bool, bool, int64_t, int64_t, bool>>();
torch::autograd::tensor_list out;
if (!grid_requires_grad && !input_requires_grad) {
out.resize(7);
return out;
}
const auto saved = ctx->get_saved_variables();
const torch::Tensor& input = saved[0];
const torch::Tensor& grid = saved[1];
auto grad_out = grid_scatter_2d_cuda_backward(
grad_outputs[0],
input,
grid,
padding_mode,
interpolation_mode,
align_corners,
grid_requires_grad,
input_requires_grad);
out.push_back(std::get<0>(grad_out));
out.push_back(std::get<1>(grad_out));
out.emplace_back();
out.emplace_back();
out.emplace_back();
out.emplace_back();
out.emplace_back();
return out;
}
};
torch::Tensor grid_scatter_2d_autograd(
const torch::Tensor& input,
const torch::Tensor& grid,
int64_t output_height,
int64_t output_width,
int64_t padding_mode,
int64_t interpolation_mode,
bool align_corners) {
return GridScatter2DFunction::apply(
input, grid, output_height, output_width, padding_mode, interpolation_mode, align_corners)[0];
}
torch::Tensor grid_scatter_2d_autocast(
const torch::Tensor& input,
const torch::Tensor& grid,
int64_t output_height,
int64_t output_width,
int64_t padding_mode,
int64_t interpolation_mode,
bool align_corners) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return grid_scatter_2d(
at::autocast::cached_cast(torch::kFloat32, input),
at::autocast::cached_cast(torch::kFloat32, grid),
output_height,
output_width,
padding_mode,
interpolation_mode,
align_corners);
}
#ifndef NO_PYBIND
// Just so that we can import this file as a Python module to get the path and
// import the Torch ops.
PYBIND11_MODULE(grid_scatter_ext, m) {}
#endif
TORCH_LIBRARY(grid_scatter_ext, m) {
m.def(
"grid_scatter_2d(Tensor input, Tensor grid, int output_height, int output_width, int padding_mode, int interpolation_mode, bool align_corners) -> Tensor");
}
TORCH_LIBRARY_IMPL(grid_scatter_ext, Autograd, m) {
m.impl("grid_scatter_2d", &grid_scatter_2d_autograd);
}
TORCH_LIBRARY_IMPL(grid_scatter_ext, Autocast, m) {
m.impl("grid_scatter_2d", grid_scatter_2d_autocast);
}
TORCH_LIBRARY_IMPL(grid_scatter_ext, CUDA, m) {
m.impl("grid_scatter_2d", &grid_scatter_2d_cuda);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment