Commit 298ffd02 authored by Ruilong Li's avatar Ruilong Li
Browse files

init

parents
# Visual Studio Code configs.
.vscode/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
# lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.DS_Store
# Direnv config.
.envrc
# line_profiler
*.lprof
# vscode
.vsocde
import math
from typing import Callable, Tuple
import torch
from .cuda import VolumeRenderer, ray_aabb_intersect, ray_marching
def volumetric_rendering(
query_fn: Callable,
rays_o: torch.Tensor,
rays_d: torch.Tensor,
scene_aabb: torch.Tensor,
scene_occ_binary: torch.Tensor,
scene_resolution: Tuple[int, int, int],
render_bkgd: torch.Tensor = None,
render_n_samples: int = 1024,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""A *fast* version of differentiable volumetric rendering."""
device = rays_o.device
if render_bkgd is None:
render_bkgd = torch.ones(3, device=device)
scene_resolution = torch.tensor(scene_resolution, dtype=torch.int, device=device)
rays_o = rays_o.contiguous()
rays_d = rays_d.contiguous()
scene_aabb = scene_aabb.contiguous()
scene_occ_binary = scene_occ_binary.contiguous()
render_bkgd = render_bkgd.contiguous()
n_rays = rays_o.shape[0]
render_total_samples = n_rays * render_n_samples
render_step_size = (
(scene_aabb[3:] - scene_aabb[:3]).max() * math.sqrt(3) / render_n_samples
)
with torch.no_grad():
# TODO: avoid clamp here. kinda stupid
t_min, t_max = ray_aabb_intersect(rays_o, rays_d, scene_aabb)
t_min = torch.clamp(t_min, max=1e10)
t_max = torch.clamp(t_max, max=1e10)
(
packed_info,
frustum_origins,
frustum_dirs,
frustum_starts,
frustum_ends,
) = ray_marching(
# rays
rays_o,
rays_d,
t_min,
t_max,
# density grid
scene_aabb,
scene_resolution,
scene_occ_binary,
# sampling
render_total_samples,
render_n_samples,
render_step_size,
)
# squeeze valid samples
total_samples = max(packed_info[:, -1].sum(), 1)
frustum_origins = frustum_origins[:total_samples]
frustum_dirs = frustum_dirs[:total_samples]
frustum_starts = frustum_starts[:total_samples]
frustum_ends = frustum_ends[:total_samples]
frustum_positions = (
frustum_origins + frustum_dirs * (frustum_starts + frustum_ends) / 2.0
)
query_results = query_fn(frustum_positions, frustum_dirs, **kwargs)
rgbs, densities = query_results[0], query_results[1]
(
accumulated_weight,
accumulated_depth,
accumulated_color,
alive_ray_mask,
) = VolumeRenderer.apply(
packed_info,
frustum_starts,
frustum_ends,
densities.contiguous(),
rgbs.contiguous(),
)
accumulated_depth = torch.clip(accumulated_depth, t_min[:, None], t_max[:, None])
accumulated_color = accumulated_color + render_bkgd * (1.0 - accumulated_weight)
return accumulated_color, accumulated_depth, accumulated_weight, alive_ray_mask
import torch
from torch.cuda.amp import custom_bwd, custom_fwd
from ._backend import _C
ray_aabb_intersect = _C.ray_aabb_intersect
ray_marching = _C.ray_marching
volumetric_rendering_forward = _C.volumetric_rendering_forward
volumetric_rendering_backward = _C.volumetric_rendering_backward
class VolumeRenderer(torch.autograd.Function):
"""CUDA Volumetirc Renderer"""
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, packed_info, starts, ends, sigmas, rgbs):
(
accumulated_weight,
accumulated_depth,
accumulated_color,
mask,
) = volumetric_rendering_forward(packed_info, starts, ends, sigmas, rgbs)
ctx.save_for_backward(
accumulated_weight,
accumulated_depth,
accumulated_color,
packed_info,
starts,
ends,
sigmas,
rgbs,
)
return accumulated_weight, accumulated_depth, accumulated_color, mask
@staticmethod
@custom_bwd
def backward(ctx, grad_weight, grad_depth, grad_color, _grad_mask):
(
accumulated_weight,
accumulated_depth,
accumulated_color,
packed_info,
starts,
ends,
sigmas,
rgbs,
) = ctx.saved_tensors
grad_sigmas, grad_rgbs = volumetric_rendering_backward(
accumulated_weight,
accumulated_depth,
accumulated_color,
grad_weight,
grad_depth,
grad_color,
packed_info,
starts,
ends,
sigmas,
rgbs,
)
# corresponds to the input argument list of forward()
return None, None, None, grad_sigmas, grad_rgbs
"""Setup cuda backend."""
import glob
import os
from subprocess import DEVNULL, call
from torch.utils.cpp_extension import load
PATH = os.path.dirname(os.path.abspath(__file__))
def cuda_toolkit_available():
"""Check if the nvcc is avaiable on the machine."""
# https://github.com/idiap/fast-transformers/blob/master/setup.py
try:
call(["nvcc"], stdout=DEVNULL, stderr=DEVNULL)
return True
except FileNotFoundError:
return False
if cuda_toolkit_available():
sources = glob.glob(os.path.join(PATH, "csrc/*.cu"))
else:
sources = glob.glob(os.path.join(PATH, "csrc/*.cpp"))
extra_cflags = ["-O3"]
extra_cuda_cflags = ["-O3"]
_C = load(
name="nerfacc_cuda",
sources=sources,
extra_cflags=extra_cflags,
extra_cuda_cflags=extra_cuda_cflags,
)
__all__ = ["_C"]
#pragma once
#ifdef __CUDACC__
#define CUDA_HOSTDEV __host__ __device__
#else
#define CUDA_HOSTDEV
#endif
#include <torch/extension.h>
inline constexpr CUDA_HOSTDEV float __SQRT3() { return 1.73205080757f; }
template <typename scalar_t>
inline CUDA_HOSTDEV void __swap(scalar_t &a, scalar_t &b)
{
scalar_t c = a;
a = b;
b = c;
}
inline CUDA_HOSTDEV float __clamp(float f, float a, float b) { return fmaxf(a, fminf(f, b)); }
inline CUDA_HOSTDEV int __clamp(int f, int a, int b) { return std::max(a, std::min(f, b)); }
inline CUDA_HOSTDEV float __sign(float x) { return copysignf(1.0, x); }
inline CUDA_HOSTDEV uint32_t __expand_bits(uint32_t v)
{
v = (v * 0x00010001u) & 0xFF0000FFu;
v = (v * 0x00000101u) & 0x0F00F00Fu;
v = (v * 0x00000011u) & 0xC30C30C3u;
v = (v * 0x00000005u) & 0x49249249u;
return v;
}
inline CUDA_HOSTDEV uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z)
{
uint32_t xx = __expand_bits(x);
uint32_t yy = __expand_bits(y);
uint32_t zz = __expand_bits(z);
return xx | (yy << 1) | (zz << 2);
}
inline CUDA_HOSTDEV uint32_t __morton3D_invert(uint32_t x)
{
x = x & 0x49249249;
x = (x | (x >> 2)) & 0xc30c30c3;
x = (x | (x >> 4)) & 0x0f00f00f;
x = (x | (x >> 8)) & 0xff0000ff;
x = (x | (x >> 16)) & 0x0000ffff;
return x;
}
\ No newline at end of file
#pragma once
#include "helpers.h"
#include <c10/cuda/CUDAGuard.h>
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define CUDA_GET_THREAD_ID(tid, Q) \
const int tid = blockIdx.x * blockDim.x + threadIdx.x; \
if (tid >= Q) \
return
#define CUDA_N_BLOCKS_NEEDED(Q, CUDA_N_THREADS) ((Q - 1) / CUDA_N_THREADS + 1)
#define DEVICE_GUARD(_ten) \
const at::cuda::OptionalCUDAGuard device_guard(device_of(_ten));
\ No newline at end of file
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of NVIDIA CORPORATION nor the names of its
* contributors may be used to endorse or promote products derived
* from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
/*
* This file implements common mathematical operations on vector types
* (float3, float4 etc.) since these are not provided as standard by CUDA.
*
* The syntax is modeled on the Cg standard library.
*
* This is part of the Helper library includes
*
* Thanks to Linh Hah for additions and fixes.
*/
#ifndef HELPER_MATH_H
#define HELPER_MATH_H
#include "cuda_runtime.h"
typedef unsigned int uint;
typedef unsigned short ushort;
#ifndef EXIT_WAIVED
#define EXIT_WAIVED 2
#endif
#ifndef __CUDACC__
#include <math.h>
////////////////////////////////////////////////////////////////////////////////
// host implementations of CUDA functions
////////////////////////////////////////////////////////////////////////////////
inline float fminf(float a, float b)
{
return a < b ? a : b;
}
inline float fmaxf(float a, float b)
{
return a > b ? a : b;
}
inline int max(int a, int b)
{
return a > b ? a : b;
}
inline int min(int a, int b)
{
return a < b ? a : b;
}
inline float rsqrtf(float x)
{
return 1.0f / sqrtf(x);
}
#endif
////////////////////////////////////////////////////////////////////////////////
// constructors
////////////////////////////////////////////////////////////////////////////////
inline __host__ __device__ float2 make_float2(float s)
{
return make_float2(s, s);
}
inline __host__ __device__ float2 make_float2(float3 a)
{
return make_float2(a.x, a.y);
}
inline __host__ __device__ float2 make_float2(int2 a)
{
return make_float2(float(a.x), float(a.y));
}
inline __host__ __device__ float2 make_float2(uint2 a)
{
return make_float2(float(a.x), float(a.y));
}
inline __host__ __device__ int2 make_int2(int s)
{
return make_int2(s, s);
}
inline __host__ __device__ int2 make_int2(int3 a)
{
return make_int2(a.x, a.y);
}
inline __host__ __device__ int2 make_int2(uint2 a)
{
return make_int2(int(a.x), int(a.y));
}
inline __host__ __device__ int2 make_int2(float2 a)
{
return make_int2(int(a.x), int(a.y));
}
inline __host__ __device__ uint2 make_uint2(uint s)
{
return make_uint2(s, s);
}
inline __host__ __device__ uint2 make_uint2(uint3 a)
{
return make_uint2(a.x, a.y);
}
inline __host__ __device__ uint2 make_uint2(int2 a)
{
return make_uint2(uint(a.x), uint(a.y));
}
inline __host__ __device__ float3 make_float3(float s)
{
return make_float3(s, s, s);
}
inline __host__ __device__ float3 make_float3(float2 a)
{
return make_float3(a.x, a.y, 0.0f);
}
inline __host__ __device__ float3 make_float3(float2 a, float s)
{
return make_float3(a.x, a.y, s);
}
inline __host__ __device__ float3 make_float3(float4 a)
{
return make_float3(a.x, a.y, a.z);
}
inline __host__ __device__ float3 make_float3(int3 a)
{
return make_float3(float(a.x), float(a.y), float(a.z));
}
inline __host__ __device__ float3 make_float3(uint3 a)
{
return make_float3(float(a.x), float(a.y), float(a.z));
}
inline __host__ __device__ int3 make_int3(int s)
{
return make_int3(s, s, s);
}
inline __host__ __device__ int3 make_int3(int2 a)
{
return make_int3(a.x, a.y, 0);
}
inline __host__ __device__ int3 make_int3(int2 a, int s)
{
return make_int3(a.x, a.y, s);
}
inline __host__ __device__ int3 make_int3(uint3 a)
{
return make_int3(int(a.x), int(a.y), int(a.z));
}
inline __host__ __device__ int3 make_int3(float3 a)
{
return make_int3(int(a.x), int(a.y), int(a.z));
}
inline __host__ __device__ uint3 make_uint3(uint s)
{
return make_uint3(s, s, s);
}
inline __host__ __device__ uint3 make_uint3(uint2 a)
{
return make_uint3(a.x, a.y, 0);
}
inline __host__ __device__ uint3 make_uint3(uint2 a, uint s)
{
return make_uint3(a.x, a.y, s);
}
inline __host__ __device__ uint3 make_uint3(uint4 a)
{
return make_uint3(a.x, a.y, a.z);
}
inline __host__ __device__ uint3 make_uint3(int3 a)
{
return make_uint3(uint(a.x), uint(a.y), uint(a.z));
}
inline __host__ __device__ float4 make_float4(float s)
{
return make_float4(s, s, s, s);
}
inline __host__ __device__ float4 make_float4(float3 a)
{
return make_float4(a.x, a.y, a.z, 0.0f);
}
inline __host__ __device__ float4 make_float4(float3 a, float w)
{
return make_float4(a.x, a.y, a.z, w);
}
inline __host__ __device__ float4 make_float4(int4 a)
{
return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
}
inline __host__ __device__ float4 make_float4(uint4 a)
{
return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
}
inline __host__ __device__ int4 make_int4(int s)
{
return make_int4(s, s, s, s);
}
inline __host__ __device__ int4 make_int4(int3 a)
{
return make_int4(a.x, a.y, a.z, 0);
}
inline __host__ __device__ int4 make_int4(int3 a, int w)
{
return make_int4(a.x, a.y, a.z, w);
}
inline __host__ __device__ int4 make_int4(uint4 a)
{
return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
}
inline __host__ __device__ int4 make_int4(float4 a)
{
return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
}
inline __host__ __device__ uint4 make_uint4(uint s)
{
return make_uint4(s, s, s, s);
}
inline __host__ __device__ uint4 make_uint4(uint3 a)
{
return make_uint4(a.x, a.y, a.z, 0);
}
inline __host__ __device__ uint4 make_uint4(uint3 a, uint w)
{
return make_uint4(a.x, a.y, a.z, w);
}
inline __host__ __device__ uint4 make_uint4(int4 a)
{
return make_uint4(uint(a.x), uint(a.y), uint(a.z), uint(a.w));
}
////////////////////////////////////////////////////////////////////////////////
// negate
////////////////////////////////////////////////////////////////////////////////
inline __host__ __device__ float2 operator-(float2 &a)
{
return make_float2(-a.x, -a.y);
}
inline __host__ __device__ int2 operator-(int2 &a)
{
return make_int2(-a.x, -a.y);
}
inline __host__ __device__ float3 operator-(float3 &a)
{
return make_float3(-a.x, -a.y, -a.z);
}
inline __host__ __device__ int3 operator-(int3 &a)
{
return make_int3(-a.x, -a.y, -a.z);
}
inline __host__ __device__ float4 operator-(float4 &a)
{
return make_float4(-a.x, -a.y, -a.z, -a.w);
}
inline __host__ __device__ int4 operator-(int4 &a)
{
return make_int4(-a.x, -a.y, -a.z, -a.w);
}
////////////////////////////////////////////////////////////////////////////////
// addition
////////////////////////////////////////////////////////////////////////////////
inline __host__ __device__ float2 operator+(float2 a, float2 b)
{
return make_float2(a.x + b.x, a.y + b.y);
}
inline __host__ __device__ void operator+=(float2 &a, float2 b)
{
a.x += b.x;
a.y += b.y;
}
inline __host__ __device__ float2 operator+(float2 a, float b)
{
return make_float2(a.x + b, a.y + b);
}
inline __host__ __device__ float2 operator+(float b, float2 a)
{
return make_float2(a.x + b, a.y + b);
}
inline __host__ __device__ void operator+=(float2 &a, float b)
{
a.x += b;
a.y += b;
}
inline __host__ __device__ int2 operator+(int2 a, int2 b)
{
return make_int2(a.x + b.x, a.y + b.y);
}
inline __host__ __device__ void operator+=(int2 &a, int2 b)
{
a.x += b.x;
a.y += b.y;
}
inline __host__ __device__ int2 operator+(int2 a, int b)
{
return make_int2(a.x + b, a.y + b);
}
inline __host__ __device__ int2 operator+(int b, int2 a)
{
return make_int2(a.x + b, a.y + b);
}
inline __host__ __device__ void operator+=(int2 &a, int b)
{
a.x += b;
a.y += b;
}
inline __host__ __device__ uint2 operator+(uint2 a, uint2 b)
{
return make_uint2(a.x + b.x, a.y + b.y);
}
inline __host__ __device__ void operator+=(uint2 &a, uint2 b)
{
a.x += b.x;
a.y += b.y;
}
inline __host__ __device__ uint2 operator+(uint2 a, uint b)
{
return make_uint2(a.x + b, a.y + b);
}
inline __host__ __device__ uint2 operator+(uint b, uint2 a)
{
return make_uint2(a.x + b, a.y + b);
}
inline __host__ __device__ void operator+=(uint2 &a, uint b)
{
a.x += b;
a.y += b;
}
inline __host__ __device__ float3 operator+(float3 a, float3 b)
{
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
}
inline __host__ __device__ void operator+=(float3 &a, float3 b)
{
a.x += b.x;
a.y += b.y;
a.z += b.z;
}
inline __host__ __device__ float3 operator+(float3 a, float b)
{
return make_float3(a.x + b, a.y + b, a.z + b);
}
inline __host__ __device__ void operator+=(float3 &a, float b)
{
a.x += b;
a.y += b;
a.z += b;
}
inline __host__ __device__ int3 operator+(int3 a, int3 b)
{
return make_int3(a.x + b.x, a.y + b.y, a.z + b.z);
}
inline __host__ __device__ void operator+=(int3 &a, int3 b)
{
a.x += b.x;
a.y += b.y;
a.z += b.z;
}
inline __host__ __device__ int3 operator+(int3 a, int b)
{
return make_int3(a.x + b, a.y + b, a.z + b);
}
inline __host__ __device__ void operator+=(int3 &a, int b)
{
a.x += b;
a.y += b;
a.z += b;
}
inline __host__ __device__ uint3 operator+(uint3 a, uint3 b)
{
return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z);
}
inline __host__ __device__ void operator+=(uint3 &a, uint3 b)
{
a.x += b.x;
a.y += b.y;
a.z += b.z;
}
inline __host__ __device__ uint3 operator+(uint3 a, uint b)
{
return make_uint3(a.x + b, a.y + b, a.z + b);
}
inline __host__ __device__ void operator+=(uint3 &a, uint b)
{
a.x += b;
a.y += b;
a.z += b;
}
inline __host__ __device__ int3 operator+(int b, int3 a)
{
return make_int3(a.x + b, a.y + b, a.z + b);
}
inline __host__ __device__ uint3 operator+(uint b, uint3 a)
{
return make_uint3(a.x + b, a.y + b, a.z + b);
}
inline __host__ __device__ float3 operator+(float b, float3 a)
{
return make_float3(a.x + b, a.y + b, a.z + b);
}
inline __host__ __device__ float4 operator+(float4 a, float4 b)
{
return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
}
inline __host__ __device__ void operator+=(float4 &a, float4 b)
{
a.x += b.x;
a.y += b.y;
a.z += b.z;
a.w += b.w;
}
inline __host__ __device__ float4 operator+(float4 a, float b)
{
return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
}
inline __host__ __device__ float4 operator+(float b, float4 a)
{
return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
}
inline __host__ __device__ void operator+=(float4 &a, float b)
{
a.x += b;
a.y += b;
a.z += b;
a.w += b;
}
inline __host__ __device__ int4 operator+(int4 a, int4 b)
{
return make_int4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
}
inline __host__ __device__ void operator+=(int4 &a, int4 b)
{
a.x += b.x;
a.y += b.y;
a.z += b.z;
a.w += b.w;
}
inline __host__ __device__ int4 operator+(int4 a, int b)
{
return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
}
inline __host__ __device__ int4 operator+(int b, int4 a)
{
return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
}
inline __host__ __device__ void operator+=(int4 &a, int b)
{
a.x += b;
a.y += b;
a.z += b;
a.w += b;
}
inline __host__ __device__ uint4 operator+(uint4 a, uint4 b)
{
return make_uint4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
}
inline __host__ __device__ void operator+=(uint4 &a, uint4 b)
{
a.x += b.x;
a.y += b.y;
a.z += b.z;
a.w += b.w;
}
inline __host__ __device__ uint4 operator+(uint4 a, uint b)
{
return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
}
inline __host__ __device__ uint4 operator+(uint b, uint4 a)
{
return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
}
inline __host__ __device__ void operator+=(uint4 &a, uint b)
{
a.x += b;
a.y += b;
a.z += b;
a.w += b;
}
////////////////////////////////////////////////////////////////////////////////
// subtract
////////////////////////////////////////////////////////////////////////////////
inline __host__ __device__ float2 operator-(float2 a, float2 b)
{
return make_float2(a.x - b.x, a.y - b.y);
}
inline __host__ __device__ void operator-=(float2 &a, float2 b)
{
a.x -= b.x;
a.y -= b.y;
}
inline __host__ __device__ float2 operator-(float2 a, float b)
{
return make_float2(a.x - b, a.y - b);
}
inline __host__ __device__ float2 operator-(float b, float2 a)
{
return make_float2(b - a.x, b - a.y);
}
inline __host__ __device__ void operator-=(float2 &a, float b)
{
a.x -= b;
a.y -= b;
}
inline __host__ __device__ int2 operator-(int2 a, int2 b)
{
return make_int2(a.x - b.x, a.y - b.y);
}
inline __host__ __device__ void operator-=(int2 &a, int2 b)
{
a.x -= b.x;
a.y -= b.y;
}
inline __host__ __device__ int2 operator-(int2 a, int b)
{
return make_int2(a.x - b, a.y - b);
}
inline __host__ __device__ int2 operator-(int b, int2 a)
{
return make_int2(b - a.x, b - a.y);
}
inline __host__ __device__ void operator-=(int2 &a, int b)
{
a.x -= b;
a.y -= b;
}
inline __host__ __device__ uint2 operator-(uint2 a, uint2 b)
{
return make_uint2(a.x - b.x, a.y - b.y);
}
inline __host__ __device__ void operator-=(uint2 &a, uint2 b)
{
a.x -= b.x;
a.y -= b.y;
}
inline __host__ __device__ uint2 operator-(uint2 a, uint b)
{
return make_uint2(a.x - b, a.y - b);
}
inline __host__ __device__ uint2 operator-(uint b, uint2 a)
{
return make_uint2(b - a.x, b - a.y);
}
inline __host__ __device__ void operator-=(uint2 &a, uint b)
{
a.x -= b;
a.y -= b;
}
inline __host__ __device__ float3 operator-(float3 a, float3 b)
{
return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);
}
inline __host__ __device__ void operator-=(float3 &a, float3 b)
{
a.x -= b.x;
a.y -= b.y;
a.z -= b.z;
}
inline __host__ __device__ float3 operator-(float3 a, float b)
{
return make_float3(a.x - b, a.y - b, a.z - b);
}
inline __host__ __device__ float3 operator-(float b, float3 a)
{
return make_float3(b - a.x, b - a.y, b - a.z);
}
inline __host__ __device__ void operator-=(float3 &a, float b)
{
a.x -= b;
a.y -= b;
a.z -= b;
}
inline __host__ __device__ int3 operator-(int3 a, int3 b)
{
return make_int3(a.x - b.x, a.y - b.y, a.z - b.z);
}
inline __host__ __device__ void operator-=(int3 &a, int3 b)
{
a.x -= b.x;
a.y -= b.y;
a.z -= b.z;
}
inline __host__ __device__ int3 operator-(int3 a, int b)
{
return make_int3(a.x - b, a.y - b, a.z - b);
}
inline __host__ __device__ int3 operator-(int b, int3 a)
{
return make_int3(b - a.x, b - a.y, b - a.z);
}
inline __host__ __device__ void operator-=(int3 &a, int b)
{
a.x -= b;
a.y -= b;
a.z -= b;
}
inline __host__ __device__ uint3 operator-(uint3 a, uint3 b)
{
return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z);
}
inline __host__ __device__ void operator-=(uint3 &a, uint3 b)
{
a.x -= b.x;
a.y -= b.y;
a.z -= b.z;
}
inline __host__ __device__ uint3 operator-(uint3 a, uint b)
{
return make_uint3(a.x - b, a.y - b, a.z - b);
}
inline __host__ __device__ uint3 operator-(uint b, uint3 a)
{
return make_uint3(b - a.x, b - a.y, b - a.z);
}
inline __host__ __device__ void operator-=(uint3 &a, uint b)
{
a.x -= b;
a.y -= b;
a.z -= b;
}
inline __host__ __device__ float4 operator-(float4 a, float4 b)
{
return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
}
inline __host__ __device__ void operator-=(float4 &a, float4 b)
{
a.x -= b.x;
a.y -= b.y;
a.z -= b.z;
a.w -= b.w;
}
inline __host__ __device__ float4 operator-(float4 a, float b)
{
return make_float4(a.x - b, a.y - b, a.z - b, a.w - b);
}
inline __host__ __device__ void operator-=(float4 &a, float b)
{
a.x -= b;
a.y -= b;
a.z -= b;
a.w -= b;
}
inline __host__ __device__ int4 operator-(int4 a, int4 b)
{
return make_int4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
}
inline __host__ __device__ void operator-=(int4 &a, int4 b)
{
a.x -= b.x;
a.y -= b.y;
a.z -= b.z;
a.w -= b.w;
}
inline __host__ __device__ int4 operator-(int4 a, int b)
{
return make_int4(a.x - b, a.y - b, a.z - b, a.w - b);
}
inline __host__ __device__ int4 operator-(int b, int4 a)
{
return make_int4(b - a.x, b - a.y, b - a.z, b - a.w);
}
inline __host__ __device__ void operator-=(int4 &a, int b)
{
a.x -= b;
a.y -= b;
a.z -= b;
a.w -= b;
}
inline __host__ __device__ uint4 operator-(uint4 a, uint4 b)
{
return make_uint4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
}
inline __host__ __device__ void operator-=(uint4 &a, uint4 b)
{
a.x -= b.x;
a.y -= b.y;
a.z -= b.z;
a.w -= b.w;
}
inline __host__ __device__ uint4 operator-(uint4 a, uint b)
{
return make_uint4(a.x - b, a.y - b, a.z - b, a.w - b);
}
inline __host__ __device__ uint4 operator-(uint b, uint4 a)
{
return make_uint4(b - a.x, b - a.y, b - a.z, b - a.w);
}
inline __host__ __device__ void operator-=(uint4 &a, uint b)
{
a.x -= b;
a.y -= b;
a.z -= b;
a.w -= b;
}
////////////////////////////////////////////////////////////////////////////////
// multiply
////////////////////////////////////////////////////////////////////////////////
inline __host__ __device__ float2 operator*(float2 a, float2 b)
{
return make_float2(a.x * b.x, a.y * b.y);
}
inline __host__ __device__ void operator*=(float2 &a, float2 b)
{
a.x *= b.x;
a.y *= b.y;
}
inline __host__ __device__ float2 operator*(float2 a, float b)
{
return make_float2(a.x * b, a.y * b);
}
inline __host__ __device__ float2 operator*(float b, float2 a)
{
return make_float2(b * a.x, b * a.y);
}
inline __host__ __device__ void operator*=(float2 &a, float b)
{
a.x *= b;
a.y *= b;
}
inline __host__ __device__ int2 operator*(int2 a, int2 b)
{
return make_int2(a.x * b.x, a.y * b.y);
}
inline __host__ __device__ void operator*=(int2 &a, int2 b)
{
a.x *= b.x;
a.y *= b.y;
}
inline __host__ __device__ int2 operator*(int2 a, int b)
{
return make_int2(a.x * b, a.y * b);
}
inline __host__ __device__ int2 operator*(int b, int2 a)
{
return make_int2(b * a.x, b * a.y);
}
inline __host__ __device__ void operator*=(int2 &a, int b)
{
a.x *= b;
a.y *= b;
}
inline __host__ __device__ uint2 operator*(uint2 a, uint2 b)
{
return make_uint2(a.x * b.x, a.y * b.y);
}
inline __host__ __device__ void operator*=(uint2 &a, uint2 b)
{
a.x *= b.x;
a.y *= b.y;
}
inline __host__ __device__ uint2 operator*(uint2 a, uint b)
{
return make_uint2(a.x * b, a.y * b);
}
inline __host__ __device__ uint2 operator*(uint b, uint2 a)
{
return make_uint2(b * a.x, b * a.y);
}
inline __host__ __device__ void operator*=(uint2 &a, uint b)
{
a.x *= b;
a.y *= b;
}
inline __host__ __device__ float3 operator*(float3 a, float3 b)
{
return make_float3(a.x * b.x, a.y * b.y, a.z * b.z);
}
inline __host__ __device__ void operator*=(float3 &a, float3 b)
{
a.x *= b.x;
a.y *= b.y;
a.z *= b.z;
}
inline __host__ __device__ float3 operator*(float3 a, float b)
{
return make_float3(a.x * b, a.y * b, a.z * b);
}
inline __host__ __device__ float3 operator*(float b, float3 a)
{
return make_float3(b * a.x, b * a.y, b * a.z);
}
inline __host__ __device__ void operator*=(float3 &a, float b)
{
a.x *= b;
a.y *= b;
a.z *= b;
}
inline __host__ __device__ int3 operator*(int3 a, int3 b)
{
return make_int3(a.x * b.x, a.y * b.y, a.z * b.z);
}
inline __host__ __device__ void operator*=(int3 &a, int3 b)
{
a.x *= b.x;
a.y *= b.y;
a.z *= b.z;
}
inline __host__ __device__ int3 operator*(int3 a, int b)
{
return make_int3(a.x * b, a.y * b, a.z * b);
}
inline __host__ __device__ int3 operator*(int b, int3 a)
{
return make_int3(b * a.x, b * a.y, b * a.z);
}
inline __host__ __device__ void operator*=(int3 &a, int b)
{
a.x *= b;
a.y *= b;
a.z *= b;
}
inline __host__ __device__ uint3 operator*(uint3 a, uint3 b)
{
return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z);
}
inline __host__ __device__ void operator*=(uint3 &a, uint3 b)
{
a.x *= b.x;
a.y *= b.y;
a.z *= b.z;
}
inline __host__ __device__ uint3 operator*(uint3 a, uint b)
{
return make_uint3(a.x * b, a.y * b, a.z * b);
}
inline __host__ __device__ uint3 operator*(uint b, uint3 a)
{
return make_uint3(b * a.x, b * a.y, b * a.z);
}
inline __host__ __device__ void operator*=(uint3 &a, uint b)
{
a.x *= b;
a.y *= b;
a.z *= b;
}
inline __host__ __device__ float4 operator*(float4 a, float4 b)
{
return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
}
inline __host__ __device__ void operator*=(float4 &a, float4 b)
{
a.x *= b.x;
a.y *= b.y;
a.z *= b.z;
a.w *= b.w;
}
inline __host__ __device__ float4 operator*(float4 a, float b)
{
return make_float4(a.x * b, a.y * b, a.z * b, a.w * b);
}
inline __host__ __device__ float4 operator*(float b, float4 a)
{
return make_float4(b * a.x, b * a.y, b * a.z, b * a.w);
}
inline __host__ __device__ void operator*=(float4 &a, float b)
{
a.x *= b;
a.y *= b;
a.z *= b;
a.w *= b;
}
inline __host__ __device__ int4 operator*(int4 a, int4 b)
{
return make_int4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
}
inline __host__ __device__ void operator*=(int4 &a, int4 b)
{
a.x *= b.x;
a.y *= b.y;
a.z *= b.z;
a.w *= b.w;
}
inline __host__ __device__ int4 operator*(int4 a, int b)
{
return make_int4(a.x * b, a.y * b, a.z * b, a.w * b);
}
inline __host__ __device__ int4 operator*(int b, int4 a)
{
return make_int4(b * a.x, b * a.y, b * a.z, b * a.w);
}
inline __host__ __device__ void operator*=(int4 &a, int b)
{
a.x *= b;
a.y *= b;
a.z *= b;
a.w *= b;
}
inline __host__ __device__ uint4 operator*(uint4 a, uint4 b)
{
return make_uint4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
}
inline __host__ __device__ void operator*=(uint4 &a, uint4 b)
{
a.x *= b.x;
a.y *= b.y;
a.z *= b.z;
a.w *= b.w;
}
inline __host__ __device__ uint4 operator*(uint4 a, uint b)
{
return make_uint4(a.x * b, a.y * b, a.z * b, a.w * b);
}
inline __host__ __device__ uint4 operator*(uint b, uint4 a)
{
return make_uint4(b * a.x, b * a.y, b * a.z, b * a.w);
}
inline __host__ __device__ void operator*=(uint4 &a, uint b)
{
a.x *= b;
a.y *= b;
a.z *= b;
a.w *= b;
}
////////////////////////////////////////////////////////////////////////////////
// divide
////////////////////////////////////////////////////////////////////////////////
inline __host__ __device__ float2 operator/(float2 a, float2 b)
{
return make_float2(a.x / b.x, a.y / b.y);
}
inline __host__ __device__ void operator/=(float2 &a, float2 b)
{
a.x /= b.x;
a.y /= b.y;
}
inline __host__ __device__ float2 operator/(float2 a, float b)
{
return make_float2(a.x / b, a.y / b);
}
inline __host__ __device__ void operator/=(float2 &a, float b)
{
a.x /= b;
a.y /= b;
}
inline __host__ __device__ float2 operator/(float b, float2 a)
{
return make_float2(b / a.x, b / a.y);
}
inline __host__ __device__ float3 operator/(float3 a, float3 b)
{
return make_float3(a.x / b.x, a.y / b.y, a.z / b.z);
}
inline __host__ __device__ void operator/=(float3 &a, float3 b)
{
a.x /= b.x;
a.y /= b.y;
a.z /= b.z;
}
inline __host__ __device__ float3 operator/(float3 a, float b)
{
return make_float3(a.x / b, a.y / b, a.z / b);
}
inline __host__ __device__ void operator/=(float3 &a, float b)
{
a.x /= b;
a.y /= b;
a.z /= b;
}
inline __host__ __device__ float3 operator/(float b, float3 a)
{
return make_float3(b / a.x, b / a.y, b / a.z);
}
inline __host__ __device__ float4 operator/(float4 a, float4 b)
{
return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w);
}
inline __host__ __device__ void operator/=(float4 &a, float4 b)
{
a.x /= b.x;
a.y /= b.y;
a.z /= b.z;
a.w /= b.w;
}
inline __host__ __device__ float4 operator/(float4 a, float b)
{
return make_float4(a.x / b, a.y / b, a.z / b, a.w / b);
}
inline __host__ __device__ void operator/=(float4 &a, float b)
{
a.x /= b;
a.y /= b;
a.z /= b;
a.w /= b;
}
inline __host__ __device__ float4 operator/(float b, float4 a)
{
return make_float4(b / a.x, b / a.y, b / a.z, b / a.w);
}
////////////////////////////////////////////////////////////////////////////////
// min
////////////////////////////////////////////////////////////////////////////////
inline __host__ __device__ float2 fminf(float2 a, float2 b)
{
return make_float2(fminf(a.x,b.x), fminf(a.y,b.y));
}
inline __host__ __device__ float3 fminf(float3 a, float3 b)
{
return make_float3(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z));
}
inline __host__ __device__ float4 fminf(float4 a, float4 b)
{
return make_float4(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z), fminf(a.w,b.w));
}
inline __host__ __device__ int2 min(int2 a, int2 b)
{
return make_int2(min(a.x,b.x), min(a.y,b.y));
}
inline __host__ __device__ int3 min(int3 a, int3 b)
{
return make_int3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
}
inline __host__ __device__ int4 min(int4 a, int4 b)
{
return make_int4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
}
inline __host__ __device__ uint2 min(uint2 a, uint2 b)
{
return make_uint2(min(a.x,b.x), min(a.y,b.y));
}
inline __host__ __device__ uint3 min(uint3 a, uint3 b)
{
return make_uint3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
}
inline __host__ __device__ uint4 min(uint4 a, uint4 b)
{
return make_uint4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
}
////////////////////////////////////////////////////////////////////////////////
// max
////////////////////////////////////////////////////////////////////////////////
inline __host__ __device__ float2 fmaxf(float2 a, float2 b)
{
return make_float2(fmaxf(a.x,b.x), fmaxf(a.y,b.y));
}
inline __host__ __device__ float3 fmaxf(float3 a, float3 b)
{
return make_float3(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z));
}
inline __host__ __device__ float4 fmaxf(float4 a, float4 b)
{
return make_float4(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z), fmaxf(a.w,b.w));
}
inline __host__ __device__ int2 max(int2 a, int2 b)
{
return make_int2(max(a.x,b.x), max(a.y,b.y));
}
inline __host__ __device__ int3 max(int3 a, int3 b)
{
return make_int3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
}
inline __host__ __device__ int4 max(int4 a, int4 b)
{
return make_int4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
}
inline __host__ __device__ uint2 max(uint2 a, uint2 b)
{
return make_uint2(max(a.x,b.x), max(a.y,b.y));
}
inline __host__ __device__ uint3 max(uint3 a, uint3 b)
{
return make_uint3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
}
inline __host__ __device__ uint4 max(uint4 a, uint4 b)
{
return make_uint4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
}
////////////////////////////////////////////////////////////////////////////////
// lerp
// - linear interpolation between a and b, based on value t in [0, 1] range
////////////////////////////////////////////////////////////////////////////////
inline __device__ __host__ float lerp(float a, float b, float t)
{
return a + t*(b-a);
}
inline __device__ __host__ float2 lerp(float2 a, float2 b, float t)
{
return a + t*(b-a);
}
inline __device__ __host__ float3 lerp(float3 a, float3 b, float t)
{
return a + t*(b-a);
}
inline __device__ __host__ float4 lerp(float4 a, float4 b, float t)
{
return a + t*(b-a);
}
////////////////////////////////////////////////////////////////////////////////
// clamp
// - clamp the value v to be in the range [a, b]
////////////////////////////////////////////////////////////////////////////////
inline __device__ __host__ float clamp(float f, float a, float b)
{
return fmaxf(a, fminf(f, b));
}
inline __device__ __host__ int clamp(int f, int a, int b)
{
return max(a, min(f, b));
}
inline __device__ __host__ uint clamp(uint f, uint a, uint b)
{
return max(a, min(f, b));
}
inline __device__ __host__ float2 clamp(float2 v, float a, float b)
{
return make_float2(clamp(v.x, a, b), clamp(v.y, a, b));
}
inline __device__ __host__ float2 clamp(float2 v, float2 a, float2 b)
{
return make_float2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
}
inline __device__ __host__ float3 clamp(float3 v, float a, float b)
{
return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
}
inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b)
{
return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
}
inline __device__ __host__ float4 clamp(float4 v, float a, float b)
{
return make_float4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
}
inline __device__ __host__ float4 clamp(float4 v, float4 a, float4 b)
{
return make_float4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
}
inline __device__ __host__ int2 clamp(int2 v, int a, int b)
{
return make_int2(clamp(v.x, a, b), clamp(v.y, a, b));
}
inline __device__ __host__ int2 clamp(int2 v, int2 a, int2 b)
{
return make_int2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
}
inline __device__ __host__ int3 clamp(int3 v, int a, int b)
{
return make_int3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
}
inline __device__ __host__ int3 clamp(int3 v, int3 a, int3 b)
{
return make_int3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
}
inline __device__ __host__ int4 clamp(int4 v, int a, int b)
{
return make_int4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
}
inline __device__ __host__ int4 clamp(int4 v, int4 a, int4 b)
{
return make_int4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
}
inline __device__ __host__ uint2 clamp(uint2 v, uint a, uint b)
{
return make_uint2(clamp(v.x, a, b), clamp(v.y, a, b));
}
inline __device__ __host__ uint2 clamp(uint2 v, uint2 a, uint2 b)
{
return make_uint2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
}
inline __device__ __host__ uint3 clamp(uint3 v, uint a, uint b)
{
return make_uint3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
}
inline __device__ __host__ uint3 clamp(uint3 v, uint3 a, uint3 b)
{
return make_uint3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
}
inline __device__ __host__ uint4 clamp(uint4 v, uint a, uint b)
{
return make_uint4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
}
inline __device__ __host__ uint4 clamp(uint4 v, uint4 a, uint4 b)
{
return make_uint4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
}
////////////////////////////////////////////////////////////////////////////////
// dot product
////////////////////////////////////////////////////////////////////////////////
inline __host__ __device__ float dot(float2 a, float2 b)
{
return a.x * b.x + a.y * b.y;
}
inline __host__ __device__ float dot(float3 a, float3 b)
{
return a.x * b.x + a.y * b.y + a.z * b.z;
}
inline __host__ __device__ float dot(float4 a, float4 b)
{
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
}
inline __host__ __device__ int dot(int2 a, int2 b)
{
return a.x * b.x + a.y * b.y;
}
inline __host__ __device__ int dot(int3 a, int3 b)
{
return a.x * b.x + a.y * b.y + a.z * b.z;
}
inline __host__ __device__ int dot(int4 a, int4 b)
{
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
}
inline __host__ __device__ uint dot(uint2 a, uint2 b)
{
return a.x * b.x + a.y * b.y;
}
inline __host__ __device__ uint dot(uint3 a, uint3 b)
{
return a.x * b.x + a.y * b.y + a.z * b.z;
}
inline __host__ __device__ uint dot(uint4 a, uint4 b)
{
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
}
////////////////////////////////////////////////////////////////////////////////
// length
////////////////////////////////////////////////////////////////////////////////
inline __host__ __device__ float length(float2 v)
{
return sqrtf(dot(v, v));
}
inline __host__ __device__ float length(float3 v)
{
return sqrtf(dot(v, v));
}
inline __host__ __device__ float length(float4 v)
{
return sqrtf(dot(v, v));
}
////////////////////////////////////////////////////////////////////////////////
// normalize
////////////////////////////////////////////////////////////////////////////////
inline __host__ __device__ float2 normalize(float2 v)
{
float invLen = rsqrtf(dot(v, v));
return v * invLen;
}
inline __host__ __device__ float3 normalize(float3 v)
{
float invLen = rsqrtf(dot(v, v));
return v * invLen;
}
inline __host__ __device__ float4 normalize(float4 v)
{
float invLen = rsqrtf(dot(v, v));
return v * invLen;
}
////////////////////////////////////////////////////////////////////////////////
// floor
////////////////////////////////////////////////////////////////////////////////
inline __host__ __device__ float2 floorf(float2 v)
{
return make_float2(floorf(v.x), floorf(v.y));
}
inline __host__ __device__ float3 floorf(float3 v)
{
return make_float3(floorf(v.x), floorf(v.y), floorf(v.z));
}
inline __host__ __device__ float4 floorf(float4 v)
{
return make_float4(floorf(v.x), floorf(v.y), floorf(v.z), floorf(v.w));
}
////////////////////////////////////////////////////////////////////////////////
// frac - returns the fractional portion of a scalar or each vector component
////////////////////////////////////////////////////////////////////////////////
inline __host__ __device__ float fracf(float v)
{
return v - floorf(v);
}
inline __host__ __device__ float2 fracf(float2 v)
{
return make_float2(fracf(v.x), fracf(v.y));
}
inline __host__ __device__ float3 fracf(float3 v)
{
return make_float3(fracf(v.x), fracf(v.y), fracf(v.z));
}
inline __host__ __device__ float4 fracf(float4 v)
{
return make_float4(fracf(v.x), fracf(v.y), fracf(v.z), fracf(v.w));
}
////////////////////////////////////////////////////////////////////////////////
// fmod
////////////////////////////////////////////////////////////////////////////////
inline __host__ __device__ float2 fmodf(float2 a, float2 b)
{
return make_float2(fmodf(a.x, b.x), fmodf(a.y, b.y));
}
inline __host__ __device__ float3 fmodf(float3 a, float3 b)
{
return make_float3(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z));
}
inline __host__ __device__ float4 fmodf(float4 a, float4 b)
{
return make_float4(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z), fmodf(a.w, b.w));
}
////////////////////////////////////////////////////////////////////////////////
// absolute value
////////////////////////////////////////////////////////////////////////////////
inline __host__ __device__ float2 fabs(float2 v)
{
return make_float2(fabs(v.x), fabs(v.y));
}
inline __host__ __device__ float3 fabs(float3 v)
{
return make_float3(fabs(v.x), fabs(v.y), fabs(v.z));
}
inline __host__ __device__ float4 fabs(float4 v)
{
return make_float4(fabs(v.x), fabs(v.y), fabs(v.z), fabs(v.w));
}
inline __host__ __device__ int2 abs(int2 v)
{
return make_int2(abs(v.x), abs(v.y));
}
inline __host__ __device__ int3 abs(int3 v)
{
return make_int3(abs(v.x), abs(v.y), abs(v.z));
}
inline __host__ __device__ int4 abs(int4 v)
{
return make_int4(abs(v.x), abs(v.y), abs(v.z), abs(v.w));
}
////////////////////////////////////////////////////////////////////////////////
// reflect
// - returns reflection of incident ray I around surface normal N
// - N should be normalized, reflected vector's length is equal to length of I
////////////////////////////////////////////////////////////////////////////////
inline __host__ __device__ float3 reflect(float3 i, float3 n)
{
return i - 2.0f * n * dot(n,i);
}
////////////////////////////////////////////////////////////////////////////////
// cross product
////////////////////////////////////////////////////////////////////////////////
inline __host__ __device__ float3 cross(float3 a, float3 b)
{
return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x);
}
////////////////////////////////////////////////////////////////////////////////
// smoothstep
// - returns 0 if x < a
// - returns 1 if x > b
// - otherwise returns smooth interpolation between 0 and 1 based on x
////////////////////////////////////////////////////////////////////////////////
inline __device__ __host__ float smoothstep(float a, float b, float x)
{
float y = clamp((x - a) / (b - a), 0.0f, 1.0f);
return (y*y*(3.0f - (2.0f*y)));
}
inline __device__ __host__ float2 smoothstep(float2 a, float2 b, float2 x)
{
float2 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
return (y*y*(make_float2(3.0f) - (make_float2(2.0f)*y)));
}
inline __device__ __host__ float3 smoothstep(float3 a, float3 b, float3 x)
{
float3 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
return (y*y*(make_float3(3.0f) - (make_float3(2.0f)*y)));
}
inline __device__ __host__ float4 smoothstep(float4 a, float4 b, float4 x)
{
float4 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
return (y*y*(make_float4(3.0f) - (make_float4(2.0f)*y)));
}
#endif
\ No newline at end of file
#include "include/helpers_cuda.h"
template <typename scalar_t>
inline __host__ __device__ void _ray_aabb_intersect(
const scalar_t* rays_o,
const scalar_t* rays_d,
const scalar_t* aabb,
scalar_t* near,
scalar_t* far
) {
// aabb is [xmin, ymin, zmin, xmax, ymax, zmax]
scalar_t tmin = (aabb[0] - rays_o[0]) / rays_d[0];
scalar_t tmax = (aabb[3] - rays_o[0]) / rays_d[0];
if (tmin > tmax) __swap(tmin, tmax);
scalar_t tymin = (aabb[1] - rays_o[1]) / rays_d[1];
scalar_t tymax = (aabb[4] - rays_o[1]) / rays_d[1];
if (tymin > tymax) __swap(tymin, tymax);
if (tmin > tymax || tymin > tmax){
*near = std::numeric_limits<scalar_t>::max();
*far = std::numeric_limits<scalar_t>::max();
return;
}
if (tymin > tmin) tmin = tymin;
if (tymax < tmax) tmax = tymax;
scalar_t tzmin = (aabb[2] - rays_o[2]) / rays_d[2];
scalar_t tzmax = (aabb[5] - rays_o[2]) / rays_d[2];
if (tzmin > tzmax) __swap(tzmin, tzmax);
if (tmin > tzmax || tzmin > tmax){
*near = std::numeric_limits<scalar_t>::max();
*far = std::numeric_limits<scalar_t>::max();
return;
}
if (tzmin > tmin) tmin = tzmin;
if (tzmax < tmax) tmax = tzmax;
*near = tmin;
*far = tmax;
return;
}
template <typename scalar_t>
__global__ void kernel_ray_aabb_intersect(
const int N,
const scalar_t* rays_o,
const scalar_t* rays_d,
const scalar_t* aabb,
scalar_t* t_min,
scalar_t* t_max
){
// aabb is [xmin, ymin, zmin, xmax, ymax, zmax]
CUDA_GET_THREAD_ID(thread_id, N);
// locate
rays_o += thread_id * 3;
rays_d += thread_id * 3;
t_min += thread_id;
t_max += thread_id;
_ray_aabb_intersect<scalar_t>(rays_o, rays_d, aabb, t_min, t_max);
scalar_t zero = static_cast<scalar_t>(0.f);
*t_min = *t_min > zero ? *t_min : zero;
return;
}
/**
* @brief Ray AABB Test
*
* @param rays_o Ray origins. Tensor with shape [N, 3].
* @param rays_d Normalized ray directions. Tensor with shape [N, 3].
* @param aabb Scene AABB [xmin, ymin, zmin, xmax, ymax, zmax]. Tensor with shape [6].
* @return std::vector<torch::Tensor>
* Ray AABB intersection {t_min, t_max} with shape [N] respectively. Note the t_min is
* clipped to minimum zero.
*/
std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o, const torch::Tensor rays_d, const torch::Tensor aabb
) {
DEVICE_GUARD(rays_o);
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(aabb);
TORCH_CHECK(rays_o.ndimension() == 2 & rays_o.size(1) == 3)
TORCH_CHECK(rays_d.ndimension() == 2 & rays_d.size(1) == 3)
TORCH_CHECK(aabb.ndimension() == 1 & aabb.size(0) == 6)
const int N = rays_o.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(N, threads);
torch::Tensor t_min = torch::empty({N}, rays_o.options());
torch::Tensor t_max = torch::empty({N}, rays_o.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
rays_o.scalar_type(), "ray_aabb_intersect",
([&] {
kernel_ray_aabb_intersect<scalar_t><<<blocks, threads>>>(
N,
rays_o.data_ptr<scalar_t>(),
rays_d.data_ptr<scalar_t>(),
aabb.data_ptr<scalar_t>(),
t_min.data_ptr<scalar_t>(),
t_max.data_ptr<scalar_t>()
);
})
);
return {t_min, t_max};
}
\ No newline at end of file
#include "include/helpers_cuda.h"
std::vector<torch::Tensor> ray_aabb_intersect(
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor aabb
);
std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min,
const torch::Tensor t_max,
// density grid
const torch::Tensor aabb,
const torch::Tensor resolution,
const torch::Tensor occ_binary,
// sampling
const int max_total_samples,
const int max_per_ray_samples,
const float dt
);
std::vector<torch::Tensor> volumetric_rendering_forward(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas,
torch::Tensor rgbs
);
std::vector<torch::Tensor> volumetric_rendering_backward(
torch::Tensor accumulated_weight,
torch::Tensor accumulated_depth,
torch::Tensor accumulated_color,
torch::Tensor grad_weight,
torch::Tensor grad_depth,
torch::Tensor grad_color,
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas,
torch::Tensor rgbs
);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("ray_aabb_intersect", &ray_aabb_intersect);
m.def("ray_marching", &ray_marching);
m.def("volumetric_rendering_forward", &volumetric_rendering_forward);
m.def("volumetric_rendering_backward", &volumetric_rendering_backward);
}
\ No newline at end of file
#include "include/helpers_cuda.h"
inline __device__ int cascaded_grid_idx_at(
const float x, const float y, const float z,
const int* resolution, const float* aabb
) {
// TODO(ruilongli): if the x, y, z is outside the aabb, it will be clipped into aabb!!! We should just return false
int ix = (int)(((x - aabb[0]) / (aabb[3] - aabb[0])) * resolution[0]);
int iy = (int)(((y - aabb[1]) / (aabb[4] - aabb[1])) * resolution[1]);
int iz = (int)(((z - aabb[2]) / (aabb[5] - aabb[2])) * resolution[2]);
ix = __clamp(ix, 0, resolution[0]-1);
iy = __clamp(iy, 0, resolution[1]-1);
iz = __clamp(iz, 0, resolution[2]-1);
int idx = ix * resolution[1] * resolution[2] + iy * resolution[2] + iz;
return idx;
}
inline __device__ bool grid_occupied_at(
const float x, const float y, const float z,
const int* resolution, const float* aabb, const bool* occ_binary
) {
int idx = cascaded_grid_idx_at(x, y, z, resolution, aabb);
return occ_binary[idx];
}
inline __device__ float distance_to_next_voxel(
float x, float y, float z,
float dir_x, float dir_y, float dir_z,
float idir_x, float idir_y, float idir_z,
const int* resolution
) { // dda like step
// TODO: warning: expression has no effect?
x, y, z = resolution[0] * x, resolution[1] * y, resolution[2] * z;
float tx = ((floorf(x + 0.5f + 0.5f * __sign(dir_x)) - x) * idir_x) / resolution[0];
float ty = ((floorf(y + 0.5f + 0.5f * __sign(dir_y)) - y) * idir_y) / resolution[1];
float tz = ((floorf(z + 0.5f + 0.5f * __sign(dir_z)) - z) * idir_z) / resolution[2];
float t = min(min(tx, ty), tz);
return fmaxf(t, 0.0f);
}
inline __device__ float advance_to_next_voxel(
float t,
float x, float y, float z,
float dir_x, float dir_y, float dir_z,
float idir_x, float idir_y, float idir_z,
const int* resolution, float dt_min) {
// Regular stepping (may be slower but matches non-empty space)
float t_target = t + distance_to_next_voxel(
x, y, z, dir_x, dir_y, dir_z, idir_x, idir_y, idir_z, resolution
);
do {
t += dt_min;
} while (t < t_target);
return t;
}
__global__ void kernel_raymarching(
// rays info
const uint32_t n_rays,
const float* rays_o, // shape (n_rays, 3)
const float* rays_d, // shape (n_rays, 3)
const float* t_min, // shape (n_rays,)
const float* t_max, // shape (n_rays,)
// density grid
const float* aabb, // [min_x, min_y, min_z, max_x, max_y, max_y]
const int* resolution, // [reso_x, reso_y, reso_z]
const bool* occ_binary, // shape (reso_x, reso_y, reso_z)
// sampling
const int max_total_samples,
const int max_per_ray_samples,
const float dt,
// writable helpers
int* steps_counter,
int* rays_counter,
// frustrum outputs
int* packed_info,
float* frustum_origins,
float* frustum_dirs,
float* frustum_starts,
float* frustum_ends
) {
CUDA_GET_THREAD_ID(i, n_rays);
// locate
rays_o += i * 3;
rays_d += i * 3;
t_min += i;
t_max += i;
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
const float near = t_min[0], far = t_max[0];
uint32_t ray_idx, base, marching_samples;
uint32_t j;
float t0, t1, t_mid;
// first pass to compute an accurate number of steps
j = 0;
t0 = near; // TODO(ruilongli): perturb `near` as in ngp_pl?
t1 = t0 + dt;
t_mid = (t0 + t1) * 0.5f;
while (t_mid < far && j < max_per_ray_samples) {
// current center
const float x = ox + t_mid * dx;
const float y = oy + t_mid * dy;
const float z = oz + t_mid * dz;
if (grid_occupied_at(x, y, z, resolution, aabb, occ_binary)) {
++j;
// march to next sample
t0 = t1;
t1 = t0 + dt;
t_mid = (t0 + t1) * 0.5f;
}
else {
// march to next sample
t_mid = advance_to_next_voxel(
t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resolution, dt
);
t0 = t_mid - dt * 0.5f;
t1 = t_mid + dt * 0.5f;
}
}
if (j == 0) return;
marching_samples = j;
base = atomicAdd(steps_counter, marching_samples);
if (base + marching_samples > max_total_samples) return;
ray_idx = atomicAdd(rays_counter, 1);
// locate
frustum_origins += base * 3;
frustum_dirs += base * 3;
frustum_starts += base;
frustum_ends += base;
// Second round
j = 0;
t0 = near;
t1 = t0 + dt;
t_mid = (t0 + t1) / 2.;
while (t_mid < far && j < marching_samples) {
// current center
const float x = ox + t_mid * dx;
const float y = oy + t_mid * dy;
const float z = oz + t_mid * dz;
if (grid_occupied_at(x, y, z, resolution, aabb, occ_binary)) {
frustum_origins[j * 3 + 0] = ox;
frustum_origins[j * 3 + 1] = oy;
frustum_origins[j * 3 + 2] = oz;
frustum_dirs[j * 3 + 0] = dx;
frustum_dirs[j * 3 + 1] = dy;
frustum_dirs[j * 3 + 2] = dz;
frustum_starts[j] = t0;
frustum_ends[j] = t1;
++j;
// march to next sample
t0 = t1;
t1 = t0 + dt;
t_mid = (t0 + t1) * 0.5f;
}
else {
// march to next sample
t_mid = advance_to_next_voxel(
t_mid, x, y, z, dx, dy, dz, rdx, rdy, rdz, resolution, dt
);
t0 = t_mid - dt * 0.5f;
t1 = t_mid + dt * 0.5f;
}
}
packed_info[ray_idx * 3 + 0] = i; // ray idx in {rays_o, rays_d}
packed_info[ray_idx * 3 + 1] = base; // point idx start.
packed_info[ray_idx * 3 + 2] = j; // point idx shift (actual marching samples).
return;
}
/**
* @brief Sample points by ray marching.
*
* @param rays_o Ray origins Shape of [n_rays, 3].
* @param rays_d Normalized ray directions. Shape of [n_rays, 3].
* @param t_min Near planes of rays. Shape of [n_rays].
* @param t_max Far planes of rays. Shape of [n_rays].
* @param grid_center Density grid center. TODO: support 3-dims.
* @param grid_scale Density grid base level scale. TODO: support 3-dims.
* @param grid_cascades Density grid levels.
* @param grid_size Density grid resolution.
* @param grid_bitfield Density grid uint8 bit field.
* @param marching_steps Marching steps during inference.
* @param max_total_samples Maximum total number of samples in this batch.
* @param max_ray_samples Used to define the minimal step size: SQRT3() / max_ray_samples.
* @param cone_angle 0. for nerf-synthetic and 1./256 for real scenes.
* @param step_scale Scale up the step size by this much. Usually equals to scene scale.
* @return std::vector<torch::Tensor>
* - packed_info: Stores how to index the ray samples from the returned values.
* Shape of [n_rays, 3]. First value is the ray index. Second value is the sample
* start index in the results for this ray. Third value is the number of samples for
* this ray. Note for rays that have zero samples, we simply skip them so the `packed_info`
* has some zero padding in the end.
* - origins: Ray origins for those samples. [max_total_samples, 3]
* - dirs: Ray directions for those samples. [max_total_samples, 3]
* - starts: Where the frustum-shape sample starts along a ray. [max_total_samples, 1]
* - ends: Where the frustum-shape sample ends along a ray. [max_total_samples, 1]
*/
std::vector<torch::Tensor> ray_marching(
// rays
const torch::Tensor rays_o,
const torch::Tensor rays_d,
const torch::Tensor t_min,
const torch::Tensor t_max,
// density grid
const torch::Tensor aabb,
const torch::Tensor resolution,
const torch::Tensor occ_binary,
// sampling
const int max_total_samples,
const int max_per_ray_samples,
const float dt
) {
DEVICE_GUARD(rays_o);
CHECK_INPUT(rays_o);
CHECK_INPUT(rays_d);
CHECK_INPUT(t_min);
CHECK_INPUT(t_max);
CHECK_INPUT(aabb);
CHECK_INPUT(resolution);
CHECK_INPUT(occ_binary);
const int n_rays = rays_o.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// helper counter
torch::Tensor steps_counter = torch::zeros(
{1}, rays_o.options().dtype(torch::kInt32));
torch::Tensor rays_counter = torch::zeros(
{1}, rays_o.options().dtype(torch::kInt32));
// output frustum samples
torch::Tensor packed_info = torch::zeros(
{n_rays, 3}, rays_o.options().dtype(torch::kInt32)); // ray_id, sample_id, num_samples
torch::Tensor frustum_origins = torch::zeros({max_total_samples, 3}, rays_o.options());
torch::Tensor frustum_dirs = torch::zeros({max_total_samples, 3}, rays_o.options());
torch::Tensor frustum_starts = torch::zeros({max_total_samples, 1}, rays_o.options());
torch::Tensor frustum_ends = torch::zeros({max_total_samples, 1}, rays_o.options());
kernel_raymarching<<<blocks, threads>>>(
// rays
n_rays,
rays_o.data_ptr<float>(),
rays_d.data_ptr<float>(),
t_min.data_ptr<float>(),
t_max.data_ptr<float>(),
// density grid
aabb.data_ptr<float>(),
resolution.data_ptr<int>(),
occ_binary.data_ptr<bool>(),
// sampling
max_total_samples,
max_per_ray_samples,
dt,
// writable helpers
steps_counter.data_ptr<int>(), // total samples.
rays_counter.data_ptr<int>(), // total rays.
packed_info.data_ptr<int>(),
frustum_origins.data_ptr<float>(),
frustum_dirs.data_ptr<float>(),
frustum_starts.data_ptr<float>(),
frustum_ends.data_ptr<float>()
);
return {packed_info, frustum_origins, frustum_dirs, frustum_starts, frustum_ends};
}
#include "include/helpers_cuda.h"
template <typename scalar_t>
__global__ void volumetric_rendering_forward_kernel(
const uint32_t n_rays,
const int* packed_info, // input ray & point indices.
const scalar_t* starts, // input start t
const scalar_t* ends, // input end t
const scalar_t* sigmas, // input density after activation
const scalar_t* rgbs, // input rgb after activation
// should be all-zero initialized
scalar_t* accumulated_weight, // output
scalar_t* accumulated_depth, // output
scalar_t* accumulated_color, // output
bool* mask // output
) {
CUDA_GET_THREAD_ID(thread_id, n_rays);
// locate
const int i = packed_info[thread_id * 3 + 0]; // ray idx in {rays_o, rays_d}
const int base = packed_info[thread_id * 3 + 1]; // point idx start.
const int numsteps = packed_info[thread_id * 3 + 2]; // point idx shift.
if (numsteps == 0) return;
starts += base;
ends += base;
sigmas += base;
rgbs += base * 3;
accumulated_weight += i;
accumulated_depth += i;
accumulated_color += i * 3;
mask += i;
// accumulated rendering
scalar_t T = 1.f;
scalar_t EPSILON = 1e-4f;
int j = 0;
for (; j < numsteps; ++j) {
if (T < EPSILON) {
break;
}
const scalar_t delta = ends[j] - starts[j];
const scalar_t t = (ends[j] + starts[j]) * 0.5f;
const scalar_t alpha = 1.f - __expf(-sigmas[j] * delta);
const scalar_t weight = alpha * T;
accumulated_weight[0] += weight;
accumulated_depth[0] += weight * t;
accumulated_color[0] += weight * rgbs[j * 3 + 0];
accumulated_color[1] += weight * rgbs[j * 3 + 1];
accumulated_color[2] += weight * rgbs[j * 3 + 2];
T *= (1.f - alpha);
}
mask[0] = true;
}
template <typename scalar_t>
__global__ void volumetric_rendering_backward_kernel(
const uint32_t n_rays,
const int* packed_info, // input ray & point indices.
const scalar_t* starts, // input start t
const scalar_t* ends, // input end t
const scalar_t* sigmas, // input density after activation
const scalar_t* rgbs, // input rgb after activation
const scalar_t* accumulated_weight, // forward output
const scalar_t* accumulated_depth, // forward output
const scalar_t* accumulated_color, // forward output
const scalar_t* grad_weight, // input
const scalar_t* grad_depth, // input
const scalar_t* grad_color, // input
scalar_t* grad_sigmas, // output
scalar_t* grad_rgbs // output
) {
CUDA_GET_THREAD_ID(thread_id, n_rays);
// locate
const int i = packed_info[thread_id * 3 + 0]; // ray idx in {rays_o, rays_d}
const int base = packed_info[thread_id * 3 + 1]; // point idx start.
const int numsteps = packed_info[thread_id * 3 + 2]; // point idx shift.
if (numsteps == 0) return;
starts += base;
ends += base;
sigmas += base;
rgbs += base * 3;
grad_sigmas += base;
grad_rgbs += base * 3;
accumulated_weight += i;
accumulated_depth += i;
accumulated_color += i * 3;
grad_weight += i;
grad_depth += i;
grad_color += i * 3;
// backward of accumulated rendering
scalar_t T = 1.f;
scalar_t EPSILON = 1e-4f;
int j = 0;
scalar_t r = 0, g = 0, b = 0, d = 0;
for (; j < numsteps; ++j) {
if (T < EPSILON) {
break;
}
const scalar_t delta = ends[j] - starts[j];
const scalar_t t = (ends[j] + starts[j]) * 0.5f;
const scalar_t alpha = 1.f - __expf(-sigmas[j] * delta);
const scalar_t weight = alpha * T;
r += weight * rgbs[j * 3 + 0];
g += weight * rgbs[j * 3 + 1];
b += weight * rgbs[j * 3 + 2];
d += weight * t;
T *= (1.f - alpha);
grad_rgbs[j * 3 + 0] = grad_color[0] * weight;
grad_rgbs[j * 3 + 1] = grad_color[1] * weight;
grad_rgbs[j * 3 + 2] = grad_color[2] * weight;
grad_sigmas[j] = delta * (
grad_color[0] * (T * rgbs[j * 3 + 0] - (accumulated_color[0] - r)) +
grad_color[1] * (T * rgbs[j * 3 + 1] - (accumulated_color[1] - g)) +
grad_color[2] * (T * rgbs[j * 3 + 2] - (accumulated_color[2] - b)) +
grad_weight[0] * (1.f - accumulated_weight[0]) +
grad_depth[0] * (t * T - (accumulated_depth[0] - d))
);
}
}
/**
* @brief Volumetric Rendering: Accumulating samples in the forward pass.
* The inputs, excepct for `sigmas` and `rgbs`, are the outputs of our
* cuda ray marching function in `raymarching.cu`
*
* @param packed_info Stores how to index the ray samples from the returned values.
* Shape of [n_rays, 3]. First value is the ray index. Second value is the sample
* start index in the results for this ray. Third value is the number of samples for
* this ray. Note for rays that have zero samples, we simply skip them so the `packed_info`
* has some zero padding in the end.
* @param starts: Where the frustum-shape sample starts along a ray. [total_samples, 1]
* @param ends: Where the frustum-shape sample ends along a ray. [total_samples, 1]
* @param sigmas Densities at those samples. [total_samples, 1]
* @param rgbs RGBs at those samples. [total_samples, 3]
* @return std::vector<torch::Tensor>
* - accumulated_weight: Ray opacity. [n_rays, 1]
* - accumulated_depth: Ray depth. [n_rays, 1]
* - accumulated_color: Ray color. [n_rays, 3]
* - mask: Boolen value store if this ray has valid samples from packed_info. [n_rays]
*/
std::vector<torch::Tensor> volumetric_rendering_forward(
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas,
torch::Tensor rgbs
) {
DEVICE_GUARD(packed_info);
CHECK_INPUT(packed_info);
CHECK_INPUT(starts);
CHECK_INPUT(ends);
CHECK_INPUT(sigmas);
CHECK_INPUT(rgbs);
TORCH_CHECK(packed_info.ndimension() == 2 & packed_info.size(1) == 3);
TORCH_CHECK(starts.ndimension() == 2 & starts.size(1) == 1);
TORCH_CHECK(ends.ndimension() == 2 & ends.size(1) == 1);
TORCH_CHECK(sigmas.ndimension() == 2 & sigmas.size(1) == 1);
TORCH_CHECK(rgbs.ndimension() == 2 & rgbs.size(1) == 3);
const uint32_t n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor accumulated_weight = torch::zeros({n_rays, 1}, sigmas.options());
torch::Tensor accumulated_depth = torch::zeros({n_rays, 1}, sigmas.options());
torch::Tensor accumulated_color = torch::zeros({n_rays, 3}, sigmas.options());
// The rays that are not skipped during sampling.
torch::Tensor mask = torch::zeros({n_rays}, sigmas.options().dtype(torch::kBool));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
sigmas.scalar_type(),
"volumetric_rendering_forward",
([&]
{ volumetric_rendering_forward_kernel<scalar_t><<<blocks, threads>>>(
n_rays,
packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(),
sigmas.data_ptr<scalar_t>(),
rgbs.data_ptr<scalar_t>(),
accumulated_weight.data_ptr<scalar_t>(),
accumulated_depth.data_ptr<scalar_t>(),
accumulated_color.data_ptr<scalar_t>(),
mask.data_ptr<bool>()
);
}));
return {accumulated_weight, accumulated_depth, accumulated_color, mask};
}
/**
* @brief Volumetric Rendering: Accumulating samples in the backward pass.
*/
std::vector<torch::Tensor> volumetric_rendering_backward(
torch::Tensor accumulated_weight,
torch::Tensor accumulated_depth,
torch::Tensor accumulated_color,
torch::Tensor grad_weight,
torch::Tensor grad_depth,
torch::Tensor grad_color,
torch::Tensor packed_info,
torch::Tensor starts,
torch::Tensor ends,
torch::Tensor sigmas,
torch::Tensor rgbs
) {
DEVICE_GUARD(packed_info);
const uint32_t n_rays = packed_info.size(0);
const int threads = 256;
const int blocks = CUDA_N_BLOCKS_NEEDED(n_rays, threads);
// outputs
torch::Tensor grad_sigmas = torch::zeros(sigmas.sizes(), sigmas.options());
torch::Tensor grad_rgbs = torch::zeros(rgbs.sizes(), rgbs.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
sigmas.scalar_type(),
"volumetric_rendering_backward",
([&]
{ volumetric_rendering_backward_kernel<scalar_t><<<blocks, threads>>>(
n_rays,
packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(),
ends.data_ptr<scalar_t>(),
sigmas.data_ptr<scalar_t>(),
rgbs.data_ptr<scalar_t>(),
accumulated_weight.data_ptr<scalar_t>(),
accumulated_depth.data_ptr<scalar_t>(),
accumulated_color.data_ptr<scalar_t>(),
grad_weight.data_ptr<scalar_t>(),
grad_depth.data_ptr<scalar_t>(),
grad_color.data_ptr<scalar_t>(),
grad_sigmas.data_ptr<scalar_t>(),
grad_rgbs.data_ptr<scalar_t>()
);
}));
return {grad_sigmas, grad_rgbs};
}
\ No newline at end of file
ninja
pybind11
--extra-index-url https://download.pytorch.org/whl/cu116
torch==1.12.1
-e .
from setuptools import find_packages, setup
setup(
name="nerfacc",
description="NeRF accelerated rendering",
version="0.0.2",
python_requires=">=3.9",
packages=find_packages(exclude=("tests*",)),
)
conda create -n nerfacc python=3.9 -y
conda activate nerfacc
pip install -r requirements.txt
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment