Commit 9bddde48 authored by Ruilong Li's avatar Ruilong Li
Browse files

steam

parent ba049483
...@@ -244,7 +244,7 @@ std::vector<torch::Tensor> volumetric_marching( ...@@ -244,7 +244,7 @@ std::vector<torch::Tensor> volumetric_marching(
{n_rays}, rays_o.options().dtype(torch::kInt32)); {n_rays}, rays_o.options().dtype(torch::kInt32));
// count number of samples per ray // count number of samples per ray
marching_steps_kernel<<<blocks, threads>>>( marching_steps_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays // rays
n_rays, n_rays,
rays_o.data_ptr<float>(), rays_o.data_ptr<float>(),
...@@ -276,7 +276,7 @@ std::vector<torch::Tensor> volumetric_marching( ...@@ -276,7 +276,7 @@ std::vector<torch::Tensor> volumetric_marching(
torch::Tensor frustum_starts = torch::zeros({total_steps, 1}, rays_o.options()); torch::Tensor frustum_starts = torch::zeros({total_steps, 1}, rays_o.options());
torch::Tensor frustum_ends = torch::zeros({total_steps, 1}, rays_o.options()); torch::Tensor frustum_ends = torch::zeros({total_steps, 1}, rays_o.options());
marching_forward_kernel<<<blocks, threads>>>( marching_forward_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
// rays // rays
n_rays, n_rays,
rays_o.data_ptr<float>(), rays_o.data_ptr<float>(),
......
...@@ -166,7 +166,7 @@ std::vector<torch::Tensor> volumetric_rendering_steps( ...@@ -166,7 +166,7 @@ std::vector<torch::Tensor> volumetric_rendering_steps(
sigmas.scalar_type(), sigmas.scalar_type(),
"volumetric_marching_steps", "volumetric_marching_steps",
([&] ([&]
{ volumetric_rendering_steps_kernel<scalar_t><<<blocks, threads>>>( { volumetric_rendering_steps_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays, n_rays,
packed_info.data_ptr<int>(), packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(), starts.data_ptr<scalar_t>(),
...@@ -214,7 +214,7 @@ std::vector<torch::Tensor> volumetric_rendering_weights_forward( ...@@ -214,7 +214,7 @@ std::vector<torch::Tensor> volumetric_rendering_weights_forward(
sigmas.scalar_type(), sigmas.scalar_type(),
"volumetric_rendering_weights_forward", "volumetric_rendering_weights_forward",
([&] ([&]
{ volumetric_rendering_weights_forward_kernel<scalar_t><<<blocks, threads>>>( { volumetric_rendering_weights_forward_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays, n_rays,
packed_info.data_ptr<int>(), packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(), starts.data_ptr<scalar_t>(),
...@@ -251,7 +251,7 @@ torch::Tensor volumetric_rendering_weights_backward( ...@@ -251,7 +251,7 @@ torch::Tensor volumetric_rendering_weights_backward(
sigmas.scalar_type(), sigmas.scalar_type(),
"volumetric_rendering_weights_backward", "volumetric_rendering_weights_backward",
([&] ([&]
{ volumetric_rendering_weights_backward_kernel<scalar_t><<<blocks, threads>>>( { volumetric_rendering_weights_backward_kernel<scalar_t><<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays, n_rays,
packed_info.data_ptr<int>(), packed_info.data_ptr<int>(),
starts.data_ptr<scalar_t>(), starts.data_ptr<scalar_t>(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment