"docs/git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "045a9743a9def9001c3f657bd70f12a36a64227e"
Commit 03a8a782 authored by Ruilong Li's avatar Ruilong Li
Browse files

volrend support CUB

parent ea07af8e
...@@ -7,6 +7,7 @@ from typing import Callable, Dict, Optional, Tuple ...@@ -7,6 +7,7 @@ from typing import Callable, Dict, Optional, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
from .cuda import is_cub_available
from .pack import pack_info from .pack import pack_info
from .scan import exclusive_prod, exclusive_sum from .scan import exclusive_prod, exclusive_sum
...@@ -197,10 +198,14 @@ def render_transmittance_from_alpha( ...@@ -197,10 +198,14 @@ def render_transmittance_from_alpha(
# FIXME Try not to use exclusive_prod because: # FIXME Try not to use exclusive_prod because:
# 1. torch.cumprod is much slower than torch.cumsum # 1. torch.cumprod is much slower than torch.cumsum
# 2. exclusive_prod gradient on input == 0 is not correct. # 2. exclusive_prod gradient on input == 0 is not correct.
if ray_indices is not None and packed_info is None: if not is_cub_available() and packed_info is None:
# Convert ray indices to packed info
packed_info = pack_info(ray_indices, n_rays) packed_info = pack_info(ray_indices, n_rays)
ray_indices = None
trans = exclusive_prod(1 - alphas, packed_info) trans = exclusive_prod(
1 - alphas, packed_info=packed_info, indices=ray_indices
)
if prefix_trans is not None: if prefix_trans is not None:
trans *= prefix_trans trans *= prefix_trans
return trans return trans
...@@ -253,12 +258,16 @@ def render_transmittance_from_density( ...@@ -253,12 +258,16 @@ def render_transmittance_from_density(
alphas: [0.33, 0.55, 0.095, 0.55, 0.095, 0.00, 0.59] alphas: [0.33, 0.55, 0.095, 0.55, 0.095, 0.00, 0.59]
""" """
if ray_indices is not None and packed_info is None: if not is_cub_available() and packed_info is None:
# Convert ray indices to packed info
packed_info = pack_info(ray_indices, n_rays) packed_info = pack_info(ray_indices, n_rays)
ray_indices = None
sigmas_dt = sigmas * (t_ends - t_starts) sigmas_dt = sigmas * (t_ends - t_starts)
alphas = 1.0 - torch.exp(-sigmas_dt) alphas = 1.0 - torch.exp(-sigmas_dt)
trans = torch.exp(-exclusive_sum(sigmas_dt, packed_info)) trans = torch.exp(
-exclusive_sum(sigmas_dt, packed_info=packed_info, indices=ray_indices)
)
if prefix_trans is not None: if prefix_trans is not None:
trans = trans * prefix_trans trans = trans * prefix_trans
return trans, alphas return trans, alphas
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment