"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "969e2af866045417dccbc3980422c80d9736d970"
Commit 86bb5e1a authored by sangwzh's avatar sangwzh
Browse files

fix error tests/examples/test_sampling_examples.py by using...

fix error tests/examples/test_sampling_examples.py by using getTensorDevicePointer to get the real device pointer.
parent 83d2fa9d
......@@ -8,6 +8,7 @@
#include <thrust/gather.h>
#include "./common.h"
#include "./utils.h"
namespace graphbolt {
namespace ops {
......@@ -25,9 +26,9 @@ torch::Tensor Gather(
AT_DISPATCH_INTEGRAL_TYPES(*dtype, "GatherOutputType", ([&] {
using output_t = scalar_t;
THRUST_CALL(
gather, index.data_ptr<index_t>(),
index.data_ptr<index_t>() + index.size(0),
input.data_ptr<input_t>(), output.data_ptr<output_t>());
gather, cuda::getTensorDevicePointer<index_t>(index),
cuda::getTensorDevicePointer<index_t>(index) + index.size(0),
cuda::getTensorDevicePointer<input_t>(input), output.data_ptr<output_t>());
}));
}));
}));
......
......@@ -509,10 +509,12 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
// Compute row and random number pairs.
CUDA_KERNEL_CALL(
_ComputeRandomsNS, grid, block, 0, num_edges.value(),
sliced_indptr_ptr, sub_indptr.data_ptr<indptr_t>(),
output_indptr.data_ptr<indptr_t>(),
coo_rows.data_ptr<indices_t>(), random_seed.get_seed(0),
picked_eids.data_ptr<indptr_t>());
// sliced_indptr_ptr, sub_indptr.data_ptr<indptr_t>(),
sliced_indptr_ptr, cuda::getTensorDevicePointer<indptr_t>(sub_indptr),
// output_indptr.data_ptr<indptr_t>(),
cuda::getTensorDevicePointer<indptr_t>(output_indptr),
cuda::getTensorDevicePointer<indices_t>(coo_rows), random_seed.get_seed(0),
cuda::getTensorDevicePointer<indptr_t>(picked_eids));
}));
picked_eids =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment