Unverified Commit 910488e3 authored by Zhihao Liang's avatar Zhihao Liang Committed by GitHub
Browse files

fix: fix the dimension comments of t_sorted and t_indices (#221)


Co-authored-by: default avatarchihaoliang(chihaoliang) <chihaoliang@tencent.com>
parent 2312ac20
......@@ -55,8 +55,8 @@ std::tuple<RaySegmentsSpec, RaySegmentsSpec, torch::Tensor> traverse_grids(
const torch::Tensor binaries, // [n_grids, resx, resy, resz]
const torch::Tensor aabbs, // [n_grids, 6]
// intersections
const torch::Tensor t_sorted, // [n_rays, n_grids]
const torch::Tensor t_indices, // [n_rays, n_grids]
const torch::Tensor t_sorted, // [n_rays, n_grids * 2]
const torch::Tensor t_indices, // [n_rays, n_grids * 2]
const torch::Tensor hits, // [n_rays, n_grids]
// options
const torch::Tensor near_planes,
......
......@@ -107,8 +107,8 @@ def traverse_grids(
over_allocate: Optional[bool] = False,
rays_mask: Optional[Tensor] = None, # [n_rays]
# pre-compute intersections
t_sorted: Optional[Tensor] = None, # [n_rays, n_grids]
t_indices: Optional[Tensor] = None, # [n_rays, n_grids]
t_sorted: Optional[Tensor] = None, # [n_rays, n_grids * 2]
t_indices: Optional[Tensor] = None, # [n_rays, n_grids * 2]
hits: Optional[Tensor] = None, # [n_rays, n_grids]
) -> Tuple[RayIntervals, RaySamples, Tensor]:
"""Ray Traversal within Multiple Grids.
......@@ -129,8 +129,8 @@ def traverse_grids(
traverse_steps_limit: Optional. Maximum number of samples per ray.
over_allocate: Optional. Whether to over-allocate the memory for the outputs.
rays_mask: Optional. (n_rays,) Skip some rays if given.
t_sorted: Optional. (n_rays, n_grids) Pre-computed sorted t values for each ray-grid pair. Default to None.
t_indices: Optional. (n_rays, n_grids) Pre-computed sorted t indices for each ray-grid pair. Default to None.
t_sorted: Optional. (n_rays, n_grids * 2) Pre-computed sorted t values for each ray-grid pair. Default to None.
t_indices: Optional. (n_rays, n_grids * 2) Pre-computed sorted t indices for each ray-grid pair. Default to None.
hits: Optional. (n_rays, n_grids) Pre-computed hit flags for each ray-grid pair. Default to None.
Returns:
......@@ -171,8 +171,8 @@ def traverse_grids(
binaries.contiguous(), # [m, resx, resy, resz]
aabbs.contiguous(), # [m, 6]
# intersections
t_sorted.contiguous(), # [n_rays, m]
t_indices.contiguous(), # [n_rays, m]
t_sorted.contiguous(), # [n_rays, m * 2]
t_indices.contiguous(), # [n_rays, m * 2]
hits.contiguous(), # [n_rays, m]
# options
near_planes.contiguous(), # [n_rays]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment