Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
a259767b
Unverified
Commit
a259767b
authored
Dec 25, 2023
by
Muhammed Fatih BALIN
Committed by
GitHub
Dec 25, 2023
Browse files
[GraphBolt][CUDA] Use Current Stream and introduce `CopyScalar` class (#6796)
parent
4d8ae71b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
61 additions
and
28 deletions
+61
-28
graphbolt/src/cuda/common.h
graphbolt/src/cuda/common.h
+39
-0
graphbolt/src/cuda/index_select_csc_impl.cu
graphbolt/src/cuda/index_select_csc_impl.cu
+20
-26
graphbolt/src/cuda/index_select_impl.cu
graphbolt/src/cuda/index_select_impl.cu
+1
-1
graphbolt/src/cuda/sort_impl.cu
graphbolt/src/cuda/sort_impl.cu
+1
-1
No files found.
graphbolt/src/cuda/common.h
View file @
a259767b
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#ifndef GRAPHBOLT_CUDA_COMMON_H_
#ifndef GRAPHBOLT_CUDA_COMMON_H_
#define GRAPHBOLT_CUDA_COMMON_H_
#define GRAPHBOLT_CUDA_COMMON_H_
#include <ATen/cuda/CUDAEvent.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAException.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
...
@@ -90,6 +91,44 @@ inline bool is_zero<dim3>(dim3 size) {
...
@@ -90,6 +91,44 @@ inline bool is_zero<dim3>(dim3 size) {
} \
} \
}
}
/**
* @brief This class is designed to handle the copy operation of a single
* scalar_t item from a given CUDA device pointer. Later, if the object is cast
* into scalar_t, the value can be read.
*
* auto num_edges = cuda::CopyScalar(indptr.data_ptr<scalar_t>() +
* indptr.size(0) - 1);
* // Perform many operations here, they will run as normal.
* // We finally need to read num_edges.
* auto indices = torch::empty(static_cast<scalar_t>(num_edges));
*/
template
<
typename
scalar_t
>
struct
CopyScalar
{
CopyScalar
(
const
scalar_t
*
device_ptr
)
:
is_ready_
(
false
)
{
pinned_scalar_
=
torch
::
empty
(
sizeof
(
scalar_t
),
c10
::
TensorOptions
().
dtype
(
torch
::
kBool
).
pinned_memory
(
true
));
auto
stream
=
GetCurrentStream
();
CUDA_CALL
(
cudaMemcpyAsync
(
reinterpret_cast
<
scalar_t
*>
(
pinned_scalar_
.
data_ptr
()),
device_ptr
,
sizeof
(
scalar_t
),
cudaMemcpyDeviceToHost
,
stream
));
copy_event_
.
record
(
stream
);
}
operator
scalar_t
()
{
if
(
!
is_ready_
)
{
copy_event_
.
synchronize
();
is_ready_
=
true
;
}
return
reinterpret_cast
<
scalar_t
*>
(
pinned_scalar_
.
data_ptr
())[
0
];
}
private:
torch
::
Tensor
pinned_scalar_
;
at
::
cuda
::
CUDAEvent
copy_event_
;
bool
is_ready_
;
};
// This includes all integer, float and boolean types.
// This includes all integer, float and boolean types.
#define GRAPHBOLT_DISPATCH_CASE_ALL_TYPES(...) \
#define GRAPHBOLT_DISPATCH_CASE_ALL_TYPES(...) \
AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \
...
...
graphbolt/src/cuda/index_select_csc_impl.cu
View file @
a259767b
...
@@ -145,8 +145,6 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
...
@@ -145,8 +145,6 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
auto
output_indptr
=
auto
output_indptr
=
torch
::
empty
(
num_nodes
+
1
,
nodes_options
.
dtype
(
indptr_scalar_type
));
torch
::
empty
(
num_nodes
+
1
,
nodes_options
.
dtype
(
indptr_scalar_type
));
// Actual and modified number of edges.
indptr_t
edge_count
,
edge_count_aligned
;
auto
output_indptr_aligned
=
auto
output_indptr_aligned
=
allocator
.
AllocateStorage
<
indptr_t
>
(
num_nodes
+
1
);
allocator
.
AllocateStorage
<
indptr_t
>
(
num_nodes
+
1
);
...
@@ -170,29 +168,28 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
...
@@ -170,29 +168,28 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
output_indptr_pair
,
PairSum
{},
zero_value
,
num_nodes
+
1
,
stream
));
output_indptr_pair
,
PairSum
{},
zero_value
,
num_nodes
+
1
,
stream
));
}
}
// Copy the modified number of edges.
CUDA_CALL
(
cudaMemcpyAsync
(
&
edge_count_aligned
,
output_indptr_aligned
.
get
()
+
num_nodes
,
sizeof
(
edge_count_aligned
),
cudaMemcpyDeviceToHost
,
stream
));
// Copy the actual total number of edges.
// Copy the actual total number of edges.
CUDA_CALL
(
cudaMemcpyAsync
(
auto
edge_count
=
&
edge_count
,
output_indptr
.
data_ptr
<
indptr_t
>
()
+
num_nodes
,
cuda
::
CopyScalar
{
output_indptr
.
data_ptr
<
indptr_t
>
()
+
num_nodes
};
sizeof
(
edge_count
),
cudaMemcpyDeviceToHost
,
stream
));
// Copy the modified number of edges.
// synchronizes here, we can read edge_count and
edge_count_aligned
auto
edge_count_aligned
=
CUDA_CALL
(
cudaStreamSynchronize
(
stream
))
;
cuda
::
CopyScalar
{
output_indptr_aligned
.
get
()
+
num_nodes
}
;
// Allocate output array with actual number of edges.
// Allocate output array with actual number of edges.
torch
::
Tensor
output_indices
=
torch
::
Tensor
output_indices
=
torch
::
empty
(
torch
::
empty
(
edge_count
,
nodes_options
.
dtype
(
indices
.
scalar_type
()));
static_cast
<
indptr_t
>
(
edge_count
),
nodes_options
.
dtype
(
indices
.
scalar_type
()));
const
dim3
block
(
BLOCK_SIZE
);
const
dim3
block
(
BLOCK_SIZE
);
const
dim3
grid
((
edge_count_aligned
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
);
const
dim3
grid
(
(
static_cast
<
indptr_t
>
(
edge_count_aligned
)
+
BLOCK_SIZE
-
1
)
/
BLOCK_SIZE
);
// Perform the actual copying, of the indices array into
// Perform the actual copying, of the indices array into
// output_indices in an aligned manner.
// output_indices in an aligned manner.
CUDA_KERNEL_CALL
(
CUDA_KERNEL_CALL
(
_CopyIndicesAlignedKernel
,
grid
,
block
,
0
,
stream
,
edge_count_aligned
,
_CopyIndicesAlignedKernel
,
grid
,
block
,
0
,
stream
,
num_nodes
,
sliced_indptr
,
output_indptr
.
data_ptr
<
indptr
_t
>
()
,
static_cast
<
indptr_t
>
(
edge_count_aligned
),
num_nodes
,
sliced_
indptr
,
output_indptr_aligned
.
get
(),
output_indptr
.
data_ptr
<
indptr_t
>
(),
output_indptr_aligned
.
get
(),
reinterpret_cast
<
indices_t
*>
(
indices
.
data_ptr
()),
reinterpret_cast
<
indices_t
*>
(
indices
.
data_ptr
()),
reinterpret_cast
<
indices_t
*>
(
output_indices
.
data_ptr
()),
perm
);
reinterpret_cast
<
indices_t
*>
(
output_indices
.
data_ptr
()),
perm
);
return
{
output_indptr
,
output_indices
};
return
{
output_indptr
,
output_indices
};
...
@@ -203,7 +200,7 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl(
...
@@ -203,7 +200,7 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCImpl(
// Sorting nodes so that accesses over PCI-e are more regular.
// Sorting nodes so that accesses over PCI-e are more regular.
const
auto
sorted_idx
=
const
auto
sorted_idx
=
Sort
(
nodes
,
cuda
::
NumberOfBits
(
indptr
.
size
(
0
)
-
1
)).
second
;
Sort
(
nodes
,
cuda
::
NumberOfBits
(
indptr
.
size
(
0
)
-
1
)).
second
;
auto
stream
=
c10
::
cuda
::
g
et
DefaultCUDA
Stream
();
auto
stream
=
cuda
::
G
et
Current
Stream
();
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
auto
in_degree_and_sliced_indptr
=
SliceCSCIndptr
(
indptr
,
nodes
);
auto
in_degree_and_sliced_indptr
=
SliceCSCIndptr
(
indptr
,
nodes
);
...
@@ -272,7 +269,7 @@ void IndexSelectCSCCopyIndices(
...
@@ -272,7 +269,7 @@ void IndexSelectCSCCopyIndices(
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
IndexSelectCSCImpl
(
std
::
tuple
<
torch
::
Tensor
,
torch
::
Tensor
>
IndexSelectCSCImpl
(
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
nodes
)
{
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
nodes
)
{
auto
stream
=
c10
::
cuda
::
g
et
DefaultCUDA
Stream
();
auto
stream
=
cuda
::
G
et
Current
Stream
();
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
auto
in_degree_and_sliced_indptr
=
SliceCSCIndptr
(
indptr
,
nodes
);
auto
in_degree_and_sliced_indptr
=
SliceCSCIndptr
(
indptr
,
nodes
);
return
AT_DISPATCH_INTEGRAL_TYPES
(
return
AT_DISPATCH_INTEGRAL_TYPES
(
...
@@ -299,15 +296,12 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
...
@@ -299,15 +296,12 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSCImpl(
}
}
// Number of edges being copied.
// Number of edges being copied.
indptr_t
edge_count
;
auto
edge_count
=
CUDA_CALL
(
cudaMemcpyAsync
(
cuda
::
CopyScalar
{
output_indptr
.
data_ptr
<
indptr_t
>
()
+
num_nodes
};
&
edge_count
,
output_indptr
.
data_ptr
<
indptr_t
>
()
+
num_nodes
,
sizeof
(
edge_count
),
cudaMemcpyDeviceToHost
,
stream
));
// blocking read of edge_count
CUDA_CALL
(
cudaStreamSynchronize
(
stream
));
// Allocate output array of size number of copied edges.
// Allocate output array of size number of copied edges.
torch
::
Tensor
output_indices
=
torch
::
empty
(
torch
::
Tensor
output_indices
=
torch
::
empty
(
edge_count
,
nodes
.
options
().
dtype
(
indices
.
scalar_type
()));
static_cast
<
indptr_t
>
(
edge_count
),
nodes
.
options
().
dtype
(
indices
.
scalar_type
()));
GRAPHBOLT_DISPATCH_ELEMENT_SIZES
(
GRAPHBOLT_DISPATCH_ELEMENT_SIZES
(
indices
.
element_size
(),
"IndexSelectCSCCopyIndices"
,
([
&
]
{
indices
.
element_size
(),
"IndexSelectCSCCopyIndices"
,
([
&
]
{
using
indices_t
=
element_size_t
;
using
indices_t
=
element_size_t
;
...
...
graphbolt/src/cuda/index_select_impl.cu
View file @
a259767b
...
@@ -124,7 +124,7 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) {
...
@@ -124,7 +124,7 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) {
const
IdType
*
index_sorted_ptr
=
sorted_index
.
data_ptr
<
IdType
>
();
const
IdType
*
index_sorted_ptr
=
sorted_index
.
data_ptr
<
IdType
>
();
const
int64_t
*
permutation_ptr
=
permutation
.
data_ptr
<
int64_t
>
();
const
int64_t
*
permutation_ptr
=
permutation
.
data_ptr
<
int64_t
>
();
cudaStream_t
stream
=
c10
::
cuda
::
g
et
DefaultCUDA
Stream
();
cudaStream_t
stream
=
cuda
::
G
et
Current
Stream
();
if
(
aligned_feature_size
==
1
)
{
if
(
aligned_feature_size
==
1
)
{
// Use a single thread to process each output row to avoid wasting threads.
// Use a single thread to process each output row to avoid wasting threads.
...
...
graphbolt/src/cuda/sort_impl.cu
View file @
a259767b
...
@@ -24,7 +24,7 @@ std::pair<torch::Tensor, torch::Tensor> Sort(
...
@@ -24,7 +24,7 @@ std::pair<torch::Tensor, torch::Tensor> Sort(
auto
sorted_array
=
torch
::
empty_like
(
input
);
auto
sorted_array
=
torch
::
empty_like
(
input
);
auto
sorted_idx
=
torch
::
empty_like
(
original_idx
);
auto
sorted_idx
=
torch
::
empty_like
(
original_idx
);
auto
allocator
=
cuda
::
GetAllocator
();
auto
allocator
=
cuda
::
GetAllocator
();
auto
stream
=
c10
::
cuda
::
g
et
DefaultCUDA
Stream
();
auto
stream
=
cuda
::
G
et
Current
Stream
();
AT_DISPATCH_INDEX_TYPES
(
AT_DISPATCH_INDEX_TYPES
(
input
.
scalar_type
(),
"SortImpl"
,
([
&
]
{
input
.
scalar_type
(),
"SortImpl"
,
([
&
]
{
const
auto
input_keys
=
input
.
data_ptr
<
index_t
>
();
const
auto
input_keys
=
input
.
data_ptr
<
index_t
>
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment