Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
92a70644
"src/git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "d32b88fe7893d4fd05dc5e3400b256c1c8fc37d4"
Commit
92a70644
authored
Mar 01, 2025
by
sangwz
Browse files
宏定义更新,sortKerys更新
parent
83ea9a8d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
3 deletions
+24
-3
graphbolt/src/cuda/neighbor_sampler.hip
graphbolt/src/cuda/neighbor_sampler.hip
+22
-1
src/array/cuda/csr_transpose.cc
src/array/cuda/csr_transpose.cc
+1
-1
src/random/continuous_seed.h
src/random/continuous_seed.h
+1
-1
No files found.
graphbolt/src/cuda/neighbor_sampler.hip
View file @
92a70644
...
...
@@ -210,6 +210,16 @@ struct SegmentEndFunc {
}
};
template <typename indptr_t, typename in_degree_iterator_t>
struct SegmentEndFunc_hip {
indptr_t* indptr;
in_degree_iterator_t in_degree;
indptr_t* segment_end; // 存储段结束位置的设备指针
__host__ __device__ void operator()(int64_t i) {
segment_end[i] = indptr[i] + in_degree[i]; // 直接写入设备指针
}
};
c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::Tensor indptr, torch::Tensor indices,
torch::optional<torch::Tensor> seeds,
...
...
@@ -438,6 +448,17 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
// Ensuring sort result still ends up in
// sorted_edge_id_segments
std::swap(edge_id_segments, sorted_edge_id_segments);
// ************* //
// 分配 segment_end 内存
thrust::device_vector<indptr_t> segment_end(num_rows);
auto segment_end_ptr = segment_end.data().get();
// 计算段结束位置
thrust::for_each(
thrust::make_counting_iterator<int64_t>(0),
thrust::make_counting_iterator<int64_t>(num_rows),
SegmentEndFunc_hip<indptr_t, decltype(sampled_degree)>{
sub_indptr.data_ptr<indptr_t>(), sampled_degree, segment_end_ptr});
// ***************** //
auto sampled_segment_end_it = thrust::make_transform_iterator(
iota,
SegmentEndFunc<indptr_t, decltype(sampled_degree)>{
...
...
@@ -446,7 +467,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
DeviceSegmentedSort::SortKeys, edge_id_segments.get(),
sorted_edge_id_segments.get(), picked_eids.size(0),
num_rows, sub_indptr.data_ptr<indptr_t>(),
s
ub_indptr.data_ptr<i
ndptr
_t>()+1
);
s
egment_e
nd
_
ptr);
// sub_indptr.data_ptr<indptr_t>()+1);
}
...
...
src/array/cuda/csr_transpose.cc
View file @
92a70644
...
...
@@ -24,7 +24,7 @@ CSRMatrix CSRTranspose(CSRMatrix csr) {
template
<
>
CSRMatrix
CSRTranspose
<
kDGLCUDA
,
int32_t
>
(
CSRMatrix
csr
)
{
#if
CUDA
RT_VERSION < 12000
#if
DTK
RT_VERSION < 12000
auto
*
thr_entry
=
runtime
::
CUDAThreadEntry
::
ThreadLocal
();
hipStream_t
stream
=
runtime
::
getCurrentHIPStreamMasqueradingAsCUDA
();
// allocate cusparse handle if needed
...
...
src/random/continuous_seed.h
View file @
92a70644
...
...
@@ -25,7 +25,7 @@
#include <cmath>
#ifdef __HIP
CC
__
#ifdef __HIP
_DEVICE_COMPILE
__
#include <hiprand/hiprand_kernel.h>
#else
#include <random>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment