Unverified Commit 57ccfa14 authored by Skylar Wurster's avatar Skylar Wurster Committed by GitHub
Browse files

Use int64_t instead of long (#301)

parent 8a668783
Pipeline #2966 canceled with stages
......@@ -87,13 +87,13 @@ torch::Tensor inclusive_sum_cub(
#if CUB_SUPPORTS_SCAN_BY_KEY()
if (backward) {
inclusive_sum_by_key(
thrust::make_reverse_iterator(indices.data_ptr<long>() + n_edges),
thrust::make_reverse_iterator(indices.data_ptr<int64_t>() + n_edges),
thrust::make_reverse_iterator(inputs.data_ptr<float>() + n_edges),
thrust::make_reverse_iterator(outputs.data_ptr<float>() + n_edges),
n_edges);
} else {
inclusive_sum_by_key(
indices.data_ptr<long>(),
indices.data_ptr<int64_t>(),
inputs.data_ptr<float>(),
outputs.data_ptr<float>(),
n_edges);
......@@ -129,13 +129,13 @@ torch::Tensor exclusive_sum_cub(
#if CUB_SUPPORTS_SCAN_BY_KEY()
if (backward) {
exclusive_sum_by_key(
thrust::make_reverse_iterator(indices.data_ptr<long>() + n_edges),
thrust::make_reverse_iterator(indices.data_ptr<int64_t>() + n_edges),
thrust::make_reverse_iterator(inputs.data_ptr<float>() + n_edges),
thrust::make_reverse_iterator(outputs.data_ptr<float>() + n_edges),
n_edges);
} else {
exclusive_sum_by_key(
indices.data_ptr<long>(),
indices.data_ptr<int64_t>(),
inputs.data_ptr<float>(),
outputs.data_ptr<float>(),
n_edges);
......@@ -169,7 +169,7 @@ torch::Tensor inclusive_prod_cub_forward(
#if CUB_SUPPORTS_SCAN_BY_KEY()
inclusive_prod_by_key(
indices.data_ptr<long>(),
indices.data_ptr<int64_t>(),
inputs.data_ptr<float>(),
outputs.data_ptr<float>(),
n_edges);
......@@ -203,7 +203,7 @@ torch::Tensor inclusive_prod_cub_backward(
}
#if CUB_SUPPORTS_SCAN_BY_KEY()
inclusive_sum_by_key(
thrust::make_reverse_iterator(indices.data_ptr<long>() + n_edges),
thrust::make_reverse_iterator(indices.data_ptr<int64_t>() + n_edges),
thrust::make_reverse_iterator((grad_outputs * outputs).data_ptr<float>() + n_edges),
thrust::make_reverse_iterator(grad_inputs.data_ptr<float>() + n_edges),
n_edges);
......@@ -237,7 +237,7 @@ torch::Tensor exclusive_prod_cub_forward(
}
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_prod_by_key(
indices.data_ptr<long>(),
indices.data_ptr<int64_t>(),
inputs.data_ptr<float>(),
outputs.data_ptr<float>(),
n_edges);
......@@ -272,7 +272,7 @@ torch::Tensor exclusive_prod_cub_backward(
#if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_sum_by_key(
thrust::make_reverse_iterator(indices.data_ptr<long>() + n_edges),
thrust::make_reverse_iterator(indices.data_ptr<int64_t>() + n_edges),
thrust::make_reverse_iterator((grad_outputs * outputs).data_ptr<float>() + n_edges),
thrust::make_reverse_iterator(grad_inputs.data_ptr<float>() + n_edges),
n_edges);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment