Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
910cec0c
Commit
910cec0c
authored
Oct 15, 2024
by
sangwzh
Browse files
update device pointer getting while using UVA
parent
f7b4c93d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
22 additions
and
14 deletions
+22
-14
graphbolt/src/cuda/neighbor_sampler.hip
graphbolt/src/cuda/neighbor_sampler.hip
+2
-2
src/array/cuda/array_cumsum.hip
src/array/cuda/array_cumsum.hip
+0
-1
src/array/cuda/labor_sampling.hip
src/array/cuda/labor_sampling.hip
+13
-7
src/array/cuda/rowwise_sampling_prob.hip
src/array/cuda/rowwise_sampling_prob.hip
+7
-4
No files found.
graphbolt/src/cuda/neighbor_sampler.hip
View file @
910cec0c
...
@@ -368,8 +368,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
...
@@ -368,8 +368,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
CUB_CALL(
CUB_CALL(
DeviceSegmentedSort::SortKeys, edge_id_segments.get(),
DeviceSegmentedSort::SortKeys, edge_id_segments.get(),
sorted_edge_id_segments.get(), picked_eids.size(0),
sorted_edge_id_segments.get(), picked_eids.size(0),
num_rows, s
ampled_segment_end_it
,
num_rows, s
ub_indptr.data_ptr<indptr_t>()
,
s
ampled_segment_end_it
);
s
ub_indptr.data_ptr<indptr_t>()+1
);
}
}
auto input_buffer_it = thrust::make_transform_iterator(
auto input_buffer_it = thrust::make_transform_iterator(
...
...
src/array/cuda/array_cumsum.hip
View file @
910cec0c
...
@@ -48,7 +48,6 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
...
@@ -48,7 +48,6 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
workspace, workspace_size, in_d, out_d, len, stream));
workspace, workspace_size, in_d, out_d, len, stream));
device->FreeWorkspace(array->ctx, workspace);
device->FreeWorkspace(array->ctx, workspace);
std::cout << "cuda ret : " << ret << std::endl;
return ret;
return ret;
}
}
...
...
src/array/cuda/labor_sampling.hip
View file @
910cec0c
...
@@ -529,13 +529,19 @@ __host__ std::pair<COOMatrix, FloatArray> CSRLaborSampling(
...
@@ -529,13 +529,19 @@ __host__ std::pair<COOMatrix, FloatArray> CSRLaborSampling(
auto device = runtime::DeviceAPI::Get(ctx);
auto device = runtime::DeviceAPI::Get(ctx);
const IdType num_rows = rows_arr->shape[0];
const IdType num_rows = rows_arr->shape[0];
IdType* const rows = rows_arr.Ptr<IdType>();
// IdType* const rows = rows_arr.Ptr<IdType>();
IdType* const nids = IsNullArray(NIDs) ? nullptr : NIDs.Ptr<IdType>();
IdType* const rows = static_cast<IdType*>(GetDevicePointer(rows_arr));
FloatType* const A = prob_arr.Ptr<FloatType>();
// IdType* const nids = IsNullArray(NIDs) ? nullptr : NIDs.Ptr<IdType>();
IdType* const nids = IsNullArray(NIDs) ? nullptr : static_cast<IdType*>(GetDevicePointer(NIDs));
IdType* const indptr_ = mat.indptr.Ptr<IdType>();
// FloatType* const A = prob_arr.Ptr<FloatType>();
IdType* const indices_ = mat.indices.Ptr<IdType>();
FloatType* const A = static_cast<FloatType*>(GetDevicePointer(prob_arr));;
IdType* const data = CSRHasData(mat) ? mat.data.Ptr<IdType>() : nullptr;
// IdType* const indptr_ = mat.indptr.Ptr<IdType>();
IdType* const indptr_ = static_cast<IdType*>(GetDevicePointer(mat.indptr));
// IdType* const indices_ = mat.indices.Ptr<IdType>();
IdType* const indices_ = static_cast<IdType*>(GetDevicePointer(mat.indices));
// IdType* const data = CSRHasData(mat) ? mat.data.Ptr<IdType>() : nullptr;
IdType* const data = CSRHasData(mat) ? static_cast<IdType*>(GetDevicePointer(mat.data)) : nullptr;
// Read indptr only once in case it is pinned and access is slow.
// Read indptr only once in case it is pinned and access is slow.
auto indptr = allocator.alloc_unique<IdType>(num_rows);
auto indptr = allocator.alloc_unique<IdType>(num_rows);
...
...
src/array/cuda/rowwise_sampling_prob.hip
View file @
910cec0c
...
@@ -402,11 +402,13 @@ COOMatrix COOGeneralRemoveIf(const COOMatrix& coo, MaskGen maskgen) {
...
@@ -402,11 +402,13 @@ COOMatrix COOGeneralRemoveIf(const COOMatrix& coo, MaskGen maskgen) {
const auto idtype = coo.row->dtype;
const auto idtype = coo.row->dtype;
const auto ctx = coo.row->ctx;
const auto ctx = coo.row->ctx;
const int64_t nnz = coo.row->shape[0];
const int64_t nnz = coo.row->shape[0];
const IdType* row = coo.row.Ptr<IdType>();
// const IdType* row = coo.row.Ptr<IdType>();
const IdType* col = coo.col.Ptr<IdType>();
const IdType* row = static_cast<IdType*>(GetDevicePointer(coo.row));
// const IdType* col = coo.col.Ptr<IdType>();
const IdType* col = static_cast<IdType*>(GetDevicePointer(coo.col));
const IdArray& eid =
const IdArray& eid =
COOHasData(coo) ? coo.data : Range(0, nnz, sizeof(IdType) * 8, ctx);
COOHasData(coo) ? coo.data : Range(0, nnz, sizeof(IdType) * 8, ctx);
const IdType* data =
coo.data.Ptr
<IdType>();
const IdType* data =
static_cast
<IdType
*
>(
GetDevicePointer(coo.data)
);
IdArray new_row = IdArray::Empty({nnz}, idtype, ctx);
IdArray new_row = IdArray::Empty({nnz}, idtype, ctx);
IdArray new_col = IdArray::Empty({nnz}, idtype, ctx);
IdArray new_col = IdArray::Empty({nnz}, idtype, ctx);
IdArray new_eid = IdArray::Empty({nnz}, idtype, ctx);
IdArray new_eid = IdArray::Empty({nnz}, idtype, ctx);
...
@@ -441,7 +443,8 @@ COOMatrix COOGeneralRemoveIf(const COOMatrix& coo, MaskGen maskgen) {
...
@@ -441,7 +443,8 @@ COOMatrix COOGeneralRemoveIf(const COOMatrix& coo, MaskGen maskgen) {
template <DGLDeviceType XPU, typename IdType, typename DType>
template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix _COORemoveIf(
COOMatrix _COORemoveIf(
const COOMatrix& coo, const NDArray& values, DType criteria) {
const COOMatrix& coo, const NDArray& values, DType criteria) {
const DType* val = values.Ptr<DType>();
// const DType* val = values.Ptr<DType>();
const DType* val = static_cast<DType*>(GetDevicePointer(values));
auto maskgen = [val, criteria](
auto maskgen = [val, criteria](
int nb, int nt, hipStream_t stream, int64_t nnz,
int nb, int nt, hipStream_t stream, int64_t nnz,
const IdType* data, int8_t* flags) {
const IdType* data, int8_t* flags) {
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment