• Stanislav Pidhorskyi's avatar
    grid_scatter · b0810efa
    Stanislav Pidhorskyi authored
    Summary:
    Adds `grid_scatter` op that is similar to `grid_sample` but the grid points to the destination location instead of the source.
    
    `grid_scatter` is indeed dual to `grid_sample`. Forward of `grid_scatter` is backward of `grid_sample` and backward of  `grid_scatter` is forward of `grid_sample` (with the exception for the gradient with respect to grid) which is reflected in the reference implementation in `drtk/grid_scatter.py`.
    
    ```python
    def grid_scatter(
        input: th.Tensor,
        grid: th.Tensor,
        output_height: int,
        output_width: int,
        mode: str = "bilinear",
        padding_mode: str = "border",
        align_corners: Optional[bool] = None,
    ) -> th.Tensor:
    ```
    
    Where :
    * `input` [N x C x H x W]: is the input tensor values from which will be transferred to the result.
    * `grid` [N x H x W x 2]: is the grid tensor that points to the location where the values from the  input tensor should be copied to. The `W`, `H` sizes of grid should match the corresponding sizes of the `input` tensor.
    *  `output_height`, `output_width`: size of the output, where output will be: [N x C x `output_height` x `output_width`]. In contrast to `grid_sample`, we can no longer rely on the sizes of the `grid` for this information.
    * `mode`, `padding_mode`, `align_corners` same as for the `grid_sample`, but now for the reverse operation - splatting (or scattering).
    
    At the moment does not support "nearest" mode, which is rarely needed. Maybe will add later.
    
    Ideally, we would also want to support autocast mode where the `input` and output tensors are float16 while the `grid` is float32. This is not the case at the moment, but I'll add that later.
    
    ## Example usage
    
    Let's assume that we loaded mesh into `v, vi, vt, vti`, have defined `image_width, image_height`, `cam_pos`, `cam_rot`, `focal`, `princpt`, and computed normals for the mesh `normals`. We also define a shading function, e.g.:
    
    ```lang=python
    def shade(
        vn_img: th.Tensor,
        light_dir: th.Tensor,
        ambient_intensity: float = 1.0,
        direct_intensity: float = 1.0,
        shadow_img: Optional[th.Tensor] = None,
    ):
        ambient = (vn_img[:, 1:2] * 0.5 + 0.5) * th.as_tensor([0.45, 0.5, 0.7]).cuda()[
            None, :, None, None
        ]
        direct = (
            th.sum(vn_img.mul(thf.normalize(light_dir, dim=1)), dim=1, keepdim=True).clamp(
                min=0.0
            )
            * th.as_tensor([0.65, 0.6, 0.5]).cuda()[None, :, None, None]
        )
        if shadow_img is not None:
            direct = direct * shadow_img
        return th.pow(ambient * ambient_intensity + direct * direct_intensity, 1 / 2.2)
    ```
    
    And we can render the image as:
    
    ```lang=python
    v_pix = transform(v, cam_pos, cam_rot, focal, princpt)
    index_img = rasterize(v_pix, vi, image_height, image_width)
    _, bary_img = render(v_pix, vi, index_img)
    
    # mask image
    mask: th.Tensor = (index_img != -1)[:, None]
    
    # compute vt image
    vt_img = interpolate(vt.mul(2.0).sub(1.0)[None], vti, index_img, bary_img)
    
    # compute normals
    vn_img = interpolate(normals, vi, index_img, bary_img)
    
    diffuse = (
        shade(vn_img, th.as_tensor([0.5, 0.5, 0.0]).cuda()[None, :, None, None]) * mask
    )
    ```
    
     {F1801805545}
    
    ## Shadow mapping
    
    We can use  `grid_scatter` to compute mesh visibility from the camera view:
    
    ```lang=python
    texel_weight = grid_scatter(
        mask.float(),
        vt_img.permute(0, 2, 3, 1),
        output_width=512,
        output_height=512,
        mode="bilinear",
        padding_mode="border",
        align_corners=False,
    )
    threshold = 0.1  # texel_weight is proportional to how much pixel are the texel covers. We can specify a threshold of how much covered pixel area counts as visible.
    visibility = (texel_weight > threshold).float()
    ```
     {F1801810094}
    
    Now we can render the scene from different angle and use the visibility mask for shadows:
    
    ```lang=python
    v_pix = transform(v, cam_pos_new, cam_rot_new, focal, princpt)
    index_img = rasterize(v_pix, vi, image_height, image_width)
    _, bary_img = render(v_pix, vi, index_img)
    
    # mask image
    mask: th.Tensor = (index_img != -1)[:, None]
    
    # compute vt image
    vt_img = interpolate(vt.mul(2.0).sub(1.0)[None], vti, index_img, bary_img)
    
    # compute v image (for near-field)
    v_img = interpolate(v, vi, index_img, bary_img)
    
    # shadow
    shadow_img = thf.grid_sample(visibility, vt_img.permute(0, 2, 3, 1), mode="bilinear", padding_mode="border", align_corners=False)
    
    # compute normals
    vn_img = interpolate(normals, vi, index_img, bary_img)
    
    diffuse = shade(vn_img, cam_pos[:, :, None, None] - v_img, 0.05, 0.4, shadow_img) * mask
    ```
     {F1801811232}
    
    ## Texture projection
    
    Let's load a test image:
    
    ```lang=python
    import skimage
    test_image = (
        th.as_tensor(skimage.data.coffee(), dtype=th.float32).permute(2, 0, 1)[None, ...].mul(1 / 255).contiguous().cuda()
    )
    
    test_image = thf.interpolate(test_image, scale_factor=2.0, mode="bilinear", align_corners=False)
    ```
    
    {F1801814094}
    
    We can use `grid_scatter` to project the image onto the uv space:
    
    ```lang=python
    camera_image_extended = (
        th.cat([test_image, th.ones_like(test_image[:, :1])], dim=1) * mask
    )
    
    texture_weight = grid_scatter(
        camera_image_extended,
        vt_img.permute(0, 2, 3, 1),
        output_width=512,
        output_height=512,
        mode="bilinear",
        padding_mode="border",
        align_corners=False,
    )
    
    texture = texture_weight[:, :3] / texture_weight[:, 3:4].clamp(min=1e-4)
    ```
    
    {F1801816367}
    
    And if we render the scene from a different angle using the projected texture:
    
     {F1801817130}
    
    Reviewed By: HapeMask
    
    Differential Revision: D61006613
    
    fbshipit-source-id: 98c83ba4eda531e9d73cb9e533176286dc699f63
    b0810efa
grid_scatter_kernel.cu 28.6 KB