Commit 8f11ff9b authored by sangwzh's avatar sangwzh
Browse files

update device ptr getting when tensor is pinned

parent 5f463f9b
...@@ -160,7 +160,7 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices( ...@@ -160,7 +160,7 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
_CopyIndicesAlignedKernel, grid, block, 0, _CopyIndicesAlignedKernel, grid, block, 0,
static_cast<indptr_t>(edge_count_aligned_), sliced_indptr, static_cast<indptr_t>(edge_count_aligned_), sliced_indptr,
output_indptr.data_ptr<indptr_t>(), output_indptr_aligned_ptr, output_indptr.data_ptr<indptr_t>(), output_indptr_aligned_ptr,
reinterpret_cast<indices_t*>(indices.data_ptr()), reinterpret_cast<indices_t*>(cuda::getTensorDevicePointer<indptr_t>(indices)),
coo_aligned_rows.data_ptr<coo_rows_t>(), coo_aligned_rows.data_ptr<coo_rows_t>(),
reinterpret_cast<indices_t*>(output_indices.data_ptr()), perm); reinterpret_cast<indices_t*>(output_indices.data_ptr()), perm);
})); }));
...@@ -180,8 +180,10 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl( ...@@ -180,8 +180,10 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl(
return GRAPHBOLT_DISPATCH_ELEMENT_SIZES( return GRAPHBOLT_DISPATCH_ELEMENT_SIZES(
indices.element_size(), "UVAIndexSelectCSCCopyIndices", ([&] { indices.element_size(), "UVAIndexSelectCSCCopyIndices", ([&] {
return UVAIndexSelectCSCCopyIndices<indptr_t, element_size_t>( return UVAIndexSelectCSCCopyIndices<indptr_t, element_size_t>(
indices, num_nodes, in_degree.data_ptr<indptr_t>(), // indices, num_nodes, in_degree.data_ptr<indptr_t>(),
sliced_indptr.data_ptr<indptr_t>(), indices, num_nodes, cuda::getTensorDevicePointer<indptr_t>(in_degree),
// sliced_indptr.data_ptr<indptr_t>(),
cuda::getTensorDevicePointer<indptr_t>(sliced_indptr),
sorted_idx.data_ptr<int64_t>(), nodes.options(), sorted_idx.data_ptr<int64_t>(), nodes.options(),
sliced_indptr.scalar_type(), output_size); sliced_indptr.scalar_type(), output_size);
})); }));
......
...@@ -325,7 +325,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -325,7 +325,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
.data_ptr<probs_t>(); .data_ptr<probs_t>();
} }
const indices_t* indices_ptr = const indices_t* indices_ptr =
layer ? indices.data_ptr<indices_t>() : nullptr; layer ? cuda::getTensorDevicePointer<indices_t>(indices) : nullptr;
const dim3 block(BLOCK_SIZE); const dim3 block(BLOCK_SIZE);
const dim3 grid( const dim3 grid(
(num_edges.value() + BLOCK_SIZE - 1) / (num_edges.value() + BLOCK_SIZE - 1) /
...@@ -334,8 +334,10 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -334,8 +334,10 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
CUDA_KERNEL_CALL( CUDA_KERNEL_CALL(
_ComputeRandoms, grid, block, 0, _ComputeRandoms, grid, block, 0,
num_edges.value(), num_edges.value(),
sliced_indptr.data_ptr<indptr_t>(), // sliced_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>(), cuda::getTensorDevicePointer<indptr_t>(sliced_indptr),
// sub_indptr.data_ptr<indptr_t>(),
cuda::getTensorDevicePointer<indptr_t>(sub_indptr),
coo_rows.data_ptr<indices_t>(), sliced_probs_ptr, coo_rows.data_ptr<indices_t>(), sliced_probs_ptr,
indices_ptr, random_seed, randoms.get(), indices_ptr, random_seed, randoms.get(),
edge_id_segments.get()); edge_id_segments.get());
...@@ -374,13 +376,13 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -374,13 +376,13 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto input_buffer_it = thrust::make_transform_iterator( auto input_buffer_it = thrust::make_transform_iterator(
iota, IteratorFunc<indptr_t, edge_id_t>{ iota, IteratorFunc<indptr_t, edge_id_t>{
sub_indptr.data_ptr<indptr_t>(), cuda::getTensorDevicePointer<indptr_t>(sub_indptr),
sorted_edge_id_segments.get()}); sorted_edge_id_segments.get()});
auto output_buffer_it = thrust::make_transform_iterator( auto output_buffer_it = thrust::make_transform_iterator(
iota, IteratorFuncAddOffset<indptr_t, indptr_t>{ iota, IteratorFuncAddOffset<indptr_t, indptr_t>{
output_indptr.data_ptr<indptr_t>(), cuda::getTensorDevicePointer<indptr_t>(output_indptr),
sliced_indptr.data_ptr<indptr_t>(), cuda::getTensorDevicePointer<indptr_t>(sliced_indptr),
picked_eids.data_ptr<indptr_t>()}); cuda::getTensorDevicePointer<indptr_t>(picked_eids)});
constexpr int64_t max_copy_at_once = constexpr int64_t max_copy_at_once =
std::numeric_limits<int32_t>::max(); std::numeric_limits<int32_t>::max();
...@@ -404,7 +406,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -404,7 +406,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
THRUST_CALL( THRUST_CALL(
gather, picked_eids.data_ptr<indptr_t>(), gather, picked_eids.data_ptr<indptr_t>(),
picked_eids.data_ptr<indptr_t>() + picked_eids.size(0), picked_eids.data_ptr<indptr_t>() + picked_eids.size(0),
indices.data_ptr<indices_t>(), cuda::getTensorDevicePointer<indices_t>(indices),
output_indices.data_ptr<indices_t>()); output_indices.data_ptr<indices_t>());
})); }));
......
...@@ -56,7 +56,8 @@ std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr( ...@@ -56,7 +56,8 @@ std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr(
THRUST_CALL( THRUST_CALL(
for_each, iota, iota + num_nodes, for_each, iota, iota + num_nodes,
SliceFunc<indptr_t, nodes_t>{ SliceFunc<indptr_t, nodes_t>{
nodes.data_ptr<nodes_t>(), indptr.data_ptr<indptr_t>(), // nodes.data_ptr<nodes_t>(), indptr.data_ptr<indptr_t>(),
cuda::getTensorDevicePointer<nodes_t>(nodes), cuda::getTensorDevicePointer<indptr_t>(indptr),
in_degree.data_ptr<indptr_t>(), in_degree.data_ptr<indptr_t>(),
sliced_indptr.data_ptr<indptr_t>()}); sliced_indptr.data_ptr<indptr_t>()});
})); }));
...@@ -72,7 +73,7 @@ std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr( ...@@ -72,7 +73,7 @@ std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr(
using indptr_t = scalar_t; using indptr_t = scalar_t;
CUB_CALL( CUB_CALL(
DeviceAdjacentDifference::SubtractLeftCopy, DeviceAdjacentDifference::SubtractLeftCopy,
indptr.data_ptr<indptr_t>(), in_degree.data_ptr<indptr_t>(), cuda::getTensorDevicePointer<indptr_t>(indptr), in_degree.data_ptr<indptr_t>(),
num_nodes + 1, hipcub::Difference{}); num_nodes + 1, hipcub::Difference{});
})); }));
in_degree = in_degree.slice(0, 1); in_degree = in_degree.slice(0, 1);
......
...@@ -101,6 +101,17 @@ __device__ indices_t UpperBound(const indptr_t* A, indices_t n, indptr_t x) { ...@@ -101,6 +101,17 @@ __device__ indices_t UpperBound(const indptr_t* A, indices_t n, indptr_t x) {
return l; return l;
} }
template<typename DType>
inline DType* getTensorDevicePointer(torch::Tensor inputTensor)
{
DType* ret = inputTensor.data_ptr<DType>();
if(inputTensor.is_pinned())
{
CUDA_CALL(hipHostGetDevicePointer((void**)&ret, (void*)ret, 0));
}
return ret;
}
} // namespace cuda } // namespace cuda
} // namespace graphbolt } // namespace graphbolt
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment