".github/git@developer.sourcefind.cn:ox696c/ktransformers.git" did not exist on "c80490a95e935c783686401287301a550cc2f5f2"
Commit ecc1df99 authored by facebook-github-bot's avatar facebook-github-bot
Browse files

Initial commit

fbshipit-source-id: afc575e8e7d8e2796a3f77d8b1c6c4fcb999558d
parents
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Sequence, Tuple
import torch as th
from drtk.utils import project_points
def transform(
v: th.Tensor,
campos: Optional[th.Tensor] = None,
camrot: Optional[th.Tensor] = None,
focal: Optional[th.Tensor] = None,
princpt: Optional[th.Tensor] = None,
K: Optional[th.Tensor] = None,
Rt: Optional[th.Tensor] = None,
distortion_mode: Optional[Sequence[str]] = None,
distortion_coeff: Optional[th.Tensor] = None,
fov: Optional[th.Tensor] = None,
) -> th.Tensor:
"""
v: Tensor, N x V x 3
Batch of vertex positions for vertices in the mesh.
campos: Tensor, N x 3
Camera position.
camrot: Tensor, N x 3 x 3
Camera rotation matrix.
focal: Tensor, N x 2 x 2
Focal length [[fx, 0],
[0, fy]]
princpt: Tensor, N x 2
Principal point [cx, cy]
K: Tensor, N x 3 x 3
Camera intrinsic calibration matrix. Either this or both (focal,
princpt) must be provided.
Rt: Tensor, N x 3 x 4 or N x 4 x 4
Camera extrinsic matrix. Either this or both (camrot, campos) must be
provided. Camrot is the upper 3x3 of Rt, campos = -R.T @ t.
distortion_mode: List[str]
Names of the distortion modes.
distortion_coeff: Tensor, N x 4
Distortion coefficients.
fov: Tensor, N x 1
Valid field of view of the distortion model.
"""
v_pix, _ = transform_with_v_cam(
v, campos, camrot, focal, princpt, K, Rt, distortion_mode, distortion_coeff, fov
)
return v_pix
def transform_with_v_cam(
v: th.Tensor,
campos: Optional[th.Tensor] = None,
camrot: Optional[th.Tensor] = None,
focal: Optional[th.Tensor] = None,
princpt: Optional[th.Tensor] = None,
K: Optional[th.Tensor] = None,
Rt: Optional[th.Tensor] = None,
distortion_mode: Optional[Sequence[str]] = None,
distortion_coeff: Optional[th.Tensor] = None,
fov: Optional[th.Tensor] = None,
) -> Tuple[th.Tensor, th.Tensor]:
"""
Same as transform, but also returns the camera-space coordinates.
In most cases it is not needed, but renderlayer depends on it
"""
if not ((camrot is not None and campos is not None) ^ (Rt is not None)):
raise ValueError("You must provide exactly one of Rt or (campos, camrot).")
if not ((focal is not None and princpt is not None) ^ (K is not None)):
raise ValueError("You must provide exactly one of K or (focal, princpt).")
if campos is None:
assert Rt is not None
camrot = Rt[:, :3, :3]
campos = -(camrot.transpose(-2, -1) @ Rt[:, :3, 3:4])[..., 0]
if focal is None:
assert K is not None
focal = K[:, :2, :2]
princpt = K[:, :2, 2]
assert camrot is not None
assert princpt is not None
# Compute camera-space 3D coordinates and 2D pixel-space projections.
v_pix, v_cam = project_points(
v, campos, camrot, focal, princpt, distortion_mode, distortion_coeff, fov
)
return v_pix, v_cam
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from drtk.utils.geometry import ( # noqa
face_dpdt,
face_info,
vert_binormals,
vert_normals,
)
from drtk.utils.indexing import index # noqa
from drtk.utils.projection import ( # noqa
DISTORTION_MODES, # noqa
project_points, # noqa
project_points_grad, # noqa
)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional, Tuple, Union
import torch as th
import torch.nn.functional as thf
from drtk.utils.indexing import index
from torch import Tensor
eps = 1e-8
def face_dpdt(
v: th.Tensor, vt: th.Tensor, vi: th.Tensor, vti: th.Tensor
) -> Tuple[th.Tensor, th.Tensor]:
"""
This function calculates the transposed Jacobian matrix (∂p/∂t)^T for each triangle.
Where:
- p represents the 3D coordinates of a point on the plane of the triangle,
- t denotes the UV coordinates assiciated with the point.
Args:
v: vertex position tensor
N x V x 3
vt: vertex uv tensor
N x T x 2
vi: face vertex position index list tensor
F x 3
vti: face vertex uv index list tensor
F x 3
Jacobian is computed as:
∂p/∂t = ∂p / ∂b * (∂t / ∂b)^-1
Where b - barycentric coordinates
However the implementation computes a transposed Jacobian (purely from
practical perspective - fewer permutations are needed), so the above
becomes:
(∂p/∂t)^T = ((∂t / ∂b)^T)^-1 * (∂p / ∂b)^T
Returns:
dpdt - transposed Jacobian (∂p/∂t)^T. Shape: N x F x 2 x 3
Where ∂p∂t[..., i, j] = ∂p[..., j] / ∂t[..., i]
v012 - vertex positions per triangle. Shape: N x F x 3
where: N - batch size; F - number of triangles
"""
v012 = v[:, vi]
vt012 = vt[:, vti]
dpdb_t = v012[:, :, 1:3] - v012[:, :, 0:1]
dtdb_t = vt012[:, :, 1:3] - vt012[:, :, 0:1]
# (db / dt)^T = ((dt / db)^T)^-1
dbdt_t = th.inverse(dtdb_t)
# (dp / dt)^T = (db / dt)^T) * (dp / db)^T
dpdt_t = dbdt_t @ dpdb_t
return dpdt_t, v012
def face_attribute_to_vert(v: th.Tensor, vi: th.Tensor, attr: th.Tensor) -> Tensor:
"""
For each vertex, computes a summation of the face attributes to which the
vertex belongs.
"""
attr = (
attr[:, :, None]
.expand(-1, -1, 3, -1)
.reshape(attr.shape[0], -1, attr.shape[-1])
)
vi_flat = vi.view(vi.shape[0], -1).expand(v.shape[0], -1)
vattr = th.zeros(v.shape[:-1], dtype=v.dtype, device=v.device)
vattr = th.stack(
[vattr.scatter_add(1, vi_flat, attr[..., i]) for i in range(attr.shape[-1])],
dim=-1,
)
return vattr
def face_info(
v: th.Tensor, vi: th.Tensor, to_compute: Optional[List[str]] = None
) -> Union[th.Tensor, Dict[str, th.Tensor]]:
"""Given a set of vertices ``v`` and indices ``vi`` indexing into ``v``
defining a set of faces, compute face information (normals, edges, face
areas) for each face.
Args:
v: Vertex positions, shape [batch_size, n_vertices, 3]
vi: Vertex indices, shape [n_faces, 3]
to_compute: list of desired information. Any of: {normals, edges, areas}, defaults to all.
Returns:
Dict: Face information in the following format::
{
"normals": shape [batch_size, n_faces, 3]
"edges": shape [batch_size, n_faces, 3, 3]
"areas": shape [batch_size, n_faces, 1]
}
or just one of the above values not in a Dict if only one is
requested.
"""
if to_compute is None:
to_compute = ["normals", "edges", "areas"]
b = v.shape[0]
vi = vi.expand(b, -1, -1)
p0 = th.stack([index(v[i], vi[i, :, 0], 0) for i in range(b)])
p1 = th.stack([index(v[i], vi[i, :, 1], 0) for i in range(b)])
p2 = th.stack([index(v[i], vi[i, :, 2], 0) for i in range(b)])
v0 = p1 - p0
v1 = p0 - p2
need_normals = "normals" in to_compute
need_areas = "areas" in to_compute
need_edges = "edges" in to_compute
output = {}
if need_normals or need_areas:
normals = th.cross(v1, v0, dim=-1)
norm = th.linalg.vector_norm(normals, dim=-1, keepdim=True)
if need_areas:
output["areas"] = 0.5 * norm
if need_normals:
output["normals"] = normals / norm.clamp(min=eps)
if need_edges:
v2 = p2 - p1
output["edges"] = th.stack([v0, v1, v2], dim=2)
if len(to_compute) == 1:
return output[to_compute[0]]
else:
return output
def vert_binormals(v: Tensor, vt: Tensor, vi: Tensor, vti: Tensor) -> Tensor:
# Compute (dp/dt)^T
dpdt_t, vf = face_dpdt(v, vt, vi, vti)
# Take the dp/dt.u part. Produces u vector in 3D world-space which we use for binormal vector
fbnorms = dpdt_t[:, :, 0, :]
vbnorms = face_attribute_to_vert(v, vi, fbnorms)
return thf.normalize(vbnorms, dim=-1)
def vert_normals(
v: th.Tensor, vi: th.Tensor, fnorms: Optional[th.Tensor] = None
) -> th.Tensor:
"""Given a set of vertices ``v`` and indices ``vi`` indexing into ``v``
defining a set of faces, compute normals for each vertex by averaging the
face normals for each face which includes that vertex.
Args:
v: Vertex positions, shape [batch_size, n_vertices, 3]
vi: Vertex indices, shape [batch_size, n_faces, 3]
fnorms: Face normals. Optional, provide them if available, otherwise they will be computed
from `v` and `vi`. Shape [n_faces, 3]
Returns:
th.Tensor: Vertex normals, shape [batch_size, n_vertices, 3]
"""
if fnorms is None:
fnorms = face_info(v, vi, ["normals"])
assert isinstance(fnorms, th.Tensor)
vnorms = face_attribute_to_vert(v, vi, fnorms)
return thf.normalize(vnorms, dim=-1)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch as th
def index(x: th.Tensor, idxs: th.Tensor, dim: int) -> th.Tensor:
"""Index a tensor along a given dimension using an index tensor, replacing
the shape along the given dimension with the shape of the index tensor.
Example:
x: [8, 7306, 3]
idxs: [11000, 3]
y = index(x, idxs, dim=1) -> y: [8, 11000, 3, 3]
with each y[b, i, j, k] = x[b, idxs[i, j], k]
"""
target_shape = [*x.shape]
del target_shape[dim]
target_shape[dim:dim] = [*idxs.shape]
return x.index_select(dim, idxs.view(-1)).reshape(target_shape)
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Sequence, Set, Tuple, Union
import numpy as np
import torch as th
DISTORTION_MODES: Set[Optional[str]] = {None, "radial-tangential", "fisheye"}
def project_pinhole(
v_cam: th.Tensor, focal: th.Tensor, princpt: th.Tensor
) -> th.Tensor:
"""Project camera-space points to pixel-space points with camera
intrinsics.
v_cam: N x V x 3
focal: N x 2 x 2
princpt: N x 2
"""
z = v_cam[:, :, 2:3]
z = th.where(z < 0, z.clamp(max=-1e-8), z.clamp(min=1e-8))
v_proj = v_cam[:, :, 0:2] / z
v_pix = (focal[:, None] @ v_proj[..., None])[..., 0] + princpt[:, None]
return v_pix
def project_pinhole_distort_rt(
v_cam: th.Tensor,
focal: th.Tensor,
princpt: th.Tensor,
D: th.Tensor,
fov: Optional[th.Tensor] = None,
) -> th.Tensor:
"""Project camera-space points to distorted pixel-space using the radial
and tangential model (4 parameters).
v_cam: N x V x 3
focal: N x 2 x 2
princpt: N x 2
D: N x 4
fov: N x 1
"""
# See https://docs.opencv.org/2.4/doc/tutorials/calib3d/camera_calibration/camera_calibration.html
if fov is None:
with th.no_grad():
fov = estimate_rt_fov(D)
z = v_cam[:, :, 2:3]
z = th.where(z < 0, z.clamp(max=-1e-8), z.clamp(min=1e-8))
v_proj = v_cam[:, :, :2] / z
r2 = v_proj.pow(2).sum(-1)
# Clamp x, y and r to avoid wrapping behavior of the distortion model.
r2 = r2.clamp(max=fov.pow(2))
v_clamped = v_proj.clamp(min=-fov[..., None], max=fov[..., None])
assert D.shape[1] in [4, 5, 8]
# 4 param: R = (1 + k1 r^2 + k2 r^4)
R = 1 + D[:, 0:1] * r2 + D[:, 1:2] * r2.pow(2)
# 5 param: R = (1 + k1 r^2 + k2 r^4 + k3 r^6)
if D.shape[1] == 5:
R = R + D[:, 4:5] * r2.pow(3)
# 8 param: R = (1 + k1 r^2 + k2 r^4 + k3 r^6) / (1 + k4 r^2 + k5 r^4 + k6 r^6)
if D.shape[1] == 8:
R = R + D[:, 4:5] * r2.pow(3)
R = R / (1 + D[:, 5:6] * r2 + D[:, 6:7] * r2.pow(2) + D[:, 7:8] * r2.pow(3))
# [x' y'] * R
v_proj_dist = v_proj * R[..., None]
# [2 p1 x' y', 2 p2 x' y']
v_proj_dist += (
2
* v_clamped[..., 0:1]
* v_clamped[..., 1:2]
* th.stack((D[:, 2:3], D[:, 3:4]), dim=-1)
)
# [p2 r^2, p1 r^2]
v_proj_dist += r2[..., None] * th.stack((D[:, 3:4], D[:, 2:3]), dim=-1)
# [2 p2 x'^2, 2 p1 y'^2]
v_proj_dist += th.stack(
(
2 * D[:, 3:4] * v_clamped[..., 0].pow(2),
2 * D[:, 2:3] * v_clamped[..., 1].pow(2),
),
dim=-1,
)
v_pix_dist = (focal[:, None] @ v_proj_dist[..., None])[..., 0] + princpt[:, None]
return v_pix_dist
def project_fisheye_distort(
v_cam: th.Tensor,
focal: th.Tensor,
princpt: th.Tensor,
D: th.Tensor,
fov: Optional[th.Tensor] = None,
) -> th.Tensor:
"""Project camera-space points to distort pixel-space points using the
fisheye distortion model.
v_cam: N x V x 3
focal: N x 2 x 2
princpt: N x 2
D: N x 4
fov: N x 1
"""
# See https://github.com/opencv/opencv/blob/master/modules/calib3d/src/fisheye.cpp
if fov is None:
with th.no_grad():
fov = estimate_fisheye_fov(D)
z = v_cam[:, :, 2:3]
z = th.where(z < 0, z.clamp(max=-1e-8), z.clamp(min=1e-8))
v_proj = v_cam[:, :, :2] / z
r = v_proj.pow(2).sum(-1).sqrt()
r = r.clamp(max=fov, min=1e-8 * th.ones_like(fov))
theta = th.atan(r)
theta_d = theta * (
1
+ D[:, 0:1] * theta.pow(2)
+ D[:, 1:2] * theta.pow(4)
+ D[:, 2:3] * theta.pow(6)
+ D[:, 3:4] * theta.pow(8)
)
r = th.where(r < 0, r.clamp(max=-1e-8), r.clamp(min=1e-8))
v_proj_dist = v_proj * (theta_d / r)[..., None]
v_pix_dist = (focal[:, None] @ v_proj_dist[..., None])[..., 0] + princpt[:, None]
return v_pix_dist
def project_fisheye_distort_62(
v_cam: th.Tensor,
focal: th.Tensor,
princpt: th.Tensor,
D: th.Tensor,
fov: Optional[th.Tensor] = None,
) -> th.Tensor:
"""Project camera-space points to distort pixel-space points using the
OculusVisionFishEye62 distortion model.
v_cam: N x V x 3
focal: N x 2 x 2
princpt: N x 2
D: N x 4
fov: N x 1
"""
# See https://www.internalfb.com/code/fbsource/[188bdaeaad64]/arvr/projects/nimble/prod/pynimble/visualization/shaders.py?lines=103-123
# a more readible version: https://euratom-software.github.io/calcam/html/intro_theory.html
if fov is None:
with th.no_grad():
fov = estimate_fisheye_fov(D)
z = v_cam[:, :, 2:3]
z = th.where(z < 0, z.clamp(max=-1e-8), z.clamp(min=1e-8))
v_proj = v_cam[:, :, :2] / z
r = v_proj.pow(2).sum(-1).sqrt() # rp
r = r.clamp(max=fov, min=1e-8 * th.ones_like(fov))
theta = th.atan(r)
theta_d = theta * (
1
+ D[:, 0:1] * theta.pow(2)
+ D[:, 1:2] * theta.pow(4)
+ D[:, 2:3] * theta.pow(6)
+ D[:, 3:4] * theta.pow(8)
+ D[:, 4:5] * theta.pow(10)
+ D[:, 5:6] * theta.pow(12)
)
r = th.where(r < 0, r.clamp(max=-1e-8), r.clamp(min=1e-8))
v_proj_dist = v_proj * (theta_d / r)[..., None]
# Tangential Distortion
x = v_proj_dist[:, :, 0]
y = v_proj_dist[:, :, 1]
xtan = D[:, 6:7] * (r.pow(2) + 2 * x.pow(2)) + 2 * D[:, 7:8] * x * y
ytan = 2 * D[:, 6:7] * x * y + D[:, 7:8] * (r.pow(2) + 2 * y.pow(2))
pTangential = th.cat([xtan[..., None], ytan[..., None]], dim=-1)
v_proj_dist = v_proj_dist + pTangential
v_pix_dist = (focal[:, None] @ v_proj_dist[..., None])[..., 0] + princpt[:, None]
return v_pix_dist
def estimate_rt_fov(D: Union[np.ndarray, th.Tensor]) -> th.Tensor:
"""Estimate the maximum field of view based on the assumption that the 5th order
polynomial for fish-eye effect is non-decreasing.
D: N x 4
"""
if th.is_tensor(D):
coefs = D.cpu().numpy()
else:
coefs = D
ones = np.ones_like(coefs[:, 0])
zeros = np.zeros_like(coefs[:, 0])
coefs = np.stack(
[
5 * coefs[:, 1],
zeros,
3 * coefs[:, 0],
zeros,
ones,
],
axis=-1,
)
fov = []
for coef in coefs:
roots = np.roots(coef)
real_valued = roots.real[abs(roots.imag) < 1e-5]
positive_roots = real_valued[real_valued > 0]
if len(positive_roots) == 0:
fov.append(np.inf)
else:
fov.append(positive_roots.min())
fov = np.asarray(fov, dtype=np.float32)[..., None]
if th.is_tensor(D):
fov = th.from_numpy(fov).to(D)
return fov
def estimate_fisheye_fov(D: Union[np.ndarray, th.Tensor]) -> th.Tensor:
"""Estimate the maximum field of view based on the assumption that the 9th order
polynomial is non-decreasing.
D: N x 4
"""
if th.is_tensor(D):
coefs = D.cpu().numpy()
else:
coefs = D
ones = np.ones_like(coefs[:, 0])
zeros = np.zeros_like(coefs[:, 0])
coefs = np.stack(
[
9 * coefs[:, -1],
zeros,
7 * coefs[:, -2],
zeros,
5 * coefs[:, -3],
zeros,
3 * coefs[:, -4],
zeros,
ones,
],
axis=-1,
)
fov = []
for coef in coefs:
roots = np.roots(coef)
real_valued = roots.real[abs(roots.imag) < 1e-5]
positive_roots = real_valued[real_valued > 0]
if len(positive_roots) == 0:
fov.append(np.pi / 2)
else:
fov.append(min(positive_roots.min(), np.pi / 2))
fov = np.asarray(np.tan(fov), dtype=np.float32)[..., None]
if th.is_tensor(D):
fov = th.from_numpy(fov).to(D)
return fov
def project_points(
v: th.Tensor,
campos: th.Tensor,
camrot: th.Tensor,
focal: th.Tensor,
princpt: th.Tensor,
distortion_mode: Optional[Sequence[str]] = None,
distortion_coeff: Optional[th.Tensor] = None,
fov: Optional[th.Tensor] = None,
) -> Tuple[th.Tensor, th.Tensor]:
"""Project 3D world-space vertices to pixel-space, optionally applying a
distortion model with provided coefficients.
Returns v_pix, v_cam, both N x V x 3 since we preserve the camera-space
Z-coordinate for v_pix.
v: N x V x 3
camrot: N x 3 x 3
campos: N x 3
focal: N x 2 x 2
princpt: N x 2
distortion_coeff: N x 4
fov: N x 1
"""
if distortion_mode is not None:
assert distortion_coeff is not None, "Missing distortion coefficients."
v_cam = (camrot[:, None] @ (v - campos[:, None])[..., None])[..., 0]
# Fall back to single distortion mode if all the distortion modes are the same.
if isinstance(distortion_mode, (list, tuple)):
modes = list(set(distortion_mode))
if len(modes) == 0:
distortion_mode = None
elif len(modes) == 1:
distortion_mode = modes[0]
if distortion_mode is None:
v_pix = project_pinhole(v_cam, focal, princpt)
elif isinstance(distortion_mode, str):
assert distortion_coeff is not None
# Single distortion model
if distortion_mode == "radial-tangential":
v_pix = project_pinhole_distort_rt(
v_cam, focal, princpt, distortion_coeff, fov
)
elif distortion_mode == "fisheye":
v_pix = project_fisheye_distort(
v_cam, focal, princpt, distortion_coeff, fov
)
elif distortion_mode == "fisheye62":
v_pix = project_fisheye_distort_62(
v_cam, focal, princpt, distortion_coeff, fov
)
else:
raise ValueError(
f"Invalid distortion mode: {distortion_mode}. Valid options: {DISTORTION_MODES}."
)
elif isinstance(distortion_mode, (list, tuple)):
assert distortion_coeff is not None
# A mix of multiple distortion modes
modes = set(distortion_mode)
if not modes <= DISTORTION_MODES:
raise ValueError(
f"Invalid distortion mode: {distortion_mode}. Valid options: {DISTORTION_MODES}."
)
v_pix = th.empty_like(v_cam[..., :2])
if None in modes:
idx = th.tensor(
[mode is None for mode in distortion_mode], device=v_pix.device
)
v_pix[idx] = project_pinhole(v_cam[idx], focal[idx], princpt[idx])
if "radial-tangential" in modes:
idx = th.tensor(
[mode == "radial-tangential" for mode in distortion_mode],
device=v_pix.device,
)
v_pix[idx] = project_pinhole_distort_rt(
v_cam[idx],
focal[idx],
princpt[idx],
distortion_coeff[idx],
fov[idx] if fov is not None else None,
)
if "fisheye" in modes:
idx = th.tensor(
[mode == "fisheye" for mode in distortion_mode], device=v_pix.device
)
v_pix[idx] = project_fisheye_distort(
v_cam[idx],
focal[idx],
princpt[idx],
distortion_coeff[idx],
fov[idx] if fov is not None else None,
)
else:
raise ValueError(
f"Invalid distortion mode: {distortion_mode}. Valid options: {DISTORTION_MODES}."
)
v_pix = th.cat((v_pix[:, :, 0:2], v_cam[:, :, 2:3]), dim=-1)
return v_pix, v_cam
def project_points_grad(
v_grad: th.Tensor,
v: th.Tensor,
campos: th.Tensor,
camrot: th.Tensor,
focal: th.Tensor,
distortion_mode: Optional[Sequence[str]] = None,
distortion_coeff: Optional[th.Tensor] = None,
) -> th.Tensor:
"""Computes the gradient of projected (pixel-space) vertex positions with
respect to the 3D world-space vertex positions given the gradient of the 3D
world-space vertex positions.
project_points_grad(dv, v) = d project_points(v) / dv * dv
Args:
v_grad: Gradient of 3D world-space vertices. Shape: N x V x 3
v: 3D world-space vertices. Shape: N x V x 3
camrot: Camera rotation. Shape: N x 3 x 3
camrot: Camera position. Shape: N x 3
focal: Focal length. Shape: N x 2 x 2
distortion_mode: Distortion currently not implemented and must be None.
distortion_coeff: Distortion currently not implemented and must be None.
Returns:
Gradient of 2D pixel-space vertices: N x V x 2
"""
if distortion_mode is not None:
assert distortion_coeff is not None, "Missing distortion coefficients."
# d v_cam = d (Rv + T) = Rdv
v_cam_grad = (camrot[:, None] @ v_grad[..., None])[..., 0]
v_cam = (camrot[:, None] @ (v - campos[:, None])[..., None])[..., 0]
if distortion_mode is None:
z = v_cam[:, :, 2:3]
z_grad = v_cam_grad[:, :, 2:3]
z = th.where(z < 0, z.clamp(max=-1e-8), z.clamp(min=1e-8))
# Using quotient rule:
# d (v_cam / z) = (d v_cam * z - v_cam * dz) / z^2
v_proj_grad = (v_cam_grad[:, :, 0:2] * z - v_cam[:, :, 0:2] * z_grad) / z**2.0
# d v_pix = d (Kv + cp) = Kdv
v_pix_grad = (focal[:, None] @ v_proj_grad[..., None])[..., 0]
elif distortion_mode == "radial-tangential":
raise NotImplementedError
elif distortion_mode == "fisheye":
raise NotImplementedError
else:
raise ValueError(
f"Invalid distortion mode: {distortion_mode}. Valid options: {DISTORTION_MODES}."
)
return v_pix_grad
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import platform
import re
import sys
from pkg_resources import DistributionNotFound, get_distribution
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
def main(debug: bool) -> None:
extra_link_args = {
"linux": ["-static-libgcc"] + ([] if debug else ["-flto"]),
"win32": ["/DEBUG"] if debug else [],
}
cxx_args = {
"linux": ["-std=c++17", "-Wall"]
+ (["-O0", "-g3", "-DDEBUG"] if debug else ["-O3", "--fast-math"]),
"win32": (
["/std:c++17", "/MT", "/GR-", "/EHsc", '/D "NOMINMAX"']
+ (
["/Od", '/D "_DEBUG"']
if debug
else ["/O2", "/fp:fast", "/GL", '/D "NDEBUG"']
)
),
}
nvcc_args = [
"-gencode=arch=compute_72,code=sm_72",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_86,code=sm_86",
"-gencode=arch=compute_90,code=sm_90",
] + (["-O0", "-g", "-DDEBUG"] if debug else ["-O3", "--use_fast_math"])
# There is som issue effecting latest NVCC and pytorch 2.3.0 https://github.com/pytorch/pytorch/issues/122169
# The workaround is adding -std=c++20 to NVCC args
nvcc_args.append("-std=c++20")
def get_dist(name):
try:
return get_distribution(name)
except DistributionNotFound:
return None
root_path = os.path.dirname(os.path.abspath(__file__))
package_name = "drtk"
with open(os.path.join(root_path, "drtk", "__init__.py")) as f:
init_file = f.read()
pattern = re.compile(r"__version__\s*=\s*\"(\d*\.\d*.\d*)\"")
groups = pattern.findall(init_file)
assert len(groups) == 1
version = groups[0]
if get_dist("torch") is None:
raise RuntimeError("Setup requires torch package to be installed")
import torch as th
assert th.cuda.is_available()
target_os = "none"
if sys.platform == "darwin":
target_os = "macos"
elif os.name == "posix":
target_os = "linux"
elif platform.system() == "Windows":
target_os = "win32"
if target_os == "none":
raise RuntimeError("Could not detect platform")
if target_os == "macos":
raise RuntimeError("Platform is not supported")
include_dir = [os.path.join(root_path, "src", "include")]
with open("README.md") as f:
readme = f.read()
pillow = "pillow" if get_dist("pillow-simd") is None else "pillow-simd"
setup(
name=package_name,
version=version,
author="Reality Labs, Meta",
description="Differentiable Rendering Toolkit",
long_description=readme,
long_description_content_type="text/markdown",
license="MIT",
install_requires=["numpy", "torch", "torchvision", pillow],
ext_modules=[
CUDAExtension(
name="drtk.rasterize_ext",
sources=[
"src/rasterize/rasterize_module.cpp",
"src/rasterize/rasterize_kernel.cu",
],
extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
extra_link_args=extra_link_args[target_os],
include_dirs=include_dir,
),
CUDAExtension(
"drtk.render_ext",
sources=["src/render/render_kernel.cu", "src/render/render_module.cpp"],
extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
include_dirs=include_dir,
),
CUDAExtension(
"drtk.edge_grad_ext",
sources=[
"src/edge_grad/edge_grad_module.cpp",
"src/edge_grad/edge_grad_kernel.cu",
],
extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
include_dirs=include_dir,
),
CUDAExtension(
"drtk.mipmap_grid_sampler_ext",
sources=[
"src/mipmap_grid_sampler/mipmap_grid_sampler_module.cpp",
"src/mipmap_grid_sampler/mipmap_grid_sampler_kernel.cu",
],
extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
include_dirs=include_dir,
),
CUDAExtension(
"drtk.msi_ext",
sources=[
"src/msi/msi_module.cpp",
"src/msi/msi_kernel.cu",
],
extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
include_dirs=include_dir,
),
CUDAExtension(
"drtk.interpolate_ext",
sources=[
"src/interpolate/interpolate_module.cpp",
"src/interpolate/interpolate_kernel.cu",
],
extra_compile_args={"cxx": cxx_args[target_os], "nvcc": nvcc_args},
include_dirs=include_dir,
),
],
cmdclass={"build_ext": BuildExtension},
packages=["drtk", "drtk.utils"],
)
if __name__ == "__main__":
main(any(x in sys.argv for x in ["debug", "-debug", "--debug"]))
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#include <c10/cuda/CUDAGuard.h>
#include <cuda_math_helper.h>
#include <torch/types.h>
#include <ATen/native/cuda/KernelUtils.cuh>
#include <kernel_utils.h>
#include "edge_grad_kernel.h"
using namespace math;
using at::native::fastAtomicAdd;
template <typename scalar_t>
struct TriInfo {
typedef typename math::TVec2<scalar_t> scalar2_t;
const scalar2_t p_0;
const scalar2_t p_1;
const scalar2_t v_01;
const scalar2_t v_02;
const scalar2_t v_12;
const scalar_t denominator;
};
template <typename scalar_t>
__device__ bool pix_in_tri(const TriInfo<scalar_t>& tri, const int x, const int y) {
typedef typename math::TVec2<scalar_t> scalar2_t;
typedef typename math::TVec3<scalar_t> scalar3_t;
if (tri.denominator != 0.f) {
const scalar2_t p = {(scalar_t)x, (scalar_t)y};
const scalar2_t vp0p = p - tri.p_0;
const scalar2_t vp1p = p - tri.p_1;
scalar3_t bary = scalar3_t({
vp1p.y * tri.v_12.x - vp1p.x * tri.v_12.y,
vp0p.x * tri.v_02.y - vp0p.y * tri.v_02.x,
vp0p.y * tri.v_01.x - vp0p.x * tri.v_01.y,
});
bary *= sign(tri.denominator);
const bool on_edge_or_inside = (bary.x >= 0.f) && (bary.y >= 0.f) && (bary.z >= 0.f);
bool on_edge_0 = bary.x == 0.f;
bool on_edge_1 = bary.y == 0.f;
bool on_edge_2 = bary.z == 0.f;
const bool is_top_left_0 = (tri.denominator > 0)
? (tri.v_12.y < 0.f || tri.v_12.y == 0.0f && tri.v_12.x > 0.f)
: (tri.v_12.y > 0.f || tri.v_12.y == 0.0f && tri.v_12.x < 0.f);
const bool is_top_left_1 = (tri.denominator > 0)
? (tri.v_02.y > 0.f || tri.v_02.y == 0.0f && tri.v_02.x < 0.f)
: (tri.v_02.y < 0.f || tri.v_02.y == 0.0f && tri.v_02.x > 0.f);
const bool is_top_left_2 = (tri.denominator > 0)
? (tri.v_01.y < 0.f || tri.v_01.y == 0.0f && tri.v_01.x > 0.f)
: (tri.v_01.y > 0.f || tri.v_01.y == 0.0f && tri.v_01.x < 0.f);
const bool is_top_left_or_inside = on_edge_or_inside &&
!(on_edge_0 && !is_top_left_0 || on_edge_1 && !is_top_left_1 ||
on_edge_2 && !is_top_left_2);
return is_top_left_or_inside;
}
return false;
}
template <typename scalar_t, typename index_t>
__device__ TriInfo<scalar_t>
get_tri_info(const scalar_t* v_ptr, index_t v_sV, index_t v_sC, int3 vi) {
typedef typename math::TVec2<scalar_t> scalar2_t;
const scalar2_t p_0 = {v_ptr[v_sV * vi.x + v_sC * 0], v_ptr[v_sV * vi.x + v_sC * 1]};
const scalar2_t p_1 = {v_ptr[v_sV * vi.y + v_sC * 0], v_ptr[v_sV * vi.y + v_sC * 1]};
const scalar2_t p_2 = {v_ptr[v_sV * vi.z + v_sC * 0], v_ptr[v_sV * vi.z + v_sC * 1]};
const scalar2_t v_01 = p_1 - p_0;
const scalar2_t v_02 = p_2 - p_0;
const scalar2_t v_12 = p_2 - p_1;
const scalar_t denominator = v_01.x * v_02.y - v_01.y * v_02.x;
return {p_0, p_1, v_01, v_02, v_12, denominator};
}
template <typename scalar_t, typename index_t>
__device__ math::TVec3<scalar_t>
get_tri_normal(const scalar_t* v_ptr, index_t v_sV, index_t v_sC, int3 vi) {
typedef typename math::TVec3<scalar_t> scalar3_t;
const scalar3_t p_0 = {
v_ptr[v_sV * vi.x + v_sC * 0], v_ptr[v_sV * vi.x + v_sC * 1], v_ptr[v_sV * vi.x + v_sC * 2]};
const scalar3_t p_1 = {
v_ptr[v_sV * vi.y + v_sC * 0], v_ptr[v_sV * vi.y + v_sC * 1], v_ptr[v_sV * vi.y + v_sC * 2]};
const scalar3_t p_2 = {
v_ptr[v_sV * vi.z + v_sC * 0], v_ptr[v_sV * vi.z + v_sC * 1], v_ptr[v_sV * vi.z + v_sC * 2]};
return normalize(cross(p_0 - p_2, p_1 - p_0));
}
template <typename scalar_t>
__device__ math::TVec2<scalar_t> get_db_dp(
const math::TVec2<scalar_t>& n_varying_,
const math::TVec2<scalar_t>& n_fixed_) {
/*
Computes derivative of the point position with respect to edge displacement
Args:
- n_varying_: Projection of the normal of the movable triangle onto the plane of
consideration (XZ or YZ) N x 3 x H x W.
- n_fixed_: Projection of the normal of the fixed triangle onto the plane of consideration
(XZ or YZ) N x 3 x H x W.
Please refer to the paper "Rasterized Edge Gradients: Handling Discontinuities Differentiably"
for details.
*/
typedef typename math::TVec2<scalar_t> scalar2_t;
const auto n_varying = normalize(n_varying_);
const auto n_fixed = normalize(n_fixed_);
const scalar2_t b = {-n_fixed.y, n_fixed.x};
const auto b_dot_varyingg = dot(b, n_varying);
return b.x / epsclamp(b_dot_varyingg) * n_varying;
}
template <typename scalar_t, typename index_t>
__device__ math::TVec3<scalar_t> load_vec3_if_valid(
const scalar_t* __restrict ptr,
index_t stride,
bool valid,
const math::TVec3<scalar_t>& def) {
if (valid) {
return {ptr[0 * stride], ptr[1 * stride], ptr[2 * stride]};
}
return def;
}
template <typename scalar_t, typename index_t>
C10_LAUNCH_BOUNDS_1(256)
__global__ void edge_grad_backward_kernel(
const index_t nthreads,
TensorInfo<scalar_t, index_t> v_pix,
TensorInfo<scalar_t, index_t> img,
TensorInfo<int32_t, index_t> index_img,
TensorInfo<int32_t, index_t> vi,
TensorInfo<scalar_t, index_t> grad_output,
TensorInfo<scalar_t, index_t> grad_v_pix_img,
const index_t memory_span) {
typedef typename math::TVec2<scalar_t> scalar2_t;
typedef typename math::TVec3<scalar_t> scalar3_t;
const index_t v_pix_sN = v_pix.strides[0];
const index_t v_pix_sV = v_pix.strides[1];
const index_t v_pix_sC = v_pix.strides[2];
const index_t C = img.sizes[1];
const index_t H = img.sizes[2];
const index_t W = img.sizes[3];
const index_t V = v_pix.sizes[1];
const index_t index_img_sN = index_img.strides[0];
const index_t index_img_sH = index_img.strides[1];
const index_t index_img_sW = index_img.strides[2];
const index_t img_sN = img.strides[0];
const index_t img_sC = img.strides[1];
const index_t img_sH = img.strides[2];
const index_t img_sW = img.strides[3];
const index_t grad_output_sN = grad_output.strides[0];
const index_t grad_output_sC = grad_output.strides[1];
const index_t grad_output_sH = grad_output.strides[2];
const index_t grad_output_sW = grad_output.strides[3];
const index_t grad_v_pix_img_sN = grad_v_pix_img.strides[0];
const index_t grad_v_pix_img_sC = grad_v_pix_img.strides[1];
const index_t grad_v_pix_img_sH = grad_v_pix_img.strides[2];
const index_t grad_v_pix_img_sW = grad_v_pix_img.strides[3];
const index_t vi_sV = vi.strides[0];
const index_t vi_sF = vi.strides[1];
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
const index_t x = index % W;
const index_t y = (index / W) % H;
const index_t n = index / (H * W);
if (x < (W - 1) && y < (H - 1)) {
// center-right-down (CRD)
//
// *--------*--------*
// | center | right |
// | (0, 0) | (1, 0) |
// *--------*--------*
// | down |
// | (0, 1) |
// *--------*
// Computing indicator variables
// BEGIN
// triangle indices of CRD pixels
const int32_t* __restrict index_img_ptr = index_img.data + n * index_img_sN;
const int32_t center_index = index_img_ptr[(y + 0) * index_img_sH + (x + 0) * index_img_sW];
const int32_t right_index = index_img_ptr[(y + 0) * index_img_sH + (x + 1) * index_img_sW];
const int32_t down_index = index_img_ptr[(y + 1) * index_img_sH + (x + 0) * index_img_sW];
// valid mask
const bool c_valid = (center_index >= 0);
const bool r_valid = (right_index >= 0);
const bool d_valid = (down_index >= 0);
// vertex indices of triangles of CRD pixels
// 0,0,0 - if not valid
const int3 vi_pt_center = load_vec3_if_valid<int32_t, index_t>(
vi.data + center_index * vi_sV, vi_sF, c_valid, {0, 0, 0});
const int3 vi_pt_right = load_vec3_if_valid<int32_t, index_t>(
vi.data + right_index * vi_sV, vi_sF, r_valid, {0, 0, 0});
const int3 vi_pt_down = load_vec3_if_valid<int32_t, index_t>(
vi.data + down_index * vi_sV, vi_sF, d_valid, {0, 0, 0});
// center <-> right differ
const bool lr_diff = (center_index != right_index);
// center <-> down differ
const bool ud_diff = (center_index != down_index);
// if horizontal pair (vertical edge) composed of two triangles
const bool x_both_valid = c_valid && r_valid;
// if vertical pair (horizontal edge) composed of two triangles
const bool y_both_valid = c_valid && d_valid;
// Get CRD triangle info
const scalar_t* __restrict v_pix_ptr = v_pix.data + n * v_pix_sN;
const auto tri_center = get_tri_info(v_pix_ptr, v_pix_sV, v_pix_sC, vi_pt_center);
const auto tri_right = get_tri_info(v_pix_ptr, v_pix_sV, v_pix_sC, vi_pt_right);
const auto tri_down = get_tri_info(v_pix_ptr, v_pix_sV, v_pix_sC, vi_pt_down);
// Compute indicators of edge type
const bool center_pix_in_right_tri = lr_diff && x_both_valid && pix_in_tri(tri_right, x, y);
const bool right_pix_in_center_tri =
lr_diff && x_both_valid && pix_in_tri(tri_center, x + 1, y);
const bool center_pix_in_down_tri = ud_diff && y_both_valid && pix_in_tri(tri_down, x, y);
const bool down_pix_in_center_tri =
ud_diff && y_both_valid && pix_in_tri(tri_center, x, y + 1);
// Overlap flags
const bool l_over_r = center_pix_in_right_tri && (!right_pix_in_center_tri);
const bool r_over_l = right_pix_in_center_tri && (!center_pix_in_right_tri);
const bool u_over_d = center_pix_in_down_tri && (!down_pix_in_center_tri);
const bool d_over_u = down_pix_in_center_tri && (!center_pix_in_down_tri);
// Intersection flags
const bool horiz_int = center_pix_in_right_tri && right_pix_in_center_tri;
const bool vert_int = center_pix_in_down_tri && down_pix_in_center_tri;
// Intersection flags
const bool horiz_adjacent =
lr_diff && x_both_valid && (!center_pix_in_right_tri && !right_pix_in_center_tri);
const bool vert_adjacent =
ud_diff && y_both_valid && (!center_pix_in_down_tri && !down_pix_in_center_tri);
// END
// Compute image gradient dot output gradient from backward
// This is computed regardless of the edge type as long as there is an edge (lr_diff or
// ud_diff) BEGIN
const scalar_t* __restrict img_ptr = img.data + img_sN * n;
const scalar_t* __restrict grad_output_ptr = grad_output.data + grad_output_sN * n;
scalar_t grad_dot_x = 0.f;
scalar_t grad_dot_y = 0.f;
if (lr_diff) {
const scalar_t* __restrict img_ptr_right = img_ptr + y * img_sH + (x + 1) * img_sW;
const scalar_t* __restrict img_ptr_center = img_ptr + y * img_sH + (x + 0) * img_sW;
const scalar_t* __restrict grad_output_ptr_right =
grad_output_ptr + y * grad_output_sH + (x + 1) * grad_output_sW;
const scalar_t* __restrict grad_output_ptr_center =
grad_output_ptr + y * grad_output_sH + (x + 0) * grad_output_sW;
for (size_t c = 0; c < C; ++c) {
grad_dot_x += (img_ptr_right[c * img_sC] - img_ptr_center[c * img_sC]) *
(0.5f *
(grad_output_ptr_right[c * grad_output_sC] +
grad_output_ptr_center[c * grad_output_sC]));
}
}
if (ud_diff) {
const scalar_t* __restrict img_ptr_down = img_ptr + (y + 1) * img_sH + x * img_sW;
const scalar_t* __restrict img_ptr_center = img_ptr + (y + 0) * img_sH + x * img_sW;
const scalar_t* __restrict grad_output_ptr_down =
grad_output_ptr + (y + 1) * grad_output_sH + x * grad_output_sW;
const scalar_t* __restrict grad_output_ptr_center =
grad_output_ptr + (y + 0) * grad_output_sH + x * grad_output_sW;
for (size_t c = 0; c < C; ++c) {
grad_dot_y += (img_ptr_down[c * img_sC] - img_ptr_center[c * img_sC]) *
(0.5f *
(grad_output_ptr_down[c * grad_output_sC] +
grad_output_ptr_center[c * grad_output_sC]));
}
}
// END
scalar3_t grad_v_pix_center = {0.f, 0.f, 0.f};
scalar3_t grad_v_pix_right = {0.f, 0.f, 0.f};
scalar3_t grad_v_pix_down = {0.f, 0.f, 0.f};
const scalar3_t center_normal = get_tri_normal(v_pix_ptr, v_pix_sV, v_pix_sC, vi_pt_center);
const scalar3_t right_normal = get_tri_normal(v_pix_ptr, v_pix_sV, v_pix_sC, vi_pt_right);
const scalar3_t down_normal = get_tri_normal(v_pix_ptr, v_pix_sV, v_pix_sC, vi_pt_down);
if (!horiz_int) {
grad_v_pix_center.x += (!c_valid || r_over_l || horiz_adjacent) ? 0.f : grad_dot_x;
grad_v_pix_right.x += (!r_valid || l_over_r || horiz_adjacent) ? 0.f : grad_dot_x;
} else {
// Center triangle moves, right fixed.
scalar2_t dbx_dp = get_db_dp<scalar_t>(
{center_normal.x, center_normal.z}, {right_normal.x, right_normal.z});
grad_v_pix_center.x += grad_dot_x * dbx_dp.x;
grad_v_pix_center.z += grad_dot_x * dbx_dp.y;
// Center triangle fixed, right moves.
dbx_dp = get_db_dp<scalar_t>(
{right_normal.x, right_normal.z}, {center_normal.x, center_normal.z});
grad_v_pix_right.x += grad_dot_x * dbx_dp.x;
grad_v_pix_right.z += grad_dot_x * dbx_dp.y;
}
if (!vert_int) {
grad_v_pix_center.y += (!c_valid || d_over_u || vert_adjacent) ? 0.f : grad_dot_y;
grad_v_pix_down.y += (!d_valid || u_over_d || vert_adjacent) ? 0.f : grad_dot_y;
} else {
// Center triangle moves, lower fixed.
scalar2_t dby_dp =
get_db_dp<scalar_t>({center_normal.y, center_normal.z}, {down_normal.y, down_normal.z});
grad_v_pix_center.y += grad_dot_y * dby_dp.x;
grad_v_pix_center.z += grad_dot_y * dby_dp.x;
// Center triangle fixed, lower moves.
dby_dp =
get_db_dp<scalar_t>({down_normal.y, down_normal.z}, {center_normal.y, center_normal.z});
grad_v_pix_down.y += grad_dot_y * dby_dp.x;
grad_v_pix_down.z += grad_dot_y * dby_dp.x;
}
// Writing grads out
// BEGIN
scalar_t* __restrict grad_v_pix_img_ptr = grad_v_pix_img.data + grad_v_pix_img_sN * n;
// center
auto* ptr_c = grad_v_pix_img_ptr + (y + 0) * grad_v_pix_img_sH + (x + 0) * grad_v_pix_img_sW;
atomicAdd(ptr_c + 0 * grad_v_pix_img_sC, -grad_v_pix_center.x);
atomicAdd(ptr_c + 1 * grad_v_pix_img_sC, -grad_v_pix_center.y);
atomicAdd(ptr_c + 2 * grad_v_pix_img_sC, -grad_v_pix_center.z);
// right
auto* ptr_r = grad_v_pix_img_ptr + (y + 0) * grad_v_pix_img_sH + (x + 1) * grad_v_pix_img_sW;
atomicAdd(ptr_r + 0 * grad_v_pix_img_sC, -grad_v_pix_right.x);
atomicAdd(ptr_r + 1 * grad_v_pix_img_sC, -grad_v_pix_right.y);
atomicAdd(ptr_r + 2 * grad_v_pix_img_sC, -grad_v_pix_right.z);
// down
auto* ptr_d = grad_v_pix_img_ptr + (y + 1) * grad_v_pix_img_sH + (x + 0) * grad_v_pix_img_sW;
atomicAdd(ptr_d + 0 * grad_v_pix_img_sC, -grad_v_pix_down.x);
atomicAdd(ptr_d + 1 * grad_v_pix_img_sC, -grad_v_pix_down.y);
atomicAdd(ptr_d + 2 * grad_v_pix_img_sC, -grad_v_pix_down.z);
// END
}
}
}
template <typename scalar_t, typename index_type>
void edge_grad_estimator_cuda_backward_(
const int64_t count,
const torch::Tensor& v_pix,
const torch::Tensor& img,
const torch::Tensor& index_img,
const torch::Tensor& vi,
const torch::Tensor& grad_outputs,
const torch::Tensor& grad_v_pix_img) {
edge_grad_backward_kernel<scalar_t, index_type>
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
static_cast<index_type>(count),
getTensorInfo<scalar_t, index_type>(v_pix),
getTensorInfo<scalar_t, index_type>(img),
getTensorInfo<int32_t, index_type>(index_img),
getTensorInfo<int32_t, index_type>(vi),
getTensorInfo<scalar_t, index_type>(grad_outputs),
getTensorInfo<scalar_t, index_type>(grad_v_pix_img),
grad_v_pix_img.numel());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
torch::Tensor edge_grad_estimator_cuda_backward(
const torch::Tensor& v_pix,
const torch::Tensor& img,
const torch::Tensor& index_img,
const torch::Tensor& vi,
const torch::Tensor& grad_outputs) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(img));
const auto N = img.sizes()[0];
const auto C = img.sizes()[1];
const auto H = img.sizes()[2];
const auto W = img.sizes()[3];
const auto V = v_pix.sizes()[1];
const auto count = N * H * W;
auto grad_v_pix_img = torch::zeros({N, 3, H, W}, v_pix.options());
if (count > 0) {
AT_DISPATCH_FLOATING_TYPES(v_pix.scalar_type(), "edge_grad_estimator_kernel", [&] {
if (at::native::canUse32BitIndexMath(v_pix) && at::native::canUse32BitIndexMath(img) &&
at::native::canUse32BitIndexMath(index_img) && at::native::canUse32BitIndexMath(vi) &&
at::native::canUse32BitIndexMath(grad_outputs) &&
at::native::canUse32BitIndexMath(grad_v_pix_img)) {
edge_grad_estimator_cuda_backward_<scalar_t, int>(
count, v_pix, img, index_img, vi, grad_outputs, grad_v_pix_img);
} else {
edge_grad_estimator_cuda_backward_<scalar_t, int64_t>(
count, v_pix, img, index_img, vi, grad_outputs, grad_v_pix_img);
}
});
}
return grad_v_pix_img;
}
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
torch::Tensor edge_grad_estimator_cuda_backward(
const torch::Tensor& v_pix,
const torch::Tensor& img,
const torch::Tensor& index_img,
const torch::Tensor& vi,
const torch::Tensor& grad_outputs);
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#include <torch/script.h>
#include <ATen/autocast_mode.h>
#ifndef NO_PYBIND
#include <torch/extension.h>
#endif
#include "edge_grad_kernel.h"
// Dispatch function
torch::Tensor edge_grad_estimator(
const torch::Tensor& v_pix,
const torch::Tensor& v_pix_img,
const torch::Tensor& vi,
const torch::Tensor& img,
const torch::Tensor& index_img) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("edge_grad_ext::edge_grad_estimator", "")
.typed<decltype(edge_grad_estimator)>();
return op.call(v_pix, v_pix_img, vi, img, index_img);
}
torch::Tensor edge_grad_estimator_fwd(
const torch::Tensor& v_pix,
const torch::Tensor& v_pix_img,
const torch::Tensor& vi,
const torch::Tensor& img,
const torch::Tensor& index_img) {
TORCH_CHECK(
v_pix.defined() && v_pix_img.defined() && vi.defined() && img.defined() &&
index_img.defined(),
"edge_grad_estimator(): expected all inputs to be defined");
TORCH_CHECK(
(v_pix.device() == v_pix_img.device()) && (v_pix.device() == vi.device()) &&
(v_pix.device() == img.device()) && (v_pix.device() == index_img.device()) &&
(v_pix.is_cuda()),
"edge_grad_estimator(): expected all inputs to be on same cuda device");
TORCH_CHECK(
v_pix.is_floating_point() && v_pix_img.is_floating_point() && img.is_floating_point(),
"edge_grad_estimator(): expected v_pix, v_pix_img, and img to have floating point type, but v_pix has ",
v_pix.dtype(),
" v_pix has ",
v_pix_img.dtype(),
" img has ",
img.dtype());
TORCH_CHECK(
vi.dtype() == torch::kInt32,
"edge_grad_estimator(): expected vi to have int32 type, but vi has ",
vi.dtype());
TORCH_CHECK(
index_img.dtype() == torch::kInt32,
"edge_grad_estimator(): expected index_img to have int32 type, but index_img has ",
index_img.dtype());
TORCH_CHECK(
v_pix.layout() == torch::kStrided && v_pix_img.layout() == torch::kStrided &&
vi.layout() == torch::kStrided && img.layout() == torch::kStrided &&
index_img.layout() == torch::kStrided,
"edge_grad_estimator(): expected all inputs to have torch.strided layout");
TORCH_CHECK(
(v_pix.dim() == 3) && (v_pix_img.dim() == 4) && (vi.dim() == 2) && (img.dim() == 4) &&
(index_img.dim() == 3),
"edge_grad_estimator(): expected v_pix.ndim == 3, v_pix_img.ndim == 4, vi.ndim == 2, img.ndim == 4, index_img.ndim == 3, "
"but got v_pix with sizes ",
v_pix.sizes(),
" and v_pix_img with sizes ",
v_pix_img.sizes(),
" and vi with sizes ",
vi.sizes(),
" and img with sizes ",
img.sizes(),
" and index_img with sizes ",
index_img.sizes());
TORCH_CHECK(
v_pix.size(0) == v_pix_img.size(0) && v_pix.size(0) == img.size(0) &&
v_pix.size(0) == index_img.size(0),
"edge_grad_estimator(): expected v and index_img to have same batch size, "
"but got v_pix with sizes ",
v_pix.sizes(),
", v_pix_img with sizes ",
v_pix_img.sizes(),
", img with sizes ",
img.sizes(),
" and index_img with sizes ",
index_img.sizes());
TORCH_CHECK(
v_pix.size(2) == 3 && v_pix_img.size(1) == 3 && vi.size(1) == 3,
"edge_grad_estimator(): expected third dim of v_pix to be of size 3, and second dim of vi to be of size 3, but got ",
v_pix.size(2),
" in the third dim of v_pix, and ",
v_pix_img.size(1),
" in the second dim of v_pix_img, and ",
vi.size(1),
" in the second dim of vi");
TORCH_CHECK(
v_pix_img.size(3) == img.size(3) && v_pix_img.size(3) == index_img.size(2) &&
v_pix_img.size(2) == img.size(2) && v_pix_img.size(2) == index_img.size(1),
"edge_grad_estimator(): expected width and height of v_pix_img, img, and index_img to match, but got size of v_pix_img: ",
v_pix_img.sizes(),
", size of img: ",
img.sizes(),
", size of index_img: ",
index_img.sizes());
return img;
}
// Ideally we would need to turn off autograd handling and re-dispatch, but we just call
// cuda kernels directly
class EdgeGradEstimatorFunction : public torch::autograd::Function<EdgeGradEstimatorFunction> {
public:
static torch::autograd::tensor_list forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& v_pix,
const torch::Tensor& v_pix_img,
const torch::Tensor& vi,
const torch::Tensor& img,
const torch::Tensor& index_img) {
ctx->set_materialize_grads(false);
ctx->save_for_backward({v_pix, img, index_img, vi});
ctx->saved_data["v_pix_img_requires_grad"] = v_pix_img.requires_grad();
return {img};
}
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
// If v_pix_img doesn't require grad, we don't need to do anything.
if (!ctx->saved_data["v_pix_img_requires_grad"].toBool()) {
return {torch::Tensor(), torch::Tensor(), torch::Tensor(), grad_outputs[0], torch::Tensor()};
}
const auto saved = ctx->get_saved_variables();
const auto& v_pix = saved[0];
const auto& img = saved[1];
const auto& index_img = saved[2];
const auto& vi = saved[3];
auto grad_v_pix_img =
edge_grad_estimator_cuda_backward(v_pix, img, index_img, vi, grad_outputs[0]);
return {torch::Tensor(), grad_v_pix_img, torch::Tensor(), grad_outputs[0], torch::Tensor()};
}
};
torch::Tensor edge_grad_estimator_autograd(
const torch::Tensor& v_pix,
const torch::Tensor& v_pix_img,
const torch::Tensor& vi,
const torch::Tensor& img,
const torch::Tensor& index_img) {
return EdgeGradEstimatorFunction::apply(v_pix, v_pix_img, vi, img, index_img)[0];
}
torch::Tensor edge_grad_estimator_autocast(
const torch::Tensor& v_pix,
const torch::Tensor& v_pix_img,
const torch::Tensor& vi,
const torch::Tensor& img,
const torch::Tensor& index_img) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return edge_grad_estimator(
at::autocast::cached_cast(torch::kFloat32, v_pix),
at::autocast::cached_cast(torch::kFloat32, v_pix_img),
vi,
at::autocast::cached_cast(torch::kFloat32, img),
index_img)[0];
}
#ifndef NO_PYBIND
// Just so that we can import this file as a Python module to get the path and
// import the Torch ops.
PYBIND11_MODULE(edge_grad_ext, m) {}
#endif
TORCH_LIBRARY(edge_grad_ext, m) {
m.def(
"edge_grad_estimator(Tensor v_pix, Tensor v_pix_img, Tensor vi, Tensor img, Tensor index_img) -> Tensor");
}
TORCH_LIBRARY_IMPL(edge_grad_ext, Autograd, m) {
m.impl("edge_grad_estimator", &edge_grad_estimator_autograd);
}
TORCH_LIBRARY_IMPL(edge_grad_ext, Autocast, m) {
m.impl("edge_grad_estimator", edge_grad_estimator_autocast);
}
TORCH_LIBRARY_IMPL(edge_grad_ext, CUDA, m) {
m.impl("edge_grad_estimator", &edge_grad_estimator_fwd);
}
This diff is collapsed.
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <ATen/native/cuda/GridSampler.cuh>
#include <ATen/native/cuda/UpSample.cuh>
enum class GridSamplerInterpolation { Bilinear, Nearest, Bicubic };
using at::native::clip_coordinates;
using at::native::cubic_interp1d;
using at::native::fastAtomicAdd;
using at::native::get_cubic_upsampling_coefficients;
using at::native::grid_sampler_compute_source_index;
using at::native::grid_sampler_compute_source_index_set_grad;
using at::native::grid_sampler_unnormalize;
using at::native::grid_sampler_unnormalize_set_grad;
using at::native::reflect_coordinates;
using at::native::safe_downgrade_to_int_range;
using at::native::within_bounds_2d;
using at::native::detail::GridSamplerPadding;
template <typename scalar_t>
static __forceinline__ __device__ scalar_t
area_pixel_compute_source_index(scalar_t scale, int64_t dst_index, bool align_corners) {
if (align_corners) {
return scale * dst_index;
} else {
scalar_t src_idx = scale * (dst_index + 0.5) - 0.5;
return (src_idx < 0) ? scalar_t(0) : src_idx;
}
}
template <typename scalar_t>
static __forceinline__ __device__ scalar_t
area_pixel_compute_scale(int64_t input_size, int64_t output_size, bool align_corners) {
// see Note [area_pixel_compute_scale]
if (align_corners) {
if (output_size > 1) {
return static_cast<scalar_t>(input_size - 1) / (output_size - 1);
} else {
return static_cast<scalar_t>(0);
}
} else {
return static_cast<scalar_t>(input_size) / output_size;
}
}
template <typename scalar_t, typename index_t>
static __forceinline__ __device__ void safe_add_2d(
scalar_t* data,
int h,
int w,
int sH,
int sW,
int H,
int W,
scalar_t delta,
const index_t NC_offset,
const index_t memory_span) {
if (within_bounds_2d(h, w, H, W)) {
fastAtomicAdd(data, NC_offset + h * sH + w * sW, memory_span, delta, true);
}
}
template <typename scalar_t, typename index_t>
static __forceinline__ __device__ void add_2d(
scalar_t* data,
int h,
int w,
int sH,
int sW,
scalar_t delta,
const index_t NC_offset,
const index_t memory_span) {
fastAtomicAdd(data, NC_offset + h * sH + w * sW, memory_span, delta, true);
}
template <typename scalar_t>
static __forceinline__ __device__ scalar_t
compute_coordinates(scalar_t coord, int size, GridSamplerPadding padding_mode, bool align_corners) {
if (padding_mode == GridSamplerPadding::Border) {
// clip coordinates to image borders
coord = clip_coordinates(coord, size);
} else if (padding_mode == GridSamplerPadding::Reflection) {
// reflect coordinates by image borders
if (align_corners) {
coord = reflect_coordinates(coord, 0, 2 * (size - 1));
} else {
coord = reflect_coordinates(coord, -1, 2 * size - 1);
}
// clip coordinates to image borders
coord = clip_coordinates(coord, size);
}
coord = safe_downgrade_to_int_range(coord);
return coord;
}
template <typename scalar_t>
static __forceinline__ __device__ scalar_t get_value_bounded(
scalar_t* data,
scalar_t x,
scalar_t y,
int W,
int H,
int sW,
int sH,
GridSamplerPadding padding_mode,
bool align_corners) {
x = compute_coordinates(x, W, padding_mode, align_corners);
y = compute_coordinates(y, H, padding_mode, align_corners);
int ix = static_cast<int>(x);
int iy = static_cast<int>(y);
if (within_bounds_2d(iy, ix, H, W)) {
return data[iy * sH + ix * sW];
}
return static_cast<scalar_t>(0);
}
// Calculate the differential of the cubic convolution, i.e. `d coeff / d x`
template <typename scalar_t>
static __forceinline__ __device__ void get_cubic_coefficients_grad(scalar_t coeffs[4], scalar_t t) {
// Must be the same as forward calculation in
// aten/src/ATen/native/cuda/UpSample.cuh:get_cubic_upsample_coefficients
scalar_t A = -0.75;
scalar_t x;
x = -1 - t; // 1 < x = |-1 - tx| < 2
coeffs[0] = (-3 * A * x - 10 * A) * x - 8 * A;
x = -t; // x = |0 - tx| <= 1
coeffs[1] = (-3 * (A + 2) * x - 2 * (A + 3)) * x;
x = 1 - t; // x = |1 - tx| <= 1
coeffs[2] = (3 * (A + 2) * x - 2 * (A + 3)) * x;
x = 2 - t; // 1 < x = |2 - tx| < 2
coeffs[3] = (3 * A * x - 10 * A) * x + 8 * A;
}
template <typename scalar_t, typename index_t>
static __forceinline__ __device__ void add_value_bounded(
scalar_t* data,
scalar_t x,
scalar_t y,
int W,
int H,
int sW,
int sH,
scalar_t delta,
GridSamplerPadding padding_mode,
bool align_corners,
const index_t NC_offset,
const index_t memory_span) {
x = compute_coordinates(x, W, padding_mode, align_corners);
y = compute_coordinates(y, H, padding_mode, align_corners);
int ix = static_cast<int>(x);
int iy = static_cast<int>(y);
safe_add_2d(data, iy, ix, sH, sW, H, W, delta, NC_offset, memory_span);
}
template <typename scalar_t, int D>
__device__ __forceinline__ static math::TVec<scalar_t, D> cubic_interp1d(
math::TVec<scalar_t, D> x0,
math::TVec<scalar_t, D> x1,
math::TVec<scalar_t, D> x2,
math::TVec<scalar_t, D> x3,
scalar_t t) {
scalar_t coeffs[4];
get_cubic_upsampling_coefficients<scalar_t>(coeffs, t);
using namespace math;
return x0 * coeffs[0] + x1 * coeffs[1] + x2 * coeffs[2] + x3 * coeffs[3];
}
template <typename scalar_t, typename index_t>
inline __device__ typename math::TVec4<scalar_t> load4(scalar_t* ptr, index_t stride) {
return {ptr[0 * stride], ptr[1 * stride], ptr[2 * stride], ptr[3 * stride]};
}
template <typename scalar_t, typename index_t>
static __forceinline__ __device__ void safe_add_2d4(
scalar_t* data,
index_t stride,
int h,
int w,
int sH,
int sW,
int H,
int W,
math::TVec4<scalar_t> delta,
const index_t N_offset,
const index_t memory_span) {
if (within_bounds_2d(h, w, H, W)) {
auto ptr = N_offset + h * sH + w * sW;
fastAtomicAdd(data, ptr + 0 * stride, memory_span, delta.x, true);
fastAtomicAdd(data, ptr + 1 * stride, memory_span, delta.y, true);
fastAtomicAdd(data, ptr + 2 * stride, memory_span, delta.z, true);
fastAtomicAdd(data, ptr + 3 * stride, memory_span, delta.w, true);
}
}
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
using at::cuda::detail::getTensorInfo;
using at::cuda::detail::TensorInfo;
#define C10_CUDA_KERNEL_LAUNCH_CHECK() C10_CUDA_CHECK(cudaGetLastError())
// Use 1024 threads per block, which requires cuda sm_2x or above
constexpr int CUDA_NUM_THREADS = 1024;
// CUDA: number of blocks for threads.
inline int GET_BLOCKS(const int64_t N, const int64_t max_threads_per_block = CUDA_NUM_THREADS) {
TORCH_INTERNAL_ASSERT(N > 0, "CUDA kernel launch blocks must be positive, but got N=", N);
constexpr int64_t max_int = std::numeric_limits<int>::max();
// Round up division for positive number that cannot cause integer overflow
auto block_num = (N - 1) / max_threads_per_block + 1;
TORCH_INTERNAL_ASSERT(block_num <= max_int, "Can't schedule too many blocks on CUDA device");
return static_cast<int>(block_num);
}
// Dispatch macroses are updated in current pytorch.
// Which causes that the same code is compilable on DGX with the older pytorch
// but no longer compilable on prod
// Thus keeping these macroses here.
#undef AT_PRIVATE_CASE_TYPE
#undef AT_DISPATCH_FLOATING_TYPES
#undef AT_DISPATCH_FLOATING_TYPES_AND_HALF
#undef DISPATCH_FLOAT_AND_HALF
#define AT_PRIVATE_CASE_TYPE(enum_type, type, ...) \
case enum_type: { \
using scalar_t = type; \
return __VA_ARGS__(); \
}
// Dispatches for float and double
#define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
}()
// Dispatches for float, double, and half
#define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \
[&] { \
const auto& the_type = TYPE; \
/* don't use TYPE again in case it is an expensive or side-effect op */ \
at::ScalarType _st = ::detail::scalar_type(the_type); \
switch (_st) { \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Double, double, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Float, float, __VA_ARGS__) \
AT_PRIVATE_CASE_TYPE(at::ScalarType::Half, at::Half, __VA_ARGS__) \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(_st), "'"); \
} \
}()
// Dispatches for float, double, and half
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, __VA_ARGS__)
// A simple stub to match the dispathcing for multimple types structure, but only for
// for float.
#define DISPATCH_FLOAT(NAME, ...) \
[&] { \
using scalar_t = float; \
return __VA_ARGS__(); \
}()
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
#pragma once
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
using at::cuda::detail::getTensorInfo;
using at::cuda::detail::TensorInfo;
// TensorInfoCompact is similar to TensorInfo but has fixed number of dims same as
// PackedTensorAccessor. It is supposed to be used on for CUDA `Tensor`s on the host when default
// constructor, assignment and copy constructors are needed, e.g. using in arrays in order to
// transfer them on the device when calling kernels. TensorInfo has a default, assignment and copy
// constructors, but PackedTensorAccessor does not. However TensorInfo is too large to be
// transferred in arrays when calling kernels. On the device, indexing of multidimensional tensors
// produces `TensorAccessor`s. Using RestrictPtrTraits as a default. If aliasing is possible (likely
// to be a very rare case) please use DefaultPtrTraits. Default constructor, assignment and copy
// constructors are only needed on the host aren't available on the device
template <
typename T,
typename index_t,
int N_DIMS,
template <typename> class PtrTraits = at::RestrictPtrTraits>
struct TensorInfoCompact {
typedef typename PtrTraits<T>::PtrType PtrType;
TensorInfoCompact(){};
__host__ TensorInfoCompact<T, index_t, N_DIMS, PtrTraits>& operator=(
const TensorInfoCompact<T, index_t, N_DIMS>& other) {
data = other.data;
for (int i = 0; i < N_DIMS; ++i) {
sizes[i] = other.sizes[i];
strides[i] = other.strides[i];
}
return *this;
};
__host__ TensorInfoCompact(const TensorInfoCompact<T, index_t, N_DIMS, PtrTraits>& other)
: data(other.data) {
for (int i = 0; i < N_DIMS; ++i) {
sizes[i] = other.sizes[i];
strides[i] = other.strides[i];
}
};
__host__ TensorInfoCompact(const TensorInfo<T, index_t>& other) : data(other.data) {
for (int i = 0; i < N_DIMS; ++i) {
sizes[i] = other.sizes[i];
strides[i] = other.strides[i];
}
}
__device__ at::TensorAccessor<T, N_DIMS - 1, PtrTraits, index_t> operator[](index_t i) {
index_t* new_sizes = sizes + 1;
index_t* new_strides = strides + 1;
return at::TensorAccessor<T, N_DIMS - 1, PtrTraits, index_t>(
data + strides[0] * i, new_sizes, new_strides);
}
__device__ const at::TensorAccessor<T, N_DIMS - 1, PtrTraits, index_t> operator[](
index_t i) const {
const index_t* new_sizes = sizes + 1;
const index_t* new_strides = strides + 1;
return at::TensorAccessor<T, N_DIMS - 1, PtrTraits, index_t>(
data + strides[0] * i, new_sizes, new_strides);
}
PtrType data;
index_t sizes[N_DIMS];
index_t strides[N_DIMS];
};
template <
typename scalar_t,
typename index_t,
int N_DIMS,
template <typename> class PtrTraits = at::RestrictPtrTraits>
TensorInfoCompact<scalar_t, index_t, N_DIMS, PtrTraits> getTensorInfoCompact(const at::Tensor& x) {
auto out = getTensorInfo<scalar_t, index_t>(x);
assert(out.dims == N_DIMS);
return out;
}
template <
typename T,
typename index_t,
int N,
int N_DIMS,
template <typename> class PtrTraits = at::RestrictPtrTraits>
struct TensorInfoList {
__device__ __host__ TensorInfoCompact<T, index_t, N_DIMS, PtrTraits>& operator[](int i) {
return data[i];
}
__device__ __host__ const TensorInfoCompact<T, index_t, N_DIMS, PtrTraits>& operator[](
int i) const {
return data[i];
}
TensorInfoCompact<T, index_t, N_DIMS, PtrTraits> data[N];
};
template <typename IndexType, int N>
struct IndexList {
__device__ __host__ IndexType& operator[](int i) {
return data[i];
}
__device__ __host__ const IndexType& operator[](int i) const {
return data[i];
}
IndexType data[N] = {0};
};
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment