You need to sign in or sign up before continuing.
Unverified Commit 57ccfa14 authored by Skylar Wurster's avatar Skylar Wurster Committed by GitHub
Browse files

Use int64_t instead of long (#301)

parent 8a668783
Pipeline #2966 canceled with stages
...@@ -87,13 +87,13 @@ torch::Tensor inclusive_sum_cub( ...@@ -87,13 +87,13 @@ torch::Tensor inclusive_sum_cub(
#if CUB_SUPPORTS_SCAN_BY_KEY() #if CUB_SUPPORTS_SCAN_BY_KEY()
if (backward) { if (backward) {
inclusive_sum_by_key( inclusive_sum_by_key(
thrust::make_reverse_iterator(indices.data_ptr<long>() + n_edges), thrust::make_reverse_iterator(indices.data_ptr<int64_t>() + n_edges),
thrust::make_reverse_iterator(inputs.data_ptr<float>() + n_edges), thrust::make_reverse_iterator(inputs.data_ptr<float>() + n_edges),
thrust::make_reverse_iterator(outputs.data_ptr<float>() + n_edges), thrust::make_reverse_iterator(outputs.data_ptr<float>() + n_edges),
n_edges); n_edges);
} else { } else {
inclusive_sum_by_key( inclusive_sum_by_key(
indices.data_ptr<long>(), indices.data_ptr<int64_t>(),
inputs.data_ptr<float>(), inputs.data_ptr<float>(),
outputs.data_ptr<float>(), outputs.data_ptr<float>(),
n_edges); n_edges);
...@@ -129,13 +129,13 @@ torch::Tensor exclusive_sum_cub( ...@@ -129,13 +129,13 @@ torch::Tensor exclusive_sum_cub(
#if CUB_SUPPORTS_SCAN_BY_KEY() #if CUB_SUPPORTS_SCAN_BY_KEY()
if (backward) { if (backward) {
exclusive_sum_by_key( exclusive_sum_by_key(
thrust::make_reverse_iterator(indices.data_ptr<long>() + n_edges), thrust::make_reverse_iterator(indices.data_ptr<int64_t>() + n_edges),
thrust::make_reverse_iterator(inputs.data_ptr<float>() + n_edges), thrust::make_reverse_iterator(inputs.data_ptr<float>() + n_edges),
thrust::make_reverse_iterator(outputs.data_ptr<float>() + n_edges), thrust::make_reverse_iterator(outputs.data_ptr<float>() + n_edges),
n_edges); n_edges);
} else { } else {
exclusive_sum_by_key( exclusive_sum_by_key(
indices.data_ptr<long>(), indices.data_ptr<int64_t>(),
inputs.data_ptr<float>(), inputs.data_ptr<float>(),
outputs.data_ptr<float>(), outputs.data_ptr<float>(),
n_edges); n_edges);
...@@ -169,7 +169,7 @@ torch::Tensor inclusive_prod_cub_forward( ...@@ -169,7 +169,7 @@ torch::Tensor inclusive_prod_cub_forward(
#if CUB_SUPPORTS_SCAN_BY_KEY() #if CUB_SUPPORTS_SCAN_BY_KEY()
inclusive_prod_by_key( inclusive_prod_by_key(
indices.data_ptr<long>(), indices.data_ptr<int64_t>(),
inputs.data_ptr<float>(), inputs.data_ptr<float>(),
outputs.data_ptr<float>(), outputs.data_ptr<float>(),
n_edges); n_edges);
...@@ -203,7 +203,7 @@ torch::Tensor inclusive_prod_cub_backward( ...@@ -203,7 +203,7 @@ torch::Tensor inclusive_prod_cub_backward(
} }
#if CUB_SUPPORTS_SCAN_BY_KEY() #if CUB_SUPPORTS_SCAN_BY_KEY()
inclusive_sum_by_key( inclusive_sum_by_key(
thrust::make_reverse_iterator(indices.data_ptr<long>() + n_edges), thrust::make_reverse_iterator(indices.data_ptr<int64_t>() + n_edges),
thrust::make_reverse_iterator((grad_outputs * outputs).data_ptr<float>() + n_edges), thrust::make_reverse_iterator((grad_outputs * outputs).data_ptr<float>() + n_edges),
thrust::make_reverse_iterator(grad_inputs.data_ptr<float>() + n_edges), thrust::make_reverse_iterator(grad_inputs.data_ptr<float>() + n_edges),
n_edges); n_edges);
...@@ -237,7 +237,7 @@ torch::Tensor exclusive_prod_cub_forward( ...@@ -237,7 +237,7 @@ torch::Tensor exclusive_prod_cub_forward(
} }
#if CUB_SUPPORTS_SCAN_BY_KEY() #if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_prod_by_key( exclusive_prod_by_key(
indices.data_ptr<long>(), indices.data_ptr<int64_t>(),
inputs.data_ptr<float>(), inputs.data_ptr<float>(),
outputs.data_ptr<float>(), outputs.data_ptr<float>(),
n_edges); n_edges);
...@@ -272,7 +272,7 @@ torch::Tensor exclusive_prod_cub_backward( ...@@ -272,7 +272,7 @@ torch::Tensor exclusive_prod_cub_backward(
#if CUB_SUPPORTS_SCAN_BY_KEY() #if CUB_SUPPORTS_SCAN_BY_KEY()
exclusive_sum_by_key( exclusive_sum_by_key(
thrust::make_reverse_iterator(indices.data_ptr<long>() + n_edges), thrust::make_reverse_iterator(indices.data_ptr<int64_t>() + n_edges),
thrust::make_reverse_iterator((grad_outputs * outputs).data_ptr<float>() + n_edges), thrust::make_reverse_iterator((grad_outputs * outputs).data_ptr<float>() + n_edges),
thrust::make_reverse_iterator(grad_inputs.data_ptr<float>() + n_edges), thrust::make_reverse_iterator(grad_inputs.data_ptr<float>() + n_edges),
n_edges); n_edges);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment