"src/git@developer.sourcefind.cn:one/TransferBench.git" did not exist on "d32b88fe7893d4fd05dc5e3400b256c1c8fc37d4"
Commit 92a70644 authored by sangwz's avatar sangwz
Browse files

宏定义更新,sortKerys更新

parent 83ea9a8d
......@@ -210,6 +210,16 @@ struct SegmentEndFunc {
}
};
template <typename indptr_t, typename in_degree_iterator_t>
struct SegmentEndFunc_hip {
indptr_t* indptr;
in_degree_iterator_t in_degree;
indptr_t* segment_end; // 存储段结束位置的设备指针
__host__ __device__ void operator()(int64_t i) {
segment_end[i] = indptr[i] + in_degree[i]; // 直接写入设备指针
}
};
c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
torch::Tensor indptr, torch::Tensor indices,
torch::optional<torch::Tensor> seeds,
......@@ -438,6 +448,17 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
// Ensuring sort result still ends up in
// sorted_edge_id_segments
std::swap(edge_id_segments, sorted_edge_id_segments);
// ************* //
// 分配 segment_end 内存
thrust::device_vector<indptr_t> segment_end(num_rows);
auto segment_end_ptr = segment_end.data().get();
// 计算段结束位置
thrust::for_each(
thrust::make_counting_iterator<int64_t>(0),
thrust::make_counting_iterator<int64_t>(num_rows),
SegmentEndFunc_hip<indptr_t, decltype(sampled_degree)>{
sub_indptr.data_ptr<indptr_t>(), sampled_degree, segment_end_ptr});
// ***************** //
auto sampled_segment_end_it = thrust::make_transform_iterator(
iota,
SegmentEndFunc<indptr_t, decltype(sampled_degree)>{
......@@ -446,7 +467,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
DeviceSegmentedSort::SortKeys, edge_id_segments.get(),
sorted_edge_id_segments.get(), picked_eids.size(0),
num_rows, sub_indptr.data_ptr<indptr_t>(),
sub_indptr.data_ptr<indptr_t>()+1);
segment_end_ptr);
// sub_indptr.data_ptr<indptr_t>()+1);
}
......
......@@ -24,7 +24,7 @@ CSRMatrix CSRTranspose(CSRMatrix csr) {
template <>
CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr) {
#if CUDART_VERSION < 12000
#if DTKRT_VERSION < 12000
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
hipStream_t stream = runtime::getCurrentHIPStreamMasqueradingAsCUDA();
// allocate cusparse handle if needed
......
......@@ -25,7 +25,7 @@
#include <cmath>
#ifdef __HIPCC__
#ifdef __HIP_DEVICE_COMPILE__
#include <hiprand/hiprand_kernel.h>
#else
#include <random>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment