Commit 910cec0c authored by sangwzh's avatar sangwzh
Browse files

update device pointer getting while using UVA

parent f7b4c93d
...@@ -368,8 +368,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors( ...@@ -368,8 +368,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
CUB_CALL( CUB_CALL(
DeviceSegmentedSort::SortKeys, edge_id_segments.get(), DeviceSegmentedSort::SortKeys, edge_id_segments.get(),
sorted_edge_id_segments.get(), picked_eids.size(0), sorted_edge_id_segments.get(), picked_eids.size(0),
num_rows, sampled_segment_end_it, num_rows, sub_indptr.data_ptr<indptr_t>(),
sampled_segment_end_it); sub_indptr.data_ptr<indptr_t>()+1);
} }
auto input_buffer_it = thrust::make_transform_iterator( auto input_buffer_it = thrust::make_transform_iterator(
......
...@@ -48,7 +48,6 @@ IdArray CumSum(IdArray array, bool prepend_zero) { ...@@ -48,7 +48,6 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
workspace, workspace_size, in_d, out_d, len, stream)); workspace, workspace_size, in_d, out_d, len, stream));
device->FreeWorkspace(array->ctx, workspace); device->FreeWorkspace(array->ctx, workspace);
std::cout << "cuda ret : " << ret << std::endl;
return ret; return ret;
} }
......
...@@ -529,13 +529,19 @@ __host__ std::pair<COOMatrix, FloatArray> CSRLaborSampling( ...@@ -529,13 +529,19 @@ __host__ std::pair<COOMatrix, FloatArray> CSRLaborSampling(
auto device = runtime::DeviceAPI::Get(ctx); auto device = runtime::DeviceAPI::Get(ctx);
const IdType num_rows = rows_arr->shape[0]; const IdType num_rows = rows_arr->shape[0];
IdType* const rows = rows_arr.Ptr<IdType>(); // IdType* const rows = rows_arr.Ptr<IdType>();
IdType* const nids = IsNullArray(NIDs) ? nullptr : NIDs.Ptr<IdType>(); IdType* const rows = static_cast<IdType*>(GetDevicePointer(rows_arr));
FloatType* const A = prob_arr.Ptr<FloatType>(); // IdType* const nids = IsNullArray(NIDs) ? nullptr : NIDs.Ptr<IdType>();
IdType* const nids = IsNullArray(NIDs) ? nullptr : static_cast<IdType*>(GetDevicePointer(NIDs));
IdType* const indptr_ = mat.indptr.Ptr<IdType>(); // FloatType* const A = prob_arr.Ptr<FloatType>();
IdType* const indices_ = mat.indices.Ptr<IdType>(); FloatType* const A = static_cast<FloatType*>(GetDevicePointer(prob_arr));;
IdType* const data = CSRHasData(mat) ? mat.data.Ptr<IdType>() : nullptr;
// IdType* const indptr_ = mat.indptr.Ptr<IdType>();
IdType* const indptr_ = static_cast<IdType*>(GetDevicePointer(mat.indptr));
// IdType* const indices_ = mat.indices.Ptr<IdType>();
IdType* const indices_ = static_cast<IdType*>(GetDevicePointer(mat.indices));
// IdType* const data = CSRHasData(mat) ? mat.data.Ptr<IdType>() : nullptr;
IdType* const data = CSRHasData(mat) ? static_cast<IdType*>(GetDevicePointer(mat.data)) : nullptr;
// Read indptr only once in case it is pinned and access is slow. // Read indptr only once in case it is pinned and access is slow.
auto indptr = allocator.alloc_unique<IdType>(num_rows); auto indptr = allocator.alloc_unique<IdType>(num_rows);
......
...@@ -402,11 +402,13 @@ COOMatrix COOGeneralRemoveIf(const COOMatrix& coo, MaskGen maskgen) { ...@@ -402,11 +402,13 @@ COOMatrix COOGeneralRemoveIf(const COOMatrix& coo, MaskGen maskgen) {
const auto idtype = coo.row->dtype; const auto idtype = coo.row->dtype;
const auto ctx = coo.row->ctx; const auto ctx = coo.row->ctx;
const int64_t nnz = coo.row->shape[0]; const int64_t nnz = coo.row->shape[0];
const IdType* row = coo.row.Ptr<IdType>(); // const IdType* row = coo.row.Ptr<IdType>();
const IdType* col = coo.col.Ptr<IdType>(); const IdType* row = static_cast<IdType*>(GetDevicePointer(coo.row));
// const IdType* col = coo.col.Ptr<IdType>();
const IdType* col = static_cast<IdType*>(GetDevicePointer(coo.col));
const IdArray& eid = const IdArray& eid =
COOHasData(coo) ? coo.data : Range(0, nnz, sizeof(IdType) * 8, ctx); COOHasData(coo) ? coo.data : Range(0, nnz, sizeof(IdType) * 8, ctx);
const IdType* data = coo.data.Ptr<IdType>(); const IdType* data = static_cast<IdType*>(GetDevicePointer(coo.data));
IdArray new_row = IdArray::Empty({nnz}, idtype, ctx); IdArray new_row = IdArray::Empty({nnz}, idtype, ctx);
IdArray new_col = IdArray::Empty({nnz}, idtype, ctx); IdArray new_col = IdArray::Empty({nnz}, idtype, ctx);
IdArray new_eid = IdArray::Empty({nnz}, idtype, ctx); IdArray new_eid = IdArray::Empty({nnz}, idtype, ctx);
...@@ -441,7 +443,8 @@ COOMatrix COOGeneralRemoveIf(const COOMatrix& coo, MaskGen maskgen) { ...@@ -441,7 +443,8 @@ COOMatrix COOGeneralRemoveIf(const COOMatrix& coo, MaskGen maskgen) {
template <DGLDeviceType XPU, typename IdType, typename DType> template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix _COORemoveIf( COOMatrix _COORemoveIf(
const COOMatrix& coo, const NDArray& values, DType criteria) { const COOMatrix& coo, const NDArray& values, DType criteria) {
const DType* val = values.Ptr<DType>(); // const DType* val = values.Ptr<DType>();
const DType* val = static_cast<DType*>(GetDevicePointer(values));
auto maskgen = [val, criteria]( auto maskgen = [val, criteria](
int nb, int nt, hipStream_t stream, int64_t nnz, int nb, int nt, hipStream_t stream, int64_t nnz,
const IdType* data, int8_t* flags) { const IdType* data, int8_t* flags) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment