Commit ecc1df99 authored by facebook-github-bot's avatar facebook-github-bot
Browse files

Initial commit

fbshipit-source-id: afc575e8e7d8e2796a3f77d8b1c6c4fcb999558d
parents
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Sequence, Tuple
import torch as th
from drtk.utils import project_points
def transform(
v: th.Tensor,
campos: Optional[th.Tensor] = None,
camrot: Optional[th.Tensor] = None,
focal: Optional[th.Tensor] = None,
princpt: Optional[th.Tensor] = None,
K: Optional[th.Tensor] = None,
Rt: Optional[th.Tensor] = None,
distortion_mode: Optional[Sequence[str]] = None,
distortion_coeff: Optional[th.Tensor] = None,
fov: Optional[th.Tensor] = None,
) -> th.Tensor:
"""
v: Tensor, N x V x 3
Batch of vertex positions for vertices in the mesh.
campos: Tensor, N x 3
Camera position.
camrot: Tensor, N x 3 x 3
Camera rotation matrix.
focal: Tensor, N x 2 x 2
Focal length [[fx, 0],
[0, fy]]
princpt: Tensor, N x 2
Principal point [cx, cy]
K: Tensor, N x 3 x 3
Camera intrinsic calibration matrix. Either this or both (focal,
princpt) must be provided.
Rt: Tensor, N x 3 x 4 or N x 4 x 4
Camera extrinsic matrix. Either this or both (camrot, campos) must be
provided. Camrot is the upper 3x3 of Rt, campos = -R.T @ t.
distortion_mode: List[str]
Names of the distortion modes.
distortion_coeff: Tensor, N x 4
Distortion coefficients.
fov: Tensor, N x 1
Valid field of view of the distortion model.
"""
v_pix, _ = transform_with_v_cam(
v, campos, camrot, focal, princpt, K, Rt, distortion_mode, distortion_coeff, fov
)
return v_pix
def transform_with_v_cam(
v: th.Tensor,
campos: Optional[th.Tensor] = None,
camrot: Optional[th.Tensor] = None,
focal: Optional[th.Tensor] = None,
princpt: Optional[th.Tensor] = None,
K: Optional[th.Tensor] = None,
Rt: Optional[th.Tensor] = None,
distortion_mode: Optional[Sequence[str]] = None,
distortion_coeff: Optional[th.Tensor] = None,
fov: Optional[th.Tensor] = None,
) -> Tuple[th.Tensor, th.Tensor]:
"""
Same as transform, but also returns the camera-space coordinates.
In most cases it is not needed, but renderlayer depends on it
"""
if not ((camrot is not None and campos is not None) ^ (Rt is not None)):
raise ValueError("You must provide exactly one of Rt or (campos, camrot).")
if not ((focal is not None and princpt is not None) ^ (K is not None)):
raise ValueError("You must provide exactly one of K or (focal, princpt).")
if campos is None:
assert Rt is not None
camrot = Rt[:, :3, :3]
campos = -(camrot.transpose(-2, -1) @ Rt[:, :3, 3:4])[..., 0]
if focal is None:
assert K is not None
focal = K[:, :2, :2]
princpt = K[:, :2, 2]
assert camrot is not None
assert princpt is not None
# Compute camera-space 3D coordinates and 2D pixel-space projections.
v_pix, v_cam = project_points(
v, campos, camrot, focal, princpt, distortion_mode, distortion_coeff, fov
)
return v_pix, v_cam
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from drtk.utils.geometry import ( # noqa
face_dpdt,
face_info,
vert_binormals,
vert_normals,
)
from drtk.utils.indexing import index # noqa
from drtk.utils.projection import ( # noqa
DISTORTION_MODES, # noqa
project_points, # noqa
project_points_grad, # noqa
)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional, Tuple, Union
import torch as th
import torch.nn.functional as thf
from drtk.utils.indexing import index
from torch import Tensor
eps = 1e-8
def face_dpdt(
v: th.Tensor, vt: th.Tensor, vi: th.Tensor, vti: th.Tensor
) -> Tuple[th.Tensor, th.Tensor]:
"""
This function calculates the transposed Jacobian matrix (∂p/∂t)^T for each triangle.
Where:
- p represents the 3D coordinates of a point on the plane of the triangle,
- t denotes the UV coordinates assiciated with the point.
Args:
v: vertex position tensor
N x V x 3
vt: vertex uv tensor
N x T x 2
vi: face vertex position index list tensor
F x 3
vti: face vertex uv index list tensor
F x 3
Jacobian is computed as:
∂p/∂t = ∂p / ∂b * (∂t / ∂b)^-1
Where b - barycentric coordinates
However the implementation computes a transposed Jacobian (purely from
practical perspective - fewer permutations are needed), so the above
becomes:
(∂p/∂t)^T = ((∂t / ∂b)^T)^-1 * (∂p / ∂b)^T
Returns:
dpdt - transposed Jacobian (∂p/∂t)^T. Shape: N x F x 2 x 3
Where ∂p∂t[..., i, j] = ∂p[..., j] / ∂t[..., i]
v012 - vertex positions per triangle. Shape: N x F x 3
where: N - batch size; F - number of triangles
"""
v012 = v[:, vi]
vt012 = vt[:, vti]
dpdb_t = v012[:, :, 1:3] - v012[:, :, 0:1]
dtdb_t = vt012[:, :, 1:3] - vt012[:, :, 0:1]
# (db / dt)^T = ((dt / db)^T)^-1
dbdt_t = th.inverse(dtdb_t)
# (dp / dt)^T = (db / dt)^T) * (dp / db)^T
dpdt_t = dbdt_t @ dpdb_t
return dpdt_t, v012
def face_attribute_to_vert(v: th.Tensor, vi: th.Tensor, attr: th.Tensor) -> Tensor:
"""
For each vertex, computes a summation of the face attributes to which the
vertex belongs.
"""
attr = (
attr[:, :, None]
.expand(-1, -1, 3, -1)
.reshape(attr.shape[0], -1, attr.shape[-1])
)
vi_flat = vi.view(vi.shape[0], -1).expand(v.shape[0], -1)
vattr = th.zeros(v.shape[:-1], dtype=v.dtype, device=v.device)
vattr = th.stack(
[vattr.scatter_add(1, vi_flat, attr[..., i]) for i in range(attr.shape[-1])],
dim=-1,
)
return vattr
def face_info(
v: th.Tensor, vi: th.Tensor, to_compute: Optional[List[str]] = None
) -> Union[th.Tensor, Dict[str, th.Tensor]]:
"""Given a set of vertices ``v`` and indices ``vi`` indexing into ``v``
defining a set of faces, compute face information (normals, edges, face
areas) for each face.
Args:
v: Vertex positions, shape [batch_size, n_vertices, 3]
vi: Vertex indices, shape [n_faces, 3]
to_compute: list of desired information. Any of: {normals, edges, areas}, defaults to all.
Returns:
Dict: Face information in the following format::
{
"normals": shape [batch_size, n_faces, 3]
"edges": shape [batch_size, n_faces, 3, 3]
"areas": shape [batch_size, n_faces, 1]
}
or just one of the above values not in a Dict if only one is
requested.
"""
if to_compute is None:
to_compute = ["normals", "edges", "areas"]
b = v.shape[0]
vi = vi.expand(b, -1, -1)
p0 = th.stack([index(v[i], vi[i, :, 0], 0) for i in range(b)])
p1 = th.stack([index(v[i], vi[i, :, 1], 0) for i in range(b)])
p2 = th.stack([index(v[i], vi[i, :, 2], 0) for i in range(b)])
v0 = p1 - p0
v1 = p0 - p2
need_normals = "normals" in to_compute
need_areas = "areas" in to_compute
need_edges = "edges" in to_compute
output = {}
if need_normals or need_areas:
normals = th.cross(v1, v0, dim=-1)
norm = th.linalg.vector_norm(normals, dim=-1, keepdim=True)
if need_areas:
output["areas"] = 0.5 * norm
if need_normals:
output["normals"] = normals / norm.clamp(min=eps)
if need_edges:
v2 = p2 - p1
output["edges"] = th.stack([v0, v1, v2], dim=2)
if len(to_compute) == 1:
return output[to_compute[0]]
else:
return output
def vert_binormals(v: Tensor, vt: Tensor, vi: Tensor, vti: Tensor) -> Tensor:
# Compute (dp/dt)^T
dpdt_t, vf = face_dpdt(v, vt, vi, vti)
# Take the dp/dt.u part. Produces u vector in 3D world-space which we use for binormal vector
fbnorms = dpdt_t[:, :, 0, :]
vbnorms = face_attribute_to_vert(v, vi, fbnorms)
return thf.normalize(vbnorms, dim=-1)
def vert_normals(
v: th.Tensor, vi: th.Tensor, fnorms: Optional[th.Tensor] = None
) -> th.Tensor:
"""Given a set of vertices ``v`` and indices ``vi`` indexing into ``v``
defining a set of faces, compute normals for each vertex by averaging the
face normals for each face which includes that vertex.
Args:
v: Vertex positions, shape [batch_size, n_vertices, 3]
vi: Vertex indices, shape [batch_size, n_faces, 3]
fnorms: Face normals. Optional, provide them if available, otherwise they will be computed
from `v` and `vi`. Shape [n_faces, 3]
Returns:
th.Tensor: Vertex normals, shape [batch_size, n_vertices, 3]
"""
if fnorms is None:
fnorms = face_info(v, vi, ["normals"])
assert isinstance(fnorms, th.Tensor)
vnorms = face_attribute_to_vert(v, vi, fnorms)
return thf.normalize(vnorms, dim=-1)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch as th
def index(x: th.Tensor, idxs: th.Tensor, dim: int) -> th.Tensor:
"""Index a tensor along a given dimension using an index tensor, replacing
the shape along the given dimension with the shape of the index tensor.
Example:
x: [8, 7306, 3]
idxs: [11000, 3]
y = index(x, idxs, dim=1) -> y: [8, 11000, 3, 3]
with each y[b, i, j, k] = x[b, idxs[i, j], k]
"""
target_shape = [*x.shape]
del target_shape[dim]
target_shape[dim:dim] = [*idxs.shape]
return x.index_select(dim, idxs.view(-1)).reshape(target_shape)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Sequence, Set, Tuple, Union
import numpy as np
import torch as th
DISTORTION_MODES: Set[Optional[str]] = {None, "radial-tangential", "fisheye"}
def project_pinhole(
v_cam: th.Tensor, focal: th.Tensor, princpt: th.Tensor
) -> th.Tensor:
"""Project camera-space points to pixel-space points with camera
intrinsics.
v_cam: N x V x 3
focal: N x 2 x 2
princpt: N x 2
"""
z = v_cam[:, :, 2:3]
z = th.where(z < 0, z.clamp(max=-1e-8), z.clamp(min=1e-8))
v_proj = v_cam[:, :, 0:2] / z
v_pix = (focal[:, None] @ v_proj[..., None])[..., 0] + princpt[:, None]
return v_pix
def project_pinhole_distort_rt(
v_cam: th.Tensor,
focal: th.Tensor,
princpt: th.Tensor,
D: th.Tensor,
fov: Optional[th.Tensor] = None,
) -> th.Tensor:
"""Project camera-space points to distorted pixel-space using the radial
and tangential model (4 parameters).
v_cam: N x V x 3
focal: N x 2 x 2
princpt: N x 2
D: N x 4
fov: N x 1
"""
# See https://docs.opencv.org/2.4/doc/tutorials/calib3d/camera_calibration/camera_calibration.html
if fov is None:
with th.no_grad():
fov = estimate_rt_fov(D)
z = v_cam[:, :, 2:3]
z = th.where(z < 0, z.clamp(max=-1e-8), z.clamp(min=1e-8))
v_proj = v_cam[:, :, :2] / z
r2 = v_proj.pow(2).sum(-1)
# Clamp x, y and r to avoid wrapping behavior of the distortion model.
r2 = r2.clamp(max=fov.pow(2))
v_clamped = v_proj.clamp(min=-fov[..., None], max=fov[..., None])
assert D.shape[1] in [4, 5, 8]
# 4 param: R = (1 + k1 r^2 + k2 r^4)
R = 1 + D[:, 0:1] * r2 + D[:, 1:2] * r2.pow(2)
# 5 param: R = (1 + k1 r^2 + k2 r^4 + k3 r^6)
if D.shape[1] == 5:
R = R + D[:, 4:5] * r2.pow(3)
# 8 param: R = (1 + k1 r^2 + k2 r^4 + k3 r^6) / (1 + k4 r^2 + k5 r^4 + k6 r^6)
if D.shape[1] == 8:
R = R + D[:, 4:5] * r2.pow(3)
R = R / (1 + D[:, 5:6] * r2 + D[:, 6:7] * r2.pow(2) + D[:, 7:8] * r2.pow(3))
# [x' y'] * R
v_proj_dist = v_proj * R[..., None]
# [2 p1 x' y', 2 p2 x' y']
v_proj_dist += (
2
* v_clamped[..., 0:1]
* v_clamped[..., 1:2]
* th.stack((D[:, 2:3], D[:, 3:4]), dim=-1)
)
# [p2 r^2, p1 r^2]
v_proj_dist += r2[..., None] * th.stack((D[:, 3:4], D[:, 2:3]), dim=-1)
# [2 p2 x'^2, 2 p1 y'^2]
v_proj_dist += th.stack(
(
2 * D[:, 3:4] * v_clamped[..., 0].pow(2),
2 * D[:, 2:3] * v_clamped[..., 1].pow(2),
),
dim=-1,
)
v_pix_dist = (focal[:, None] @ v_proj_dist[..., None])[..., 0] + princpt[:, None]
return v_pix_dist
def project_fisheye_distort(
v_cam: th.Tensor,
focal: th.Tensor,
princpt: th.Tensor,
D: th.Tensor,
fov: Optional[th.Tensor] = None,
) -> th.Tensor:
"""Project camera-space points to distort pixel-space points using the
fisheye distortion model.
v_cam: N x V x 3
focal: N x 2 x 2
princpt: N x 2
D: N x 4
fov: N x 1
"""
# See https://github.com/opencv/opencv/blob/master/modules/calib3d/src/fisheye.cpp
if fov is None:
with th.no_grad():
fov = estimate_fisheye_fov(D)
z = v_cam[:, :, 2:3]
z = th.where(z < 0, z.clamp(max=-1e-8), z.clamp(min=1e-8))
v_proj = v_cam[:, :, :2] / z
r = v_proj.pow(2).sum(-1).sqrt()
r = r.clamp(max=fov, min=1e-8 * th.ones_like(fov))
theta = th.atan(r)
theta_d = theta * (
1
+ D[:, 0:1] * theta.pow(2)
+ D[:, 1:2] * theta.pow(4)
+ D[:, 2:3] * theta.pow(6)
+ D[:, 3:4] * theta.pow(8)
)
r = th.where(r < 0, r.clamp(max=-1e-8), r.clamp(min=1e-8))
v_proj_dist = v_proj * (theta_d / r)[..., None]
v_pix_dist = (focal[:, None] @ v_proj_dist[..., None])[..., 0] + princpt[:, None]
return v_pix_dist
def project_fisheye_distort_62(
v_cam: th.Tensor,
focal: th.Tensor,
princpt: th.Tensor,
D: th.Tensor,
fov: Optional[th.Tensor] = None,
) -> th.Tensor:
"""Project camera-space points to distort pixel-space points using the
OculusVisionFishEye62 distortion model.
v_cam: N x V x 3
focal: N x 2 x 2
princpt: N x 2
D: N x 4
fov: N x 1
"""
# See https://www.internalfb.com/code/fbsource/[188bdaeaad64]/arvr/projects/nimble/prod/pynimble/visualization/shaders.py?lines=103-123
# a more readible version: https://euratom-software.github.io/calcam/html/intro_theory.html
if fov is None:
with th.no_grad():
fov = estimate_fisheye_fov(D)
z = v_cam[:, :, 2:3]
z = th.where(z < 0, z.clamp(max=-1e-8), z.clamp(min=1e-8))
v_proj = v_cam[:, :, :2] / z
r = v_proj.pow(2).sum(-1).sqrt() # rp
r = r.clamp(max=fov, min=1e-8 * th.ones_like(fov))
theta = th.atan(r)
theta_d = theta * (
1
+ D[:, 0:1] * theta.pow(2)
+ D[:, 1:2] * theta.pow(4)
+ D[:, 2:3] * theta.pow(6)
+ D[:, 3:4] * theta.pow(8)
+ D[:, 4:5] * theta.pow(10)
+ D[:, 5:6] * theta.pow(12)
)
r = th.where(r < 0, r.clamp(max=-1e-8), r.clamp(min=1e-8))
v_proj_dist = v_proj * (theta_d / r)[..., None]
# Tangential Distortion
x = v_proj_dist[:, :, 0]
y = v_proj_dist[:, :, 1]
xtan = D[:, 6:7] * (r.pow(2) + 2 * x.pow(2)) + 2 * D[:, 7:8] * x * y
ytan = 2 * D[:, 6:7] * x * y + D[:, 7:8] * (r.pow(2) + 2 * y.pow(2))
pTangential = th.cat([xtan[..., None], ytan[..., None]], dim=-1)
v_proj_dist = v_proj_dist + pTangential
v_pix_dist = (focal[:, None] @ v_proj_dist[..., None])[..., 0] + princpt[:, None]
return v_pix_dist
def estimate_rt_fov(D: Union[np.ndarray, th.Tensor]) -> th.Tensor:
"""Estimate the maximum field of view based on the assumption that the 5th order
polynomial for fish-eye effect is non-decreasing.
D: N x 4
"""
if th.is_tensor(D):
coefs = D.cpu().numpy()
else:
coefs = D
ones = np.ones_like(coefs[:, 0])
zeros = np.zeros_like(coefs[:, 0])
coefs = np.stack(
[
5 * coefs[:, 1],
zeros,
3 * coefs[:, 0],
zeros,
ones,
],
axis=-1,
)
fov = []
for coef in coefs:
roots = np.roots(coef)
real_valued = roots.real[abs(roots.imag) < 1e-5]
positive_roots = real_valued[real_valued > 0]
if len(positive_roots) == 0:
fov.append(np.inf)
else:
fov.append(positive_roots.min())
fov = np.asarray(fov, dtype=np.float32)[..., None]
if th.is_tensor(D):
fov = th.from_numpy(fov).to(D)
return fov
def estimate_fisheye_fov(D: Union[np.ndarray, th.Tensor]) -> th.Tensor:
"""Estimate the maximum field of view based on the assumption that the 9th order
polynomial is non-decreasing.
D: N x 4
"""
if th.is_tensor(D):
coefs = D.cpu().numpy()
else:
coefs = D
ones = np.ones_like(coefs[:, 0])
zeros = np.zeros_like(coefs[:, 0])
coefs = np.stack(
[
9 * coefs[:, -1],
zeros,
7 * coefs[:, -2],
zeros,
5 * coefs[:, -3],
zeros,
3 * coefs[:, -4],
zeros,
ones,
],
axis=-1,
)
fov = []
for coef in coefs:
roots = np.roots(coef)
real_valued = roots.real[abs(roots.imag) < 1e-5]
positive_roots = real_valued[real_valued > 0]
if len(positive_roots) == 0:
fov.append(np.pi / 2)
else:
fov.append(min(positive_roots.min(), np.pi / 2))
fov = np.asarray(np.tan(fov), dtype=np.float32)[..., None]
if th.is_tensor(D):
fov = th.from_numpy(fov).to(D)
return fov
def project_points(
v: th.Tensor,
campos: th.Tensor,
camrot: th.Tensor,
focal: th.Tensor,
princpt: th.Tensor,
distortion_mode: Optional[Sequence[str]] = None,
distortion_coeff: Optional[th.Tensor] = None,
fov: Optional[th.Tensor] = None,
) -> Tuple[th.Tensor, th.Tensor]:
"""Project 3D world-space vertices to pixel-space, optionally applying a
distortion model with provided coefficients.
Returns v_pix, v_cam, both N x V x 3 since we preserve the camera-space
Z-coordinate for v_pix.
v: N x V x 3
camrot: N x 3 x 3
campos: N x 3
focal: N x 2 x 2
princpt: N x 2
distortion_coeff: N x 4
fov: N x 1
"""
if distortion_mode is not None:
assert distortion_coeff is not None, "Missing distortion coefficients."
v_cam = (camrot[:, None] @ (v - campos[:, None])[..., None])[..., 0]
# Fall back to single distortion mode if all the distortion modes are the same.
if isinstance(distortion_mode, (list, tuple)):
modes = list(set(distortion_mode))
if len(modes) == 0:
distortion_mode = None
elif len(modes) == 1:
distortion_mode = modes[0]
if distortion_mode is None:
v_pix = project_pinhole(v_cam, focal, princpt)
elif isinstance(distortion_mode, str):
assert distortion_coeff is not None
# Single distortion model
if distortion_mode == "radial-tangential":
v_pix = project_pinhole_distort_rt(
v_cam, focal, princpt, distortion_coeff, fov
)
elif distortion_mode == "fisheye":
v_pix = project_fisheye_distort(
v_cam, focal, princpt, distortion_coeff, fov
)
elif distortion_mode == "fisheye62":
v_pix = project_fisheye_distort_62(
v_cam, focal, princpt, distortion_coeff, fov
)
else:
raise ValueError(
f"Invalid distortion mode: {distortion_mode}. Valid options: {DISTORTION_MODES}."
)
elif isinstance(distortion_mode, (list, tuple)):
assert distortion_coeff is not None
# A mix of multiple distortion modes
modes = set(distortion_mode)
if not modes <= DISTORTION_MODES:
raise ValueError(
f"Invalid distortion mode: {distortion_mode}. Valid options: {DISTORTION_MODES}."
)
v_pix = th.empty_like(v_cam[..., :2])
if None in modes:
idx = th.tensor(
[mode is None for mode in distortion_mode], device=v_pix.device
)
v_pix[idx] = project_pinhole(v_cam[idx], focal[idx], princpt[idx])
if "radial-tangential" in modes:
idx = th.tensor(
[mode == "radial-tangential" for mode in distortion_mode],
device=v_pix.device,
)
v_pix[idx] = project_pinhole_distort_rt(
v_cam[idx],
focal[idx],
princpt[idx],
distortion_coeff[idx],
fov[idx] if fov is not None else None,
)
if "fisheye" in modes:
idx = th.tensor(
[mode == "fisheye" for mode in distortion_mode], device=v_pix.device
)
v_pix[idx] = project_fisheye_distort(
v_cam[idx],
focal[idx],
princpt[idx],
distortion_coeff[idx],
fov[idx] if fov is not None else None,
)
else:
raise ValueError(
f"Invalid distortion mode: {distortion_mode}. Valid options: {DISTORTION_MODES}."
)
v_pix = th.cat((v_pix[:, :, 0:2], v_cam[:, :, 2:3]), dim=-1)
return v_pix, v_cam
def project_points_grad(
v_grad: th.Tensor,
v: th.Tensor,
campos: th.Tensor,
camrot: th.Tensor,
focal: th.Tensor,
distortion_mode: Optional[Sequence[str]] = None,
distortion_coeff: Optional[th.Tensor] = None,
) -> th.Tensor:
"""Computes the gradient of projected (pixel-space) vertex positions with
respect to the 3D world-space vertex positions given the gradient of the 3D
world-space vertex positions.
project_points_grad(dv, v) = d project_points(v) / dv * dv
Args:
v_grad: Gradient of 3D world-space vertices. Shape: N x V x 3
v: 3D world-space vertices. Shape: N x V x 3
camrot: Camera rotation. Shape: N x 3 x 3
camrot: Camera position. Shape: N x 3
focal: Focal length. Shape: N x 2 x 2
distortion_mode: Distortion currently not implemented and must be None.
distortion_coeff: Distortion currently not implemented and must be None.
Returns:
Gradient of 2D pixel-space vertices: N x V x 2
"""
if distortion_mode is not None:
assert distortion_coeff is not None, "Missing distortion coefficients."
# d v_cam = d (Rv + T) = Rdv
v_cam_grad = (camrot[:, None] @ v_grad[..., None])[..., 0]
v_cam = (camrot[:, None] @ (v - campos[:, None])[..., None])[..., 0]
if distortion_mode is None:
z = v_cam[:, :, 2:3]
z_grad = v_cam_grad[:, :, 2:3]
z = th.where(z < 0, z.clamp(max=-1e-8), z.clamp(min=1e-8))
# Using quotient rule:
# d (v_cam / z) = (d v_cam * z - v_cam * dz) / z^2
v_proj_grad = (v_cam_grad[:, :, 0:2] * z - v_cam[:, :, 0:2] * z_grad) / z**2.0
# d v_pix = d (Kv + cp) = Kdv
v_pix_grad = (focal[:, None] @ v_proj_grad[..., None])[..., 0]
elif distortion_mode == "radial-tangential":
raise NotImplementedError
elif distortion_mode == "fisheye":
raise NotImplementedError
else:
raise ValueError(
f"Invalid distortion mode: {distortion_mode}. Valid options: {DISTORTION_MODES}."
)
return v_pix_grad
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import platform
import re
import sys
from pkg_resources import DistributionNotFound, get_distribution
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
def main(debug: bool) -> None:
extra_link_args = {
"linux": ["-static-libgcc"] + ([] if debug else ["-flto"]),
"win32": ["/DEBUG"] if debug else [],
}
cxx_args = {
"linux": ["-std=c++17", "-Wall"]
+ (["-O0", "-g3", "-DDEBUG"] if debug else ["-O3", "--fast-math"]),
"win32": (
["/std:c++17", "/MT", "/GR-", "/EHsc", '/D "NOMINMAX"']
+ (
["/Od", '/D "_DEBUG"']
if debug
else ["/O2", "/fp:fast", "/GL", '/D "NDEBUG"']
)
),
}
nvcc_args = [
"-gencode=arch=compute_72,code=sm_72",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_86,code=sm_86",
"-gencode=arch=compute_90,code=sm_90",
] + (["-O0", "-g", "-DDEBUG"] if debug else ["-O3", "--use_fast_math"])
# There is som issue effecting latest NVCC and pytorch 2.3.0 https://github.com/pytorch/pytorch/issues/122169
# The workaround is adding -std=c++20 to NVCC args
nvcc_args.append("-std=c++20")
def get_dist(name):
try:
return get_distribution(name)
except DistributionNotFound:
return None
root_path = os.path.dirname(os.path.abspath(__file__))
package_name = "drtk"
with open(os.path.join(root_path, "drtk", "__init__.py")) as f:
init_file = f.read()
pattern = re.compile(r"__version__\s*=\s*\"(\d*\.\d*.\d*)\"")
groups = pattern.findall(init_file)
assert len(groups) == 1
version = groups[0]
if get_dist("torch") is None:
raise RuntimeError("Setup requires torch package to be installed")
import torch as th
assert th.cuda.is_available()
target_os = "none"
if sys.platform == "darwin":
target_os = "macos"
elif os.name == "posix":
target_os = "linux"
elif platform.system() == "Windows":
target_os = "win32"
if target_os == "none":
raise RuntimeError("Could not detect platform")
if target_os == "macos":
raise RuntimeError("Platform is not supported")
include_dir = [os.path.join(root_path, "src", "include")]
with open("README.md") as f:
readme = f.read()
pillow = "pillow" if get_dist("pillow-simd") is None else "pillow-simd"
setup(
name=package_name,
version=version,
author="Reality Labs, Meta",
description="Differentiable Rendering Toolkit",
long_description=readme,
long_description_content_type="text/markdown",
license="MIT",
install_requires=["numpy", "torch", "torchvision", pillow],
ext_modules=[
CUDAExtension(
name="drtk.rasterize_ext",
sources=[
"src/rasterize/rasterize_module.cpp",
"src/rasterize/rasterize_kernel.cu",
],
extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
extra_link_args=extra_link_args[target_os],
include_dirs=include_dir,
),
CUDAExtension(
"drtk.render_ext",
sources=["src/render/render_kernel.cu", "src/render/render_module.cpp"],
extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
include_dirs=include_dir,
),
CUDAExtension(
"drtk.edge_grad_ext",
sources=[
"src/edge_grad/edge_grad_module.cpp",
"src/edge_grad/edge_grad_kernel.cu",
],
extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
include_dirs=include_dir,
),
CUDAExtension(
"drtk.mipmap_grid_sampler_ext",
sources=[
"src/mipmap_grid_sampler/mipmap_grid_sampler_module.cpp",
"src/mipmap_grid_sampler/mipmap_grid_sampler_kernel.cu",
],
extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
include_dirs=include_dir,
),
CUDAExtension(
"drtk.msi_ext",
sources=[
"src/msi/msi_module.cpp",
"src/msi/msi_kernel.cu",
],
extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
include_dirs=include_dir,
),
CUDAExtension(
"drtk.interpolate_ext",
sources=[
"src/interpolate/interpolate_module.cpp",
"src/interpolate/interpolate_kernel.cu",
],
extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
include_dirs=include_dir,
),
],
cmdclass={"build_ext": BuildExtension},
packages=["drtk", "drtk.utils"],
)
if __name__ == "__main__":
main(any(x in sys.argv for x in ["debug", "-debug", "--debug"]))
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#include <c10/cuda/CUDAGuard.h>
#include <cuda_math_helper.h>
#include <torch/types.h>
#include <ATen/native/cuda/KernelUtils.cuh>
#include <kernel_utils.h>
#include "edge_grad_kernel.h"
using namespace math;
using at::native::fastAtomicAdd;
template <typename scalar_t>
struct TriInfo {
typedef typename math::TVec2<scalar_t> scalar2_t;
const scalar2_t p_0;
const scalar2_t p_1;
const scalar2_t v_01;
const scalar2_t v_02;
const scalar2_t v_12;
const scalar_t denominator;
};
template <typename scalar_t>
__device__ bool pix_in_tri(const TriInfo<scalar_t>& tri, const int x, const int y) {
typedef typename math::TVec2<scalar_t> scalar2_t;
typedef typename math::TVec3<scalar_t> scalar3_t;
if (tri.denominator != 0.f) {
const scalar2_t p = {(scalar_t)x, (scalar_t)y};
const scalar2_t vp0p = p - tri.p_0;
const scalar2_t vp1p = p - tri.p_1;
scalar3_t bary = scalar3_t({
vp1p.y * tri.v_12.x - vp1p.x * tri.v_12.y,
vp0p.x * tri.v_02.y - vp0p.y * tri.v_02.x,
vp0p.y * tri.v_01.x - vp0p.x * tri.v_01.y,
});
bary *= sign(tri.denominator);
const bool on_edge_or_inside = (bary.x >= 0.f) && (bary.y >= 0.f) && (bary.z >= 0.f);
bool on_edge_0 = bary.x == 0.f;
bool on_edge_1 = bary.y == 0.f;
bool on_edge_2 = bary.z == 0.f;
const bool is_top_left_0 = (tri.denominator > 0)
? (tri.v_12.y < 0.f || tri.v_12.y == 0.0f && tri.v_12.x > 0.f)
: (tri.v_12.y > 0.f || tri.v_12.y == 0.0f && tri.v_12.x < 0.f);
const bool is_top_left_1 = (tri.denominator > 0)
? (tri.v_02.y > 0.f || tri.v_02.y == 0.0f && tri.v_02.x < 0.f)
: (tri.v_02.y < 0.f || tri.v_02.y == 0.0f && tri.v_02.x > 0.f);
const bool is_top_left_2 = (tri.denominator > 0)
? (tri.v_01.y < 0.f || tri.v_01.y == 0.0f && tri.v_01.x > 0.f)
: (tri.v_01.y > 0.f || tri.v_01.y == 0.0f && tri.v_01.x < 0.f);
const bool is_top_left_or_inside = on_edge_or_inside &&
!(on_edge_0 && !is_top_left_0 || on_edge_1 && !is_top_left_1 ||
on_edge_2 && !is_top_left_2);
return is_top_left_or_inside;
}
return false;
}
template <typename scalar_t, typename index_t>
__device__ TriInfo<scalar_t>
get_tri_info(const scalar_t* v_ptr, index_t v_sV, index_t v_sC, int3 vi) {
typedef typename math::TVec2<scalar_t> scalar2_t;
const scalar2_t p_0 = {v_ptr[v_sV * vi.x + v_sC * 0], v_ptr[v_sV * vi.x + v_sC * 1]};
const scalar2_t p_1 = {v_ptr[v_sV * vi.y + v_sC * 0], v_ptr[v_sV * vi.y + v_sC * 1]};
const scalar2_t p_2 = {v_ptr[v_sV * vi.z + v_sC * 0], v_ptr[v_sV * vi.z + v_sC * 1]};
const scalar2_t v_01 = p_1 - p_0;
const scalar2_t v_02 = p_2 - p_0;
const scalar2_t v_12 = p_2 - p_1;
const scalar_t denominator = v_01.x * v_02.y - v_01.y * v_02.x;
return {p_0, p_1, v_01, v_02, v_12, denominator};
}
template <typename scalar_t, typename index_t>
__device__ math::TVec3<scalar_t>
get_tri_normal(const scalar_t* v_ptr, index_t v_sV, index_t v_sC, int3 vi) {
typedef typename math::TVec3<scalar_t> scalar3_t;
const scalar3_t p_0 = {
v_ptr[v_sV * vi.x + v_sC * 0], v_ptr[v_sV * vi.x + v_sC * 1], v_ptr[v_sV * vi.x + v_sC * 2]};
const scalar3_t p_1 = {
v_ptr[v_sV * vi.y + v_sC * 0], v_ptr[v_sV * vi.y + v_sC * 1], v_ptr[v_sV * vi.y + v_sC * 2]};
const scalar3_t p_2 = {
v_ptr[v_sV * vi.z + v_sC * 0], v_ptr[v_sV * vi.z + v_sC * 1], v_ptr[v_sV * vi.z + v_sC * 2]};
return normalize(cross(p_0 - p_2, p_1 - p_0));
}
template <typename scalar_t>
__device__ math::TVec2<scalar_t> get_db_dp(
const math::TVec2<scalar_t>& n_varying_,
const math::TVec2<scalar_t>& n_fixed_) {
/*
Computes derivative of the point position with respect to edge displacement
Args:
- n_varying_: Projection of the normal of the movable triangle onto the plane of
consideration (XZ or YZ) N x 3 x H x W.
- n_fixed_: Projection of the normal of the fixed triangle onto the plane of consideration
(XZ or YZ) N x 3 x H x W.
Please refer to the paper "Rasterized Edge Gradients: Handling Discontinuities Differentiably"
for details.
*/
typedef typename math::TVec2<scalar_t> scalar2_t;
const auto n_varying = normalize(n_varying_);
const auto n_fixed = normalize(n_fixed_);
const scalar2_t b = {-n_fixed.y, n_fixed.x};
const auto b_dot_varyingg = dot(b, n_varying);
return b.x / epsclamp(b_dot_varyingg) * n_varying;
}
template <typename scalar_t, typename index_t>
__device__ math::TVec3<scalar_t> load_vec3_if_valid(
const scalar_t* __restrict ptr,
index_t stride,
bool valid,
const math::TVec3<scalar_t>& def) {
if (valid) {
return {ptr[0 * stride], ptr[1 * stride], ptr[2 * stride]};
}
return def;
}
template <typename scalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(256)
__global__ void edge_grad_backward_kernel(
const index_t nthreads,
TensorInfo<scalar_t, index_t> v_pix,
TensorInfo<scalar_t, index_t> img,
TensorInfo<int32_t, index_t> index_img,
TensorInfo<int32_t, index_t> vi,
TensorInfo<scalar_t, index_t> grad_output,
TensorInfo<scalar_t, index_t> grad_v_pix_img,
const index_t memory_span) {
typedef typename math::TVec2<scalar_t> scalar2_t;
typedef typename math::TVec3<scalar_t> scalar3_t;
const index_t v_pix_sN = v_pix.strides[0];
const index_t v_pix_sV = v_pix.strides[1];
const index_t v_pix_sC = v_pix.strides[2];
const index_t C = img.sizes[1];
const index_t H = img.sizes[2];
const index_t W = img.sizes[3];
const index_t V = v_pix.sizes[1];
const index_t index_img_sN = index_img.strides[0];
const index_t index_img_sH = index_img.strides[1];
const index_t index_img_sW = index_img.strides[2];
const index_t img_sN = img.strides[0];
const index_t img_sC = img.strides[1];
const index_t img_sH = img.strides[2];
const index_t img_sW = img.strides[3];
const index_t grad_output_sN = grad_output.strides[0];
const index_t grad_output_sC = grad_output.strides[1];
const index_t grad_output_sH = grad_output.strides[2];
const index_t grad_output_sW = grad_output.strides[3];
const index_t grad_v_pix_img_sN = grad_v_pix_img.strides[0];
const index_t grad_v_pix_img_sC = grad_v_pix_img.strides[1];
const index_t grad_v_pix_img_sH = grad_v_pix_img.strides[2];
const index_t grad_v_pix_img_sW = grad_v_pix_img.strides[3];
const index_t vi_sV = vi.strides[0];
const index_t vi_sF = vi.strides[1];
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
const index_t x = index % W;
const index_t y = (index / W) % H;
const index_t n = index / (H * W);
if (x < (W - 1) && y < (H - 1)) {
// center-right-down (CRD)
//
// *--------*--------*
// | center | right |
// | (0, 0) | (1, 0) |
// *--------*--------*
// | down |
// | (0, 1) |
// *--------*
// Computing indicator variables
// BEGIN
// triangle indices of CRD pixels
const int32_t* __restrict index_img_ptr = index_img.data + n * index_img_sN;
const int32_t center_index = index_img_ptr[(y + 0) * index_img_sH + (x + 0) * index_img_sW];
const int32_t right_index = index_img_ptr[(y + 0) * index_img_sH + (x + 1) * index_img_sW];
const int32_t down_index = index_img_ptr[(y + 1) * index_img_sH + (x + 0) * index_img_sW];
// valid mask
const bool c_valid = (center_index >= 0);
const bool r_valid = (right_index >= 0);
const bool d_valid = (down_index >= 0);
// vertex indices of triangles of CRD pixels
// 0,0,0 - if not valid
const int3 vi_pt_center = load_vec3_if_valid<int32_t, index_t>(
vi.data + center_index * vi_sV, vi_sF, c_valid, {0, 0, 0});
const int3 vi_pt_right = load_vec3_if_valid<int32_t, index_t>(
vi.data + right_index * vi_sV, vi_sF, r_valid, {0, 0, 0});
const int3 vi_pt_down = load_vec3_if_valid<int32_t, index_t>(
vi.data + down_index * vi_sV, vi_sF, d_valid, {0, 0, 0});
// center <-> right differ
const bool lr_diff = (center_index != right_index);
// center <-> down differ
const bool ud_diff = (center_index != down_index);
// if horizontal pair (vertical edge) composed of two triangles
const bool x_both_valid = c_valid && r_valid;
// if vertical pair (horizontal edge) composed of two triangles
const bool y_both_valid = c_valid && d_valid;
// Get CRD triangle info
const scalar_t* __restrict v_pix_ptr = v_pix.data + n * v_pix_sN;
const auto tri_center = get_tri_info(v_pix_ptr, v_pix_sV, v_pix_sC, vi_pt_center);
const auto tri_right = get_tri_info(v_pix_ptr, v_pix_sV, v_pix_sC, vi_pt_right);
const auto tri_down = get_tri_info(v_pix_ptr, v_pix_sV, v_pix_sC, vi_pt_down);
// Compute indicators of edge type
const bool center_pix_in_right_tri = lr_diff && x_both_valid && pix_in_tri(tri_right, x, y);
const bool right_pix_in_center_tri =
lr_diff && x_both_valid && pix_in_tri(tri_center, x + 1, y);
const bool center_pix_in_down_tri = ud_diff && y_both_valid && pix_in_tri(tri_down, x, y);
const bool down_pix_in_center_tri =
ud_diff && y_both_valid && pix_in_tri(tri_center, x, y + 1);
// Overlap flags
const bool l_over_r = center_pix_in_right_tri && (!right_pix_in_center_tri);
const bool r_over_l = right_pix_in_center_tri && (!center_pix_in_right_tri);
const bool u_over_d = center_pix_in_down_tri && (!down_pix_in_center_tri);
const bool d_over_u = down_pix_in_center_tri && (!center_pix_in_down_tri);
// Intersection flags
const bool horiz_int = center_pix_in_right_tri && right_pix_in_center_tri;
const bool vert_int = center_pix_in_down_tri && down_pix_in_center_tri;
// Intersection flags
const bool horiz_adjacent =
lr_diff && x_both_valid && (!center_pix_in_right_tri && !right_pix_in_center_tri);
const bool vert_adjacent =
ud_diff && y_both_valid && (!center_pix_in_down_tri && !down_pix_in_center_tri);
// END
// Compute image gradient dot output gradient from backward
// This is computed regardless of the edge type as long as there is an edge (lr_diff or
// ud_diff) BEGIN
const scalar_t* __restrict img_ptr = img.data + img_sN * n;
const scalar_t* __restrict grad_output_ptr = grad_output.data + grad_output_sN * n;
scalar_t grad_dot_x = 0.f;
scalar_t grad_dot_y = 0.f;
if (lr_diff) {
const scalar_t* __restrict img_ptr_right = img_ptr + y * img_sH + (x + 1) * img_sW;
const scalar_t* __restrict img_ptr_center = img_ptr + y * img_sH + (x + 0) * img_sW;
const scalar_t* __restrict grad_output_ptr_right =
grad_output_ptr + y * grad_output_sH + (x + 1) * grad_output_sW;
const scalar_t* __restrict grad_output_ptr_center =
grad_output_ptr + y * grad_output_sH + (x + 0) * grad_output_sW;
for (size_t c = 0; c < C; ++c) {
grad_dot_x += (img_ptr_right[c * img_sC] - img_ptr_center[c * img_sC]) *
(0.5f *
(grad_output_ptr_right[c * grad_output_sC] +
grad_output_ptr_center[c * grad_output_sC]));
}
}
if (ud_diff) {
const scalar_t* __restrict img_ptr_down = img_ptr + (y + 1) * img_sH + x * img_sW;
const scalar_t* __restrict img_ptr_center = img_ptr + (y + 0) * img_sH + x * img_sW;
const scalar_t* __restrict grad_output_ptr_down =
grad_output_ptr + (y + 1) * grad_output_sH + x * grad_output_sW;
const scalar_t* __restrict grad_output_ptr_center =
grad_output_ptr + (y + 0) * grad_output_sH + x * grad_output_sW;
for (size_t c = 0; c < C; ++c) {
grad_dot_y += (img_ptr_down[c * img_sC] - img_ptr_center[c * img_sC]) *
(0.5f *
(grad_output_ptr_down[c * grad_output_sC] +
grad_output_ptr_center[c * grad_output_sC]));
}
}
// END
scalar3_t grad_v_pix_center = {0.f, 0.f, 0.f};
scalar3_t grad_v_pix_right = {0.f, 0.f, 0.f};
scalar3_t grad_v_pix_down = {0.f, 0.f, 0.f};
const scalar3_t center_normal = get_tri_normal(v_pix_ptr, v_pix_sV, v_pix_sC, vi_pt_center);
const scalar3_t right_normal = get_tri_normal(v_pix_ptr, v_pix_sV, v_pix_sC, vi_pt_right);
const scalar3_t down_normal = get_tri_normal(v_pix_ptr, v_pix_sV, v_pix_sC, vi_pt_down);
if (!horiz_int) {
grad_v_pix_center.x += (!c_valid || r_over_l || horiz_adjacent) ? 0.f : grad_dot_x;
grad_v_pix_right.x += (!r_valid || l_over_r || horiz_adjacent) ? 0.f : grad_dot_x;
} else {
// Center triangle moves, right fixed.
scalar2_t dbx_dp = get_db_dp<scalar_t>(
{center_normal.x, center_normal.z}, {right_normal.x, right_normal.z});
grad_v_pix_center.x += grad_dot_x * dbx_dp.x;
grad_v_pix_center.z += grad_dot_x * dbx_dp.y;
// Center triangle fixed, right moves.
dbx_dp = get_db_dp<scalar_t>(
{right_normal.x, right_normal.z}, {center_normal.x, center_normal.z});
grad_v_pix_right.x += grad_dot_x * dbx_dp.x;
grad_v_pix_right.z += grad_dot_x * dbx_dp.y;
}
if (!vert_int) {
grad_v_pix_center.y += (!c_valid || d_over_u || vert_adjacent) ? 0.f : grad_dot_y;
grad_v_pix_down.y += (!d_valid || u_over_d || vert_adjacent) ? 0.f : grad_dot_y;
} else {
// Center triangle moves, lower fixed.
scalar2_t dby_dp =
get_db_dp<scalar_t>({center_normal.y, center_normal.z}, {down_normal.y, down_normal.z});
grad_v_pix_center.y += grad_dot_y * dby_dp.x;
grad_v_pix_center.z += grad_dot_y * dby_dp.x;
// Center triangle fixed, lower moves.
dby_dp =
get_db_dp<scalar_t>({down_normal.y, down_normal.z}, {center_normal.y, center_normal.z});
grad_v_pix_down.y += grad_dot_y * dby_dp.x;
grad_v_pix_down.z += grad_dot_y * dby_dp.x;
}
// Writing grads out
// BEGIN
scalar_t* __restrict grad_v_pix_img_ptr = grad_v_pix_img.data + grad_v_pix_img_sN * n;
// center
auto* ptr_c = grad_v_pix_img_ptr + (y + 0) * grad_v_pix_img_sH + (x + 0) * grad_v_pix_img_sW;
atomicAdd(ptr_c + 0 * grad_v_pix_img_sC, -grad_v_pix_center.x);
atomicAdd(ptr_c + 1 * grad_v_pix_img_sC, -grad_v_pix_center.y);
atomicAdd(ptr_c + 2 * grad_v_pix_img_sC, -grad_v_pix_center.z);
// right
auto* ptr_r = grad_v_pix_img_ptr + (y + 0) * grad_v_pix_img_sH + (x + 1) * grad_v_pix_img_sW;
atomicAdd(ptr_r + 0 * grad_v_pix_img_sC, -grad_v_pix_right.x);
atomicAdd(ptr_r + 1 * grad_v_pix_img_sC, -grad_v_pix_right.y);
atomicAdd(ptr_r + 2 * grad_v_pix_img_sC, -grad_v_pix_right.z);
// down
auto* ptr_d = grad_v_pix_img_ptr + (y + 1) * grad_v_pix_img_sH + (x + 0) * grad_v_pix_img_sW;
atomicAdd(ptr_d + 0 * grad_v_pix_img_sC, -grad_v_pix_down.x);
atomicAdd(ptr_d + 1 * grad_v_pix_img_sC, -grad_v_pix_down.y);
atomicAdd(ptr_d + 2 * grad_v_pix_img_sC, -grad_v_pix_down.z);
// END
}
}
}
template <typename scalar_t, typename index_type>
void edge_grad_estimator_cuda_backward_(
const int64_t count,
const torch::Tensor& v_pix,
const torch::Tensor& img,
const torch::Tensor& index_img,
const torch::Tensor& vi,
const torch::Tensor& grad_outputs,
const torch::Tensor& grad_v_pix_img) {
edge_grad_backward_kernel<scalar_t, index_type>
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<index_type>(count),
getTensorInfo<scalar_t, index_type>(v_pix),
getTensorInfo<scalar_t, index_type>(img),
getTensorInfo<int32_t, index_type>(index_img),
getTensorInfo<int32_t, index_type>(vi),
getTensorInfo<scalar_t, index_type>(grad_outputs),
getTensorInfo<scalar_t, index_type>(grad_v_pix_img),
grad_v_pix_img.numel());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
torch::Tensor edge_grad_estimator_cuda_backward(
const torch::Tensor& v_pix,
const torch::Tensor& img,
const torch::Tensor& index_img,
const torch::Tensor& vi,
const torch::Tensor& grad_outputs) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(img));
const auto N = img.sizes()[0];
const auto C = img.sizes()[1];
const auto H = img.sizes()[2];
const auto W = img.sizes()[3];
const auto V = v_pix.sizes()[1];
const auto count = N * H * W;
auto grad_v_pix_img = torch::zeros({N, 3, H, W}, v_pix.options());
if (count > 0) {
AT_DISPATCH_FLOATING_TYPES(v_pix.scalar_type(), "edge_grad_estimator_kernel", [&] {
if (at::native::canUse32BitIndexMath(v_pix) && at::native::canUse32BitIndexMath(img) &&
at::native::canUse32BitIndexMath(index_img) && at::native::canUse32BitIndexMath(vi) &&
at::native::canUse32BitIndexMath(grad_outputs) &&
at::native::canUse32BitIndexMath(grad_v_pix_img)) {
edge_grad_estimator_cuda_backward_<scalar_t, int>(
count, v_pix, img, index_img, vi, grad_outputs, grad_v_pix_img);
} else {
edge_grad_estimator_cuda_backward_<scalar_t, int64_t>(
count, v_pix, img, index_img, vi, grad_outputs, grad_v_pix_img);
}
});
}
return grad_v_pix_img;
}
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
torch::Tensor edge_grad_estimator_cuda_backward(
const torch::Tensor& v_pix,
const torch::Tensor& img,
const torch::Tensor& index_img,
const torch::Tensor& vi,
const torch::Tensor& grad_outputs);
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#include <torch/script.h>
#include <ATen/autocast_mode.h>
#ifndef NO_PYBIND
#include <torch/extension.h>
#endif
#include "edge_grad_kernel.h"
// Dispatch function
torch::Tensor edge_grad_estimator(
const torch::Tensor& v_pix,
const torch::Tensor& v_pix_img,
const torch::Tensor& vi,
const torch::Tensor& img,
const torch::Tensor& index_img) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("edge_grad_ext::edge_grad_estimator", "")
.typed<decltype(edge_grad_estimator)>();
return op.call(v_pix, v_pix_img, vi, img, index_img);
}
torch::Tensor edge_grad_estimator_fwd(
const torch::Tensor& v_pix,
const torch::Tensor& v_pix_img,
const torch::Tensor& vi,
const torch::Tensor& img,
const torch::Tensor& index_img) {
TORCH_CHECK(
v_pix.defined() && v_pix_img.defined() && vi.defined() && img.defined() &&
index_img.defined(),
"edge_grad_estimator(): expected all inputs to be defined");
TORCH_CHECK(
(v_pix.device() == v_pix_img.device()) && (v_pix.device() == vi.device()) &&
(v_pix.device() == img.device()) && (v_pix.device() == index_img.device()) &&
(v_pix.is_cuda()),
"edge_grad_estimator(): expected all inputs to be on same cuda device");
TORCH_CHECK(
v_pix.is_floating_point() && v_pix_img.is_floating_point() && img.is_floating_point(),
"edge_grad_estimator(): expected v_pix, v_pix_img, and img to have floating point type, but v_pix has ",
v_pix.dtype(),
" v_pix has ",
v_pix_img.dtype(),
" img has ",
img.dtype());
TORCH_CHECK(
vi.dtype() == torch::kInt32,
"edge_grad_estimator(): expected vi to have int32 type, but vi has ",
vi.dtype());
TORCH_CHECK(
index_img.dtype() == torch::kInt32,
"edge_grad_estimator(): expected index_img to have int32 type, but index_img has ",
index_img.dtype());
TORCH_CHECK(
v_pix.layout() == torch::kStrided && v_pix_img.layout() == torch::kStrided &&
vi.layout() == torch::kStrided && img.layout() == torch::kStrided &&
index_img.layout() == torch::kStrided,
"edge_grad_estimator(): expected all inputs to have torch.strided layout");
TORCH_CHECK(
(v_pix.dim() == 3) && (v_pix_img.dim() == 4) && (vi.dim() == 2) && (img.dim() == 4) &&
(index_img.dim() == 3),
"edge_grad_estimator(): expected v_pix.ndim == 3, v_pix_img.ndim == 4, vi.ndim == 2, img.ndim == 4, index_img.ndim == 3, "
"but got v_pix with sizes ",
v_pix.sizes(),
" and v_pix_img with sizes ",
v_pix_img.sizes(),
" and vi with sizes ",
vi.sizes(),
" and img with sizes ",
img.sizes(),
" and index_img with sizes ",
index_img.sizes());
TORCH_CHECK(
v_pix.size(0) == v_pix_img.size(0) && v_pix.size(0) == img.size(0) &&
v_pix.size(0) == index_img.size(0),
"edge_grad_estimator(): expected v and index_img to have same batch size, "
"but got v_pix with sizes ",
v_pix.sizes(),
", v_pix_img with sizes ",
v_pix_img.sizes(),
", img with sizes ",
img.sizes(),
" and index_img with sizes ",
index_img.sizes());
TORCH_CHECK(
v_pix.size(2) == 3 && v_pix_img.size(1) == 3 && vi.size(1) == 3,
"edge_grad_estimator(): expected third dim of v_pix to be of size 3, and second dim of vi to be of size 3, but got ",
v_pix.size(2),
" in the third dim of v_pix, and ",
v_pix_img.size(1),
" in the second dim of v_pix_img, and ",
vi.size(1),
" in the second dim of vi");
TORCH_CHECK(
v_pix_img.size(3) == img.size(3) && v_pix_img.size(3) == index_img.size(2) &&
v_pix_img.size(2) == img.size(2) && v_pix_img.size(2) == index_img.size(1),
"edge_grad_estimator(): expected width and height of v_pix_img, img, and index_img to match, but got size of v_pix_img: ",
v_pix_img.sizes(),
", size of img: ",
img.sizes(),
", size of index_img: ",
index_img.sizes());
return img;
}
// Ideally we would need to turn off autograd handling and re-dispatch, but we just call
// cuda kernels directly
class EdgeGradEstimatorFunction : public torch::autograd::Function<EdgeGradEstimatorFunction> {
public:
static torch::autograd::tensor_list forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& v_pix,
const torch::Tensor& v_pix_img,
const torch::Tensor& vi,
const torch::Tensor& img,
const torch::Tensor& index_img) {
ctx->set_materialize_grads(false);
ctx->save_for_backward({v_pix, img, index_img, vi});
ctx->saved_data["v_pix_img_requires_grad"] = v_pix_img.requires_grad();
return {img};
}
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
// If v_pix_img doesn't require grad, we don't need to do anything.
if (!ctx->saved_data["v_pix_img_requires_grad"].toBool()) {
return {torch::Tensor(), torch::Tensor(), torch::Tensor(), grad_outputs[0], torch::Tensor()};
}
const auto saved = ctx->get_saved_variables();
const auto& v_pix = saved[0];
const auto& img = saved[1];
const auto& index_img = saved[2];
const auto& vi = saved[3];
auto grad_v_pix_img =
edge_grad_estimator_cuda_backward(v_pix, img, index_img, vi, grad_outputs[0]);
return {torch::Tensor(), grad_v_pix_img, torch::Tensor(), grad_outputs[0], torch::Tensor()};
}
};
torch::Tensor edge_grad_estimator_autograd(
const torch::Tensor& v_pix,
const torch::Tensor& v_pix_img,
const torch::Tensor& vi,
const torch::Tensor& img,
const torch::Tensor& index_img) {
return EdgeGradEstimatorFunction::apply(v_pix, v_pix_img, vi, img, index_img)[0];
}
torch::Tensor edge_grad_estimator_autocast(
const torch::Tensor& v_pix,
const torch::Tensor& v_pix_img,
const torch::Tensor& vi,
const torch::Tensor& img,
const torch::Tensor& index_img) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return edge_grad_estimator(
at::autocast::cached_cast(torch::kFloat32, v_pix),
at::autocast::cached_cast(torch::kFloat32, v_pix_img),
vi,
at::autocast::cached_cast(torch::kFloat32, img),
index_img)[0];
}
#ifndef NO_PYBIND
// Just so that we can import this file as a Python module to get the path and
// import the Torch ops.
PYBIND11_MODULE(edge_grad_ext, m) {}
#endif
TORCH_LIBRARY(edge_grad_ext, m) {
m.def(
"edge_grad_estimator(Tensor v_pix, Tensor v_pix_img, Tensor vi, Tensor img, Tensor index_img) -> Tensor");
}
TORCH_LIBRARY_IMPL(edge_grad_ext, Autograd, m) {
m.impl("edge_grad_estimator", &edge_grad_estimator_autograd);
}
TORCH_LIBRARY_IMPL(edge_grad_ext, Autocast, m) {
m.impl("edge_grad_estimator", edge_grad_estimator_autocast);
}
TORCH_LIBRARY_IMPL(edge_grad_ext, CUDA, m) {
m.impl("edge_grad_estimator", &edge_grad_estimator_fwd);
}
// Copyright (c) Meta Platforms, Inc. and affiliates.
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <cuda_runtime.h>
#include <algorithm>
#include <cassert>
#include <limits>
// The header provides uniform GLSL-like math API for the following three cases:
// - non-NVCC compiler
// - NVCC compiler, host code
// - NVCC compiler, device code
// Designed to be a more flexible replacement of similar header from NVidia
#ifndef __CUDACC__
#define MH_NON_NVCC
#else
#define MH_NVCC
#ifdef __CUDA_ARCH__
#define MH_NVCC_DEVICE
#else
#define MH_NVCC_HOST
#endif
#endif
#if defined(MH_NVCC_HOST) || defined(MH_NON_NVCC)
#define HOST_DEVICE_DISPATCH(HOST_CODE, DEVICE_CODE) (HOST_CODE)
#elif defined(MH_NVCC_DEVICE)
#define HOST_DEVICE_DISPATCH(HOST_CODE, DEVICE_CODE) (DEVICE_CODE)
#else
#error Dispatch failed
#endif
// if not NVCC, need to include cmath, since certain builtin NVCC functions have
// equivalent ones in cmath
#ifdef MH_NON_NVCC
#include <cmath>
#endif
#define CHD_FUNC constexpr inline __host__ __device__
#define HD_FUNC inline __host__ __device__
namespace math {
template <typename T>
struct epsilon;
template <>
struct epsilon<float> {
static constexpr float value = 1e-8f;
};
template <>
struct epsilon<double> {
static constexpr double value = 1e-16;
};
// Host and device version of saturate
// Note that unfortunately `__saturatef` aka `saturate` is a device only
// function. If you do `using namespace math` you would still have to use math
// namespace for scalars: math::saturate
HD_FUNC float saturate(float a) {
return HOST_DEVICE_DISPATCH(
fminf(fmaxf(a, 0.0f), 1.0f),
__saturatef(a) // __saturatef is a device only function
);
}
// There is no CUDA intrinsic for saturate for double type
HD_FUNC double saturate(double a) {
return fmin(fmax(a, 0.0), 1.0);
}
// If NVCC then use builtin abs/max/min/sqrt/rsqrt.
// All of them have overloads for ints, floats, and doubles,defined in
// `cuda/crt/math_functions.hpp` thus no need for explicit usage of e.g. fabsf
#if defined(MH_NVCC)
using ::abs;
using ::max;
using ::min;
using ::rsqrt;
using ::sqrt;
#else
// Otherwise use the ones from cmath
using std::abs;
using std::max;
using std::min;
using std::sqrt;
inline double rsqrt(double v) {
return 1.0 / std::sqrt(v);
}
inline float rsqrt(float v) {
return 1.0f / std::sqrt(v);
}
#endif
namespace detail {
// Provide overloads of norm3d/norm4d for floats and doubles
HD_FUNC float norm3d(float a, float b, float c) {
return HOST_DEVICE_DISPATCH(
sqrt(a * a + b * b + c * c), ::norm3df(a, b, c) // norm3df is device only
);
}
HD_FUNC double norm3d(double a, double b, double c) {
return HOST_DEVICE_DISPATCH(
sqrt(a * a + b * b + c * c), ::norm3d(a, b, c) // norm3d is device only
);
}
HD_FUNC float rnorm3d(float a, float b, float c) {
return HOST_DEVICE_DISPATCH(
1.0f / sqrt(a * a + b * b + c * c), ::rnorm3df(a, b, c) // rnorm3df is device only
);
}
HD_FUNC double rnorm3d(double a, double b, double c) {
return HOST_DEVICE_DISPATCH(
1.0 / sqrt(a * a + b * b + c * c), ::rnorm3d(a, b, c) // rnorm3d is device only
);
}
HD_FUNC float norm4d(float a, float b, float c, float d) {
return HOST_DEVICE_DISPATCH(
sqrt(a * a + b * b + c * c + d * d), ::norm4df(a, b, c, d) // norm4df is device only
);
}
HD_FUNC double norm4d(double a, double b, double c, double d) {
return HOST_DEVICE_DISPATCH(
sqrt(a * a + b * b + c * c + d * d), ::norm4d(a, b, c, d) // norm4d is device only
);
}
HD_FUNC float rnorm4d(float a, float b, float c, float d) {
return HOST_DEVICE_DISPATCH(
1.0f / sqrt(a * a + b * b + c * c + d * d), ::rnorm4df(a, b, c, d) // rnorm4df is device only
);
}
HD_FUNC double rnorm4d(double a, double b, double c, double d) {
return HOST_DEVICE_DISPATCH(
1.0 / sqrt(a * a + b * b + c * c + d * d), ::rnorm4d(a, b, c, d) // rnorm4d is device only
);
}
} // namespace detail
// Unary operators
#define UNARY_OP(T, T2, T3, T4) \
CHD_FUNC T2 operator+(T2 const& v) { \
return v; \
} \
CHD_FUNC T2 operator-(T2 const& v) { \
return {-v.x, -v.y}; \
} \
CHD_FUNC T3 operator+(T3 const& v) { \
return v; \
} \
CHD_FUNC T3 operator-(T3 const& v) { \
return {-v.x, -v.y, -v.z}; \
} \
CHD_FUNC T4 operator+(T4 const& v) { \
return v; \
} \
CHD_FUNC T4 operator-(T4 const& v) { \
return {-v.x, -v.y, -v.z, -v.w}; \
}
// -- Binary arithmetic operators --
#define BINARY_ARITHM_OP(T, T2, T3, T4) \
CHD_FUNC T2 operator+(T2 const& v, T scalar) { \
return {v.x + scalar, v.y + scalar}; \
} \
CHD_FUNC T2 operator+(T scalar, T2 const& v) { \
return {scalar + v.x, scalar + v.y}; \
} \
CHD_FUNC T2 operator+(T2 const& v1, T2 const& v2) { \
return {v1.x + v2.x, v1.y + v2.y}; \
} \
CHD_FUNC T2 operator+=(T2& v, T scalar) { \
v.x += scalar; \
v.y += scalar; \
return v; \
} \
CHD_FUNC T2 operator+=(T2& v, T2 const& v2) { \
v.x += v2.x; \
v.y += v2.y; \
return v; \
} \
CHD_FUNC T2 operator-(T2 const& v, T scalar) { \
return {v.x - scalar, v.y - scalar}; \
} \
CHD_FUNC T2 operator-(T scalar, T2 const& v) { \
return {scalar - v.x, scalar - v.y}; \
} \
CHD_FUNC T2 operator-(T2 const& v1, T2 const& v2) { \
return {v1.x - v2.x, v1.y - v2.y}; \
} \
CHD_FUNC T2 operator-=(T2& v, T scalar) { \
v.x -= scalar; \
v.y -= scalar; \
return v; \
} \
CHD_FUNC T2 operator-=(T2& v, T2 const& v2) { \
v.x -= v2.x; \
v.y -= v2.y; \
return v; \
} \
CHD_FUNC T2 operator*(T2 const& v, T scalar) { \
return {v.x * scalar, v.y * scalar}; \
} \
CHD_FUNC T2 operator*(T scalar, T2 const& v) { \
return {scalar * v.x, scalar * v.y}; \
} \
CHD_FUNC T2 operator*(T2 const& v1, T2 const& v2) { \
return {v1.x * v2.x, v1.y * v2.y}; \
} \
CHD_FUNC T2 operator*=(T2& v, T scalar) { \
v.x *= scalar; \
v.y *= scalar; \
return v; \
} \
CHD_FUNC T2 operator*=(T2& v, T2 const& v2) { \
v.x *= v2.x; \
v.y *= v2.y; \
return v; \
} \
CHD_FUNC T2 operator/(T2 const& v, T scalar) { \
return {v.x / scalar, v.y / scalar}; \
} \
CHD_FUNC T2 operator/(T scalar, T2 const& v) { \
return {scalar / v.x, scalar / v.y}; \
} \
CHD_FUNC T2 operator/(T2 const& v1, T2 const& v2) { \
return {v1.x / v2.x, v1.y / v2.y}; \
} \
CHD_FUNC T2 operator/=(T2& v, T scalar) { \
v.x /= scalar; \
v.y /= scalar; \
return v; \
} \
CHD_FUNC T2 operator/=(T2& v, T2 const& v2) { \
v.x /= v2.x; \
v.y /= v2.y; \
return v; \
} \
CHD_FUNC T3 operator+(T3 const& v, T scalar) { \
return {v.x + scalar, v.y + scalar, v.z + scalar}; \
} \
CHD_FUNC T3 operator+(T scalar, T3 const& v) { \
return {scalar + v.x, scalar + v.y, scalar + v.z}; \
} \
CHD_FUNC T3 operator+(T3 const& v1, T3 const& v2) { \
return {v1.x + v2.x, v1.y + v2.y, v1.z + v2.z}; \
} \
CHD_FUNC T3 operator+=(T3& v, T scalar) { \
v.x += scalar; \
v.y += scalar; \
v.z += scalar; \
return v; \
} \
CHD_FUNC T3 operator+=(T3& v, T3 const& v2) { \
v.x += v2.x; \
v.y += v2.y; \
v.z += v2.z; \
return v; \
} \
CHD_FUNC T3 operator-(T3 const& v, T scalar) { \
return {v.x - scalar, v.y - scalar, v.z - scalar}; \
} \
CHD_FUNC T3 operator-(T scalar, T3 const& v) { \
return {scalar - v.x, scalar - v.y, scalar - v.z}; \
} \
CHD_FUNC T3 operator-(T3 const& v1, T3 const& v2) { \
return {v1.x - v2.x, v1.y - v2.y, v1.z - v2.z}; \
} \
CHD_FUNC T3 operator-=(T3& v, T scalar) { \
v.x -= scalar; \
v.y -= scalar; \
v.z -= scalar; \
return v; \
} \
CHD_FUNC T3 operator-=(T3& v, T3 const& v2) { \
v.x -= v2.x; \
v.y -= v2.y; \
v.z -= v2.z; \
return v; \
} \
CHD_FUNC T3 operator*(T3 const& v, T scalar) { \
return {v.x * scalar, v.y * scalar, v.z * scalar}; \
} \
CHD_FUNC T3 operator*(T scalar, T3 const& v) { \
return {scalar * v.x, scalar * v.y, scalar * v.z}; \
} \
CHD_FUNC T3 operator*(T3 const& v1, T3 const& v2) { \
return {v1.x * v2.x, v1.y * v2.y, v1.z * v2.z}; \
} \
CHD_FUNC T3 operator*=(T3& v, T scalar) { \
v.x *= scalar; \
v.y *= scalar; \
v.z *= scalar; \
return v; \
} \
CHD_FUNC T3 operator*=(T3& v, T3 const& v2) { \
v.x *= v2.x; \
v.y *= v2.y; \
v.z *= v2.z; \
return v; \
} \
CHD_FUNC T3 operator/(T3 const& v, T scalar) { \
return {v.x / scalar, v.y / scalar, v.z / scalar}; \
} \
CHD_FUNC T3 operator/(T scalar, T3 const& v) { \
return {scalar / v.x, scalar / v.y, scalar / v.z}; \
} \
CHD_FUNC T3 operator/(T3 const& v1, T3 const& v2) { \
return {v1.x / v2.x, v1.y / v2.y, v1.z / v2.z}; \
} \
CHD_FUNC T3 operator/=(T3& v, T scalar) { \
v.x /= scalar; \
v.y /= scalar; \
v.z /= scalar; \
return v; \
} \
CHD_FUNC T3 operator/=(T3& v, T3 const& v2) { \
v.x /= v2.x; \
v.y /= v2.y; \
v.z /= v2.z; \
return v; \
} \
CHD_FUNC T4 operator+(T4 const& v, T scalar) { \
return {v.x + scalar, v.y + scalar, v.z + scalar, v.w + scalar}; \
} \
CHD_FUNC T4 operator+(T scalar, T4 const& v) { \
return {scalar + v.x, scalar + v.y, scalar + v.z, scalar + v.w}; \
} \
CHD_FUNC T4 operator+(T4 const& v1, T4 const& v2) { \
return {v1.x + v2.x, v1.y + v2.y, v1.z + v2.z, v1.w + v2.w}; \
} \
CHD_FUNC T4 operator+=(T4& v, T scalar) { \
v.x += scalar; \
v.y += scalar; \
v.z += scalar; \
v.w += scalar; \
return v; \
} \
CHD_FUNC T4 operator+=(T4& v, T4 const& v2) { \
v.x += v2.x; \
v.y += v2.y; \
v.z += v2.z; \
v.w += v2.w; \
return v; \
} \
CHD_FUNC T4 operator-(T4 const& v, T scalar) { \
return {v.x - scalar, v.y - scalar, v.z - scalar, v.w - scalar}; \
} \
CHD_FUNC T4 operator-(T scalar, T4 const& v) { \
return {scalar - v.x, scalar - v.y, scalar - v.z, scalar - v.w}; \
} \
CHD_FUNC T4 operator-(T4 const& v1, T4 const& v2) { \
return {v1.x - v2.x, v1.y - v2.y, v1.z - v2.z, v1.w - v2.w}; \
} \
CHD_FUNC T4 operator-=(T4& v, T scalar) { \
v.x -= scalar; \
v.y -= scalar; \
v.z -= scalar; \
v.w -= scalar; \
return v; \
} \
CHD_FUNC T4 operator-=(T4& v, T4 const& v2) { \
v.x -= v2.x; \
v.y -= v2.y; \
v.z -= v2.z; \
v.w -= v2.w; \
return v; \
} \
CHD_FUNC T4 operator*(T4 const& v, T scalar) { \
return {v.x * scalar, v.y * scalar, v.z * scalar, v.w * scalar}; \
} \
CHD_FUNC T4 operator*(T scalar, T4 const& v) { \
return {scalar * v.x, scalar * v.y, scalar * v.z, scalar * v.w}; \
} \
CHD_FUNC T4 operator*(T4 const& v1, T4 const& v2) { \
return {v1.x * v2.x, v1.y * v2.y, v1.z * v2.z, v1.w * v2.w}; \
} \
CHD_FUNC T4 operator*=(T4& v, T scalar) { \
v.x *= scalar; \
v.y *= scalar; \
v.z *= scalar; \
v.w *= scalar; \
return v; \
} \
CHD_FUNC T4 operator*=(T4& v, T4 const& v2) { \
v.x *= v2.x; \
v.y *= v2.y; \
v.z *= v2.z; \
v.w *= v2.w; \
return v; \
} \
CHD_FUNC T4 operator/(T4 const& v, T scalar) { \
return {v.x / scalar, v.y / scalar, v.z / scalar, v.w / scalar}; \
} \
CHD_FUNC T4 operator/(T scalar, T4 const& v) { \
return {scalar / v.x, scalar / v.y, scalar / v.z, scalar / v.w}; \
} \
CHD_FUNC T4 operator/(T4 const& v1, T4 const& v2) { \
return {v1.x / v2.x, v1.y / v2.y, v1.z / v2.z, v1.w / v2.w}; \
} \
CHD_FUNC T4 operator/=(T4& v, T scalar) { \
v.x /= scalar; \
v.y /= scalar; \
v.z /= scalar; \
v.w /= scalar; \
return v; \
} \
CHD_FUNC T4 operator/=(T4& v, T4 const& v2) { \
v.x /= v2.x; \
v.y /= v2.y; \
v.z /= v2.z; \
v.w /= v2.w; \
return v; \
}
#define BINARY_INT_OP(T, T2, T3, T4) \
CHD_FUNC T2 operator%(T2 const& v, T scalar) { \
return {(T)(v.x % scalar), (T)(v.y % scalar)}; \
} \
CHD_FUNC T2 operator%(T scalar, T2 const& v) { \
return {(T)(scalar % v.x), (T)(scalar % v.y)}; \
} \
CHD_FUNC T2 operator%(T2 const& v1, T2 const& v2) { \
return {(T)(v1.x % v2.x), (T)(v1.y % v2.y)}; \
} \
CHD_FUNC T3 operator%(T3 const& v, T scalar) { \
return {(T)(v.x % scalar), (T)(v.y % scalar), (T)(v.z % scalar)}; \
} \
CHD_FUNC T3 operator%(T scalar, T3 const& v) { \
return {(T)(scalar % v.x), (T)(scalar % v.y), (T)(scalar % v.z)}; \
} \
CHD_FUNC T3 operator%(T3 const& v1, T3 const& v2) { \
return {(T)(v1.x % v2.x), (T)(v1.y % v2.y), (T)(v1.z % v2.z)}; \
} \
CHD_FUNC T4 operator%(T4 const& v, T scalar) { \
return {(T)(v.x % scalar), (T)(v.y % scalar), (T)(v.z % scalar), (T)(v.w % scalar)}; \
} \
CHD_FUNC T4 operator%(T scalar, T4 const& v) { \
return {(T)(scalar % v.x), (T)(scalar % v.y), (T)(scalar % v.z), (T)(scalar % v.w)}; \
} \
CHD_FUNC T4 operator%(T4 const& v1, T4 const& v2) { \
return {(T)(v1.x % v2.x), (T)(v1.y % v2.y), (T)(v1.z % v2.z), (T)(v1.w % v2.w)}; \
}
// -- Binary bit operators --
#define BINARY_BIT_OP(T, T2, T3, T4) \
CHD_FUNC T2 operator&(T2 const& v, T scalar) { \
return {v.x & scalar, v.y & scalar}; \
} \
CHD_FUNC T2 operator&(T scalar, T2 const& v) { \
return {scalar & v.x, scalar & v.y}; \
} \
CHD_FUNC T2 operator&(T2 const& v1, T2 const& v2) { \
return {v1.x & v2.x, v1.y & v2.y}; \
} \
CHD_FUNC T2 operator|(T2 const& v, T scalar) { \
return {v.x | scalar, v.y | scalar}; \
} \
CHD_FUNC T2 operator|(T scalar, T2 const& v) { \
return {scalar | v.x, scalar | v.y}; \
} \
CHD_FUNC T2 operator|(T2 const& v1, T2 const& v2) { \
return {v1.x | v2.x, v1.y | v2.y}; \
} \
CHD_FUNC T2 operator^(T2 const& v, T scalar) { \
return {v.x ^ scalar, v.y ^ scalar}; \
} \
CHD_FUNC T2 operator^(T scalar, T2 const& v) { \
return {scalar ^ v.x, scalar ^ v.y}; \
} \
CHD_FUNC T2 operator^(T2 const& v1, T2 const& v2) { \
return {v1.x ^ v2.x, v1.y ^ v2.y}; \
} \
CHD_FUNC T2 operator<<(T2 const& v, T scalar) { \
return {v.x << scalar, v.y << scalar}; \
} \
CHD_FUNC T2 operator<<(T scalar, T2 const& v) { \
return {scalar << v.x, scalar << v.y}; \
} \
CHD_FUNC T2 operator<<(T2 const& v1, T2 const& v2) { \
return {v1.x << v2.x, v1.y << v2.y}; \
} \
CHD_FUNC T2 operator>>(T2 const& v, T scalar) { \
return {v.x >> scalar, v.y >> scalar}; \
} \
CHD_FUNC T2 operator>>(T scalar, T2 const& v) { \
return {scalar >> v.x, scalar >> v.y}; \
} \
CHD_FUNC T2 operator>>(T2 const& v1, T2 const& v2) { \
return {v1.x >> v2.x, v1.y >> v2.y}; \
} \
CHD_FUNC T2 operator~(T2 const& v) { \
return {~v.x, ~v.y}; \
} \
CHD_FUNC T3 operator&(T3 const& v, T scalar) { \
return {v.x & scalar, v.y & scalar, v.z & scalar}; \
} \
CHD_FUNC T3 operator&(T scalar, T3 const& v) { \
return {scalar & v.x, scalar & v.y, scalar & v.z}; \
} \
CHD_FUNC T3 operator&(T3 const& v1, T3 const& v2) { \
return {v1.x & v2.x, v1.y & v2.y, v1.z & v2.z}; \
} \
CHD_FUNC T3 operator|(T3 const& v, T scalar) { \
return {v.x | scalar, v.y | scalar, v.z | scalar}; \
} \
CHD_FUNC T3 operator|(T scalar, T3 const& v) { \
return {scalar | v.x, scalar | v.y, scalar | v.z}; \
} \
CHD_FUNC T3 operator|(T3 const& v1, T3 const& v2) { \
return {v1.x | v2.x, v1.y | v2.y, v1.z | v2.z}; \
} \
CHD_FUNC T3 operator^(T3 const& v, T scalar) { \
return {v.x ^ scalar, v.y ^ scalar, v.z ^ scalar}; \
} \
CHD_FUNC T3 operator^(T scalar, T3 const& v) { \
return {scalar ^ v.x, scalar ^ v.y, scalar ^ v.z}; \
} \
CHD_FUNC T3 operator^(T3 const& v1, T3 const& v2) { \
return {v1.x ^ v2.x, v1.y ^ v2.y, v1.z ^ v2.z}; \
} \
CHD_FUNC T3 operator<<(T3 const& v, T scalar) { \
return {v.x << scalar, v.y << scalar, v.z << scalar}; \
} \
CHD_FUNC T3 operator<<(T scalar, T3 const& v) { \
return {scalar << v.x, scalar << v.y, scalar << v.z}; \
} \
CHD_FUNC T3 operator<<(T3 const& v1, T3 const& v2) { \
return {v1.x << v2.x, v1.y << v2.y, v1.z << v2.z}; \
} \
CHD_FUNC T3 operator>>(T3 const& v, T scalar) { \
return {v.x >> scalar, v.y >> scalar, v.z >> scalar}; \
} \
CHD_FUNC T3 operator>>(T scalar, T3 const& v) { \
return {scalar >> v.x, scalar >> v.y, scalar >> v.z}; \
} \
CHD_FUNC T3 operator>>(T3 const& v1, T3 const& v2) { \
return {v1.x >> v2.x, v1.y >> v2.y, v1.z >> v2.z}; \
} \
CHD_FUNC T3 operator~(T3 const& v) { \
return {~v.x, ~v.y, ~v.z}; \
} \
CHD_FUNC T4 operator&(T4 const& v, T scalar) { \
return {v.x & scalar, v.y & scalar, v.z & scalar, v.w & scalar}; \
} \
CHD_FUNC T4 operator&(T scalar, T4 const& v) { \
return {scalar & v.x, scalar & v.y, scalar & v.z, scalar & v.w}; \
} \
CHD_FUNC T4 operator&(T4 const& v1, T4 const& v2) { \
return {v1.x & v2.x, v1.y & v2.y, v1.z & v2.z, v1.w & v2.w}; \
} \
CHD_FUNC T4 operator|(T4 const& v, T scalar) { \
return {v.x | scalar, v.y | scalar, v.z | scalar, v.w | scalar}; \
} \
CHD_FUNC T4 operator|(T scalar, T4 const& v) { \
return {scalar | v.x, scalar | v.y, scalar | v.z, scalar | v.w}; \
} \
CHD_FUNC T4 operator|(T4 const& v1, T4 const& v2) { \
return {v1.x | v2.x, v1.y | v2.y, v1.z | v2.z, v1.w | v2.w}; \
} \
CHD_FUNC T4 operator^(T4 const& v, T scalar) { \
return {v.x ^ scalar, v.y ^ scalar, v.z ^ scalar, v.w ^ scalar}; \
} \
CHD_FUNC T4 operator^(T scalar, T4 const& v) { \
return {scalar ^ v.x, scalar ^ v.y, scalar ^ v.z, scalar ^ v.w}; \
} \
CHD_FUNC T4 operator^(T4 const& v1, T4 const& v2) { \
return {v1.x ^ v2.x, v1.y ^ v2.y, v1.z ^ v2.z, v1.w ^ v2.w}; \
} \
CHD_FUNC T4 operator<<(T4 const& v, T scalar) { \
return {v.x << scalar, v.y << scalar, v.z << scalar, v.w << scalar}; \
} \
CHD_FUNC T4 operator<<(T scalar, T4 const& v) { \
return {scalar << v.x, scalar << v.y, scalar << v.z, scalar << v.w}; \
} \
CHD_FUNC T4 operator<<(T4 const& v1, T4 const& v2) { \
return {v1.x << v2.x, v1.y << v2.y, v1.z << v2.z, v1.w << v2.w}; \
} \
CHD_FUNC T4 operator>>(T4 const& v, T scalar) { \
return {v.x >> scalar, v.y >> scalar, v.z >> scalar, v.w >> scalar}; \
} \
CHD_FUNC T4 operator>>(T scalar, T4 const& v) { \
return {scalar >> v.x, scalar >> v.y, scalar >> v.z, scalar >> v.w}; \
} \
CHD_FUNC T4 operator>>(T4 const& v1, T4 const& v2) { \
return {v1.x >> v2.x, v1.y >> v2.y, v1.z >> v2.z, v1.w >> v2.w}; \
} \
CHD_FUNC T4 operator~(T4 const& v) { \
return {~v.x, ~v.y, ~v.z, ~v.w}; \
}
#define BINARY_EQ_OP(T, T2, T3, T4) \
CHD_FUNC bool operator==(T2 const& v1, T2 const& v2) { \
return v1.x == v2.x && v1.y == v2.y; \
} \
CHD_FUNC bool operator!=(T2 const& v1, T2 const& v2) { \
return !(v1 == v2); \
} \
CHD_FUNC bool operator==(T3 const& v1, T3 const& v2) { \
return v1.x == v2.x && v1.y == v2.y && v1.z == v2.z; \
} \
CHD_FUNC bool operator!=(T3 const& v1, T3 const& v2) { \
return !(v1 == v2); \
} \
CHD_FUNC bool operator==(T4 const& v1, T4 const& v2) { \
return v1.x == v2.x && v1.y == v2.y && v1.z == v2.z && v1.w == v2.w; \
} \
CHD_FUNC bool operator!=(T4 const& v1, T4 const& v2) { \
return !(v1 == v2); \
}
// These apply for all types
#define OTHER_FUNC_ALL(T, T2, T3, T4) \
CHD_FUNC bool all_less(T2 const& v1, T2 const& v2) { \
return (v1.x < v2.x) && (v1.y < v2.y); \
} \
CHD_FUNC bool all_less_or_eq(T2 const& v1, T2 const& v2) { \
return (v1.x <= v2.x) && (v1.y <= v2.y); \
} \
CHD_FUNC bool all_greater(T2 const& v1, T2 const& v2) { \
return (v1.x > v2.x) && (v1.y > v2.y); \
} \
CHD_FUNC bool all_greater_or_eq(T2 const& v1, T2 const& v2) { \
return (v1.x >= v2.x) && (v1.y >= v2.y); \
} \
CHD_FUNC bool all_less(T3 const& v1, T3 const& v2) { \
return (v1.x < v2.x) && (v1.y < v2.y) && (v1.z < v2.z); \
} \
CHD_FUNC bool all_less_or_eq(T3 const& v1, T3 const& v2) { \
return (v1.x <= v2.x) && (v1.y <= v2.y) && (v1.z <= v2.z); \
} \
CHD_FUNC bool all_greater(T3 const& v1, T3 const& v2) { \
return (v1.x > v2.x) && (v1.y > v2.y) && (v1.z > v2.z); \
} \
CHD_FUNC bool all_greater_or_eq(T3 const& v1, T3 const& v2) { \
return (v1.x >= v2.x) && (v1.y >= v2.y) && (v1.z >= v2.z); \
} \
CHD_FUNC bool all_less(T4 const& v1, T4 const& v2) { \
return (v1.x < v2.x) && (v1.y < v2.y) && (v1.z < v2.z) && (v1.w < v2.w); \
} \
CHD_FUNC bool all_less_or_eq(T4 const& v1, T4 const& v2) { \
return (v1.x <= v2.x) && (v1.y <= v2.y) && (v1.z <= v2.z) && (v1.w <= v2.w); \
} \
CHD_FUNC bool all_greater(T4 const& v1, T4 const& v2) { \
return (v1.x > v2.x) && (v1.y > v2.y) && (v1.z > v2.z) && (v1.w > v2.w); \
} \
CHD_FUNC bool all_greater_or_eq(T4 const& v1, T4 const& v2) { \
return (v1.x >= v2.x) && (v1.y >= v2.y) && (v1.z >= v2.z) && (v1.w >= v2.w); \
} \
CHD_FUNC bool any_less(T2 const& v1, T2 const& v2) { \
return (v1.x < v2.x) || (v1.y < v2.y); \
} \
CHD_FUNC bool any_less_or_eq(T2 const& v1, T2 const& v2) { \
return (v1.x <= v2.x) || (v1.y <= v2.y); \
} \
CHD_FUNC bool any_greater(T2 const& v1, T2 const& v2) { \
return (v1.x > v2.x) || (v1.y > v2.y); \
} \
CHD_FUNC bool any_greater_or_eq(T2 const& v1, T2 const& v2) { \
return (v1.x >= v2.x) || (v1.y >= v2.y); \
} \
CHD_FUNC bool any_less(T3 const& v1, T3 const& v2) { \
return (v1.x < v2.x) || (v1.y < v2.y) || (v1.z < v2.z); \
} \
CHD_FUNC bool any_less_or_eq(T3 const& v1, T3 const& v2) { \
return (v1.x <= v2.x) || (v1.y <= v2.y) || (v1.z <= v2.z); \
} \
CHD_FUNC bool any_greater(T3 const& v1, T3 const& v2) { \
return (v1.x > v2.x) || (v1.y > v2.y) || (v1.z > v2.z); \
} \
CHD_FUNC bool any_greater_or_eq(T3 const& v1, T3 const& v2) { \
return (v1.x >= v2.x) || (v1.y >= v2.y) || (v1.z >= v2.z); \
} \
CHD_FUNC bool any_less(T4 const& v1, T4 const& v2) { \
return (v1.x < v2.x) || (v1.y < v2.y) || (v1.z < v2.z) || (v1.w < v2.w); \
} \
CHD_FUNC bool any_less_or_eq(T4 const& v1, T4 const& v2) { \
return (v1.x <= v2.x) || (v1.y <= v2.y) || (v1.z <= v2.z) || (v1.w <= v2.w); \
} \
CHD_FUNC bool any_greater(T4 const& v1, T4 const& v2) { \
return (v1.x > v2.x) || (v1.y > v2.y) || (v1.z > v2.z) || (v1.w > v2.w); \
} \
CHD_FUNC bool any_greater_or_eq(T4 const& v1, T4 const& v2) { \
return (v1.x >= v2.x) || (v1.y >= v2.y) || (v1.z >= v2.z) || (v1.w >= v2.w); \
} \
HD_FUNC T2 max(T2 const& v1, T const& v2) { \
return {max(v1.x, v2), max(v1.y, v2)}; \
} \
HD_FUNC T2 max(T2 const& v1, T2 const& v2) { \
return {max(v1.x, v2.x), max(v1.y, v2.y)}; \
} \
HD_FUNC T2 min(T2 const& v1, T const& v2) { \
return {min(v1.x, v2), min(v1.y, v2)}; \
} \
HD_FUNC T2 min(T2 const& v1, T2 const& v2) { \
return {min(v1.x, v2.x), min(v1.y, v2.y)}; \
} \
HD_FUNC T3 max(T3 const& v1, T const& v2) { \
return {max(v1.x, v2), max(v1.y, v2), max(v1.z, v2)}; \
} \
HD_FUNC T3 max(T3 const& v1, T3 const& v2) { \
return {max(v1.x, v2.x), max(v1.y, v2.y), max(v1.z, v2.z)}; \
} \
HD_FUNC T3 min(T3 const& v1, T const& v2) { \
return {min(v1.x, v2), min(v1.y, v2), min(v1.z, v2)}; \
} \
HD_FUNC T3 min(T3 const& v1, T3 const& v2) { \
return {min(v1.x, v2.x), min(v1.y, v2.y), min(v1.z, v2.z)}; \
} \
HD_FUNC T4 max(T4 const& v1, T const& v2) { \
return {max(v1.x, v2), max(v1.y, v2), max(v1.z, v2), max(v1.w, v2)}; \
} \
HD_FUNC T4 max(T4 const& v1, T4 const& v2) { \
return {max(v1.x, v2.x), max(v1.y, v2.y), max(v1.z, v2.z), max(v1.w, v2.w)}; \
} \
HD_FUNC T4 min(T4 const& v1, T const& v2) { \
return {min(v1.x, v2), min(v1.y, v2), min(v1.z, v2), min(v1.w, v2)}; \
} \
HD_FUNC T4 min(T4 const& v1, T4 const& v2) { \
return {min(v1.x, v2.x), min(v1.y, v2.y), min(v1.z, v2.z), min(v1.w, v2.w)}; \
} \
HD_FUNC T clamp(T v, T _min, T _max) { \
return min(max(v, _min), _max); \
} \
HD_FUNC T2 clamp(T2 v, T2 _min, T2 _max) { \
return min(max(v, _min), _max); \
} \
HD_FUNC T3 clamp(T3 v, T3 _min, T3 _max) { \
return min(max(v, _min), _max); \
} \
HD_FUNC T4 clamp(T4 v, T4 _min, T4 _max) { \
return min(max(v, _min), _max); \
} \
HD_FUNC T2 clamp(T2 v, T _min, T _max) { \
return min(max(v, _min), _max); \
} \
HD_FUNC T3 clamp(T3 v, T _min, T _max) { \
return min(max(v, _min), _max); \
} \
HD_FUNC T4 clamp(T4 v, T _min, T _max) { \
return min(max(v, _min), _max); \
} \
CHD_FUNC T mix(T v1, T v2, bool a) { \
return a ? v2 : v1; \
} \
CHD_FUNC T2 mix(T2 v1, T2 v2, bool a) { \
return a ? v2 : v1; \
} \
CHD_FUNC T3 mix(T3 v1, T3 v2, bool a) { \
return a ? v2 : v1; \
} \
CHD_FUNC T4 mix(T4 v1, T4 v2, bool a) { \
return a ? v2 : v1; \
}
// These apply for all types, but unsigned ones
#define ABS_FUNC(T, T2, T3, T4) \
HD_FUNC T2 abs(T2 const& v) { \
return {abs(v.x), abs(v.y)}; \
} \
HD_FUNC T3 abs(T3 const& v) { \
return {abs(v.x), abs(v.y), abs(v.z)}; \
} \
HD_FUNC T4 abs(T4 const& v) { \
return {abs(v.x), abs(v.y), abs(v.z), abs(v.w)}; \
}
// Make functions
#define MAKE_FUNC(T, T2, T3, T4) \
HD_FUNC T2 make_##T2(T scalar) { \
return {scalar, scalar}; \
} \
HD_FUNC T3 make_##T3(T scalar) { \
return {scalar, scalar, scalar}; \
} \
HD_FUNC T4 make_##T4(T scalar) { \
return {scalar, scalar, scalar, scalar}; \
} \
HD_FUNC T3 make_##T3(T2 const& v, T scalar) { \
return {v.x, v.y, scalar}; \
} \
HD_FUNC T3 make_##T3(T scalar, T2 const& v) { \
return {scalar, v.x, v.y}; \
} \
HD_FUNC T4 make_##T4(T2 const& v1, T2 const& v2) { \
return {v1.x, v1.y, v2.x, v2.y}; \
} \
HD_FUNC T4 make_##T4(T2 const& v, T scalar1, T scalar2) { \
return {v.x, v.y, scalar1, scalar2}; \
} \
HD_FUNC T4 make_##T4(T scalar1, T scalar2, T2 const& v) { \
return {scalar1, scalar2, v.x, v.y}; \
} \
HD_FUNC T4 make_##T4(T scalar1, T2 const& v, T scalar2) { \
return {scalar1, v.x, v.y, scalar2}; \
} \
HD_FUNC T4 make_##T4(T3 const& v, T scalar) { \
return {v.x, v.y, v.z, scalar}; \
} \
HD_FUNC T4 make_##T4(T scalar, T3 const& v) { \
return {scalar, v.x, v.y, v.z}; \
} \
HD_FUNC T2 make_##T2(T3 const& v) { \
return {v.x, v.y}; \
} \
HD_FUNC T2 make_##T2(T4 const& v) { \
return {v.x, v.y}; \
} \
HD_FUNC T3 make_##T3(T4 const& v) { \
return {v.x, v.y, v.z}; \
}
#define OTHER_FUNC_INT(T, T2, T3, T4) \
CHD_FUNC T floor_div(T a, T b) { \
T t = 1 - a / b; \
return (a + t * b) / b - t; \
} \
CHD_FUNC T2 floor_div(T2 const& v1, T2 const& v2) { \
return {floor_div(v1.x, v2.x), floor_div(v1.y, v2.y)}; \
} \
CHD_FUNC T2 floor_div(T2 const& v1, T v2) { \
return {floor_div(v1.x, v2), floor_div(v1.y, v2)}; \
} \
CHD_FUNC T3 floor_div(T3 const& v1, T3 const& v2) { \
return {floor_div(v1.x, v2.x), floor_div(v1.y, v2.y), floor_div(v1.z, v2.z)}; \
} \
CHD_FUNC T3 floor_div(T3 const& v1, T v2) { \
return {floor_div(v1.x, v2), floor_div(v1.y, v2), floor_div(v1.z, v2)}; \
} \
CHD_FUNC T4 floor_div(T4 const& v1, T4 const& v2) { \
return { \
floor_div(v1.x, v2.x), \
floor_div(v1.y, v2.y), \
floor_div(v1.z, v2.z), \
floor_div(v1.w, v2.w)}; \
} \
CHD_FUNC T4 floor_div(T4 const& v1, T v2) { \
return {floor_div(v1.x, v2), floor_div(v1.y, v2), floor_div(v1.z, v2), floor_div(v1.w, v2)}; \
}
#define OTHER_FUNC_FP(T, T2, T3, T4) \
CHD_FUNC T dot(T2 a, T2 b) { \
return a.x * b.x + a.y * b.y; \
} \
CHD_FUNC T dot(T3 a, T3 b) { \
return a.x * b.x + a.y * b.y + a.z * b.z; \
} \
CHD_FUNC T dot(T4 a, T4 b) { \
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w; \
} \
CHD_FUNC T cross(T2 a, T2 b) { \
return a.x * b.y - a.y * b.x; \
} \
CHD_FUNC T3 cross(T3 a, T3 b) { \
return {a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x}; \
} \
HD_FUNC T norm(T2 a) { \
return sqrt(dot(a, a)); \
} \
HD_FUNC T norm(T3 a) { \
return detail::norm3d(a.x, a.y, a.z); \
} \
HD_FUNC T norm(T4 a) { \
return detail::norm4d(a.x, a.y, a.z, a.w); \
} \
HD_FUNC T rnorm(T2 a) { \
return rsqrt(dot(a, a)); \
} \
HD_FUNC T rnorm(T3 a) { \
return detail::rnorm3d(a.x, a.y, a.z); \
} \
HD_FUNC T rnorm(T4 a) { \
return detail::rnorm4d(a.x, a.y, a.z, a.w); \
} \
HD_FUNC T2 normalize(T2 v) { \
T invLen = rnorm(v); \
return v * invLen; \
} \
HD_FUNC T3 normalize(T3 v) { \
T invLen = rnorm(v); \
return v * invLen; \
} \
HD_FUNC T4 normalize(T4 v) { \
T invLen = rnorm(v); \
return v * invLen; \
} \
HD_FUNC T2 saturate(T2 v) { \
return {saturate(v.x), saturate(v.y)}; \
} \
HD_FUNC T3 saturate(T3 v) { \
return {saturate(v.x), saturate(v.y), saturate(v.z)}; \
} \
HD_FUNC T4 saturate(T4 v) { \
return {saturate(v.x), saturate(v.y), saturate(v.z), saturate(v.w)}; \
} \
CHD_FUNC T sign(T v) { \
return v > 0 ? 1 : (v < 0 ? -1 : 0); \
} \
CHD_FUNC T2 sign(T2 v) { \
return {sign(v.x), sign(v.y)}; \
} \
CHD_FUNC T3 sign(T3 v) { \
return {sign(v.x), sign(v.y), sign(v.z)}; \
} \
CHD_FUNC T4 sign(T4 v) { \
return {sign(v.x), sign(v.y), sign(v.z), sign(v.w)}; \
} \
CHD_FUNC T mix(T v1, T v2, T a) { \
return v1 * (T(1.0) - a) + v2 * a; \
} \
CHD_FUNC T2 mix(T2 v1, T2 v2, T a) { \
return v1 * (T(1.0) - a) + v2 * a; \
} \
CHD_FUNC T3 mix(T3 v1, T3 v2, T a) { \
return v1 * (T(1.0) - a) + v2 * a; \
} \
CHD_FUNC T4 mix(T4 v1, T4 v2, T a) { \
return v1 * (T(1.0) - a) + v2 * a; \
} \
CHD_FUNC T2 mix(T2 v1, T2 v2, T2 a) { \
return v1 * (T(1.0) - a) + v2 * a; \
} \
CHD_FUNC T3 mix(T3 v1, T3 v2, T3 a) { \
return v1 * (T(1.0) - a) + v2 * a; \
} \
CHD_FUNC T4 mix(T4 v1, T4 v2, T4 a) { \
return v1 * (T(1.0) - a) + v2 * a; \
} \
CHD_FUNC T sum(T2 const& v) { \
return v.x + v.y; \
} \
CHD_FUNC T sum(T3 const& v) { \
return v.x + v.y + v.z; \
} \
CHD_FUNC T sum(T4 const& v) { \
return v.x + v.y + v.z + v.w; \
} \
HD_FUNC T epsclamp(T v, T eps) { \
return (v < 0) ? min(v, -eps) : max(v, eps); \
} \
HD_FUNC T epsclamp(T v) { \
return epsclamp(v, epsilon<T>::value); \
} \
HD_FUNC T2 epsclamp(T2 v) { \
return {epsclamp(v.x), epsclamp(v.y)}; \
} \
HD_FUNC T3 epsclamp(T3 v) { \
return {epsclamp(v.x), epsclamp(v.y), epsclamp(v.z)}; \
} \
HD_FUNC T4 epsclamp(T4 v) { \
return {epsclamp(v.x), epsclamp(v.y), epsclamp(v.z), epsclamp(v.w)}; \
} \
HD_FUNC T2 epsclamp(T2 v, T eps) { \
return {epsclamp(v.x, eps), epsclamp(v.y, eps)}; \
} \
HD_FUNC T3 epsclamp(T3 v, T eps) { \
return {epsclamp(v.x, eps), epsclamp(v.y, eps), epsclamp(v.z, eps)}; \
} \
HD_FUNC T4 epsclamp(T4 v, T eps) { \
return {epsclamp(v.x, eps), epsclamp(v.y, eps), epsclamp(v.z, eps), epsclamp(v.w, eps)}; \
} \
CHD_FUNC void inverse(const T2(&m)[2], T2(&out)[2]) { \
T det_m = T(1.0) / (m[0].x * m[1].y - m[0].y * m[1].x); \
out[0] = det_m * T2({m[1].y, -m[0].y}); \
out[1] = det_m * T2({-m[1].x, m[0].x}); \
} \
CHD_FUNC void inverse(const T3(&m)[3], T3(&out)[3]) { \
T det_m = T(1.0) / \
(+m[0].x * (m[1].y * m[2].z - m[1].z * m[2].y) - \
m[0].y * (m[1].x * m[2].z - m[1].z * m[2].x) + \
m[0].z * (m[1].x * m[2].y - m[1].y * m[2].x)); \
out[0] = det_m * \
T3({ \
+(m[1].y * m[2].z - m[2].y * m[1].z), \
-(m[0].y * m[2].z - m[2].y * m[0].z), \
+(m[0].y * m[1].z - m[1].y * m[0].z), \
}); \
out[1] = det_m * \
T3({ \
-(m[1].x * m[2].z - m[2].x * m[1].z), \
+(m[0].x * m[2].z - m[2].x * m[0].z), \
-(m[0].x * m[1].z - m[1].x * m[0].z), \
}); \
out[2] = det_m * \
T3({ \
+(m[1].x * m[2].y - m[2].x * m[1].y), \
-(m[0].x * m[2].y - m[2].x * m[0].y), \
+(m[0].x * m[1].y - m[1].x * m[0].y), \
}); \
} \
CHD_FUNC T2 mul(const T2(&r)[2], T2 v) { \
return T2({dot(r[0], v), dot(r[1], v)}); \
} \
CHD_FUNC T3 mul(const T3(&r)[3], T3 v) { \
return T3({dot(r[0], v), dot(r[1], v), dot(r[2], v)}); \
} \
CHD_FUNC T4 mul(const T4(&r)[4], T4 v) { \
return T4({dot(r[0], v), dot(r[1], v), dot(r[2], v), dot(r[3], v)}); \
} \
CHD_FUNC void mul(const T2(&a)[2], const T2(&b)[2], T2(&out)[2]) { \
out[0] = T2({dot(a[0], T2({b[0].x, b[1].x})), dot(a[0], T2({b[0].y, b[1].y}))}); \
out[1] = T2({dot(a[1], T2({b[0].x, b[1].x})), dot(a[1], T2({b[0].y, b[1].y}))}); \
} \
CHD_FUNC void mul(const T2(&a)[2], const T3(&b)[2], T3(&out)[2]) { \
out[0] = \
T3({dot(a[0], T2({b[0].x, b[1].x})), \
dot(a[0], T2({b[0].y, b[1].y})), \
dot(a[0], T2({b[0].z, b[1].z}))}); \
out[1] = \
T3({dot(a[1], T2({b[0].x, b[1].x})), \
dot(a[1], T2({b[0].y, b[1].y})), \
dot(a[1], T2({b[0].z, b[1].z}))}); \
} \
CHD_FUNC void mul(const T3(&a)[3], const T3(&b)[3], T3(&out)[3]) { \
out[0] = \
T3({dot(a[0], T3({b[0].x, b[1].x, b[2].x})), \
dot(a[0], T3({b[0].y, b[1].y, b[2].y})), \
dot(a[0], T3({b[0].z, b[1].z, b[2].z}))}); \
out[1] = \
T3({dot(a[1], T3({b[0].x, b[1].x, b[2].x})), \
dot(a[1], T3({b[0].y, b[1].y, b[2].y})), \
dot(a[1], T3({b[0].z, b[1].z, b[2].z}))}); \
out[2] = \
T3({dot(a[2], T3({b[0].x, b[1].x, b[2].x})), \
dot(a[2], T3({b[0].y, b[1].y, b[2].y})), \
dot(a[2], T3({b[0].z, b[1].z, b[2].z}))}); \
}
#define DEFINE_FUNC_FOR_UNSIGNED_INT(T, T2, T3, T4) \
UNARY_OP(T, T2, T3, T4) \
BINARY_ARITHM_OP(T, T2, T3, T4) \
BINARY_BIT_OP(T, T2, T3, T4) \
BINARY_EQ_OP(T, T2, T3, T4) \
BINARY_INT_OP(T, T2, T3, T4) \
OTHER_FUNC_ALL(T, T2, T3, T4) \
OTHER_FUNC_INT(T, T2, T3, T4) \
MAKE_FUNC(T, T2, T3, T4)
#define DEFINE_FUNC_FOR_SIGNED_INT(T, T2, T3, T4) \
DEFINE_FUNC_FOR_UNSIGNED_INT(T, T2, T3, T4) \
ABS_FUNC(T, T2, T3, T4)
#define DEFINE_FUNC_FOR_FLOAT(T, T2, T3, T4) \
UNARY_OP(T, T2, T3, T4) \
BINARY_ARITHM_OP(T, T2, T3, T4) \
BINARY_EQ_OP(T, T2, T3, T4) \
OTHER_FUNC_ALL(T, T2, T3, T4) \
OTHER_FUNC_FP(T, T2, T3, T4) \
ABS_FUNC(T, T2, T3, T4) \
MAKE_FUNC(T, T2, T3, T4)
DEFINE_FUNC_FOR_UNSIGNED_INT(unsigned int, uint2, uint3, uint4);
DEFINE_FUNC_FOR_SIGNED_INT(int, int2, int3, int4);
DEFINE_FUNC_FOR_FLOAT(float, float2, float3, float4);
DEFINE_FUNC_FOR_FLOAT(double, double2, double3, double4);
namespace detail {
template <typename scalar_t>
struct VecType;
}
// Type inference utils for writing templates
//
// Derive vector type given the scalar type:
// math::TVec2<float> a; // `a` is of type `float2`
// math::TVec3<int> b; // `b` is of type `int3`;
//
// Derive vector type given the vector size and scalar type:
// math::TVec<double, 4> c; // `c` is of type `double4`;
template <typename scalar_t>
using TVec1 = typename detail::VecType<scalar_t>::scalar1_t;
template <typename scalar_t>
using TVec2 = typename detail::VecType<scalar_t>::scalar2_t;
template <typename scalar_t>
using TVec3 = typename detail::VecType<scalar_t>::scalar3_t;
template <typename scalar_t>
using TVec4 = typename detail::VecType<scalar_t>::scalar4_t;
template <typename scalar_t, int D>
using TVec = typename detail::VecType<scalar_t>::template dim<D>::type;
namespace detail {
template <int D, template <typename scalar_t> class Vec, typename scalar_t>
struct VecD;
template <template <typename scalar_t> class Vec, typename scalar_t>
struct VecD<1, Vec, scalar_t> {
typedef typename Vec<scalar_t>::scalar1_t type;
};
template <template <typename scalar_t> class Vec, typename scalar_t>
struct VecD<2, Vec, scalar_t> {
typedef typename Vec<scalar_t>::scalar2_t type;
};
template <template <typename scalar_t> class Vec, typename scalar_t>
struct VecD<3, Vec, scalar_t> {
typedef typename Vec<scalar_t>::scalar3_t type;
};
template <template <typename scalar_t> class Vec, typename scalar_t>
struct VecD<4, Vec, scalar_t> {
typedef typename Vec<scalar_t>::scalar4_t type;
};
#define MH_TYPE_DECLARATION(TYPE, NAME) \
template <> \
struct VecType<TYPE> { \
typedef TYPE scalar_t; \
typedef NAME##1 scalar1_t; \
typedef NAME##2 scalar2_t; \
typedef NAME##3 scalar3_t; \
typedef NAME##4 scalar4_t; \
template <int D> \
struct dim { \
typedef typename detail::VecD<D, VecType, scalar_t>::type type; \
}; \
};
MH_TYPE_DECLARATION(float, float)
MH_TYPE_DECLARATION(double, double)
MH_TYPE_DECLARATION(char, char)
MH_TYPE_DECLARATION(unsigned char, uchar)
MH_TYPE_DECLARATION(short, short)
MH_TYPE_DECLARATION(unsigned short, ushort)
MH_TYPE_DECLARATION(int, int)
MH_TYPE_DECLARATION(unsigned int, uint)
} // namespace detail
} // namespace math
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <ATen/native/cuda/GridSampler.cuh>
#include <ATen/native/cuda/UpSample.cuh>
enum class GridSamplerInterpolation { Bilinear, Nearest, Bicubic };
using at::native::clip_coordinates;
using at::native::cubic_interp1d;
using at::native::fastAtomicAdd;
using at::native::get_cubic_upsampling_coefficients;
using at::native::grid_sampler_compute_source_index;
using at::native::grid_sampler_compute_source_index_set_grad;
using at::native::grid_sampler_unnormalize;
using at::native::grid_sampler_unnormalize_set_grad;
using at::native::reflect_coordinates;
using at::native::safe_downgrade_to_int_range;
using at::native::within_bounds_2d;
using at::native::detail::GridSamplerPadding;
template <typename scalar_t>
static __forceinline__ __device__ scalar_t
area_pixel_compute_source_index(scalar_t scale, int64_t dst_index, bool align_corners) {
if (align_corners) {
return scale * dst_index;
} else {
scalar_t src_idx = scale * (dst_index + 0.5) - 0.5;
return (src_idx < 0) ? scalar_t(0) : src_idx;
}
}
template <typename scalar_t>
static __forceinline__ __device__ scalar_t
area_pixel_compute_scale(int64_t input_size, int64_t output_size, bool align_corners) {
// see Note [area_pixel_compute_scale]
if (align_corners) {
if (output_size > 1) {
return static_cast<scalar_t>(input_size - 1) / (output_size - 1);
} else {
return static_cast<scalar_t>(0);
}
} else {
return static_cast<scalar_t>(input_size) / output_size;
}
}
template <typename scalar_t, typename index_t>
static __forceinline__ __device__ void safe_add_2d(
scalar_t* data,
int h,
int w,
int sH,
int sW,
int H,
int W,
scalar_t delta,
const index_t NC_offset,
const index_t memory_span) {
if (within_bounds_2d(h, w, H, W)) {
fastAtomicAdd(data, NC_offset + h * sH + w * sW, memory_span, delta, true);
}
}
template <typename scalar_t, typename index_t>
static __forceinline__ __device__ void add_2d(
scalar_t* data,
int h,
int w,
int sH,
int sW,
scalar_t delta,
const index_t NC_offset,
const index_t memory_span) {
fastAtomicAdd(data, NC_offset + h * sH + w * sW, memory_span, delta, true);
}
template <typename scalar_t>
static __forceinline__ __device__ scalar_t
compute_coordinates(scalar_t coord, int size, GridSamplerPadding padding_mode, bool align_corners) {
if (padding_mode == GridSamplerPadding::Border) {
// clip coordinates to image borders
coord = clip_coordinates(coord, size);
} else if (padding_mode == GridSamplerPadding::Reflection) {
// reflect coordinates by image borders
if (align_corners) {
coord = reflect_coordinates(coord, 0, 2 * (size - 1));
} else {
coord = reflect_coordinates(coord, -1, 2 * size - 1);
}
// clip coordinates to image borders
coord = clip_coordinates(coord, size);
}
coord = safe_downgrade_to_int_range(coord);
return coord;
}
template <typename scalar_t>
static __forceinline__ __device__ scalar_t get_value_bounded(
scalar_t* data,
scalar_t x,
scalar_t y,
int W,
int H,
int sW,
int sH,
GridSamplerPadding padding_mode,
bool align_corners) {
x = compute_coordinates(x, W, padding_mode, align_corners);
y = compute_coordinates(y, H, padding_mode, align_corners);
int ix = static_cast<int>(x);
int iy = static_cast<int>(y);
if (within_bounds_2d(iy, ix, H, W)) {
return data[iy * sH + ix * sW];
}
return static_cast<scalar_t>(0);
}
// Calculate the differential of the cubic convolution, i.e. `d coeff / d x`
template <typename scalar_t>
static __forceinline__ __device__ void get_cubic_coefficients_grad(scalar_t coeffs[4], scalar_t t) {
// Must be the same as forward calculation in
// aten/src/ATen/native/cuda/UpSample.cuh:get_cubic_upsample_coefficients
scalar_t A = -0.75;
scalar_t x;
x = -1 - t; // 1 < x = |-1 - tx| < 2
coeffs[0] = (-3 * A * x - 10 * A) * x - 8 * A;
x = -t; // x = |0 - tx| <= 1
coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x;
x = 1 - t; // x = |1 - tx| <= 1
coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x;
x = 2 - t; // 1 < x = |2 - tx| < 2
coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A;
}
template <typename scalar_t, typename index_t>
static __forceinline__ __device__ void add_value_bounded(
scalar_t* data,
scalar_t x,
scalar_t y,
int W,
int H,
int sW,
int sH,
scalar_t delta,
GridSamplerPadding padding_mode,
bool align_corners,
const index_t NC_offset,
const index_t memory_span) {
x = compute_coordinates(x, W, padding_mode, align_corners);
y = compute_coordinates(y, H, padding_mode, align_corners);
int ix = static_cast<int>(x);
int iy = static_cast<int>(y);
safe_add_2d(data, iy, ix, sH, sW, H, W, delta, NC_offset, memory_span);
}
template <typename scalar_t, int D>
__device__ __forceinline__ static math::TVec<scalar_t, D> cubic_interp1d(
math::TVec<scalar_t, D> x0,
math::TVec<scalar_t, D> x1,
math::TVec<scalar_t, D> x2,
math::TVec<scalar_t, D> x3,
scalar_t t) {
scalar_t coeffs[4];
get_cubic_upsampling_coefficients<scalar_t>(coeffs, t);
using namespace math;
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}
template <typename scalar_t, typename index_t>
inline __device__ typename math::TVec4<scalar_t> load4(scalar_t* ptr, index_t stride) {
return {ptr[0 * stride], ptr[1 * stride], ptr[2 * stride], ptr[3 * stride]};
}
template <typename scalar_t, typename index_t>
static __forceinline__ __device__ void safe_add_2d4(
scalar_t* data,
index_t stride,
int h,
int w,
int sH,
int sW,
int H,
int W,
math::TVec4<scalar_t> delta,
const index_t N_offset,
const index_t memory_span) {
if (within_bounds_2d(h, w, H, W)) {
auto ptr = N_offset + h * sH + w * sW;
fastAtomicAdd(data, ptr + 0 * stride, memory_span, delta.x, true);
fastAtomicAdd(data, ptr + 1 * stride, memory_span, delta.y, true);
fastAtomicAdd(data, ptr + 2 * stride, memory_span, delta.z, true);
fastAtomicAdd(data, ptr + 3 * stride, memory_span, delta.w, true);
}
}
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
using at::cuda::detail::getTensorInfo;
using at::cuda::detail::TensorInfo;
#define C10_CUDA_KERNEL_LAUNCH_CHECK() C10_CUDA_CHECK(cudaGetLastError())
// Use 1024 threads per block, which requires cuda sm_2x or above
constexpr int CUDA_NUM_THREADS = 1024;
// CUDA: number of blocks for threads.
inline int GET_BLOCKS(const int64_t N, const int64_t max_threads_per_block = CUDA_NUM_THREADS) {
TORCH_INTERNAL_ASSERT(N > 0, "CUDA kernel launch blocks must be positive, but got N=", N);
constexpr int64_t max_int = std::numeric_limits<int>::max();
// Round up division for positive number that cannot cause integer overflow
auto block_num = (N - 1) / max_threads_per_block + 1;
TORCH_INTERNAL_ASSERT(block_num <= max_int, "Can't schedule too many blocks on CUDA device");
return static_cast<int>(block_num);
}
// Dispatch macroses are updated in current pytorch.
// Which causes that the same code is compilable on DGX with the older pytorch
// but no longer compilable on prod
// Thus keeping these macroses here.
#undef AT_PRIVATE_CASE_TYPE
#undef AT_DISPATCH_FLOATING_TYPES
#undef AT_DISPATCH_FLOATING_TYPES_AND_HALF
#undef DISPATCH_FLOAT_AND_HALF
#define AT_PRIVATE_CASE_TYPE(enum_type, type, ...) \
case enum_type: { \
using scalar_t = type; \
return __VA_ARGS__(); \
}
// Dispatches for float and double
#define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
}()
// Dispatches for float, double, and half
#define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
}()
// Dispatches for float, double, and half
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, __VA_ARGS__)
// A simple stub to match the dispathcing for multimple types structure, but only for
// for float.
#define DISPATCH_FLOAT(NAME, ...) \
[&] { \
using scalar_t = float; \
return __VA_ARGS__(); \
}()
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
using at::cuda::detail::getTensorInfo;
using at::cuda::detail::TensorInfo;
// TensorInfoCompact is similar to TensorInfo but has fixed number of dims same as
// PackedTensorAccessor. It is supposed to be used on for CUDA `Tensor`s on the host when default
// constructor, assignment and copy constructors are needed, e.g. using in arrays in order to
// transfer them on the device when calling kernels. TensorInfo has a default, assignment and copy
// constructors, but PackedTensorAccessor does not. However TensorInfo is too large to be
// transferred in arrays when calling kernels. On the device, indexing of multidimensional tensors
// produces `TensorAccessor`s. Using RestrictPtrTraits as a default. If aliasing is possible (likely
// to be a very rare case) please use DefaultPtrTraits. Default constructor, assignment and copy
// constructors are only needed on the host aren't available on the device
template <
typename T,
typename index_t,
int N_DIMS,
template <typename> class PtrTraits = at::RestrictPtrTraits>
struct TensorInfoCompact {
typedef typename PtrTraits<T>::PtrType PtrType;
TensorInfoCompact(){};
__host__ TensorInfoCompact<T, index_t, N_DIMS, PtrTraits>& operator=(
const TensorInfoCompact<T, index_t, N_DIMS>& other) {
data = other.data;
for (int i = 0; i < N_DIMS; ++i) {
sizes[i] = other.sizes[i];
strides[i] = other.strides[i];
}
return *this;
};
__host__ TensorInfoCompact(const TensorInfoCompact<T, index_t, N_DIMS, PtrTraits>& other)
: data(other.data) {
for (int i = 0; i < N_DIMS; ++i) {
sizes[i] = other.sizes[i];
strides[i] = other.strides[i];
}
};
__host__ TensorInfoCompact(const TensorInfo<T, index_t>& other) : data(other.data) {
for (int i = 0; i < N_DIMS; ++i) {
sizes[i] = other.sizes[i];
strides[i] = other.strides[i];
}
}
__device__ at::TensorAccessor<T, N_DIMS - 1, PtrTraits, index_t> operator[](index_t i) {
index_t* new_sizes = sizes + 1;
index_t* new_strides = strides + 1;
return at::TensorAccessor<T, N_DIMS - 1, PtrTraits, index_t>(
data + strides[0] * i, new_sizes, new_strides);
}
__device__ const at::TensorAccessor<T, N_DIMS - 1, PtrTraits, index_t> operator[](
index_t i) const {
const index_t* new_sizes = sizes + 1;
const index_t* new_strides = strides + 1;
return at::TensorAccessor<T, N_DIMS - 1, PtrTraits, index_t>(
data + strides[0] * i, new_sizes, new_strides);
}
PtrType data;
index_t sizes[N_DIMS];
index_t strides[N_DIMS];
};
template <
typename scalar_t,
typename index_t,
int N_DIMS,
template <typename> class PtrTraits = at::RestrictPtrTraits>
TensorInfoCompact<scalar_t, index_t, N_DIMS, PtrTraits> getTensorInfoCompact(const at::Tensor& x) {
auto out = getTensorInfo<scalar_t, index_t>(x);
assert(out.dims == N_DIMS);
return out;
}
template <
typename T,
typename index_t,
int N,
int N_DIMS,
template <typename> class PtrTraits = at::RestrictPtrTraits>
struct TensorInfoList {
__device__ __host__ TensorInfoCompact<T, index_t, N_DIMS, PtrTraits>& operator[](int i) {
return data[i];
}
__device__ __host__ const TensorInfoCompact<T, index_t, N_DIMS, PtrTraits>& operator[](
int i) const {
return data[i];
}
TensorInfoCompact<T, index_t, N_DIMS, PtrTraits> data[N];
};
template <typename IndexType, int N>
struct IndexList {
__device__ __host__ IndexType& operator[](int i) {
return data[i];
}
__device__ __host__ const IndexType& operator[](int i) const {
return data[i];
}
IndexType data[N] = {0};
};
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#include <c10/cuda/CUDAGuard.h>
#include <cuda_math_helper.h>
#include <torch/types.h>
#include <ATen/native/cuda/KernelUtils.cuh>
#include <cub/cub.cuh>
#include <kernel_utils.h>
using at::native::fastAtomicAdd;
template <typename scalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(256)
__global__ void interpolate_kernel(
const index_t nthreads,
TensorInfo<scalar_t, index_t> vert_attributes,
TensorInfo<int32_t, index_t> vi,
TensorInfo<int32_t, index_t> index_img,
TensorInfo<scalar_t, index_t> bary_img,
TensorInfo<scalar_t, index_t> out_img) {
const index_t C = vert_attributes.sizes[2];
const index_t H = bary_img.sizes[2];
const index_t W = bary_img.sizes[3];
const index_t vert_attributes_sN = vert_attributes.strides[0];
const index_t vert_attributes_sV = vert_attributes.strides[1];
const index_t vert_attributes_sC = vert_attributes.strides[2];
const index_t vi_sV = vi.strides[0];
const index_t vi_sF = vi.strides[1];
const index_t index_img_sN = index_img.strides[0];
const index_t index_img_sH = index_img.strides[1];
const index_t index_img_sW = index_img.strides[2];
const index_t bary_img_sN = bary_img.strides[0];
const index_t bary_img_sB = bary_img.strides[1];
const index_t bary_img_sH = bary_img.strides[2];
const index_t bary_img_sW = bary_img.strides[3];
const index_t out_img_sN = out_img.strides[0];
const index_t out_img_sC = out_img.strides[1];
const index_t out_img_sH = out_img.strides[2];
const index_t out_img_sW = out_img.strides[3];
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
const index_t w = index % W;
const index_t h = (index / W) % H;
const index_t n = index / (H * W);
const int32_t tr_index = index_img.data[n * index_img_sN + h * index_img_sH + w * index_img_sW];
scalar_t* __restrict out_ptr = out_img.data + out_img_sN * n + out_img_sH * h + out_img_sW * w;
if (tr_index != -1) {
const int32_t* __restrict vi_ptr = vi.data + tr_index * vi_sV;
const int32_t vi_0 = vi_ptr[0 * vi_sF];
const int32_t vi_1 = vi_ptr[1 * vi_sF];
const int32_t vi_2 = vi_ptr[2 * vi_sF];
const scalar_t* __restrict vert_ptr = vert_attributes.data + vert_attributes_sN * n;
const scalar_t* vert_0_ptr = vert_ptr + vert_attributes_sV * vi_0;
const scalar_t* vert_1_ptr = vert_ptr + vert_attributes_sV * vi_1;
const scalar_t* vert_2_ptr = vert_ptr + vert_attributes_sV * vi_2;
const scalar_t* __restrict bary_ptr =
bary_img.data + bary_img_sN * n + bary_img_sH * h + bary_img_sW * w;
const scalar_t bary_0 = bary_ptr[0 * bary_img_sB];
const scalar_t bary_1 = bary_ptr[1 * bary_img_sB];
const scalar_t bary_2 = bary_ptr[2 * bary_img_sB];
for (int i = 0; i < C; ++i) {
scalar_t v0 = vert_0_ptr[i * vert_attributes_sC];
scalar_t v1 = vert_1_ptr[i * vert_attributes_sC];
scalar_t v2 = vert_2_ptr[i * vert_attributes_sC];
out_ptr[out_img_sC * i] = v0 * bary_0 + v1 * bary_1 + v2 * bary_2;
}
} else {
for (int i = 0; i < C; ++i) {
const scalar_t v[2] = {(w * 2.0f + 1.0f) / W - 1.0f, (h * 2.0f + 1.0f) / H - 1.0f};
out_ptr[out_img_sC * i] = v[i % 2];
}
}
}
}
template <typename scalar_t, typename index_t, bool bary_img_requires_grad, bool vert_requires_grad>
C10_LAUNCH_BOUNDS_1(256)
__global__ void interpolate_backward_kernel(
const index_t nthreads,
TensorInfo<scalar_t, index_t> grad_out,
TensorInfo<scalar_t, index_t> vert_attributes,
TensorInfo<int32_t, index_t> vi,
TensorInfo<int32_t, index_t> index_img,
TensorInfo<scalar_t, index_t> bary_img,
TensorInfo<scalar_t, index_t> vert_attributes_grad,
TensorInfo<scalar_t, index_t> bary_img_grad,
const index_t memory_span) {
index_t C = vert_attributes.sizes[2];
index_t H = bary_img.sizes[2];
index_t W = bary_img.sizes[3];
index_t vert_attributes_sN = vert_attributes.strides[0];
index_t vert_attributes_sV = vert_attributes.strides[1];
index_t vert_attributes_sC = vert_attributes.strides[2];
index_t vert_attributes_grad_sN = vert_attributes_grad.strides[0];
index_t vert_attributes_grad_sV = vert_attributes_grad.strides[1];
index_t vert_attributes_grad_sC = vert_attributes_grad.strides[2];
index_t vi_sV = vi.strides[0];
index_t vi_sF = vi.strides[1];
index_t index_img_sN = index_img.strides[0];
index_t index_img_sH = index_img.strides[1];
index_t index_img_sW = index_img.strides[2];
index_t bary_img_sN = bary_img.strides[0];
index_t bary_img_sB = bary_img.strides[1];
index_t bary_img_sH = bary_img.strides[2];
index_t bary_img_sW = bary_img.strides[3];
index_t bary_img_grad_sN = bary_img_grad.strides[0];
index_t bary_img_grad_sB = bary_img_grad.strides[1];
index_t bary_img_grad_sH = bary_img_grad.strides[2];
index_t bary_img_grad_sW = bary_img_grad.strides[3];
index_t grad_out_sN = grad_out.strides[0];
index_t grad_out_sC = grad_out.strides[1];
index_t grad_out_sH = grad_out.strides[2];
index_t grad_out_sW = grad_out.strides[3];
int64_t index = blockIdx.x * blockDim.x + threadIdx.x;
constexpr int warp_size = 32;
int lane = threadIdx.x % warp_size;
__shared__ typename cub::WarpReduce<scalar_t>::TempStorage temp_storage_0;
__shared__ typename cub::WarpReduce<scalar_t>::TempStorage temp_storage_1;
__shared__ typename cub::WarpReduce<scalar_t>::TempStorage temp_storage_2;
{
const index_t w = index % W;
const index_t h = (index / W) % H;
const index_t n = index / (H * W);
int32_t tr_index = -1;
if (index < nthreads)
tr_index = index_img.data[n * index_img_sN + h * index_img_sH + w * index_img_sW];
const scalar_t* __restrict grad_out_ptr =
grad_out.data + grad_out_sN * n + grad_out_sH * h + grad_out_sW * w;
scalar_t* __restrict bary_grad_ptr =
bary_img_grad.data + bary_img_grad_sN * n + bary_img_grad_sH * h + bary_img_grad_sW * w;
bool thread_is_used = tr_index != -1;
// True if at least one thread in the warp is used.
bool warp_is_used = __any_sync(0xFFFFFFFFU, thread_is_used);
if (warp_is_used) {
int32_t vi_0 = -1, vi_1 = -1, vi_2 = -1;
if (thread_is_used) {
vi_0 = vi.data[tr_index * vi_sV + 0 * vi_sF];
vi_1 = vi.data[tr_index * vi_sV + 1 * vi_sF];
vi_2 = vi.data[tr_index * vi_sV + 2 * vi_sF];
}
unsigned m = 0xFFFFFFFFU;
int vi_0_head = (__shfl_up_sync(m, vi_0, 1) != vi_0) || (lane == 0);
int vi_0_tail = (__shfl_down_sync(m, vi_0, 1) != vi_0) || (lane == (warp_size - 1));
int vi_1_head = (__shfl_up_sync(m, vi_1, 1) != vi_1) || (lane == 0);
int vi_1_tail = (__shfl_down_sync(m, vi_1, 1) != vi_1) || (lane == (warp_size - 1));
int vi_2_head = (__shfl_up_sync(m, vi_2, 1) != vi_2) || (lane == 0);
int vi_2_tail = (__shfl_down_sync(m, vi_2, 1) != vi_2) || (lane == (warp_size - 1));
const scalar_t* __restrict vert_ptr = vert_attributes.data + vert_attributes_sN * n;
const scalar_t* vert_0_ptr = vert_ptr + vert_attributes_sV * vi_0;
const scalar_t* vert_1_ptr = vert_ptr + vert_attributes_sV * vi_1;
const scalar_t* vert_2_ptr = vert_ptr + vert_attributes_sV * vi_2;
scalar_t* __restrict vert_grad_ptr = vert_attributes_grad.data + vert_attributes_grad_sN * n;
scalar_t* vert_0_grad_ptr = vert_grad_ptr + vert_attributes_grad_sV * vi_0;
scalar_t* vert_1_grad_ptr = vert_grad_ptr + vert_attributes_grad_sV * vi_1;
scalar_t* vert_2_grad_ptr = vert_grad_ptr + vert_attributes_grad_sV * vi_2;
const scalar_t* __restrict bary_ptr =
bary_img.data + bary_img_sN * n + bary_img_sH * h + bary_img_sW * w;
scalar_t bary_0, bary_1, bary_2;
if (thread_is_used && vert_requires_grad) {
bary_0 = bary_ptr[0 * bary_img_sB];
bary_1 = bary_ptr[1 * bary_img_sB];
bary_2 = bary_ptr[2 * bary_img_sB];
}
auto bary_0_grad = scalar_t(0.);
auto bary_1_grad = scalar_t(0.);
auto bary_2_grad = scalar_t(0.);
for (int i = 0; i < C; ++i) {
scalar_t g_out = grad_out_ptr[i * grad_out_sC];
if (thread_is_used && bary_img_requires_grad) {
scalar_t v0 = vert_0_ptr[i * vert_attributes_sC];
scalar_t v1 = vert_1_ptr[i * vert_attributes_sC];
scalar_t v2 = vert_2_ptr[i * vert_attributes_sC];
bary_0_grad += g_out * v0;
bary_1_grad += g_out * v1;
bary_2_grad += g_out * v2;
}
if (vert_requires_grad) {
scalar_t grad_v_0 =
cub::WarpReduce<scalar_t>(temp_storage_0).TailSegmentedSum(g_out * bary_0, vi_0_tail);
scalar_t grad_v_1 =
cub::WarpReduce<scalar_t>(temp_storage_1).TailSegmentedSum(g_out * bary_1, vi_1_tail);
scalar_t grad_v_2 =
cub::WarpReduce<scalar_t>(temp_storage_2).TailSegmentedSum(g_out * bary_2, vi_2_tail);
__syncthreads();
if (vi_0_head && thread_is_used)
fastAtomicAdd(
vert_0_grad_ptr, i * vert_attributes_grad_sC, memory_span, grad_v_0, true);
if (vi_1_head && thread_is_used)
fastAtomicAdd(
vert_1_grad_ptr, i * vert_attributes_grad_sC, memory_span, grad_v_1, true);
if (vi_2_head && thread_is_used)
fastAtomicAdd(
vert_2_grad_ptr, i * vert_attributes_grad_sC, memory_span, grad_v_2, true);
}
}
if (thread_is_used && bary_img_requires_grad) {
bary_grad_ptr[0 * bary_img_grad_sB] = bary_0_grad;
bary_grad_ptr[1 * bary_img_grad_sB] = bary_1_grad;
bary_grad_ptr[2 * bary_img_grad_sB] = bary_2_grad;
}
} else if ((index < nthreads) && bary_img_requires_grad) {
bary_grad_ptr[0 * bary_img_grad_sB] = scalar_t(0.);
bary_grad_ptr[1 * bary_img_grad_sB] = scalar_t(0.);
bary_grad_ptr[2 * bary_img_grad_sB] = scalar_t(0.);
}
}
}
torch::Tensor interpolate_cuda(
const torch::Tensor& vert_attributes,
const torch::Tensor& vi,
const torch::Tensor& index_img,
const torch::Tensor& bary_img) {
TORCH_CHECK(
vert_attributes.defined() && vi.defined() && index_img.defined() && bary_img.defined(),
"interpolate(): expected all inputs to be defined");
auto vert_attributes_opt = vert_attributes.options();
auto vi_opt = vi.options();
auto index_img_opt = index_img.options();
auto bary_img_opt = bary_img.options();
TORCH_CHECK(
(vert_attributes.device() == vi.device()) &&
(vert_attributes.device() == index_img.device()) &&
(vert_attributes.device() == bary_img.device()),
"interpolate(): expected all inputs to be on same device");
TORCH_CHECK(
vert_attributes.dtype() == bary_img.dtype(),
"interpolate(): expected vert_attributes and bary_img to have same dtype, but vert_attributes has ",
vert_attributes.dtype(),
" and bary_img has ",
bary_img.dtype());
TORCH_CHECK(
vert_attributes.is_floating_point(),
"interpolate(): expected vert_attributes to have floating point type, but v has ",
vert_attributes.dtype());
TORCH_CHECK(
vi.dtype() == torch::kInt32,
"interpolate(): expected vi to have int32 type, but vi has ",
vi.dtype());
TORCH_CHECK(
index_img.dtype() == torch::kInt32,
"interpolate(): expected index_img to have int32 type, but index_img has ",
index_img.dtype());
TORCH_CHECK(
vert_attributes.layout() == torch::kStrided && vi.layout() == torch::kStrided &&
index_img.layout() == torch::kStrided && bary_img.layout() == torch::kStrided,
"interpolate(): expected all inputs to have torch.strided layout");
TORCH_CHECK(
(vert_attributes.dim() == 3) && (vi.dim() == 2) && (index_img.dim() == 3) &&
(bary_img.dim() == 4),
"interpolate(): expected vert_attributes.ndim == 3, vi.ndim == 2, index_img.ndim == 3, bary_img.ndim == 4, "
"but got vert_attributes with sizes ",
vert_attributes.sizes(),
" and vi with sizes ",
vi.sizes(),
" and index_img with sizes ",
index_img.sizes(),
" and bary_img with sizes ",
bary_img.sizes());
TORCH_CHECK(
vert_attributes.size(0) == index_img.size(0) && vert_attributes.size(0) == bary_img.size(0),
"interpolate(): expected vert_attributes, index_img and bary_img to have same batch size, "
"but got vert_attributes with sizes ",
vert_attributes.sizes(),
" and index_img with sizes ",
index_img.sizes(),
" and bary_img with sizes ",
bary_img.sizes());
TORCH_CHECK(
vi.size(1) == 3 && bary_img.size(1) == 3,
"interpolate(): expected second dim of vi to be of size 3, and second dim of bary_img to be of size 3, but got ",
vi.size(1),
" in the second dim of vi, and ",
bary_img.size(1),
" in the second dim of bary_img");
TORCH_CHECK(
index_img.size(1) == bary_img.size(2) && index_img.size(2) == bary_img.size(3),
"interpolate(): expected H and W dims of index_img and bary_img to match");
const at::cuda::OptionalCUDAGuard device_guard(device_of(vert_attributes));
auto N = vert_attributes.size(0);
auto V = vert_attributes.size(1);
auto C = vert_attributes.size(2);
auto H = bary_img.size(2);
auto W = bary_img.size(3);
int64_t count = N * H * W;
auto output = at::empty({N, C, H, W}, vert_attributes.options());
if (count > 0) {
AT_DISPATCH_FLOATING_TYPES(vert_attributes.scalar_type(), "interpolate_kernel", [&] {
if (at::native::canUse32BitIndexMath(vert_attributes) &&
at::native::canUse32BitIndexMath(bary_img) &&
at::native::canUse32BitIndexMath(index_img) && at::native::canUse32BitIndexMath(vi)) {
typedef int index_type;
interpolate_kernel<scalar_t, index_type>
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<index_type>(count),
getTensorInfo<scalar_t, index_type>(vert_attributes),
getTensorInfo<int32_t, index_type>(vi),
getTensorInfo<int32_t, index_type>(index_img),
getTensorInfo<scalar_t, index_type>(bary_img),
getTensorInfo<scalar_t, index_type>(output));
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
typedef int64_t index_type;
interpolate_kernel<scalar_t, index_type>
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<index_type>(count),
getTensorInfo<scalar_t, index_type>(vert_attributes),
getTensorInfo<int32_t, index_type>(vi),
getTensorInfo<int32_t, index_type>(index_img),
getTensorInfo<scalar_t, index_type>(bary_img),
getTensorInfo<scalar_t, index_type>(output));
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
});
}
return output;
}
template <typename scalar_t, typename index_t, bool bary_img_requires_grad, bool vert_requires_grad>
void _interpolate_cuda_backward(
int64_t count,
const torch::Tensor& grad_out,
const torch::Tensor& vert_attributes,
const torch::Tensor& vi,
const torch::Tensor& index_img,
const torch::Tensor& bary_img,
const torch::Tensor& vert_attributes_grad,
const torch::Tensor& bary_img_grad) {
interpolate_backward_kernel<scalar_t, index_t, bary_img_requires_grad, vert_requires_grad>
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<index_t>(count),
getTensorInfo<scalar_t, index_t>(grad_out),
getTensorInfo<scalar_t, index_t>(vert_attributes),
getTensorInfo<int32_t, index_t>(vi),
getTensorInfo<int32_t, index_t>(index_img),
getTensorInfo<scalar_t, index_t>(bary_img),
vert_requires_grad ? getTensorInfo<scalar_t, index_t>(vert_attributes_grad)
: TensorInfo<scalar_t, index_t>({nullptr, {0}, {0}, 0}),
bary_img_requires_grad ? getTensorInfo<scalar_t, index_t>(bary_img_grad)
: TensorInfo<scalar_t, index_t>({nullptr, {0}, {0}, 0}),
vert_attributes_grad.numel());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template <typename scalar_t, typename index_t>
void _interpolate_cuda_backward(
int64_t count,
const torch::Tensor& grad_out,
const torch::Tensor& vert_attributes,
const torch::Tensor& vi,
const torch::Tensor& index_img,
const torch::Tensor& bary_img,
const torch::Tensor& vert_attributes_grad,
const torch::Tensor& bary_img_grad,
bool bary_img_requires_grad,
bool vert_requires_grad) {
if (bary_img_requires_grad && vert_requires_grad)
_interpolate_cuda_backward<scalar_t, index_t, true, true>(
count,
grad_out,
vert_attributes,
vi,
index_img,
bary_img,
vert_attributes_grad,
bary_img_grad);
else if (bary_img_requires_grad)
_interpolate_cuda_backward<scalar_t, index_t, true, false>(
count,
grad_out,
vert_attributes,
vi,
index_img,
bary_img,
vert_attributes_grad,
bary_img_grad);
else if (vert_requires_grad)
_interpolate_cuda_backward<scalar_t, index_t, false, true>(
count,
grad_out,
vert_attributes,
vi,
index_img,
bary_img,
vert_attributes_grad,
bary_img_grad);
}
std::tuple<torch::Tensor, torch::Tensor> interpolate_cuda_backward(
const torch::Tensor& grad_out,
const torch::Tensor& vert_attributes,
const torch::Tensor& vi,
const torch::Tensor& index_img,
const torch::Tensor& bary_img) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vert_attributes));
auto N = vert_attributes.size(0);
auto V = vert_attributes.size(1);
auto C = vert_attributes.size(2);
auto H = bary_img.size(2);
auto W = bary_img.size(3);
int64_t count = N * H * W;
bool bary_img_requires_grad = bary_img.requires_grad();
bool vert_requires_grad = vert_attributes.requires_grad();
auto vert_attributes_grad =
vert_requires_grad ? at::zeros({N, V, C}, vert_attributes.options()) : torch::Tensor();
auto bary_img_grad =
bary_img_requires_grad ? at::empty({N, 3, H, W}, bary_img.options()) : torch::Tensor();
if (count > 0) {
AT_DISPATCH_FLOATING_TYPES(vert_attributes.scalar_type(), "interpolate_kernel", [&] {
if (at::native::canUse32BitIndexMath(vert_attributes) &&
at::native::canUse32BitIndexMath(bary_img) &&
at::native::canUse32BitIndexMath(index_img) && at::native::canUse32BitIndexMath(vi)) {
_interpolate_cuda_backward<scalar_t, int>(
count,
grad_out,
vert_attributes,
vi,
index_img,
bary_img,
vert_attributes_grad,
bary_img_grad,
bary_img_requires_grad,
vert_requires_grad);
} else {
_interpolate_cuda_backward<scalar_t, int64_t>(
count,
grad_out,
vert_attributes,
vi,
index_img,
bary_img,
vert_attributes_grad,
bary_img_grad,
bary_img_requires_grad,
vert_requires_grad);
}
});
}
return std::make_tuple(vert_attributes_grad, bary_img_grad);
}
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
torch::Tensor interpolate_cuda(
const torch::Tensor& vert_attributes,
const torch::Tensor& vi,
const torch::Tensor& index_img,
const torch::Tensor& bary_img);
std::tuple<torch::Tensor, torch::Tensor> interpolate_cuda_backward(
const torch::Tensor& grad_out,
const torch::Tensor& vert_attributes,
const torch::Tensor& vi,
const torch::Tensor& index_img,
const torch::Tensor& bary_img);
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#include <torch/script.h>
#include <ATen/autocast_mode.h>
#ifndef NO_PYBIND
#include <torch/extension.h>
#endif
#include "interpolate_kernel.h"
// Dispatch function
torch::Tensor interpolate(
const torch::Tensor& vert_attributes,
const torch::Tensor& vi,
const torch::Tensor& index_img,
const torch::Tensor& bary_img) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("interpolate_ext::interpolate", "")
.typed<decltype(interpolate)>();
return op.call(vert_attributes, vi, index_img, bary_img);
}
// Ideally we would need to turn off autograd handling and re-dispatch, but we just call
// cuda kernels directly
class InterpolateFunction : public torch::autograd::Function<InterpolateFunction> {
public:
static torch::autograd::tensor_list forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& vert_attributes,
const torch::Tensor& vi,
const torch::Tensor& index_img,
const torch::Tensor& bary_img) {
ctx->set_materialize_grads(false);
std::vector<torch::Tensor> save_list;
save_list.push_back(vert_attributes);
save_list.push_back(vi);
save_list.push_back(index_img);
save_list.push_back(bary_img);
ctx->save_for_backward(save_list);
return {interpolate_cuda(vert_attributes, vi, index_img, bary_img)};
}
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
const auto saved = ctx->get_saved_variables();
const torch::Tensor& vert_attributes = saved[0];
const torch::Tensor& vi = saved[1];
const torch::Tensor& index_img = saved[2];
const torch::Tensor& bary_img = saved[3];
bool bary_img_requires_grad = bary_img.requires_grad();
bool vert_requires_grad = vert_attributes.requires_grad();
torch::autograd::tensor_list out;
if ((!bary_img_requires_grad && !vert_requires_grad) || !grad_outputs[0].defined()) {
out.resize(4);
return out;
}
auto grad_out =
interpolate_cuda_backward(grad_outputs[0], vert_attributes, vi, index_img, bary_img);
out.push_back(std::get<0>(grad_out));
out.emplace_back();
out.emplace_back();
out.push_back(std::get<1>(grad_out));
return out;
}
};
torch::Tensor interpolate_autograd(
const torch::Tensor& vert_attributes,
const torch::Tensor& vi,
const torch::Tensor& index_img,
const torch::Tensor& bary_img) {
return InterpolateFunction::apply(vert_attributes, vi, index_img, bary_img)[0];
}
torch::Tensor interpolate_autocast(
const torch::Tensor& vert_attributes,
const torch::Tensor& vi,
const torch::Tensor& index_img,
const torch::Tensor& bary_img) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return interpolate(
at::autocast::cached_cast(torch::kFloat32, vert_attributes),
vi,
index_img,
at::autocast::cached_cast(torch::kFloat32, bary_img));
}
#ifndef NO_PYBIND
// Just so that we can import this file as a Python module to get the path and
// import the Torch ops.
PYBIND11_MODULE(interpolate_ext, m) {}
#endif
TORCH_LIBRARY(interpolate_ext, m) {
m.def(
"interpolate(Tensor vert_attributes, Tensor vi, Tensor index_img, Tensor bary_img) -> Tensor");
}
TORCH_LIBRARY_IMPL(interpolate_ext, Autograd, m) {
m.impl("interpolate", &interpolate_autograd);
}
TORCH_LIBRARY_IMPL(interpolate_ext, Autocast, m) {
m.impl("interpolate", interpolate_autocast);
}
TORCH_LIBRARY_IMPL(interpolate_ext, CUDA, m) {
m.impl("interpolate", &interpolate_cuda);
}
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#include <c10/cuda/CUDAGuard.h>
#include <cuda_math_helper.h>
#include <torch/types.h>
#include <grid_utils.h>
#include <kernel_utils.h>
#include <tensor_list.h>
using namespace math;
constexpr int max_mipmap_count = 11;
constexpr int tex_ndim = 4;
constexpr int uv_jacobian_ndim = 5;
template <typename scalar_t, typename index_t>
__device__ void sample_bilinear(
const TensorInfoCompact<scalar_t, index_t, tex_ndim>& input,
scalar_t x,
scalar_t y,
const index_t w,
const index_t h,
const index_t n,
const index_t C,
const TensorInfo<scalar_t, index_t>& output,
scalar_t alpha,
const GridSamplerPadding padding_mode,
bool align_corners) {
index_t out_sN = output.strides[0];
index_t out_sC = output.strides[1];
index_t out_sH = output.strides[2];
index_t out_sW = output.strides[3];
index_t inp_H = input.sizes[2];
index_t inp_W = input.sizes[3];
index_t inp_sN = input.strides[0];
index_t inp_sC = input.strides[1];
index_t inp_sH = input.strides[2];
index_t inp_sW = input.strides[3];
scalar_t ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners);
scalar_t iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners);
// get NE, NW, SE, SW pixel values from (x, y)
index_t ix_nw = static_cast<index_t>(::floor(ix));
index_t iy_nw = static_cast<index_t>(::floor(iy));
index_t ix_ne = ix_nw + 1;
index_t iy_ne = iy_nw;
index_t ix_sw = ix_nw;
index_t iy_sw = iy_nw + 1;
index_t ix_se = ix_nw + 1;
index_t iy_se = iy_nw + 1;
// get surfaces to each neighbor:
scalar_t nw = (ix_se - ix) * (iy_se - iy);
scalar_t ne = (ix - ix_sw) * (iy_sw - iy);
scalar_t sw = (ix_ne - ix) * (iy - iy_ne);
scalar_t se = (ix - ix_nw) * (iy - iy_nw);
// calculate bilinear weighted pixel value and set output pixel
auto inp_ptr_NC = input.data + n * inp_sN;
auto out_ptr_NCHW = output.data + n * out_sN + h * out_sH + w * out_sW;
for (index_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) {
if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
*out_ptr_NCHW += inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW] * nw * alpha;
}
if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
*out_ptr_NCHW += inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW] * ne * alpha;
}
if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
*out_ptr_NCHW += inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW] * sw * alpha;
}
if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
*out_ptr_NCHW += inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW] * se * alpha;
}
}
}
template <typename scalar_t, typename index_t>
__device__ void sample_bicubic(
const TensorInfoCompact<scalar_t, index_t, tex_ndim>& input,
scalar_t x,
scalar_t y,
const index_t w,
const index_t h,
const index_t n,
const index_t C,
const TensorInfo<scalar_t, index_t>& output,
scalar_t alpha,
const GridSamplerPadding padding_mode,
bool align_corners) {
index_t out_sN = output.strides[0];
index_t out_sC = output.strides[1];
index_t out_sH = output.strides[2];
index_t out_sW = output.strides[3];
index_t inp_H = input.sizes[2];
index_t inp_W = input.sizes[3];
index_t inp_sN = input.strides[0];
index_t inp_sC = input.strides[1];
index_t inp_sH = input.strides[2];
index_t inp_sW = input.strides[3];
scalar_t ix = grid_sampler_unnormalize(x, inp_W, align_corners);
scalar_t iy = grid_sampler_unnormalize(y, inp_H, align_corners);
// get NE, NW, SE, SW pixel values from (x, y)
scalar_t ix_nw = ::floor(ix);
scalar_t iy_nw = ::floor(iy);
const scalar_t tx = ix - ix_nw;
const scalar_t ty = iy - iy_nw;
auto inp_ptr_NC = input.data + n * inp_sN;
auto out_ptr_NCHW = output.data + n * out_sN + h * out_sH + w * out_sW;
for (index_t c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCHW += out_sC) {
scalar_t coefficients[4];
#pragma unroll 4
for (index_t i = 0; i < 4; ++i) {
coefficients[i] = cubic_interp1d(
get_value_bounded<scalar_t>(
inp_ptr_NC,
ix_nw - 1,
iy_nw - 1 + i,
inp_W,
inp_H,
inp_sW,
inp_sH,
padding_mode,
align_corners),
get_value_bounded<scalar_t>(
inp_ptr_NC,
ix_nw + 0,
iy_nw - 1 + i,
inp_W,
inp_H,
inp_sW,
inp_sH,
padding_mode,
align_corners),
get_value_bounded<scalar_t>(
inp_ptr_NC,
ix_nw + 1,
iy_nw - 1 + i,
inp_W,
inp_H,
inp_sW,
inp_sH,
padding_mode,
align_corners),
get_value_bounded<scalar_t>(
inp_ptr_NC,
ix_nw + 2,
iy_nw - 1 + i,
inp_W,
inp_H,
inp_sW,
inp_sH,
padding_mode,
align_corners),
tx);
}
*out_ptr_NCHW +=
cubic_interp1d(coefficients[0], coefficients[1], coefficients[2], coefficients[3], ty) *
alpha;
}
}
template <typename scalar_t, typename index_t>
__device__ TVec2<scalar_t> sample_bilinear_backward(
const TensorInfoCompact<scalar_t, index_t, tex_ndim>& input,
const TensorInfoCompact<scalar_t, index_t, tex_ndim>& grad_input,
const TensorInfoCompact<scalar_t, index_t, tex_ndim>& grad_output,
scalar_t x,
scalar_t y,
const index_t w,
const index_t h,
const index_t n,
const index_t C,
scalar_t alpha,
const GridSamplerPadding padding_mode,
bool align_corners,
index_t grad_input_memory_span) {
index_t gOut_sN = grad_output.strides[0];
index_t gOut_sC = grad_output.strides[1];
index_t gOut_sH = grad_output.strides[2];
index_t gOut_sW = grad_output.strides[3];
index_t gInp_sN = grad_input.strides[0];
index_t gInp_sC = grad_input.strides[1];
index_t gInp_sH = grad_input.strides[2];
index_t gInp_sW = grad_input.strides[3];
index_t inp_H = input.sizes[2];
index_t inp_W = input.sizes[3];
index_t inp_sN = input.strides[0];
index_t inp_sC = input.strides[1];
index_t inp_sH = input.strides[2];
index_t inp_sW = input.strides[3];
// multipliers for gradients on ix and iy
TVec2<scalar_t> gi_mult;
scalar_t ix =
grid_sampler_compute_source_index_set_grad(x, inp_W, padding_mode, align_corners, &gi_mult.x);
scalar_t iy =
grid_sampler_compute_source_index_set_grad(y, inp_H, padding_mode, align_corners, &gi_mult.y);
// get NE, NW, SE, SW pixel values from (x, y)
index_t ix_nw = static_cast<index_t>(::floor(ix));
index_t iy_nw = static_cast<index_t>(::floor(iy));
index_t ix_ne = ix_nw + 1;
index_t iy_ne = iy_nw;
index_t ix_sw = ix_nw;
index_t iy_sw = iy_nw + 1;
index_t ix_se = ix_nw + 1;
index_t iy_se = iy_nw + 1;
// get surfaces to each neighbor:
scalar_t nw = (ix_se - ix) * (iy_se - iy);
scalar_t ne = (ix - ix_sw) * (iy_sw - iy);
scalar_t sw = (ix_ne - ix) * (iy - iy_ne);
scalar_t se = (ix - ix_nw) * (iy - iy_nw);
TVec2<scalar_t> gi = {scalar_t(0), scalar_t(0)};
scalar_t* gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW;
index_t NC_offset = n * gInp_sN;
scalar_t* inp_ptr_NC = input.data + n * inp_sN;
for (index_t c = 0; c < C;
++c, inp_ptr_NC += inp_sC, NC_offset += gInp_sC, gOut_ptr_NCHW += gOut_sC) {
scalar_t gOut = *gOut_ptr_NCHW * alpha;
// calculate and set grad_input. See Note [Passing pointer and offset to fastAtomicAdd].
safe_add_2d(
grad_input.data,
iy_nw,
ix_nw,
gInp_sH,
gInp_sW,
inp_H,
inp_W,
nw * gOut,
NC_offset,
grad_input_memory_span);
safe_add_2d(
grad_input.data,
iy_ne,
ix_ne,
gInp_sH,
gInp_sW,
inp_H,
inp_W,
ne * gOut,
NC_offset,
grad_input_memory_span);
safe_add_2d(
grad_input.data,
iy_sw,
ix_sw,
gInp_sH,
gInp_sW,
inp_H,
inp_W,
sw * gOut,
NC_offset,
grad_input_memory_span);
safe_add_2d(
grad_input.data,
iy_se,
ix_se,
gInp_sH,
gInp_sW,
inp_H,
inp_W,
se * gOut,
NC_offset,
grad_input_memory_span);
// calculate grad_grid
if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
scalar_t nw_val = inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW];
gi.x -= nw_val * (iy_se - iy) * gOut;
gi.y -= nw_val * (ix_se - ix) * gOut;
}
if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
scalar_t ne_val = inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW];
gi.x += ne_val * (iy_sw - iy) * gOut;
gi.y -= ne_val * (ix - ix_sw) * gOut;
}
if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
scalar_t sw_val = inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW];
gi.x -= sw_val * (iy - iy_ne) * gOut;
gi.y += sw_val * (ix_ne - ix) * gOut;
}
if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
scalar_t se_val = inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW];
gi.x += se_val * (iy - iy_nw) * gOut;
gi.y += se_val * (ix - ix_nw) * gOut;
}
}
return gi_mult * gi;
}
template <typename scalar_t, typename index_t>
__device__ TVec2<scalar_t> sample_bicubic_backward(
const TensorInfoCompact<scalar_t, index_t, tex_ndim>& input,
const TensorInfoCompact<scalar_t, index_t, tex_ndim>& grad_input,
const TensorInfoCompact<scalar_t, index_t, tex_ndim>& grad_output,
scalar_t x,
scalar_t y,
const index_t w,
const index_t h,
const index_t n,
const index_t C,
scalar_t alpha,
const GridSamplerPadding padding_mode,
bool align_corners,
index_t grad_input_memory_span) {
index_t gOut_sN = grad_output.strides[0];
index_t gOut_sC = grad_output.strides[1];
index_t gOut_sH = grad_output.strides[2];
index_t gOut_sW = grad_output.strides[3];
index_t gInp_sN = grad_input.strides[0];
index_t gInp_sC = grad_input.strides[1];
index_t gInp_sH = grad_input.strides[2];
index_t gInp_sW = grad_input.strides[3];
index_t inp_H = input.sizes[2];
index_t inp_W = input.sizes[3];
index_t inp_sN = input.strides[0];
index_t inp_sC = input.strides[1];
index_t inp_sH = input.strides[2];
index_t inp_sW = input.strides[3];
// multipliers for gradients on ix and iy
TVec2<scalar_t> gi_mult;
scalar_t ix = grid_sampler_unnormalize_set_grad(x, inp_W, align_corners, &gi_mult.x);
scalar_t iy = grid_sampler_unnormalize_set_grad(y, inp_H, align_corners, &gi_mult.y);
scalar_t ix_nw = ::floor(ix);
scalar_t iy_nw = ::floor(iy);
const scalar_t tx = ix - ix_nw;
const scalar_t ty = iy - iy_nw;
scalar_t x_coeffs[4];
scalar_t y_coeffs[4];
scalar_t x_coeffs_grad[4];
scalar_t y_coeffs_grad[4];
get_cubic_upsampling_coefficients<scalar_t>(x_coeffs, tx);
get_cubic_upsampling_coefficients<scalar_t>(y_coeffs, ty);
get_cubic_coefficients_grad<scalar_t>(x_coeffs_grad, tx);
get_cubic_coefficients_grad<scalar_t>(y_coeffs_grad, ty);
TVec2<scalar_t> gi = {scalar_t(0), scalar_t(0)};
scalar_t* gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW;
index_t NC_offset = n * gInp_sN;
scalar_t* inp_ptr_NC = input.data + n * inp_sN;
for (index_t c = 0; c < C;
++c, inp_ptr_NC += inp_sC, NC_offset += gInp_sC, gOut_ptr_NCHW += gOut_sC) {
scalar_t gOut = *gOut_ptr_NCHW * alpha;
#pragma unroll 4
for (index_t i = 0; i < 4; ++i) {
#pragma unroll 4
for (index_t j = 0; j < 4; ++j) {
// set input gradient. See Note [Passing pointer and offset to fastAtomicAdd].
add_value_bounded<scalar_t>(
grad_input.data,
ix_nw - 1 + i,
iy_nw - 1 + j,
inp_W,
inp_H,
gInp_sW,
gInp_sH,
gOut * x_coeffs[i] * y_coeffs[j],
padding_mode,
align_corners,
NC_offset,
grad_input_memory_span);
// set grid gradient
scalar_t val = get_value_bounded<scalar_t>(
inp_ptr_NC,
ix_nw - 1 + i,
iy_nw - 1 + j,
inp_W,
inp_H,
inp_sW,
inp_sH,
padding_mode,
align_corners);
gi -= gOut * val *
TVec2<scalar_t>({x_coeffs_grad[i] * y_coeffs[j], y_coeffs_grad[j] * x_coeffs[i]});
}
}
}
return gi_mult * gi;
}
template <typename scalar_t, typename index_t, GridSamplerInterpolation interpolation_mode>
C10_LAUNCH_BOUNDS_1(256)
__global__ void mipmap_aniso_grid_sampler_2d_kernel(
const index_t nthreads,
TensorInfoList<scalar_t, index_t, max_mipmap_count, tex_ndim> inputs,
const int mipmaps,
TensorInfo<scalar_t, index_t> grid,
TensorInfo<scalar_t, index_t> vt_dxdy_img,
TensorInfo<scalar_t, index_t> output,
const GridSamplerPadding padding_mode,
int max_aniso,
bool align_corners,
bool force_max_aniso,
bool clip_grad) {
align_corners = false;
index_t C = output.sizes[1];
index_t inp_H = inputs[0].sizes[2];
index_t inp_W = inputs[0].sizes[3];
index_t out_H = grid.sizes[1];
index_t out_W = grid.sizes[2];
index_t grid_sN = grid.strides[0];
index_t grid_sH = grid.strides[1];
index_t grid_sW = grid.strides[2];
index_t grid_sCoor = grid.strides[3];
index_t vt_dxdy_img_sN = vt_dxdy_img.strides[0];
index_t vt_dxdy_img_sH = vt_dxdy_img.strides[1];
index_t vt_dxdy_img_sW = vt_dxdy_img.strides[2];
index_t vt_dxdy_img_s3 = vt_dxdy_img.strides[3];
index_t vt_dxdy_img_s4 = vt_dxdy_img.strides[4];
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
const index_t w = index % out_W;
const index_t h = (index / out_W) % out_H;
const index_t n = index / (out_H * out_W);
const index_t grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
const index_t vt_dxdy_img_offset = n * vt_dxdy_img_sN + h * vt_dxdy_img_sH + w * vt_dxdy_img_sW;
// get the corresponding input x, y co-ordinates from grid
scalar_t u = grid.data[grid_offset];
scalar_t v = grid.data[grid_offset + grid_sCoor];
scalar_t dudx = vt_dxdy_img.data[vt_dxdy_img_offset];
scalar_t dvdx = vt_dxdy_img.data[vt_dxdy_img_offset + vt_dxdy_img_s4];
scalar_t dudy = vt_dxdy_img.data[vt_dxdy_img_offset + vt_dxdy_img_s3];
scalar_t dvdy = vt_dxdy_img.data[vt_dxdy_img_offset + vt_dxdy_img_s3 + vt_dxdy_img_s4];
scalar_t px = pow(pow(abs(dudx * inp_W), 2.0f) + pow(abs(dvdx * inp_H), 2.0f) + 1e-12f, 0.5f);
scalar_t py = pow(pow(abs(dudy * inp_W), 2.0f) + pow(abs(dvdy * inp_H), 2.0f) + 1e-12f, 0.5f);
scalar_t p_max = max(px, py);
scalar_t p_min = min(px, py);
// # See p.255 of OpenGL Core Profile
// # N = min(ceil(Pmax/Pmin),maxAniso)
scalar_t N = min(ceil(p_max / p_min), (scalar_t)max_aniso);
if (p_min == 0.0 || N == 0) {
N = 1;
}
// Lambda' = log2(Pmax/N)
scalar_t lambda_ = log2(p_max / N);
if (isnan(lambda_) || isinf(lambda_)) {
lambda_ = 0.0f;
}
// See eq. 8.15, 8.16
// Substract small number (1e-6) so that `l` is always < mipmaps - 1
scalar_t l = min(lambda_, mipmaps - 1 - 1e-6);
// The following correction is divergence from the specification
// The reason is that it is typically assumed that the full pyramid is available, but if not,
// clipping of the level happens as in the line above, which causes taps to be spread with
// distances higher than the size of the texel. Which in turn causes aliasing and not desirable
// long-range sampling So if clipping happens, we recompute clipped Pmax and scale gradients
// accordingly
if (clip_grad && lambda_ > mipmaps - 1) {
scalar_t p_max_corrected = exp2(l) * N;
scalar_t scaling = p_max_corrected / p_max;
dudx *= scaling;
dvdx *= scaling;
dudy *= scaling;
dvdy *= scaling;
}
l = max(l, 0.0);
auto d1 = (index_t)floor(l);
scalar_t a = l - (scalar_t)d1;
index_t N_int = index_t(N);
if (force_max_aniso) {
N_int = max_aniso;
}
if (px > py) {
for (int i = 0; i < N_int; ++i) {
scalar_t u_offset = dudx * ((i + 1.0) / (N_int + 1.0) * 2.0 - 1.0);
scalar_t v_offset = dvdx * ((i + 1.0) / (N_int + 1.0) * 2.0 - 1.0);
scalar_t alpha_1 = a / N_int;
scalar_t alpha_2 = (1.0 - a) / N_int;
if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
sample_bilinear(
inputs[d1],
u + u_offset,
v + v_offset,
w,
h,
n,
C,
output,
alpha_2,
padding_mode,
align_corners);
if (mipmaps > 1)
sample_bilinear(
inputs[d1 + 1],
u + u_offset,
v + v_offset,
w,
h,
n,
C,
output,
alpha_1,
padding_mode,
align_corners);
} else if (interpolation_mode == GridSamplerInterpolation::Bicubic) {
sample_bicubic(
inputs[d1],
u + u_offset,
v + v_offset,
w,
h,
n,
C,
output,
alpha_2,
padding_mode,
align_corners);
if (mipmaps > 1)
sample_bicubic(
inputs[d1 + 1],
u + u_offset,
v + v_offset,
w,
h,
n,
C,
output,
alpha_1,
padding_mode,
align_corners);
}
}
} else {
for (int i = 0; i < N_int; ++i) {
scalar_t u_offset = dudy * ((i + 1.0) / (N_int + 1.0) * 2.0 - 1.0);
scalar_t v_offset = dvdy * ((i + 1.0) / (N_int + 1.0) * 2.0 - 1.0);
scalar_t alpha_1 = a / N_int;
scalar_t alpha_2 = (1.0 - a) / N_int;
if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
sample_bilinear(
inputs[d1],
u + u_offset,
v + v_offset,
w,
h,
n,
C,
output,
alpha_2,
padding_mode,
align_corners);
if (mipmaps > 1)
sample_bilinear(
inputs[d1 + 1],
u + u_offset,
v + v_offset,
w,
h,
n,
C,
output,
alpha_1,
padding_mode,
align_corners);
} else if (interpolation_mode == GridSamplerInterpolation::Bicubic) {
sample_bicubic(
inputs[d1],
u + u_offset,
v + v_offset,
w,
h,
n,
C,
output,
alpha_2,
padding_mode,
align_corners);
if (mipmaps > 1)
sample_bicubic(
inputs[d1 + 1],
u + u_offset,
v + v_offset,
w,
h,
n,
C,
output,
alpha_1,
padding_mode,
align_corners);
}
}
}
}
}
template <typename scalar_t, typename index_t, GridSamplerInterpolation interpolation_mode>
C10_LAUNCH_BOUNDS_1(256)
__global__ void mipmap_aniso_grid_sampler_2d_backward_kernel(
const index_t nthreads,
TensorInfoCompact<scalar_t, index_t, tex_ndim> grad_output,
TensorInfoList<scalar_t, index_t, max_mipmap_count, tex_ndim> inputs,
const int mipmaps,
TensorInfoCompact<scalar_t, index_t, tex_ndim> grid,
TensorInfoCompact<scalar_t, index_t, tex_ndim + 1> vt_dxdy_img,
TensorInfoList<scalar_t, index_t, max_mipmap_count, tex_ndim>
grad_inputs, // initialized to zeros
TensorInfoCompact<scalar_t, index_t, tex_ndim> grad_grid, // initialized to empty
const GridSamplerPadding padding_mode,
int max_aniso,
bool align_corners,
bool force_max_aniso,
bool clip_grad,
IndexList<index_t, max_mipmap_count> grad_input_memory_span) {
index_t C = inputs[0].sizes[1];
index_t inp_H = inputs[0].sizes[2];
index_t inp_W = inputs[0].sizes[3];
index_t out_H = grid.sizes[1];
index_t out_W = grid.sizes[2];
index_t grid_sN = grid.strides[0];
index_t grid_sH = grid.strides[1];
index_t grid_sW = grid.strides[2];
index_t grid_sCoor = grid.strides[3];
index_t gGrid_sW = grad_grid.strides[2];
index_t vt_dxdy_img_sN = vt_dxdy_img.strides[0];
index_t vt_dxdy_img_sH = vt_dxdy_img.strides[1];
index_t vt_dxdy_img_sW = vt_dxdy_img.strides[2];
index_t vt_dxdy_img_s3 = vt_dxdy_img.strides[3];
index_t vt_dxdy_img_s4 = vt_dxdy_img.strides[4];
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
const index_t w = index % out_W;
const index_t h = (index / out_W) % out_H;
const index_t n = index / (out_H * out_W);
const auto grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
const index_t vt_dxdy_img_offset = n * vt_dxdy_img_sN + h * vt_dxdy_img_sH + w * vt_dxdy_img_sW;
// get the corresponding input x, y co-ordinates from grid
scalar_t u = grid.data[grid_offset];
scalar_t v = grid.data[grid_offset + grid_sCoor];
scalar_t dudx = vt_dxdy_img.data[vt_dxdy_img_offset];
scalar_t dvdx = vt_dxdy_img.data[vt_dxdy_img_offset + vt_dxdy_img_s4];
scalar_t dudy = vt_dxdy_img.data[vt_dxdy_img_offset + vt_dxdy_img_s3];
scalar_t dvdy = vt_dxdy_img.data[vt_dxdy_img_offset + vt_dxdy_img_s3 + vt_dxdy_img_s4];
scalar_t px = pow(pow(abs(dudx * inp_W), 2.0f) + pow(abs(dvdx * inp_H), 2.0f) + 1e-12f, 0.5f);
scalar_t py = pow(pow(abs(dudy * inp_W), 2.0f) + pow(abs(dvdy * inp_H), 2.0f) + 1e-12f, 0.5f);
scalar_t p_max = max(px, py);
scalar_t p_min = min(px, py);
// # See p.255 of OpenGL Core Profile
// # N = min(ceil(Pmax/Pmin),maxAniso)
scalar_t N = min(ceil(p_max / p_min), (scalar_t)max_aniso);
if (p_min == 0.0 || N == 0) {
N = 1;
}
// Lambda' = log2(Pmax/N)
scalar_t lambda_ = log2(p_max / N);
if (isnan(lambda_) || isinf(lambda_)) {
lambda_ = 0.0f;
}
// See eq. 8.15, 8.16
// Substract small number (1e-6) so that `l` is always < mipmaps - 1
scalar_t l = min(lambda_, mipmaps - 1 - 1e-6);
// The following correction is divergence from the specification
// The reason is that it is typically assumed that the full pyramid is available, but if not,
// clipping of the level happens as in the line above, which causes taps to be spread with
// distances higher than the size of the texel. Which in turn causes aliasing and not desirable
// long-range sampling So if clipping happens, we recompute clipped Pmax and scale gradients
// accordingly
if (clip_grad && lambda_ > mipmaps - 1) {
scalar_t p_max_corrected = exp2(l) * N;
scalar_t scaling = p_max_corrected / p_max;
dudx *= scaling;
dvdx *= scaling;
dudy *= scaling;
dvdy *= scaling;
}
l = max(l, 0.0);
auto d1 = (index_t)floor(l);
scalar_t a = l - (scalar_t)d1;
index_t N_int = index_t(N);
if (force_max_aniso) {
N_int = max_aniso;
}
TVec2<scalar_t> gi_acc = {scalar_t(0), scalar_t(0)};
if (px > py) {
for (int i = 0; i < N_int; ++i) {
scalar_t u_offset = dudx * ((i + 1.0) / (N_int + 1.0) * 2.0 - 1.0);
scalar_t v_offset = dvdx * ((i + 1.0) / (N_int + 1.0) * 2.0 - 1.0);
scalar_t alpha_1 = a / N_int;
scalar_t alpha_2 = (1.0 - a) / N_int;
if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
auto ggrad = sample_bilinear_backward(
inputs[d1],
grad_inputs[d1],
grad_output,
u + u_offset,
v + v_offset,
w,
h,
n,
C,
alpha_2,
padding_mode,
align_corners,
grad_input_memory_span[d1]);
gi_acc += ggrad;
if (mipmaps > 1) {
auto ggrad2 = sample_bilinear_backward(
inputs[d1 + 1],
grad_inputs[d1 + 1],
grad_output,
u + u_offset,
v + v_offset,
w,
h,
n,
C,
alpha_1,
padding_mode,
align_corners,
grad_input_memory_span[d1 + 1]);
gi_acc += ggrad2;
}
} else if (interpolation_mode == GridSamplerInterpolation::Bicubic) {
auto ggrad = sample_bicubic_backward(
inputs[d1],
grad_inputs[d1],
grad_output,
u + u_offset,
v + v_offset,
w,
h,
n,
C,
alpha_2,
padding_mode,
align_corners,
grad_input_memory_span[d1]);
gi_acc += ggrad;
if (mipmaps > 1) {
auto ggrad2 = sample_bicubic_backward(
inputs[d1 + 1],
grad_inputs[d1 + 1],
grad_output,
u + u_offset,
v + v_offset,
w,
h,
n,
C,
alpha_1,
padding_mode,
align_corners,
grad_input_memory_span[d1 + 1]);
gi_acc += ggrad2;
}
}
}
} else {
for (int i = 0; i < N_int; ++i) {
scalar_t u_offset = dudy * ((i + 1.0) / (N_int + 1.0) * 2.0 - 1.0);
scalar_t v_offset = dvdy * ((i + 1.0) / (N_int + 1.0) * 2.0 - 1.0);
scalar_t alpha_1 = a / N_int;
scalar_t alpha_2 = (1.0 - a) / N_int;
if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
auto ggrad = sample_bilinear_backward(
inputs[d1],
grad_inputs[d1],
grad_output,
u + u_offset,
v + v_offset,
w,
h,
n,
C,
alpha_2,
padding_mode,
align_corners,
grad_input_memory_span[d1]);
gi_acc += ggrad;
if (mipmaps > 1) {
auto ggrad2 = sample_bilinear_backward(
inputs[d1 + 1],
grad_inputs[d1 + 1],
grad_output,
u + u_offset,
v + v_offset,
w,
h,
n,
C,
alpha_1,
padding_mode,
align_corners,
grad_input_memory_span[d1 + 1]);
gi_acc += ggrad2;
}
} else if (interpolation_mode == GridSamplerInterpolation::Bicubic) {
auto ggrad = sample_bicubic_backward(
inputs[d1],
grad_inputs[d1],
grad_output,
u + u_offset,
v + v_offset,
w,
h,
n,
C,
alpha_2,
padding_mode,
align_corners,
grad_input_memory_span[d1]);
gi_acc += ggrad;
if (mipmaps > 1) {
auto ggrad2 = sample_bicubic_backward(
inputs[d1 + 1],
grad_inputs[d1 + 1],
grad_output,
u + u_offset,
v + v_offset,
w,
h,
n,
C,
alpha_1,
padding_mode,
align_corners,
grad_input_memory_span[d1 + 1]);
gi_acc += ggrad2;
}
}
}
}
// assuming grad_grid is contiguous
// thus we can
// 1. use index with gGrid_sW to directly compute gGrid_ptr_NHW
// 2. directly assign to gGrid_ptr_NHW[0], gGrid_ptr_NHW[1]
scalar_t* gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW;
gGrid_ptr_NHW[0] = gi_acc.x;
gGrid_ptr_NHW[1] = gi_acc.y;
}
}
__host__ torch::Tensor mipmap_aniso_grid_sampler_2d_cuda(
const torch::TensorList& input,
const torch::Tensor& grid,
const torch::Tensor& vt_dxdy_img,
int64_t max_aniso,
int64_t padding_mode,
int64_t interpolation_mode,
bool align_corners,
bool force_max_ansio,
bool clip_grad) {
int mipmaps = input.size();
TORCH_CHECK(
mipmaps >= 1,
"mipmap_aniso_grid_sampler_2d(): expected input to have at least one mipmap level");
TORCH_CHECK(
input[0].defined() && grid.defined(),
"mipmap_aniso_grid_sampler_2d(): expected input and grid to not be undefined, but input is ",
input,
" and grid is ",
grid);
auto input_opt = input[0].options();
auto grid_opt = grid.options();
TORCH_CHECK(
input_opt.device() == grid_opt.device(),
"mipmap_aniso_grid_sampler_2d(): expected input and grid to be on same device, but input is on ",
input_opt.device(),
" and grid is on ",
grid_opt.device());
TORCH_CHECK(
input_opt.dtype() == grid_opt.dtype(),
"mipmap_aniso_grid_sampler_2d(): expected input and grid to have same dtype, but input has ",
input_opt.dtype(),
" and grid has ",
grid_opt.dtype());
TORCH_CHECK(
input_opt.layout() == torch::kStrided && grid_opt.layout() == torch::kStrided,
"mipmap_aniso_grid_sampler_2d(): expected input and grid to have torch.strided layout, but "
"input has ",
input_opt.layout(),
" and grid has ",
grid_opt.layout());
TORCH_CHECK(
(input[0].dim() == 4) && input[0].dim() == grid.dim() &&
input[0].dim() + 1 == vt_dxdy_img.dim(),
"mipmap_aniso_grid_sampler_2d(): expected 4D input and grid with same number of "
"dimensions and 5D vt_dxdy_img, but got input with sizes ",
input[0].sizes(),
" and grid with sizes ",
grid.sizes(),
" and vt_dxdy_img with sizes ",
vt_dxdy_img.sizes());
TORCH_CHECK(
input[0].size(0) == grid.size(0) && input[0].size(0) == vt_dxdy_img.size(0),
"mipmap_aniso_grid_sampler_2d(): expected grid, vt_dxdy_img and input to have same batch size, "
"but got input with sizes ",
input[0].sizes(),
" and grid with sizes ",
grid.sizes(),
" and vt_dxdy_img with sizes ",
vt_dxdy_img.sizes());
TORCH_CHECK(
grid.size(-1) == input[0].dim() - 2,
"mipmap_aniso_grid_sampler_2d(): expected grid to have size ",
input[0].dim() - 2,
" in last dimension, but got grid with sizes ",
grid.sizes());
TORCH_CHECK(
vt_dxdy_img.size(-1) == input[0].dim() - 2 && vt_dxdy_img.size(-2) == input[0].dim() - 2,
"mipmap_aniso_grid_sampler_2d(): expected vt_dxdy_img to have size ",
input[0].dim() - 2,
" in last "
"two dimension, but got grid with sizes ",
grid.sizes());
for (int64_t i = 1; i < mipmaps; i++) {
TORCH_CHECK(
input_opt.device() == input[i].options().device() &&
input_opt.dtype() == input[i].options().dtype() &&
input_opt.layout() == input[i].options().layout() && input[0].dim() == input[i].dim() &&
input[0].size(0) == input[i].size(0) && input[0].size(1) == input[i].size(1),
"mipmap_aniso_grid_sampler_2d(): expected all inputs to have same device, dtype, layout, and "
"first two dimensions");
}
for (int64_t i = 2; i < input[0].dim(); i++) {
TORCH_CHECK(
input[0].size(i) > 0,
"grid_sampler(): expected input to have non-empty spatial dimensions, "
"but input has sizes ",
input[0].sizes(),
" with dimension ",
i,
" being empty");
}
const at::cuda::OptionalCUDAGuard device_guard(device_of(input[0]));
auto N = input[0].size(0);
auto C = input[0].size(1);
auto H = grid.size(1);
auto W = grid.size(2);
auto output = at::zeros({N, C, H, W}, input[0].options());
int64_t count = N * H * W;
if (count > 0) {
// Should be AT_DISPATCH_FLOATING_TYPES_AND_HALF, but half is broken on prod
AT_DISPATCH_FLOATING_TYPES(input[0].scalar_type(), "mipmap_aniso_grid_sampler_2d_kernel", [&] {
if (at::native::canUse32BitIndexMath(input[0]) && at::native::canUse32BitIndexMath(grid) &&
at::native::canUse32BitIndexMath(output)) {
typedef int index_type;
TensorInfoList<scalar_t, index_type, max_mipmap_count, tex_ndim> inputs;
for (int i = 0; i < mipmaps; ++i) {
inputs[i] = getTensorInfo<scalar_t, index_type>(input[i]);
}
if (interpolation_mode == (int64_t)GridSamplerInterpolation::Bilinear) {
mipmap_aniso_grid_sampler_2d_kernel<
scalar_t,
index_type,
GridSamplerInterpolation::Bilinear>
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<index_type>(count),
inputs,
mipmaps,
getTensorInfo<scalar_t, index_type>(grid),
getTensorInfo<scalar_t, index_type>(vt_dxdy_img),
getTensorInfo<scalar_t, index_type>(output),
static_cast<GridSamplerPadding>(padding_mode),
(int)max_aniso,
align_corners,
force_max_ansio,
clip_grad);
}
if (interpolation_mode == (int64_t)GridSamplerInterpolation::Bicubic) {
mipmap_aniso_grid_sampler_2d_kernel<
scalar_t,
index_type,
GridSamplerInterpolation::Bicubic>
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<index_type>(count),
inputs,
mipmaps,
getTensorInfo<scalar_t, index_type>(grid),
getTensorInfo<scalar_t, index_type>(vt_dxdy_img),
getTensorInfo<scalar_t, index_type>(output),
static_cast<GridSamplerPadding>(padding_mode),
(int)max_aniso,
align_corners,
force_max_ansio,
clip_grad);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
typedef int64_t index_type;
TensorInfoList<scalar_t, index_type, max_mipmap_count, tex_ndim> inputs;
for (int i = 0; i < mipmaps; ++i) {
inputs[i] = getTensorInfo<scalar_t, index_type>(input[i]);
}
if (interpolation_mode == (int64_t)GridSamplerInterpolation::Bilinear) {
mipmap_aniso_grid_sampler_2d_kernel<
scalar_t,
index_type,
GridSamplerInterpolation::Bilinear>
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<index_type>(count),
inputs,
mipmaps,
getTensorInfo<scalar_t, index_type>(grid),
getTensorInfo<scalar_t, index_type>(vt_dxdy_img),
getTensorInfo<scalar_t, index_type>(output),
static_cast<GridSamplerPadding>(padding_mode),
(int)max_aniso,
align_corners,
force_max_ansio,
clip_grad);
}
if (interpolation_mode == (int64_t)GridSamplerInterpolation::Bicubic) {
mipmap_aniso_grid_sampler_2d_kernel<
scalar_t,
index_type,
GridSamplerInterpolation::Bicubic>
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<index_type>(count),
inputs,
mipmaps,
getTensorInfo<scalar_t, index_type>(grid),
getTensorInfo<scalar_t, index_type>(vt_dxdy_img),
getTensorInfo<scalar_t, index_type>(output),
static_cast<GridSamplerPadding>(padding_mode),
(int)max_aniso,
align_corners,
force_max_ansio,
clip_grad);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
});
}
return output;
}
__host__ std::tuple<std::vector<torch::Tensor>, torch::Tensor>
mipmap_aniso_grid_sampler_2d_cuda_backward(
const torch::Tensor& grad_output,
const torch::TensorList& input,
const torch::Tensor& grid,
const torch::Tensor& vt_dxdy_img,
int64_t max_aniso,
int64_t padding_mode,
int64_t interpolation_mode,
bool align_corners,
bool force_max_ansio,
bool clip_grad) {
int mipmaps = input.size();
auto N = input[0].size(0);
auto H = grid.size(1);
auto W = grid.size(2);
const at::cuda::OptionalCUDAGuard device_guard(device_of(input[0]));
std::vector<torch::Tensor> grad_input;
for (int i = 0; i < mipmaps; ++i) {
grad_input.push_back(at::zeros_like(input[i], LEGACY_CONTIGUOUS_MEMORY_FORMAT));
}
auto grad_grid = at::empty_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
int64_t count = N * H * W;
if (count > 0) {
// Should be AT_DISPATCH_FLOATING_TYPES_AND_HALF, but half is broken on prod
AT_DISPATCH_FLOATING_TYPES(
input[0].scalar_type(), "mipmap_aniso_grid_sampler_2d_backward_kernel", [&] {
if (at::native::canUse32BitIndexMath(input[0]) &&
at::native::canUse32BitIndexMath(grid) &&
at::native::canUse32BitIndexMath(grad_output)) {
typedef int index_type;
TensorInfoList<scalar_t, index_type, max_mipmap_count, tex_ndim> inputs;
IndexList<index_type, max_mipmap_count> grad_input_memory_span;
TensorInfoList<scalar_t, index_type, max_mipmap_count, tex_ndim> grad_inputs;
for (int i = 0; i < mipmaps; ++i) {
inputs[i] = getTensorInfo<scalar_t, index_type>(input[i]);
grad_inputs[i] = getTensorInfo<scalar_t, index_type>(grad_input[i]);
grad_input_memory_span[i] = grad_input[i].numel();
}
if (interpolation_mode == (int64_t)GridSamplerInterpolation::Bilinear) {
mipmap_aniso_grid_sampler_2d_backward_kernel<
scalar_t,
index_type,
GridSamplerInterpolation::Bilinear>
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<index_type>(count),
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(grad_output),
inputs,
mipmaps,
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(grid),
getTensorInfoCompact<scalar_t, index_type, uv_jacobian_ndim>(vt_dxdy_img),
grad_inputs,
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(grad_grid),
static_cast<GridSamplerPadding>(padding_mode),
(int)max_aniso,
align_corners,
force_max_ansio,
clip_grad,
grad_input_memory_span);
}
if (interpolation_mode == (int64_t)GridSamplerInterpolation::Bicubic) {
mipmap_aniso_grid_sampler_2d_backward_kernel<
scalar_t,
index_type,
GridSamplerInterpolation::Bicubic>
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<index_type>(count),
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(grad_output),
inputs,
mipmaps,
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(grid),
getTensorInfoCompact<scalar_t, index_type, uv_jacobian_ndim>(vt_dxdy_img),
grad_inputs,
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(grad_grid),
static_cast<GridSamplerPadding>(padding_mode),
(int)max_aniso,
align_corners,
force_max_ansio,
clip_grad,
grad_input_memory_span);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
typedef int64_t index_type;
TensorInfoList<scalar_t, index_type, max_mipmap_count, tex_ndim> inputs;
IndexList<index_type, max_mipmap_count> grad_input_memory_span;
TensorInfoList<scalar_t, index_type, max_mipmap_count, tex_ndim> grad_inputs;
for (int i = 0; i < mipmaps; ++i) {
inputs[i] = getTensorInfo<scalar_t, index_type>(input[i]);
grad_inputs[i] = getTensorInfo<scalar_t, index_type>(grad_input[i]);
grad_input_memory_span[i] = grad_input[i].numel();
}
if (interpolation_mode == (int64_t)GridSamplerInterpolation::Bilinear) {
mipmap_aniso_grid_sampler_2d_backward_kernel<
scalar_t,
index_type,
GridSamplerInterpolation::Bilinear>
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<index_type>(count),
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(grad_output),
inputs,
mipmaps,
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(grid),
getTensorInfoCompact<scalar_t, index_type, uv_jacobian_ndim>(vt_dxdy_img),
grad_inputs,
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(grad_grid),
static_cast<GridSamplerPadding>(padding_mode),
(int)max_aniso,
align_corners,
force_max_ansio,
clip_grad,
grad_input_memory_span);
}
if (interpolation_mode == (int64_t)GridSamplerInterpolation::Bicubic) {
mipmap_aniso_grid_sampler_2d_backward_kernel<
scalar_t,
index_type,
GridSamplerInterpolation::Bicubic>
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<index_type>(count),
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(grad_output),
inputs,
mipmaps,
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(grid),
getTensorInfoCompact<scalar_t, index_type, uv_jacobian_ndim>(vt_dxdy_img),
grad_inputs,
getTensorInfoCompact<scalar_t, index_type, tex_ndim>(grad_grid),
static_cast<GridSamplerPadding>(padding_mode),
(int)max_aniso,
align_corners,
force_max_ansio,
clip_grad,
grad_input_memory_span);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
});
}
return std::make_tuple(grad_input, grad_grid);
}
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <torch/torch.h>
torch::Tensor mipmap_aniso_grid_sampler_2d_cuda(
const torch::TensorList& input,
const torch::Tensor& grid,
const torch::Tensor& vt_dxdy_img,
int64_t max_aniso,
int64_t padding_mode,
int64_t interpolation_mode,
bool align_corners,
bool force_max_ansio,
bool clip_grad);
std::tuple<std::vector<torch::Tensor>, torch::Tensor> mipmap_aniso_grid_sampler_2d_cuda_backward(
const torch::Tensor& grad_output,
const torch::TensorList& input,
const torch::Tensor& grid,
const torch::Tensor& vt_dxdy_img,
int64_t max_aniso,
int64_t padding_mode,
int64_t interpolation_mode,
bool align_corners,
bool force_max_ansio,
bool clip_grad);
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#include <torch/script.h>
#include <ATen/autocast_mode.h>
#ifndef NO_PYBIND
#include <torch/extension.h>
#endif
#include "mipmap_grid_sampler_kernel.h"
// Dispatch function
torch::Tensor mipmap_grid_sampler_2d(
const torch::TensorList& input,
const torch::Tensor& grid,
const torch::Tensor& vt_dxdy_img,
int64_t max_aniso,
int64_t padding_mode,
int64_t interpolation_mode,
bool align_corners,
bool force_max_ansio,
bool clip_grad) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("mipmap_grid_sampler_ext::mipmap_grid_sampler_2d", "")
.typed<decltype(mipmap_grid_sampler_2d)>();
return op.call(
input,
grid,
vt_dxdy_img,
max_aniso,
padding_mode,
interpolation_mode,
align_corners,
force_max_ansio,
clip_grad);
}
// Ideally we would need to turn off autograd handling and re-dispatch, but we just call
// cuda kernels directly
class MipmapGridSample2DFunction : public torch::autograd::Function<MipmapGridSample2DFunction> {
public:
static torch::autograd::tensor_list forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& grid,
const torch::Tensor& vt_dxdy_img,
int64_t max_aniso,
int64_t padding_mode,
int64_t interpolation_mode,
bool align_corners,
bool force_max_ansio,
bool clip_grad,
const torch::Tensor& input0,
const c10::optional<torch::Tensor>& input1,
const c10::optional<torch::Tensor>& input2,
const c10::optional<torch::Tensor>& input3,
const c10::optional<torch::Tensor>& input4,
const c10::optional<torch::Tensor>& input5,
const c10::optional<torch::Tensor>& input6,
const c10::optional<torch::Tensor>& input7,
const c10::optional<torch::Tensor>& input8,
const c10::optional<torch::Tensor>& input9,
const c10::optional<torch::Tensor>& input10) {
std::vector<torch::Tensor> input = {input0};
if (input1.has_value())
input.push_back(input1.value());
if (input2.has_value())
input.push_back(input2.value());
if (input3.has_value())
input.push_back(input3.value());
if (input4.has_value())
input.push_back(input4.value());
if (input5.has_value())
input.push_back(input5.value());
if (input6.has_value())
input.push_back(input6.value());
if (input7.has_value())
input.push_back(input7.value());
if (input8.has_value())
input.push_back(input8.value());
if (input9.has_value())
input.push_back(input9.value());
if (input10.has_value())
input.push_back(input10.value());
ctx->set_materialize_grads(false);
std::vector<torch::Tensor> save_list;
for (auto& inp : input) {
save_list.push_back(inp);
}
save_list.push_back(grid);
save_list.push_back(vt_dxdy_img);
ctx->save_for_backward(save_list);
bool requires_grad = false;
for (const auto& inp : input) {
requires_grad = requires_grad || inp.requires_grad();
}
requires_grad = requires_grad || grid.requires_grad();
ctx->saved_data["data"] = std::make_tuple(
(int64_t)input.size(),
requires_grad,
max_aniso,
padding_mode,
interpolation_mode,
align_corners,
force_max_ansio,
clip_grad);
auto out = mipmap_aniso_grid_sampler_2d_cuda(
input,
grid,
vt_dxdy_img,
max_aniso,
padding_mode,
interpolation_mode,
align_corners,
force_max_ansio,
clip_grad);
return {out};
}
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
int64_t mipmaps;
bool requires_grad;
int64_t max_aniso;
int64_t padding_mode;
int64_t interpolation_mode;
bool align_corners;
bool force_max_ansio;
bool clip_grad;
std::tie(
mipmaps,
requires_grad,
max_aniso,
padding_mode,
interpolation_mode,
align_corners,
force_max_ansio,
clip_grad) =
ctx->saved_data["data"]
.to<std::tuple<int64_t, bool, int64_t, int64_t, int64_t, bool, bool, bool>>();
torch::autograd::tensor_list out;
if (!requires_grad) {
out.resize(mipmaps + 2);
return out;
}
const auto saved = ctx->get_saved_variables();
std::vector<torch::Tensor> input(saved.begin(), saved.begin() + mipmaps);
torch::Tensor grid = saved[mipmaps];
torch::Tensor vt_dxdy_img = saved[mipmaps + 1];
auto grad_out = mipmap_aniso_grid_sampler_2d_cuda_backward(
grad_outputs[0],
input,
grid,
vt_dxdy_img,
max_aniso,
padding_mode,
interpolation_mode,
align_corners,
force_max_ansio,
clip_grad);
std::vector<torch::Tensor> grads;
grads.push_back(std::get<1>(grad_out));
grads.push_back(torch::Tensor());
grads.push_back(torch::Tensor());
grads.push_back(torch::Tensor());
grads.push_back(torch::Tensor());
grads.push_back(torch::Tensor());
grads.push_back(torch::Tensor());
grads.push_back(torch::Tensor());
for (auto& g : std::get<0>(grad_out)) {
grads.push_back(g);
}
while (grads.size() < 19) {
grads.push_back(torch::Tensor());
}
return grads;
}
};
torch::Tensor mipmap_grid_sampler_2d_autograd(
const torch::TensorList& input,
const torch::Tensor& grid,
const torch::Tensor& vt_dxdy_img,
int64_t max_aniso,
int64_t padding_mode,
int64_t interpolation_mode,
bool align_corners,
bool force_max_ansio,
bool clip_grad) {
return MipmapGridSample2DFunction::apply(
grid,
vt_dxdy_img,
max_aniso,
padding_mode,
interpolation_mode,
align_corners,
force_max_ansio,
clip_grad,
input[0],
input.size() > 1 ? input[1] : c10::optional<torch::Tensor>(),
input.size() > 2 ? input[2] : c10::optional<torch::Tensor>(),
input.size() > 3 ? input[3] : c10::optional<torch::Tensor>(),
input.size() > 4 ? input[4] : c10::optional<torch::Tensor>(),
input.size() > 5 ? input[5] : c10::optional<torch::Tensor>(),
input.size() > 6 ? input[6] : c10::optional<torch::Tensor>(),
input.size() > 7 ? input[7] : c10::optional<torch::Tensor>(),
input.size() > 8 ? input[8] : c10::optional<torch::Tensor>(),
input.size() > 9 ? input[9] : c10::optional<torch::Tensor>(),
input.size() > 10 ? input[10] : c10::optional<torch::Tensor>())[0];
}
torch::Tensor mipmap_grid_sampler_2d_autocast(
const torch::TensorList& input,
const torch::Tensor& grid,
const torch::Tensor& vt_dxdy_img,
int64_t max_aniso,
int64_t padding_mode,
int64_t interpolation_mode,
bool align_corners,
bool force_max_ansio,
bool clip_grad) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return mipmap_grid_sampler_2d(
at::autocast::cached_cast(torch::kFloat32, input),
at::autocast::cached_cast(torch::kFloat32, grid),
at::autocast::cached_cast(torch::kFloat32, vt_dxdy_img),
max_aniso,
padding_mode,
interpolation_mode,
align_corners,
force_max_ansio,
clip_grad);
}
#ifndef NO_PYBIND
// Just so that we can import this file as a Python module to get the path and
// import the Torch ops.
PYBIND11_MODULE(mipmap_grid_sampler_ext, m) {}
#endif
TORCH_LIBRARY(mipmap_grid_sampler_ext, m) {
m.def(
"mipmap_grid_sampler_2d(Tensor[] x, Tensor grid, Tensor vt_dxdy_img, int max_aniso, int padding_mode, int interpolation_mode, bool align_corners, bool force_max_ansio, bool clip_grad) -> Tensor");
}
TORCH_LIBRARY_IMPL(mipmap_grid_sampler_ext, Autograd, m) {
m.impl("mipmap_grid_sampler_2d", &mipmap_grid_sampler_2d_autograd);
}
TORCH_LIBRARY_IMPL(mipmap_grid_sampler_ext, Autocast, m) {
m.impl("mipmap_grid_sampler_2d", mipmap_grid_sampler_2d_autocast);
}
TORCH_LIBRARY_IMPL(mipmap_grid_sampler_ext, CUDA, m) {
m.impl("mipmap_grid_sampler_2d", &mipmap_aniso_grid_sampler_2d_cuda);
}
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/types.h>
#include <cassert>
#include <cuda_math_helper.h>
#include <grid_utils.h>
#include <kernel_utils.h>
using namespace math;
template <typename scalar_t, typename index_t>
__device__ inline typename math::TVec4<scalar_t> msi_sample_bilinear_cubic(
const TensorInfo<scalar_t, index_t>& input,
math::TVec3<scalar_t> uvw) {
typedef typename math::TVec2<scalar_t> scalar2_t;
typedef typename math::TVec3<scalar_t> scalar3_t;
typedef typename math::TVec4<scalar_t> scalar4_t;
index_t inp_N = input.sizes[0];
index_t inp_H = input.sizes[2];
index_t inp_W = input.sizes[3];
index_t inp_sN = input.strides[0];
index_t inp_sC = input.strides[1];
index_t inp_sH = input.strides[2];
index_t inp_sW = input.strides[3];
int3 size = {(int)inp_W, (int)inp_H, (int)inp_N};
scalar3_t i_uvw =
((uvw + 1.f) * scalar3_t({(float)size.x, (float)size.y, (float)size.z}) - 1.f) / 2.f;
i_uvw.x = safe_downgrade_to_int_range(clip_coordinates(i_uvw.x, size.x));
i_uvw.y = safe_downgrade_to_int_range(clip_coordinates(i_uvw.y, size.y));
i_uvw.z = safe_downgrade_to_int_range(clip_coordinates(i_uvw.z, size.z));
// get NE, NW, SE, SW pixel values from (x, y)
index_t ix_nw = static_cast<index_t>(::floor(i_uvw.x));
index_t iy_nw = static_cast<index_t>(::floor(i_uvw.y));
index_t iz_nw = static_cast<index_t>(::floor(i_uvw.z));
index_t ix_ne = ix_nw + 1;
index_t iy_ne = iy_nw;
index_t ix_sw = ix_nw;
index_t iy_sw = iy_nw + 1;
index_t ix_se = ix_nw + 1;
index_t iy_se = iy_nw + 1;
const scalar_t tz = i_uvw.z - iz_nw;
// get surfaces to each neighbor:
scalar_t nw = (ix_se - i_uvw.x) * (iy_se - i_uvw.y);
scalar_t ne = (i_uvw.x - ix_sw) * (iy_sw - i_uvw.y);
scalar_t sw = (ix_ne - i_uvw.x) * (i_uvw.y - iy_ne);
scalar_t se = (i_uvw.x - ix_nw) * (i_uvw.y - iy_nw);
scalar4_t coefficients[4];
#pragma unroll 4
for (index_t i = 0; i < 4; ++i) {
scalar_t z = clip_coordinates(iz_nw - 1 + i, size.z);
int iz = static_cast<int>(z);
auto inp_ptr_NC = input.data + iz * inp_sN;
scalar4_t out = {0, 0, 0, 0};
if (within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)) {
auto ptr = inp_ptr_NC + iy_nw * inp_sH + ix_nw * inp_sW;
out = out + load4(ptr, inp_sC) * nw;
}
if (within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)) {
auto ptr = inp_ptr_NC + iy_ne * inp_sH + ix_ne * inp_sW;
out = out + load4(ptr, inp_sC) * ne;
}
if (within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)) {
auto ptr = inp_ptr_NC + iy_sw * inp_sH + ix_sw * inp_sW;
out = out + load4(ptr, inp_sC) * sw;
}
if (within_bounds_2d(iy_se, ix_se, inp_H, inp_W)) {
auto ptr = inp_ptr_NC + iy_se * inp_sH + ix_se * inp_sW;
out = out + load4(ptr, inp_sC) * se;
}
coefficients[i] = out;
}
return cubic_interp1d<scalar_t, 4>(
coefficients[0], coefficients[1], coefficients[2], coefficients[3], tz);
}
template <typename scalar_t, typename index_t>
__device__ inline void msi_sample_bilinear_cubic_backward(
const TensorInfo<scalar_t, index_t>& grad_input,
math::TVec4<scalar_t> grad_output,
math::TVec3<scalar_t> uvw,
index_t grad_input_memory_span) {
typedef typename math::TVec2<scalar_t> scalar2_t;
typedef typename math::TVec3<scalar_t> scalar3_t;
typedef typename math::TVec4<scalar_t> scalar4_t;
index_t gInp_sN = grad_input.strides[0];
index_t gInp_sC = grad_input.strides[1];
index_t gInp_sH = grad_input.strides[2];
index_t gInp_sW = grad_input.strides[3];
index_t inp_N = grad_input.sizes[0];
index_t inp_H = grad_input.sizes[2];
index_t inp_W = grad_input.sizes[3];
int3 size = {(int)inp_W, (int)inp_H, (int)inp_N};
scalar3_t i_uvw =
((uvw + 1.f) * scalar3_t({(float)size.x, (float)size.y, (float)size.z}) - 1.f) / 2.f;
i_uvw.x = safe_downgrade_to_int_range(clip_coordinates(i_uvw.x, size.x));
i_uvw.y = safe_downgrade_to_int_range(clip_coordinates(i_uvw.y, size.y));
i_uvw.z = safe_downgrade_to_int_range(clip_coordinates(i_uvw.z, size.z));
// get NE, NW, SE, SW pixel values from (x, y)
index_t ix_nw = static_cast<index_t>(::floor(i_uvw.x));
index_t iy_nw = static_cast<index_t>(::floor(i_uvw.y));
index_t iz_nw = static_cast<index_t>(::floor(i_uvw.z));
index_t ix_ne = ix_nw + 1;
index_t iy_ne = iy_nw;
index_t ix_sw = ix_nw;
index_t iy_sw = iy_nw + 1;
index_t ix_se = ix_nw + 1;
index_t iy_se = iy_nw + 1;
const scalar_t tz = i_uvw.z - iz_nw;
// get surfaces to each neighbor:
scalar_t nw = (ix_se - i_uvw.x) * (iy_se - i_uvw.y);
scalar_t ne = (i_uvw.x - ix_sw) * (iy_sw - i_uvw.y);
scalar_t sw = (ix_ne - i_uvw.x) * (i_uvw.y - iy_ne);
scalar_t se = (i_uvw.x - ix_nw) * (i_uvw.y - iy_nw);
scalar_t coeffs[4];
get_cubic_upsampling_coefficients<scalar_t>(coeffs, tz);
#pragma unroll 4
for (index_t i = 0; i < 4; ++i) {
scalar_t z = clip_coordinates(iz_nw - 1 + i, size.z);
int iz = static_cast<int>(z);
index_t N_offset = iz * gInp_sN;
// calculate and set grad_input. See Note [Passing pointer and offset to
// fastAtomicAdd].
safe_add_2d4(
grad_input.data,
gInp_sC,
iy_nw,
ix_nw,
gInp_sH,
gInp_sW,
inp_H,
inp_W,
nw * grad_output * coeffs[i],
N_offset,
grad_input_memory_span);
safe_add_2d4(
grad_input.data,
gInp_sC,
iy_ne,
ix_ne,
gInp_sH,
gInp_sW,
inp_H,
inp_W,
ne * grad_output * coeffs[i],
N_offset,
grad_input_memory_span);
safe_add_2d4(
grad_input.data,
gInp_sC,
iy_sw,
ix_sw,
gInp_sH,
gInp_sW,
inp_H,
inp_W,
sw * grad_output * coeffs[i],
N_offset,
grad_input_memory_span);
safe_add_2d4(
grad_input.data,
gInp_sC,
iy_se,
ix_se,
gInp_sH,
gInp_sW,
inp_H,
inp_W,
se * grad_output * coeffs[i],
N_offset,
grad_input_memory_span);
}
}
__device__ __host__ __forceinline__ float2 direction_to_equirectangular(float3 d) {
const float longitude = atan2f(d.z, d.x);
const float latitude = atan2f(d.y, math::norm(float2{d.x, d.z}));
constexpr float inv_pi = M_1_PI;
return float2({longitude, 2 * latitude}) * inv_pi;
}
template <typename scalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(256)
__global__ void msi_forward_kernel(
const index_t nthreads,
TensorInfo<float, index_t> ray_o,
TensorInfo<float, index_t> ray_d,
TensorInfo<scalar_t, index_t> texture,
TensorInfo<scalar_t, index_t> rgba_img,
int sub_step_count,
double min_inv_r,
double max_inv_r,
double stop_thresh) {
typedef typename math::TVec4<scalar_t> scalar4_t;
typedef typename math::TVec3<scalar_t> scalar3_t;
const int n_layers = texture.sizes[0];
const int n_steps = n_layers * sub_step_count;
const index_t ray_o_sN = ray_o.strides[0];
const index_t ray_o_sC = ray_o.strides[1];
const index_t ray_d_sN = ray_d.strides[0];
const index_t ray_d_sC = ray_d.strides[1];
const index_t rgba_img_sN = rgba_img.strides[0];
const index_t rgba_img_sC = rgba_img.strides[1];
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
auto rgba_ptr = rgba_img.data + rgba_img_sN * index;
const float3 r_o = {
ray_o.data[ray_o_sN * index + ray_o_sC * 0],
ray_o.data[ray_o_sN * index + ray_o_sC * 1],
ray_o.data[ray_o_sN * index + ray_o_sC * 2]};
const float3 r_d = normalize(float3(
{ray_d.data[ray_d_sN * index + ray_d_sC * 0],
ray_d.data[ray_d_sN * index + ray_d_sC * 1],
ray_d.data[ray_d_sN * index + ray_d_sC * 2]}));
float tc = dot(-r_o, r_d);
float h2 = dot(r_o, r_o) - tc * tc;
float step_size = 1.0f / float(n_steps);
float3 out_v = {0.f, 0.f, 0.f};
float log_transmit = 0.f;
for (int i = 0; i < n_steps; ++i) {
const float a = (float(n_steps - 1 - i) + 0.5f) / float(n_steps);
const float inv_r = (1.0 - a) * max_inv_r + a * min_inv_r;
const float r = 1.0f / inv_r;
float det = r * r - h2;
if (det < 0.0f)
continue;
float t = tc + sqrt(det);
float3 pos = t * r_d + r_o;
const float w = 1.f - a * 2.f;
const float3 uvw = make_float3(direction_to_equirectangular(pos), w);
auto sample = msi_sample_bilinear_cubic(texture, uvw);
scalar3_t rgb = {sample.x, sample.y, sample.z};
float alpha = sample.w;
if (alpha > 0.0f) {
const float pcnt = alpha * step_size;
const float weight = __expf(log_transmit) * (1.f - __expf(-pcnt));
log_transmit -= pcnt;
out_v = out_v + weight * math::max(rgb, {0.f, 0.f, 0.f});
if (__expf(log_transmit) < stop_thresh) {
log_transmit = -1e3f;
break;
}
}
}
rgba_ptr[0 * rgba_img_sC] = out_v.x;
rgba_ptr[1 * rgba_img_sC] = out_v.y;
rgba_ptr[2 * rgba_img_sC] = out_v.z;
rgba_ptr[3 * rgba_img_sC] = log_transmit;
}
}
template <typename scalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(256)
__global__ void msi_backward_kernel(
const index_t nthreads,
TensorInfo<float, index_t> ray_o,
TensorInfo<float, index_t> ray_d,
TensorInfo<scalar_t, index_t> texture,
TensorInfo<scalar_t, index_t> texture_grad,
index_t texture_grad_memory_span,
TensorInfo<scalar_t, index_t> rgba_img,
TensorInfo<scalar_t, index_t> rgba_img_grad,
int sub_step_count,
double min_inv_r,
double max_inv_r,
double stop_thresh) {
typedef typename math::TVec4<scalar_t> scalar4_t;
typedef typename math::TVec3<scalar_t> scalar3_t;
const int n_layers = texture.sizes[0];
const int n_steps = n_layers * sub_step_count;
const index_t ray_o_sN = ray_o.strides[0];
const index_t ray_o_sC = ray_o.strides[1];
const index_t ray_d_sN = ray_d.strides[0];
const index_t ray_d_sC = ray_d.strides[1];
const index_t rgba_img_sN = rgba_img.strides[0];
const index_t rgba_img_sC = rgba_img.strides[1];
const index_t rgba_img_grad_sN = rgba_img_grad.strides[0];
const index_t rgba_img_grad_sC = rgba_img_grad.strides[1];
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
auto rgba_ptr = rgba_img.data + rgba_img_sN * index;
auto rgba_grad_ptr = rgba_img_grad.data + rgba_img_grad_sN * index;
scalar3_t out_v_grad = {
rgba_grad_ptr[0 * rgba_img_grad_sC],
rgba_grad_ptr[1 * rgba_img_grad_sC],
rgba_grad_ptr[2 * rgba_img_grad_sC]};
scalar3_t out_v_acc =
out_v_grad *
scalar3_t(
{rgba_ptr[0 * rgba_img_sC], rgba_ptr[1 * rgba_img_sC], rgba_ptr[2 * rgba_img_sC]});
const float3 r_o = {
ray_o.data[ray_o_sN * index + ray_o_sC * 0],
ray_o.data[ray_o_sN * index + ray_o_sC * 1],
ray_o.data[ray_o_sN * index + ray_o_sC * 2]};
const float3 r_d = normalize(float3(
{ray_d.data[ray_d_sN * index + ray_d_sC * 0],
ray_d.data[ray_d_sN * index + ray_d_sC * 1],
ray_d.data[ray_d_sN * index + ray_d_sC * 2]}));
float tc = dot(-r_o, r_d);
float h2 = dot(r_o, r_o) - tc * tc;
float step_size = 1.0f / float(n_steps);
float log_transmit = 0.f;
for (int i = 0; i < n_steps; ++i) {
const float a = (float(n_steps - 1 - i) + 0.5f) / float(n_steps);
const float inv_r = (1.0 - a) * max_inv_r + a * min_inv_r;
const float r = 1.0f / inv_r;
float det = r * r - h2;
if (det < 0.0f)
continue;
float t = tc + sqrt(det);
float3 pos = t * r_d + r_o;
const float w = 1.f - a * 2.f;
const float3 uvw = make_float3(direction_to_equirectangular(pos), w);
auto sample = msi_sample_bilinear_cubic(texture, uvw);
scalar3_t rgb = {sample.x, sample.y, sample.z};
float alpha = sample.w;
if (alpha > 0.0f) {
const float pcnt = alpha * step_size;
const float weight = __expf(log_transmit) * (1.f - __expf(-pcnt));
log_transmit -= pcnt;
auto rgb_01 = math::max(rgb, {0.f, 0.f, 0.f});
scalar3_t color_in_01 = scalar3_t(
{scalar_t(rgb_01.x == rgb.x),
scalar_t(rgb_01.y == rgb.y),
scalar_t(rgb_01.z == rgb.z)});
scalar3_t color_grad = color_in_01 * weight * out_v_grad;
out_v_acc -= weight * rgb_01 * out_v_grad;
float alpha_grad =
sum(rgb_01 * out_v_grad * __expf(-alpha) * __expf(log_transmit) - out_v_acc);
scalar4_t rgba_grad = make_float4(color_grad, alpha_grad);
msi_sample_bilinear_cubic_backward(texture_grad, rgba_grad, uvw, texture_grad_memory_span);
if (__expf(log_transmit) < stop_thresh) {
log_transmit = -1e3f;
break;
}
}
}
}
}
__host__ torch::Tensor msi_forward_cuda(
const torch::Tensor& ray_o,
const torch::Tensor& ray_d,
const torch::Tensor& texture,
int64_t sub_step_count,
double min_inv_r,
double max_inv_r,
double stop_thresh) {
TORCH_CHECK(sub_step_count > 0, "msi(): expected step_size > 0, but got ", sub_step_count);
TORCH_CHECK(
stop_thresh > 0 && stop_thresh < 1,
"msi(): expected 0 < stop_thresh < 1, but got ",
stop_thresh);
TORCH_CHECK(
min_inv_r > max_inv_r,
"msi(): expected min_inv_r to be greater than max_inv_r, but "
"got min_inv_r:",
min_inv_r,
" and max_inv_r: ",
max_inv_r);
TORCH_CHECK(
ray_o.defined() && ray_d.defined() && texture.defined(),
"msi(): expected all inputs not be undefined, but "
"ray_o is ",
ray_o,
", ray_d is ",
ray_d,
", texture is ",
texture);
auto ray_o_opt = ray_o.options();
auto ray_d_opt = ray_d.options();
auto texture_opt = texture.options();
auto device = ray_o_opt.device();
auto tex_dtype = texture_opt.dtype();
auto ray_dtype = ray_o_opt.dtype();
TORCH_CHECK(
device.is_cuda(), "msi(): expected inputs to be on CUDA device, but got ray_o on ", device);
const at::cuda::OptionalCUDAGuard device_guard(device);
TORCH_CHECK(
device == ray_o_opt.device() && device == ray_d_opt.device() &&
device == texture_opt.device(),
"msi(): expected all inputs to be on same device, but input "
"ray_o is ",
ray_o_opt.device(),
", ray_d is ",
ray_d_opt.device(),
", texture is ",
texture_opt.device());
TORCH_CHECK(
tex_dtype == torch::kFloat64 || tex_dtype == torch::kFloat32 || tex_dtype == torch::kHalf,
"msi(): expected texture to be of type Double, Float or "
"Half, but got type ",
texture_opt.dtype());
TORCH_CHECK(
ray_o_opt.dtype() == torch::kFloat32 && ray_d_opt.dtype() == torch::kFloat32,
"msi(): expected ray_o and ray_d to be of type Float, but "
"input ray_o is ",
ray_o_opt.dtype(),
" and ray_d is ",
ray_d_opt.dtype());
TORCH_CHECK(
torch::kStrided == ray_o_opt.layout() && torch::kStrided == ray_d_opt.layout() &&
torch::kStrided == texture_opt.layout(),
"msi(): expected all inputs to have torch.strided layout, but "
"ray_o has ",
ray_o_opt.layout(),
", ray_d has ",
ray_d_opt.layout(),
", texture has ",
texture_opt.layout());
TORCH_CHECK(
ray_o.dim() == 2 && ray_d.dim() == 2 && texture.dim() == 4,
"msi(): expected ray_o and ray_d to have 2 dimensions, "
"and texture to have 4 dimension, "
"but got ray_o with size ",
ray_o.sizes(),
", ray_d with size ",
ray_d.sizes(),
", texture with size ",
texture.sizes());
TORCH_CHECK(
ray_o.size(1) == 3 && ray_d.size(1) == 3 && texture.size(1) == 4,
"msi(): expected ray_o, ray_d to have size 3 along the dimension 1, "
" and texture to have size 4 along the dimension 1, "
"but got ray_o with size ",
ray_o.sizes(),
", ray_d with size ",
ray_d.sizes(),
", texture with size ",
texture.sizes());
TORCH_CHECK(
ray_o.size(0) == ray_d.size(0),
"msi(): expected ray_o, ray_d to have the same size along "
"the dimension 0, "
"but got ray_o with size ",
ray_o.sizes(),
", ray_d with size ",
ray_d.sizes());
int N = ray_o.size(0);
auto rgba_img = torch::empty({N, 4}, texture.options());
if (N > 0) {
DISPATCH_FLOAT(texture.scalar_type(), "msi_forward_kernel", [&] {
if (at::native::canUse32BitIndexMath(ray_o) && at::native::canUse32BitIndexMath(ray_d) &&
at::native::canUse32BitIndexMath(texture)) {
typedef int index_type;
msi_forward_kernel<scalar_t, index_type>
<<<GET_BLOCKS(N, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<index_type>(N),
getTensorInfo<float, index_type>(ray_o),
getTensorInfo<float, index_type>(ray_d),
getTensorInfo<scalar_t, index_type>(texture),
getTensorInfo<scalar_t, index_type>(rgba_img),
(int)sub_step_count,
min_inv_r,
max_inv_r,
stop_thresh);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
typedef int64_t index_type;
msi_forward_kernel<scalar_t, index_type>
<<<GET_BLOCKS(N, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<index_type>(N),
getTensorInfo<float, index_type>(ray_o),
getTensorInfo<float, index_type>(ray_d),
getTensorInfo<scalar_t, index_type>(texture),
getTensorInfo<scalar_t, index_type>(rgba_img),
(int)sub_step_count,
min_inv_r,
max_inv_r,
stop_thresh);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
});
}
return rgba_img;
}
torch::Tensor msi_backward_cuda(
const torch::Tensor& rgba_img,
const torch::Tensor& rgba_img_grad,
const torch::Tensor& ray_o,
const torch::Tensor& ray_d,
const torch::Tensor& texture,
int64_t sub_step_count,
double min_inv_r,
double max_inv_r,
double stop_thresh) {
auto ray_o_opt = ray_o.options();
auto ray_d_opt = ray_d.options();
auto texture_opt = texture.options();
auto device = ray_o_opt.device();
const at::cuda::OptionalCUDAGuard device_guard(device);
auto tex_dtype = texture_opt.dtype();
auto ray_dtype = ray_o_opt.dtype();
int N = ray_o.size(0);
auto texture_grad = torch::zeros_like(texture);
if (N > 0) {
DISPATCH_FLOAT(texture.scalar_type(), "msi_forward_kernel", [&] {
if (at::native::canUse32BitIndexMath(ray_o) && at::native::canUse32BitIndexMath(ray_d) &&
at::native::canUse32BitIndexMath(rgba_img) &&
at::native::canUse32BitIndexMath(rgba_img_grad) &&
at::native::canUse32BitIndexMath(texture_grad) &&
at::native::canUse32BitIndexMath(texture)) {
typedef int index_type;
index_type texture_grad_memory_span = texture_grad.numel();
msi_backward_kernel<scalar_t, index_type>
<<<GET_BLOCKS(N, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<index_type>(N),
getTensorInfo<float, index_type>(ray_o),
getTensorInfo<float, index_type>(ray_d),
getTensorInfo<scalar_t, index_type>(texture),
getTensorInfo<scalar_t, index_type>(texture_grad),
texture_grad_memory_span,
getTensorInfo<scalar_t, index_type>(rgba_img),
getTensorInfo<scalar_t, index_type>(rgba_img_grad),
(int)sub_step_count,
min_inv_r,
max_inv_r,
stop_thresh);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
typedef int64_t index_type;
index_type texture_grad_memory_span = texture_grad.numel();
msi_backward_kernel<scalar_t, index_type>
<<<GET_BLOCKS(N, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<index_type>(N),
getTensorInfo<float, index_type>(ray_o),
getTensorInfo<float, index_type>(ray_d),
getTensorInfo<scalar_t, index_type>(texture),
getTensorInfo<scalar_t, index_type>(texture_grad),
texture_grad_memory_span,
getTensorInfo<scalar_t, index_type>(rgba_img),
getTensorInfo<scalar_t, index_type>(rgba_img_grad),
(int)sub_step_count,
min_inv_r,
max_inv_r,
stop_thresh);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
});
}
return texture_grad;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment