Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
8f11ff9b
Commit
8f11ff9b
authored
Oct 19, 2024
by
sangwzh
Browse files
update device ptr getting when tensor is pinned
parent
5f463f9b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
29 additions
and
13 deletions
+29
-13
graphbolt/src/cuda/index_select_csc_impl.hip
graphbolt/src/cuda/index_select_csc_impl.hip
+5
-3
graphbolt/src/cuda/neighbor_sampler.hip
graphbolt/src/cuda/neighbor_sampler.hip
+10
-8
graphbolt/src/cuda/sampling_utils.hip
graphbolt/src/cuda/sampling_utils.hip
+3
-2
graphbolt/src/cuda/utils.h
graphbolt/src/cuda/utils.h
+11
-0
No files found.
graphbolt/src/cuda/index_select_csc_impl.hip
View file @
8f11ff9b
...
...
@@ -160,7 +160,7 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
_CopyIndicesAlignedKernel, grid, block, 0,
static_cast<indptr_t>(edge_count_aligned_), sliced_indptr,
output_indptr.data_ptr<indptr_t>(), output_indptr_aligned_ptr,
reinterpret_cast<indices_t*>(
indices.data_ptr(
)),
reinterpret_cast<indices_t*>(
cuda::getTensorDevicePointer<indptr_t>(indices
)),
coo_aligned_rows.data_ptr<coo_rows_t>(),
reinterpret_cast<indices_t*>(output_indices.data_ptr()), perm);
}));
...
...
@@ -180,8 +180,10 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl(
return GRAPHBOLT_DISPATCH_ELEMENT_SIZES(
indices.element_size(), "UVAIndexSelectCSCCopyIndices", ([&] {
return UVAIndexSelectCSCCopyIndices<indptr_t, element_size_t>(
indices, num_nodes, in_degree.data_ptr<indptr_t>(),
sliced_indptr.data_ptr<indptr_t>(),
// indices, num_nodes, in_degree.data_ptr<indptr_t>(),
indices, num_nodes, cuda::getTensorDevicePointer<indptr_t>(in_degree),
// sliced_indptr.data_ptr<indptr_t>(),
cuda::getTensorDevicePointer<indptr_t>(sliced_indptr),
sorted_idx.data_ptr<int64_t>(), nodes.options(),
sliced_indptr.scalar_type(), output_size);
}));
...
...
graphbolt/src/cuda/neighbor_sampler.hip
View file @
8f11ff9b
...
...
@@ -325,7 +325,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
.data_ptr<probs_t>();
}
const indices_t* indices_ptr =
layer ?
indices.data_pt
r<indices_t>() : nullptr;
layer ?
cuda::getTensorDevicePointe
r<indices_t>(
indices
) : nullptr;
const dim3 block(BLOCK_SIZE);
const dim3 grid(
(num_edges.value() + BLOCK_SIZE - 1) /
...
...
@@ -334,8 +334,10 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
CUDA_KERNEL_CALL(
_ComputeRandoms, grid, block, 0,
num_edges.value(),
sliced_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>(),
// sliced_indptr.data_ptr<indptr_t>(),
cuda::getTensorDevicePointer<indptr_t>(sliced_indptr),
// sub_indptr.data_ptr<indptr_t>(),
cuda::getTensorDevicePointer<indptr_t>(sub_indptr),
coo_rows.data_ptr<indices_t>(), sliced_probs_ptr,
indices_ptr, random_seed, randoms.get(),
edge_id_segments.get());
...
...
@@ -374,13 +376,13 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto input_buffer_it = thrust::make_transform_iterator(
iota, IteratorFunc<indptr_t, edge_id_t>{
sub_indptr.data_pt
r<indptr_t>(),
cuda::getTensorDevicePointe
r<indptr_t>(
sub_indptr
),
sorted_edge_id_segments.get()});
auto output_buffer_it = thrust::make_transform_iterator(
iota, IteratorFuncAddOffset<indptr_t, indptr_t>{
output_indptr.data_ptr<
indptr
_t>(
),
sliced_indptr.data_ptr<indptr_t>(
),
picked_eids.data_pt
r<indptr_t>()});
cuda::getTensorDevicePointer<indptr_t>(output_
indptr),
cuda::getTensorDevicePointer<indptr_t>(sliced_indptr
),
cuda::getTensorDevicePointe
r<indptr_t>(
picked_eids
)});
constexpr int64_t max_copy_at_once =
std::numeric_limits<int32_t>::max();
...
...
@@ -404,7 +406,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
THRUST_CALL(
gather, picked_eids.data_ptr<indptr_t>(),
picked_eids.data_ptr<indptr_t>() + picked_eids.size(0),
indices.data_pt
r<indices_t>(),
cuda::getTensorDevicePointe
r<indices_t>(
indices
),
output_indices.data_ptr<indices_t>());
}));
...
...
graphbolt/src/cuda/sampling_utils.hip
View file @
8f11ff9b
...
...
@@ -56,7 +56,8 @@ std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr(
THRUST_CALL(
for_each, iota, iota + num_nodes,
SliceFunc<indptr_t, nodes_t>{
nodes.data_ptr<nodes_t>(), indptr.data_ptr<indptr_t>(),
// nodes.data_ptr<nodes_t>(), indptr.data_ptr<indptr_t>(),
cuda::getTensorDevicePointer<nodes_t>(nodes), cuda::getTensorDevicePointer<indptr_t>(indptr),
in_degree.data_ptr<indptr_t>(),
sliced_indptr.data_ptr<indptr_t>()});
}));
...
...
@@ -72,7 +73,7 @@ std::tuple<torch::Tensor, torch::Tensor> SliceCSCIndptr(
using indptr_t = scalar_t;
CUB_CALL(
DeviceAdjacentDifference::SubtractLeftCopy,
indptr.data_pt
r<indptr_t>(), in_degree.data_ptr<indptr_t>(),
cuda::getTensorDevicePointe
r<indptr_t>(
indptr
), in_degree.data_ptr<indptr_t>(),
num_nodes + 1, hipcub::Difference{});
}));
in_degree = in_degree.slice(0, 1);
...
...
graphbolt/src/cuda/utils.h
View file @
8f11ff9b
...
...
@@ -101,6 +101,17 @@ __device__ indices_t UpperBound(const indptr_t* A, indices_t n, indptr_t x) {
return
l
;
}
template
<
typename
DType
>
inline
DType
*
getTensorDevicePointer
(
torch
::
Tensor
inputTensor
)
{
DType
*
ret
=
inputTensor
.
data_ptr
<
DType
>
();
if
(
inputTensor
.
is_pinned
())
{
CUDA_CALL
(
hipHostGetDevicePointer
((
void
**
)
&
ret
,
(
void
*
)
ret
,
0
));
}
return
ret
;
}
}
// namespace cuda
}
// namespace graphbolt
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment