grid_scatter
Summary:
Adds `grid_scatter` op that is similar to `grid_sample` but the grid points to the destination location instead of the source.
`grid_scatter` is indeed dual to `grid_sample`. Forward of `grid_scatter` is backward of `grid_sample` and backward of `grid_scatter` is forward of `grid_sample` (with the exception for the gradient with respect to grid) which is reflected in the reference implementation in `drtk/grid_scatter.py`.
```python
def grid_scatter(
input: th.Tensor,
grid: th.Tensor,
output_height: int,
output_width: int,
mode: str = "bilinear",
padding_mode: str = "border",
align_corners: Optional[bool] = None,
) -> th.Tensor:
```
Where :
* `input` [N x C x H x W]: is the input tensor values from which will be transferred to the result.
* `grid` [N x H x W x 2]: is the grid tensor that points to the location where the values from the input tensor should be copied to. The `W`, `H` sizes of grid should match the corresponding sizes of the `input` tensor.
* `output_height`, `output_width`: size of the output, where output will be: [N x C x `output_height` x `output_width`]. In contrast to `grid_sample`, we can no longer rely on the sizes of the `grid` for this information.
* `mode`, `padding_mode`, `align_corners` same as for the `grid_sample`, but now for the reverse operation - splatting (or scattering).
At the moment does not support "nearest" mode, which is rarely needed. Maybe will add later.
Ideally, we would also want to support autocast mode where the `input` and output tensors are float16 while the `grid` is float32. This is not the case at the moment, but I'll add that later.
## Example usage
Let's assume that we loaded mesh into `v, vi, vt, vti`, have defined `image_width, image_height`, `cam_pos`, `cam_rot`, `focal`, `princpt`, and computed normals for the mesh `normals`. We also define a shading function, e.g.:
```lang=python
def shade(
vn_img: th.Tensor,
light_dir: th.Tensor,
ambient_intensity: float = 1.0,
direct_intensity: float = 1.0,
shadow_img: Optional[th.Tensor] = None,
):
ambient = (vn_img[:, 1:2] * 0.5 + 0.5) * th.as_tensor([0.45, 0.5, 0.7]).cuda()[
None, :, None, None
]
direct = (
th.sum(vn_img.mul(thf.normalize(light_dir, dim=1)), dim=1, keepdim=True).clamp(
min=0.0
)
* th.as_tensor([0.65, 0.6, 0.5]).cuda()[None, :, None, None]
)
if shadow_img is not None:
direct = direct * shadow_img
return th.pow(ambient * ambient_intensity + direct * direct_intensity, 1 / 2.2)
```
And we can render the image as:
```lang=python
v_pix = transform(v, cam_pos, cam_rot, focal, princpt)
index_img = rasterize(v_pix, vi, image_height, image_width)
_, bary_img = render(v_pix, vi, index_img)
# mask image
mask: th.Tensor = (index_img != -1)[:, None]
# compute vt image
vt_img = interpolate(vt.mul(2.0).sub(1.0)[None], vti, index_img, bary_img)
# compute normals
vn_img = interpolate(normals, vi, index_img, bary_img)
diffuse = (
shade(vn_img, th.as_tensor([0.5, 0.5, 0.0]).cuda()[None, :, None, None]) * mask
)
```
{F1801805545}
## Shadow mapping
We can use `grid_scatter` to compute mesh visibility from the camera view:
```lang=python
texel_weight = grid_scatter(
mask.float(),
vt_img.permute(0, 2, 3, 1),
output_width=512,
output_height=512,
mode="bilinear",
padding_mode="border",
align_corners=False,
)
threshold = 0.1 # texel_weight is proportional to how much pixel are the texel covers. We can specify a threshold of how much covered pixel area counts as visible.
visibility = (texel_weight > threshold).float()
```
{F1801810094}
Now we can render the scene from different angle and use the visibility mask for shadows:
```lang=python
v_pix = transform(v, cam_pos_new, cam_rot_new, focal, princpt)
index_img = rasterize(v_pix, vi, image_height, image_width)
_, bary_img = render(v_pix, vi, index_img)
# mask image
mask: th.Tensor = (index_img != -1)[:, None]
# compute vt image
vt_img = interpolate(vt.mul(2.0).sub(1.0)[None], vti, index_img, bary_img)
# compute v image (for near-field)
v_img = interpolate(v, vi, index_img, bary_img)
# shadow
shadow_img = thf.grid_sample(visibility, vt_img.permute(0, 2, 3, 1), mode="bilinear", padding_mode="border", align_corners=False)
# compute normals
vn_img = interpolate(normals, vi, index_img, bary_img)
diffuse = shade(vn_img, cam_pos[:, :, None, None] - v_img, 0.05, 0.4, shadow_img) * mask
```
{F1801811232}
## Texture projection
Let's load a test image:
```lang=python
import skimage
test_image = (
th.as_tensor(skimage.data.coffee(), dtype=th.float32).permute(2, 0, 1)[None, ...].mul(1 / 255).contiguous().cuda()
)
test_image = thf.interpolate(test_image, scale_factor=2.0, mode="bilinear", align_corners=False)
```
{F1801814094}
We can use `grid_scatter` to project the image onto the uv space:
```lang=python
camera_image_extended = (
th.cat([test_image, th.ones_like(test_image[:, :1])], dim=1) * mask
)
texture_weight = grid_scatter(
camera_image_extended,
vt_img.permute(0, 2, 3, 1),
output_width=512,
output_height=512,
mode="bilinear",
padding_mode="border",
align_corners=False,
)
texture = texture_weight[:, :3] / texture_weight[:, 3:4].clamp(min=1e-4)
```
{F1801816367}
And if we render the scene from a different angle using the projected texture:
{F1801817130}
Reviewed By: HapeMask
Differential Revision: D61006613
fbshipit-source-id: 98c83ba4eda531e9d73cb9e533176286dc699f63
Showing
drtk/grid_scatter.py
0 → 100644
drtk/grid_scatter_ext.pyi
0 → 100644
Please register or sign in to comment