Unverified Commit 0be61b2a authored by Ruilong Li(李瑞龙)'s avatar Ruilong Li(李瑞龙) Committed by GitHub
Browse files

Remove unused code and fix issues when data is empty (#211)

* ndr trial

* remove unused cuda code

* remove unused python code

* remove unused python code

* scan.cu is able to deal with empty inputs

* fix when network input is empty

* fix when network input is empty

* fix when network input is empty

* revert benchmarks submodules

* remove unused mlp model
parent d84cdf3a
"""
Copyright (c) 2022 Ruilong Li, UC Berkeley.
Seems like both colmap and nerfstudio are based on OpenCV's camera model.
References:
- nerfstudio: https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/cameras/cameras.py
- opencv:
- https://docs.opencv.org/3.4/da/d54/group__imgproc__transform.html#ga69f2545a8b62a6b0fc2ee060dc30559d
- https://docs.opencv.org/3.4/d9/d0c/group__calib3d.html
- https://docs.opencv.org/4.x/db/d58/group__calib3d__fisheye.html
- https://github.com/opencv/opencv/blob/master/modules/calib3d/src/fisheye.cpp#L321
- https://github.com/opencv/opencv/blob/17234f82d025e3bbfbf611089637e5aa2038e7b8/modules/calib3d/src/distortion_model.hpp
- https://github.com/opencv/opencv/blob/8d0fbc6a1e9f20c822921e8076551a01e58cd632/modules/calib3d/src/undistort.dispatch.cpp#L578
- colmap: https://github.com/colmap/colmap/blob/dev/src/base/camera_models.h
- calcam: https://euratom-software.github.io/calcam/html/intro_theory.html
- blender:
- https://docs.blender.org/manual/en/latest/render/cycles/object_settings/cameras.html#fisheye-lens-polynomial
- https://github.com/blender/blender/blob/03cc3b94c94c38767802bccac4e9384ab704065a/intern/cycles/kernel/kernel_projection.h
- lensfun: https://lensfun.github.io/manual/v0.3.2/annotated.html
- OpenCV and Blender has different fisheye camera models
- https://stackoverflow.com/questions/73270140/pipeline-for-fisheye-distortion-and-undistortion-with-blender-and-opencv
"""
from typing import Literal, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import Tensor
from . import cuda as _C
def ray_directions_from_uvs(
uvs: Tensor, # [..., 2]
Ks: Tensor, # [..., 3, 3]
params: Optional[Tensor] = None, # [..., M]
) -> Tensor:
"""Create ray directions from uvs and camera parameters in OpenCV format.
Args:
uvs: UV coordinates on image plane. (In pixel unit)
Ks: Camera intrinsics.
params: Camera distortion parameters. See `opencv.undistortPoints` for details.
Returns:
Normalized ray directions in camera space.
"""
u, v = torch.unbind(uvs + 0.5, dim=-1)
fx, fy = Ks[..., 0, 0], Ks[..., 1, 1]
cx, cy = Ks[..., 0, 2], Ks[..., 1, 2]
# undo intrinsics
xys = torch.stack([(u - cx) / fx, (v - cy) / fy], dim=-1) # [..., 2]
# undo lens distortion
if params is not None:
M = params.shape[-1]
if M == 14: # undo tilt projection
R, R_inv = opencv_tilt_projection_matrix(params[..., -2:])
xys_homo = F.pad(xys, (0, 1), value=1.0) # [..., 3]
xys_homo = torch.einsum(
"...ij,...j->...i", R_inv, xys_homo
) # [..., 3]
xys = xys_homo[..., :2]
homo = xys_homo[..., 2:]
xys /= torch.where(homo != 0.0, homo, torch.ones_like(homo))
xys = opencv_lens_undistortion(xys, params) # [..., 2]
# normalized homogeneous coordinates
dirs = F.pad(xys, (0, 1), value=1.0) # [..., 3]
dirs = F.normalize(dirs, dim=-1) # [..., 3]
return dirs
def opencv_lens_undistortion(
uv: Tensor, params: Tensor, eps: float = 1e-6, iters: int = 10
) -> Tensor:
"""Undistort the opencv distortion of {k1, k2, k3, k4, p1, p2}.
Note:
This function is not differentiable to any inputs.
Args:
uv: (..., 2) UV coordinates.
params: (..., 6) or (6) OpenCV distortion parameters.
Returns:
(..., 2) undistorted UV coordinates.
"""
assert uv.shape[-1] == 2
assert params.shape[-1] == 6
batch_shape = uv.shape[:-1]
params = torch.broadcast_to(params, batch_shape + (6,))
return _C.opencv_lens_undistortion(
uv.contiguous(), params.contiguous(), eps, iters
)
def opencv_tilt_projection_matrix(tau: Tensor) -> Tensor:
"""Create a tilt projection matrix.
Reference:
https://docs.opencv.org/3.4/d9/d0c/group__calib3d.html
Args:
tau: (..., 2) tilt angles.
Returns:
(..., 3, 3) tilt projection matrix.
"""
cosx, cosy = torch.unbind(torch.cos(tau), -1)
sinx, siny = torch.unbind(torch.sin(tau), -1)
one = torch.ones_like(tau)
zero = torch.zeros_like(tau)
Rx = torch.stack(
[one, zero, zero, zero, cosx, sinx, zero, -sinx, cosx], -1
).reshape(*tau.shape[:-1], 3, 3)
Ry = torch.stack(
[cosy, zero, -siny, zero, one, zero, siny, zero, cosy], -1
).reshape(*tau.shape[:-1], 3, 3)
Rxy = torch.matmul(Ry, Rx)
Rz = torch.stack(
[
Rxy[..., 2, 2],
zero,
-Rxy[..., 0, 2],
zero,
Rxy[..., 2, 2],
-Rxy[..., 1, 2],
zero,
zero,
one,
],
-1,
).reshape(*tau.shape[:-1], 3, 3)
R = torch.matmul(Rz, Rxy)
inv = 1.0 / Rxy[..., 2, 2]
Rz_inv = torch.stack(
[
inv,
zero,
inv * Rxy[..., 0, 2],
zero,
inv,
inv * Rxy[..., 1, 2],
zero,
zero,
one,
],
-1,
).reshape(*tau.shape[:-1], 3, 3)
R_inv = torch.matmul(Rxy.transpose(-1, -2), Rz_inv)
return R, R_inv
......@@ -15,11 +15,7 @@ def _make_lazy_cuda_func(name: str) -> Callable:
return call_cuda
is_cub_available = _make_lazy_cuda_func("is_cub_available")
# data specs
MultiScaleGridSpec = _make_lazy_cuda_func("MultiScaleGridSpec")
RaysSpec = _make_lazy_cuda_func("RaysSpec")
RaySegmentsSpec = _make_lazy_cuda_func("RaySegmentsSpec")
# grid
......@@ -27,7 +23,6 @@ ray_aabb_intersect = _make_lazy_cuda_func("ray_aabb_intersect")
traverse_grids = _make_lazy_cuda_func("traverse_grids")
# scan
exclusive_sum_by_key = _make_lazy_cuda_func("exclusive_sum_by_key")
inclusive_sum = _make_lazy_cuda_func("inclusive_sum")
exclusive_sum = _make_lazy_cuda_func("exclusive_sum")
inclusive_prod_forward = _make_lazy_cuda_func("inclusive_prod_forward")
......
......@@ -3,44 +3,6 @@
#include <torch/extension.h>
#include "utils_cuda.cuh"
struct MultiScaleGridSpec {
torch::Tensor data; // [levels, resx, resy, resz]
torch::Tensor occupied; // [levels, resx, resy, resz]
torch::Tensor base_aabb; // [6,]
inline void check() {
CHECK_INPUT(data);
CHECK_INPUT(occupied);
CHECK_INPUT(base_aabb);
TORCH_CHECK(data.ndimension() == 4);
TORCH_CHECK(occupied.ndimension() == 4);
TORCH_CHECK(base_aabb.ndimension() == 1);
TORCH_CHECK(data.numel() == occupied.numel());
TORCH_CHECK(base_aabb.numel() == 6);
}
};
struct RaysSpec {
torch::Tensor origins; // [n_rays, 3]
torch::Tensor dirs; // [n_rays, 3]
inline void check() {
CHECK_INPUT(origins);
CHECK_INPUT(dirs);
TORCH_CHECK(origins.ndimension() == 2);
TORCH_CHECK(dirs.ndimension() == 2);
TORCH_CHECK(origins.numel() == dirs.numel());
TORCH_CHECK(origins.size(1) == 3);
TORCH_CHECK(dirs.size(1) == 3);
}
};
struct RaySegmentsSpec {
torch::Tensor vals; // [n_edges] or [n_rays, n_edges_per_ray]
// for flattened tensor
......
......@@ -37,34 +37,6 @@ struct PackedRaySegmentsSpec {
int32_t n_edges_per_ray;
};
struct PackedMultiScaleGridSpec {
PackedMultiScaleGridSpec(MultiScaleGridSpec& spec) :
data(spec.data.data_ptr<float>()),
occupied(spec.occupied.data_ptr<bool>()),
base_aabb(spec.base_aabb.data_ptr<float>()),
levels(spec.data.size(0)),
resolution{
(int32_t)spec.data.size(1),
(int32_t)spec.data.size(2),
(int32_t)spec.data.size(3)}
{ }
float* data;
bool* occupied;
float* base_aabb;
int32_t levels;
int3 resolution;
};
struct PackedRaysSpec {
PackedRaysSpec(RaysSpec& spec) :
origins(spec.origins.data_ptr<float>()),
dirs(spec.dirs.data_ptr<float>()),
N(spec.origins.size(0))
{ }
float *origins;
float *dirs;
int32_t N;
};
struct SingleRaySpec {
// TODO: check inv_dir if dir is zero.
......@@ -77,23 +49,6 @@ struct SingleRaySpec {
tmax{tmax}
{ }
__device__ SingleRaySpec(
PackedRaysSpec& rays, int32_t id, float tmin, float tmax) :
origin{
rays.origins[id * 3],
rays.origins[id * 3 + 1],
rays.origins[id * 3 + 2]},
dir{
rays.dirs[id * 3],
rays.dirs[id * 3 + 1],
rays.dirs[id * 3 + 2]},
inv_dir{
1.0f / rays.dirs[id * 3],
1.0f / rays.dirs[id * 3 + 1],
1.0f / rays.dirs[id * 3 + 2]},
tmin{tmin},
tmax{tmax}
{ }
float3 origin;
float3 dir;
float3 inv_dir;
......
......@@ -7,15 +7,7 @@
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/Exceptions.h>
// #include <ATen/cuda/cub_definitions.cuh>
// cub support for scan by key is added to cub 1.15
// in https://github.com/NVIDIA/cub/pull/376
#if CUB_VERSION >= 101500
#define CUB_SUPPORTS_SCAN_BY_KEY() 1
#else
#define CUB_SUPPORTS_SCAN_BY_KEY() 0
#endif
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
......@@ -31,16 +23,6 @@
#define DEVICE_GUARD(_ten) \
const at::cuda::OptionalCUDAGuard device_guard(device_of(_ten));
// https://github.com/pytorch/pytorch/blob/233305a852e1cd7f319b15b5137074c9eac455f6/aten/src/ATen/cuda/cub.cuh#L38-L46
#define CUB_WRAPPER(func, ...) do { \
size_t temp_storage_bytes = 0; \
func(nullptr, temp_storage_bytes, __VA_ARGS__); \
auto& caching_allocator = *::c10::cuda::CUDACachingAllocator::get(); \
auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \
func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \
AT_CUDA_CHECK(cudaGetLastError()); \
} while (false)
template <typename scalar_t>
inline __device__ __host__ scalar_t ceil_div(scalar_t a, scalar_t b)
{
......
......@@ -7,25 +7,6 @@
#include "utils_cuda.cuh"
// CUB support for scan by key is added to cub 1.15
// in https://github.com/NVIDIA/cub/pull/376
#if CUB_VERSION >= 101500
#define CUB_SUPPORTS_SCAN_BY_KEY() 1
#else
#define CUB_SUPPORTS_SCAN_BY_KEY() 0
#endif
// https://github.com/pytorch/pytorch/blob/233305a852e1cd7f319b15b5137074c9eac455f6/aten/src/ATen/cuda/cub.cuh#L38-L46
#define CUB_WRAPPER(func, ...) do { \
size_t temp_storage_bytes = 0; \
func(nullptr, temp_storage_bytes, __VA_ARGS__); \
auto& caching_allocator = *::c10::cuda::CUDACachingAllocator::get(); \
auto temp_storage = caching_allocator.allocate(temp_storage_bytes); \
func(temp_storage.get(), temp_storage_bytes, __VA_ARGS__); \
AT_CUDA_CHECK(cudaGetLastError()); \
} while (false)
namespace {
namespace device {
......
......@@ -3,16 +3,8 @@
#include <torch/extension.h>
bool is_cub_available() {
// FIXME: why return false?
return (bool) CUB_SUPPORTS_SCAN_BY_KEY();
}
// scan
torch::Tensor exclusive_sum_by_key(
torch::Tensor indices,
torch::Tensor inputs,
bool backward);
torch::Tensor inclusive_sum(
torch::Tensor chunk_starts,
torch::Tensor chunk_cnts,
......@@ -103,9 +95,6 @@ torch::Tensor opencv_lens_undistortion_fisheye(
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
#define _REG_FUNC(funname) m.def(#funname, &funname)
_REG_FUNC(is_cub_available); // TODO: check this function
_REG_FUNC(exclusive_sum_by_key);
_REG_FUNC(inclusive_sum);
_REG_FUNC(exclusive_sum);
_REG_FUNC(inclusive_prod_forward);
......@@ -118,23 +107,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
_REG_FUNC(searchsorted);
_REG_FUNC(opencv_lens_undistortion);
_REG_FUNC(opencv_lens_undistortion_fisheye);
_REG_FUNC(opencv_lens_undistortion_fisheye); // TODO: check this function.
#undef _REG_FUNC
m.def("importance_sampling", py::overload_cast<RaySegmentsSpec, torch::Tensor, torch::Tensor, bool>(&importance_sampling));
m.def("importance_sampling", py::overload_cast<RaySegmentsSpec, torch::Tensor, int64_t, bool>(&importance_sampling));
py::class_<MultiScaleGridSpec>(m, "MultiScaleGridSpec")
.def(py::init<>())
.def_readwrite("data", &MultiScaleGridSpec::data)
.def_readwrite("occupied", &MultiScaleGridSpec::occupied)
.def_readwrite("base_aabb", &MultiScaleGridSpec::base_aabb);
py::class_<RaysSpec>(m, "RaysSpec")
.def(py::init<>())
.def_readwrite("origins", &RaysSpec::origins)
.def_readwrite("dirs", &RaysSpec::dirs);
py::class_<RaySegmentsSpec>(m, "RaySegmentsSpec")
.def(py::init<>())
.def_readwrite("vals", &RaySegmentsSpec::vals)
......
......@@ -4,75 +4,6 @@
#include <thrust/iterator/reverse_iterator.h>
#include "include/utils_scan.cuh"
#if CUB_SUPPORTS_SCAN_BY_KEY()
#include <cub/cub.cuh>
#endif
namespace {
namespace device {
#if CUB_SUPPORTS_SCAN_BY_KEY()
struct Product
{
template <typename T>
__host__ __device__ __forceinline__ T operator()(const T &a, const T &b) const { return a * b; }
};
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
inline void exclusive_sum_by_key(
KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items)
{
TORCH_CHECK(num_items <= std::numeric_limits<int64_t>::max(),
"cub ExclusiveSumByKey does not support more than LONG_MAX elements");
CUB_WRAPPER(cub::DeviceScan::ExclusiveSumByKey, keys, input, output,
num_items, cub::Equality(), at::cuda::getCurrentCUDAStream());
}
template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename ValuesOutputIteratorT>
inline void exclusive_prod_by_key(
KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items)
{
TORCH_CHECK(num_items <= std::numeric_limits<int64_t>::max(),
"cub ExclusiveScanByKey does not support more than LONG_MAX elements");
CUB_WRAPPER(cub::DeviceScan::ExclusiveScanByKey, keys, input, output, Product(), 1.0f,
num_items, cub::Equality(), at::cuda::getCurrentCUDAStream());
}
#endif
} // namespace device
} // namespace
torch::Tensor exclusive_sum_by_key(
torch::Tensor indices,
torch::Tensor inputs,
bool backward)
{
DEVICE_GUARD(inputs);
torch::Tensor outputs = torch::empty_like(inputs);
int64_t n_items = inputs.size(0);
#if CUB_SUPPORTS_SCAN_BY_KEY()
if (backward)
device::exclusive_sum_by_key(
thrust::make_reverse_iterator(indices.data_ptr<int64_t>() + n_items),
thrust::make_reverse_iterator(inputs.data_ptr<float>() + n_items),
thrust::make_reverse_iterator(outputs.data_ptr<float>() + n_items),
n_items);
else
device::exclusive_sum_by_key(
indices.data_ptr<int64_t>(),
inputs.data_ptr<float>(),
outputs.data_ptr<float>(),
n_items);
#else
std::runtime_error("CUB functions are only supported in CUDA >= 11.6.");
#endif
cudaGetLastError();
return outputs;
}
torch::Tensor inclusive_sum(
......@@ -97,13 +28,16 @@ torch::Tensor inclusive_sum(
uint32_t n_rays = chunk_cnts.size(0);
int64_t n_edges = inputs.size(0);
torch::Tensor outputs = torch::empty_like(inputs);
if (n_edges == 0) {
return outputs;
}
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int32_t max_blocks = 65535;
dim3 threads = dim3(16, 32);
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.y)));
torch::Tensor outputs = torch::empty_like(inputs);
if (backward) {
chunk_starts = n_edges - (chunk_starts + chunk_cnts);
device::inclusive_scan_kernel<float, 16, 32><<<blocks, threads, 0, stream>>>(
......@@ -153,13 +87,16 @@ torch::Tensor exclusive_sum(
uint32_t n_rays = chunk_cnts.size(0);
int64_t n_edges = inputs.size(0);
torch::Tensor outputs = torch::empty_like(inputs);
if (n_edges == 0) {
return outputs;
}
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int32_t max_blocks = 65535;
dim3 threads = dim3(16, 32);
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.y)));
torch::Tensor outputs = torch::empty_like(inputs);
if (backward) {
chunk_starts = n_edges - (chunk_starts + chunk_cnts);
device::exclusive_scan_kernel<float, 16, 32><<<blocks, threads, 0, stream>>>(
......@@ -205,13 +142,16 @@ torch::Tensor inclusive_prod_forward(
uint32_t n_rays = chunk_cnts.size(0);
int64_t n_edges = inputs.size(0);
torch::Tensor outputs = torch::empty_like(inputs);
if (n_edges == 0) {
return outputs;
}
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int32_t max_blocks = 65535;
dim3 threads = dim3(16, 32);
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.y)));
torch::Tensor outputs = torch::empty_like(inputs);
device::inclusive_scan_kernel<float, 16, 32><<<blocks, threads, 0, stream>>>(
outputs.data_ptr<float>(),
inputs.data_ptr<float>(),
......@@ -246,13 +186,16 @@ torch::Tensor inclusive_prod_backward(
uint32_t n_rays = chunk_cnts.size(0);
int64_t n_edges = inputs.size(0);
torch::Tensor grad_inputs = torch::empty_like(grad_outputs);
if (n_edges == 0) {
return grad_inputs;
}
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int32_t max_blocks = 65535;
dim3 threads = dim3(16, 32);
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.y)));
torch::Tensor grad_inputs = torch::empty_like(grad_outputs);
chunk_starts = n_edges - (chunk_starts + chunk_cnts);
device::inclusive_scan_kernel<float, 16, 32><<<blocks, threads, 0, stream>>>(
thrust::make_reverse_iterator(grad_inputs.data_ptr<float>() + n_edges),
......@@ -289,13 +232,16 @@ torch::Tensor exclusive_prod_forward(
uint32_t n_rays = chunk_cnts.size(0);
int64_t n_edges = inputs.size(0);
torch::Tensor outputs = torch::empty_like(inputs);
if (n_edges == 0) {
return outputs;
}
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int32_t max_blocks = 65535;
dim3 threads = dim3(16, 32);
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.y)));
torch::Tensor outputs = torch::empty_like(inputs);
device::exclusive_scan_kernel<float, 16, 32><<<blocks, threads, 0, stream>>>(
outputs.data_ptr<float>(),
inputs.data_ptr<float>(),
......@@ -330,13 +276,16 @@ torch::Tensor exclusive_prod_backward(
uint32_t n_rays = chunk_cnts.size(0);
int64_t n_edges = inputs.size(0);
torch::Tensor grad_inputs = torch::empty_like(grad_outputs);
if (n_edges == 0) {
return grad_inputs;
}
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
int32_t max_blocks = 65535;
dim3 threads = dim3(16, 32);
dim3 blocks = dim3(min(max_blocks, ceil_div<int32_t>(n_rays, threads.y)));
torch::Tensor grad_inputs = torch::empty_like(grad_outputs);
chunk_starts = n_edges - (chunk_starts + chunk_cnts);
device::exclusive_scan_kernel<float, 16, 32><<<blocks, threads, 0, stream>>>(
thrust::make_reverse_iterator(grad_inputs.data_ptr<float>() + n_edges),
......
......@@ -184,7 +184,10 @@ class OccGridEstimator(AbstractEstimator):
# Compute visibility of the samples, and filter out invisible samples
if sigma_fn is not None:
if t_starts.shape[0] != 0:
sigmas = sigma_fn(t_starts, t_ends, ray_indices)
else:
sigmas = torch.empty((0,), device=t_starts.device)
assert (
sigmas.shape == t_starts.shape
), "sigmas must have shape of (N,)! Got {}".format(sigmas.shape)
......@@ -197,7 +200,10 @@ class OccGridEstimator(AbstractEstimator):
alpha_thre=alpha_thre,
)
elif alpha_fn is not None:
if t_starts.shape[0] != 0:
alphas = alpha_fn(t_starts, t_ends, ray_indices)
else:
alphas = torch.empty((0,), device=t_starts.device)
assert (
alphas.shape == t_starts.shape
), "alphas must have shape of (N,)! Got {}".format(alphas.shape)
......
......@@ -85,7 +85,11 @@ def rendering(
# Query sigma/alpha and color with gradients
if rgb_sigma_fn is not None:
if t_starts.shape[0] != 0:
rgbs, sigmas = rgb_sigma_fn(t_starts, t_ends, ray_indices)
else:
rgbs = torch.empty((0, 3), device=t_starts.device)
sigmas = torch.empty((0,), device=t_starts.device)
assert rgbs.shape[-1] == 3, "rgbs must have 3 channels, got {}".format(
rgbs.shape
)
......@@ -108,7 +112,11 @@ def rendering(
"rgbs": rgbs,
}
elif rgb_alpha_fn is not None:
if t_starts.shape[0] != 0:
rgbs, alphas = rgb_alpha_fn(t_starts, t_ends, ray_indices)
else:
rgbs = torch.empty((0, 3), device=t_starts.device)
alphas = torch.empty((0,), device=t_starts.device)
assert rgbs.shape[-1] == 3, "rgbs must have 3 channels, got {}".format(
rgbs.shape
)
......
......@@ -105,7 +105,11 @@ setup(
download_url=f"{URL}/archive/{__version__}.tar.gz",
keywords=[],
python_requires=">=3.7",
install_requires=["rich>=12", "torch", "typing_extensions; python_version<'3.8'"],
install_requires=[
"rich>=12",
"torch",
"typing_extensions; python_version<'3.8'",
],
extras_require={
# dev dependencies. Install them by `pip install nerfacc[dev]`
"dev": [
......@@ -118,6 +122,7 @@ setup(
"pyyaml==6.0",
"build",
"twine",
"ninja",
],
},
ext_modules=get_extensions() if not BUILD_NO_CUDA else [],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment