"git@developer.sourcefind.cn:OpenDAS/bitsandbytes.git" did not exist on "d870f9c5af13b6ef72a881f4a36dd14bd2305f0b"
Commit 86bb5e1a authored by sangwzh's avatar sangwzh
Browse files

fix error tests/examples/test_sampling_examples.py by using...

fix error tests/examples/test_sampling_examples.py by using getTensorDevicePointer to get the real device pointer.
parent 83d2fa9d
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <thrust/gather.h> #include <thrust/gather.h>
#include "./common.h" #include "./common.h"
#include "./utils.h"
namespace graphbolt { namespace graphbolt {
namespace ops { namespace ops {
...@@ -25,9 +26,9 @@ torch::Tensor Gather( ...@@ -25,9 +26,9 @@ torch::Tensor Gather(
AT_DISPATCH_INTEGRAL_TYPES(*dtype, "GatherOutputType", ([&] { AT_DISPATCH_INTEGRAL_TYPES(*dtype, "GatherOutputType", ([&] {
using output_t = scalar_t; using output_t = scalar_t;
THRUST_CALL( THRUST_CALL(
gather, index.data_ptr<index_t>(), gather, cuda::getTensorDevicePointer<index_t>(index),
index.data_ptr<index_t>() + index.size(0), cuda::getTensorDevicePointer<index_t>(index) + index.size(0),
input.data_ptr<input_t>(), output.data_ptr<output_t>()); cuda::getTensorDevicePointer<input_t>(input), output.data_ptr<output_t>());
})); }));
})); }));
})); }));
......
...@@ -509,10 +509,12 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -509,10 +509,12 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
// Compute row and random number pairs. // Compute row and random number pairs.
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
_ComputeRandomsNS, grid, block, 0, num_edges.value(), _ComputeRandomsNS, grid, block, 0, num_edges.value(),
sliced_indptr_ptr, sub_indptr.data_ptr<indptr_t>(), // sliced_indptr_ptr, sub_indptr.data_ptr<indptr_t>(),
output_indptr.data_ptr<indptr_t>(), sliced_indptr_ptr, cuda::getTensorDevicePointer<indptr_t>(sub_indptr),
coo_rows.data_ptr<indices_t>(), random_seed.get_seed(0), // output_indptr.data_ptr<indptr_t>(),
picked_eids.data_ptr<indptr_t>()); cuda::getTensorDevicePointer<indptr_t>(output_indptr),
cuda::getTensorDevicePointer<indices_t>(coo_rows), random_seed.get_seed(0),
cuda::getTensorDevicePointer<indptr_t>(picked_eids));
})); }));
picked_eids = picked_eids =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment