Commit fb876fd4 authored by Ruilong Li's avatar Ruilong Li
Browse files

fix long to int64_t for Windows

parent 1bc23a18
......@@ -9,7 +9,7 @@ __global__ void unpack_info_kernel(
const int n_rays,
const int *packed_info,
// output
long *ray_indices)
int64_t *ray_indices)
{
CUDA_GET_THREAD_ID(i, n_rays);
......@@ -97,7 +97,7 @@ torch::Tensor unpack_info(const torch::Tensor packed_info, const int n_samples)
unpack_info_kernel<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
n_rays,
packed_info.data_ptr<int>(),
ray_indices.data_ptr<long>());
ray_indices.data_ptr<int64_t>());
return ray_indices;
}
......
......@@ -97,7 +97,7 @@ __global__ void ray_marching_kernel(
// first round outputs
int *num_steps,
// second round outputs
long *ray_indices,
int64_t *ray_indices,
float *t_starts,
float *t_ends)
{
......@@ -281,7 +281,7 @@ std::vector<torch::Tensor> ray_marching(
packed_info.data_ptr<int>(),
// outputs
nullptr, /* num_steps */
ray_indices.data_ptr<long>(),
ray_indices.data_ptr<int64_t>(),
t_starts.data_ptr<float>(),
t_ends.data_ptr<float>());
......
......@@ -20,7 +20,7 @@ template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename V
inline void exclusive_sum_by_key(
KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items)
{
TORCH_CHECK(num_items <= std::numeric_limits<long>::max(),
TORCH_CHECK(num_items <= std::numeric_limits<int64_t>::max(),
"cub ExclusiveSumByKey does not support more than LONG_MAX elements");
CUB_WRAPPER(cub::DeviceScan::ExclusiveSumByKey, keys, input, output,
num_items, cub::Equality(), at::cuda::getCurrentCUDAStream());
......@@ -30,7 +30,7 @@ template <typename KeysInputIteratorT, typename ValuesInputIteratorT, typename V
inline void exclusive_prod_by_key(
KeysInputIteratorT keys, ValuesInputIteratorT input, ValuesOutputIteratorT output, int64_t num_items)
{
TORCH_CHECK(num_items <= std::numeric_limits<long>::max(),
TORCH_CHECK(num_items <= std::numeric_limits<int64_t>::max(),
"cub ExclusiveScanByKey does not support more than LONG_MAX elements");
CUB_WRAPPER(cub::DeviceScan::ExclusiveScanByKey, keys, input, output, Product(), 1.0f,
num_items, cub::Equality(), at::cuda::getCurrentCUDAStream());
......@@ -60,7 +60,7 @@ torch::Tensor transmittance_from_sigma_forward_cub(
torch::Tensor sigmas_dt_cumsum = torch::empty_like(sigmas);
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_sum_by_key(
ray_indices.data_ptr<long>(),
ray_indices.data_ptr<int64_t>(),
sigmas_dt.data_ptr<float>(),
sigmas_dt_cumsum.data_ptr<float>(),
n_samples);
......@@ -97,7 +97,7 @@ torch::Tensor transmittance_from_sigma_backward_cub(
torch::Tensor sigmas_dt_grad = torch::empty_like(transmittance_grad);
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_sum_by_key(
thrust::make_reverse_iterator(ray_indices.data_ptr<long>() + n_samples),
thrust::make_reverse_iterator(ray_indices.data_ptr<int64_t>() + n_samples),
thrust::make_reverse_iterator(sigmas_dt_cumsum_grad.data_ptr<float>() + n_samples),
thrust::make_reverse_iterator(sigmas_dt_grad.data_ptr<float>() + n_samples),
n_samples);
......@@ -123,7 +123,7 @@ torch::Tensor transmittance_from_alpha_forward_cub(
torch::Tensor transmittance = torch::empty_like(alphas);
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_prod_by_key(
ray_indices.data_ptr<long>(),
ray_indices.data_ptr<int64_t>(),
(1.0f - alphas).data_ptr<float>(),
transmittance.data_ptr<float>(),
n_samples);
......@@ -154,7 +154,7 @@ torch::Tensor transmittance_from_alpha_backward_cub(
torch::Tensor sigmas_dt_grad = torch::empty_like(transmittance_grad);
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_sum_by_key(
thrust::make_reverse_iterator(ray_indices.data_ptr<long>() + n_samples),
thrust::make_reverse_iterator(ray_indices.data_ptr<int64_t>() + n_samples),
thrust::make_reverse_iterator(sigmas_dt_cumsum_grad.data_ptr<float>() + n_samples),
thrust::make_reverse_iterator(sigmas_dt_grad.data_ptr<float>() + n_samples),
n_samples);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment