Commit 92a70644 authored by sangwz's avatar sangwz
Browse files

宏定义更新,sortKerys更新

parent 83ea9a8d
...@@ -210,6 +210,16 @@ struct SegmentEndFunc { ...@@ -210,6 +210,16 @@ struct SegmentEndFunc {
} }
}; };
template <typename indptr_t, typename in_degree_iterator_t>
struct SegmentEndFunc_hip {
indptr_t* indptr;
in_degree_iterator_t in_degree;
indptr_t* segment_end; // 存储段结束位置的设备指针
__host__ __device__ void operator()(int64_t i) {
segment_end[i] = indptr[i] + in_degree[i]; // 直接写入设备指针
}
};
c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor indptr, torch::Tensor indices,
torch::optional<torch::Tensor> seeds, torch::optional<torch::Tensor> seeds,
...@@ -438,6 +448,17 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -438,6 +448,17 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
// Ensuring sort result still ends up in // Ensuring sort result still ends up in
// sorted_edge_id_segments // sorted_edge_id_segments
std::swap(edge_id_segments, sorted_edge_id_segments); std::swap(edge_id_segments, sorted_edge_id_segments);
// ************* //
// 分配 segment_end 内存
thrust::device_vector<indptr_t> segment_end(num_rows);
auto segment_end_ptr = segment_end.data().get();
// 计算段结束位置
thrust::for_each(
thrust::make_counting_iterator<int64_t>(0),
thrust::make_counting_iterator<int64_t>(num_rows),
SegmentEndFunc_hip<indptr_t, decltype(sampled_degree)>{
sub_indptr.data_ptr<indptr_t>(), sampled_degree, segment_end_ptr});
// ***************** //
auto sampled_segment_end_it = thrust::make_transform_iterator( auto sampled_segment_end_it = thrust::make_transform_iterator(
iota, iota,
SegmentEndFunc<indptr_t, decltype(sampled_degree)>{ SegmentEndFunc<indptr_t, decltype(sampled_degree)>{
...@@ -446,7 +467,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -446,7 +467,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
DeviceSegmentedSort::SortKeys, edge_id_segments.get(), DeviceSegmentedSort::SortKeys, edge_id_segments.get(),
sorted_edge_id_segments.get(), picked_eids.size(0), sorted_edge_id_segments.get(), picked_eids.size(0),
num_rows, sub_indptr.data_ptr<indptr_t>(), num_rows, sub_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>()+1); segment_end_ptr);
// sub_indptr.data_ptr<indptr_t>()+1); // sub_indptr.data_ptr<indptr_t>()+1);
} }
......
...@@ -24,7 +24,7 @@ CSRMatrix CSRTranspose(CSRMatrix csr) { ...@@ -24,7 +24,7 @@ CSRMatrix CSRTranspose(CSRMatrix csr) {
template <> template <>
CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr) { CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr) {
#if CUDART_VERSION < 12000 #if DTKRT_VERSION < 12000
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal(); auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA(); hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
// allocate cusparse handle if needed // allocate cusparse handle if needed
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#include <cmath> #include <cmath>
#ifdef __HIPCC__ #ifdef __HIP_DEVICE_COMPILE__
#include <hiprand/hiprand_kernel.h> #include <hiprand/hiprand_kernel.h>
#else #else
#include <random> #include <random>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment