Commit 8f11ff9b authored by sangwzh's avatar sangwzh
Browse files

update device ptr getting when tensor is pinned

parent 5f463f9b
......@@ -160,7 +160,7 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
_CopyIndicesAlignedKernel, grid, block, 0,
static_cast<indptr_t>(edge_count_aligned_), sliced_indptr,
output_indptr.data_ptr<indptr_t>(), output_indptr_aligned_ptr,
reinterpret_cast<indices_t*>(indices.data_ptr()),
reinterpret_cast<indices_t*>(cuda::getTensorDevicePointer<indptr_t>(indices)),
coo_aligned_rows.data_ptr<coo_rows_t>(),
reinterpret_cast<indices_t*>(output_indices.data_ptr()), perm);
}));
......@@ -180,8 +180,10 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl(
return GRAPHBOLT_DISPATCH_ELEMENT_SIZES(
indices.element_size(), "UVAIndexSelectCSCCopyIndices", ([&] {
return UVAIndexSelectCSCCopyIndices<indptr_t, element_size_t>(
indices, num_nodes, in_degree.data_ptr<indptr_t>(),
sliced_indptr.data_ptr<indptr_t>(),
// indices, num_nodes, in_degree.data_ptr<indptr_t>(),
indices, num_nodes, cuda::getTensorDevicePointer<indptr_t>(in_degree),
// sliced_indptr.data_ptr<indptr_t>(),
cuda::getTensorDevicePointer<indptr_t>(sliced_indptr),
sorted_idx.data_ptr<int64_t>(), nodes.options(),
sliced_indptr.scalar_type(), output_size);
}));
......
......@@ -325,7 +325,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
.data_ptr<probs_t>();
}
const indices_t* indices_ptr =
layer ? indices.data_ptr<indices_t>() : nullptr;
layer ? cuda::getTensorDevicePointer<indices_t>(indices) : nullptr;
const dim3 block(BLOCK_SIZE);
const dim3 grid(
(num_edges.value() + BLOCK_SIZE - 1) /
......@@ -334,8 +334,10 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
CUDA_KERNEL_CALL(
_ComputeRandoms, grid, block, 0,
num_edges.value(),
sliced_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>(),
// sliced_indptr.data_ptr<indptr_t>(),
cuda::getTensorDevicePointer<indptr_t>(sliced_indptr),
// sub_indptr.data_ptr<indptr_t>(),
cuda::getTensorDevicePointer<indptr_t>(sub_indptr),
coo_rows.data_ptr<indices_t>(), sliced_probs_ptr,
indices_ptr, random_seed, randoms.get(),
edge_id_segments.get());
......@@ -374,13 +376,13 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto input_buffer_it = thrust::make_transform_iterator(
iota, IteratorFunc<indptr_t, edge_id_t>{
sub_indptr.data_ptr<indptr_t>(),
cuda::getTensorDevicePointer<indptr_t>(sub_indptr),
sorted_edge_id_segments.get()});
auto output_buffer_it = thrust::make_transform_iterator(
iota, IteratorFuncAddOffset<indptr_t, indptr_t>{
output_indptr.data_ptr<indptr_t>(),
sliced_indptr.data_ptr<indptr_t>(),
picked_eids.data_ptr<indptr_t>()});
cuda::getTensorDevicePointer<indptr_t>(output_indptr),
cuda::getTensorDevicePointer<indptr_t>(sliced_indptr),
cuda::getTensorDevicePointer<indptr_t>(picked_eids)});
constexpr int64_t max_copy_at_once =
std::numeric_limits<int32_t>::max();
......@@ -404,7 +406,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
THRUST_CALL(
gather, picked_eids.data_ptr<indptr_t>(),
picked_eids.data_ptr<indptr_t>() + picked_eids.size(0),
indices.data_ptr<indices_t>(),
cuda::getTensorDevicePointer<indices_t>(indices),
output_indices.data_ptr<indices_t>());
}));
......
......@@ -56,7 +56,8 @@ std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr(
THRUST_CALL(
for_each, iota, iota + num_nodes,
SliceFunc<indptr_t, nodes_t>{
nodes.data_ptr<nodes_t>(), indptr.data_ptr<indptr_t>(),
// nodes.data_ptr<nodes_t>(), indptr.data_ptr<indptr_t>(),
cuda::getTensorDevicePointer<nodes_t>(nodes), cuda::getTensorDevicePointer<indptr_t>(indptr),
in_degree.data_ptr<indptr_t>(),
sliced_indptr.data_ptr<indptr_t>()});
}));
......@@ -72,7 +73,7 @@ std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr(
using indptr_t = scalar_t;
CUB_CALL(
DeviceAdjacentDifference::SubtractLeftCopy,
indptr.data_ptr<indptr_t>(), in_degree.data_ptr<indptr_t>(),
cuda::getTensorDevicePointer<indptr_t>(indptr), in_degree.data_ptr<indptr_t>(),
num_nodes + 1, hipcub::Difference{});
}));
in_degree = in_degree.slice(0, 1);
......
......@@ -101,6 +101,17 @@ __device__ indices_t UpperBound(const indptr_t* A, indices_t n, indptr_t x) {
return l;
}
template<typename DType>
inline DType* getTensorDevicePointer(torch::Tensor inputTensor)
{
DType* ret = inputTensor.data_ptr<DType>();
if(inputTensor.is_pinned())
{
CUDA_CALL(hipHostGetDevicePointer((void**)&ret, (void*)ret, 0));
}
return ret;
}
} // namespace cuda
} // namespace graphbolt
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment