Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
d4a6f8a0
Unverified
Commit
d4a6f8a0
authored
Apr 06, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Apr 06, 2024
Browse files
[GraphBolt][CUDA] Refactor `Gather` operation. (#7269)
parent
62aca92d
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
53 additions
and
38 deletions
+53
-38
graphbolt/include/graphbolt/cuda_ops.h
graphbolt/include/graphbolt/cuda_ops.h
+14
-0
graphbolt/src/cuda/gather.cu
graphbolt/src/cuda/gather.cu
+36
-0
graphbolt/src/cuda/neighbor_sampler.cu
graphbolt/src/cuda/neighbor_sampler.cu
+3
-38
No files found.
graphbolt/include/graphbolt/cuda_ops.h
View file @
d4a6f8a0
...
...
@@ -149,6 +149,20 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> SliceCSCIndptrHetero(
*/
torch
::
Tensor
ExclusiveCumSum
(
torch
::
Tensor
input
);
/**
* @brief Computes the gather operation on a given input and index tensor.
*
* @param input The input tensor.
* @param index The index tensor.
* @param dtype The optional output dtype. If not given, inferred from the input
* tensor.
*
* @return The result of the input.gather(0, index).to(dtype) operation.
*/
torch
::
Tensor
Gather
(
torch
::
Tensor
input
,
torch
::
Tensor
index
,
torch
::
optional
<
torch
::
ScalarType
>
dtype
=
torch
::
nullopt
);
/**
* @brief Select rows from input tensor according to index tensor.
*
...
...
graphbolt/src/cuda/gather.cu
0 → 100644
View file @
d4a6f8a0
/**
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/gather.cu
* @brief Gather operators implementation on CUDA.
*/
#include <thrust/gather.h>
#include "./common.h"
namespace
graphbolt
{
namespace
ops
{
torch
::
Tensor
Gather
(
torch
::
Tensor
input
,
torch
::
Tensor
index
,
torch
::
optional
<
torch
::
ScalarType
>
dtype
)
{
if
(
!
dtype
.
has_value
())
dtype
=
input
.
scalar_type
();
auto
output
=
torch
::
empty
(
index
.
sizes
(),
index
.
options
().
dtype
(
*
dtype
));
AT_DISPATCH_INDEX_TYPES
(
index
.
scalar_type
(),
"GatherIndexType"
,
([
&
]
{
AT_DISPATCH_INTEGRAL_TYPES
(
input
.
scalar_type
(),
"GatherInputType"
,
([
&
]
{
using
input_t
=
scalar_t
;
AT_DISPATCH_INTEGRAL_TYPES
(
*
dtype
,
"GatherOutputType"
,
([
&
]
{
using
output_t
=
scalar_t
;
THRUST_CALL
(
gather
,
index
.
data_ptr
<
index_t
>
(),
index
.
data_ptr
<
index_t
>
()
+
index
.
size
(
0
),
input
.
data_ptr
<
input_t
>
(),
output
.
data_ptr
<
output_t
>
());
}));
}));
}));
return
output
;
}
}
// namespace ops
}
// namespace graphbolt
graphbolt/src/cuda/neighbor_sampler.cu
View file @
d4a6f8a0
...
...
@@ -500,43 +500,8 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
}
}
output_indices
=
torch
::
empty
(
picked_eids
.
size
(
0
),
picked_eids
.
options
().
dtype
(
indices
.
scalar_type
()));
// Compute: output_indices = indices.gather(0, picked_eids);
AT_DISPATCH_INDEX_TYPES
(
indices
.
scalar_type
(),
"SampleNeighborsOutputIndices"
,
([
&
]
{
using
indices_t
=
index_t
;
THRUST_CALL
(
gather
,
picked_eids
.
data_ptr
<
indptr_t
>
(),
picked_eids
.
data_ptr
<
indptr_t
>
()
+
picked_eids
.
size
(
0
),
indices
.
data_ptr
<
indices_t
>
(),
output_indices
.
data_ptr
<
indices_t
>
());
}));
}));
auto
index_type_per_edge_for_sampled_edges
=
[
&
]
{
// The code behaves same as:
// output_type_per_edge = type_per_edge.gather(0, picked_eids);
// The reimplementation is required due to the torch equivalent does
// not work when type_per_edge is on pinned memory
auto
types
=
type_per_edge
.
value
();
auto
output
=
torch
::
empty
(
picked_eids
.
size
(
0
),
picked_eids
.
options
().
dtype
(
types
.
scalar_type
()));
AT_DISPATCH_INDEX_TYPES
(
indptr
.
scalar_type
(),
"SampleNeighborsIndptr"
,
([
&
]
{
using
indptr_t
=
index_t
;
AT_DISPATCH_INTEGRAL_TYPES
(
types
.
scalar_type
(),
"SampleNeighborsOutputTypePerEdge"
,
([
&
]
{
THRUST_CALL
(
gather
,
picked_eids
.
data_ptr
<
indptr_t
>
(),
picked_eids
.
data_ptr
<
indptr_t
>
()
+
picked_eids
.
size
(
0
),
types
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
());
}));
output_indices
=
Gather
(
indices
,
picked_eids
);
}));
return
output
;
};
torch
::
optional
<
torch
::
Tensor
>
output_type_per_edge
;
torch
::
optional
<
torch
::
Tensor
>
edge_offsets
;
...
...
@@ -547,7 +512,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
// type_per_edge of sampled edges and determine the offsets of different
// sampled etypes and convert to fused hetero indptr representation.
if
(
fanouts
.
size
()
==
1
)
{
output_type_per_edge
=
index_
type_per_edge
_for_sampled_edges
(
);
output_type_per_edge
=
Gather
(
*
type_per_edge
,
picked_eids
);
torch
::
Tensor
output_in_degree
,
sliced_output_indptr
;
sliced_output_indptr
=
output_indptr
.
slice
(
0
,
0
,
output_indptr
.
size
(
0
)
-
1
);
...
...
@@ -652,7 +617,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
output_indptr
=
output_indptr
.
slice
(
0
,
0
,
output_indptr
.
size
(
0
),
fanouts
.
size
());
if
(
type_per_edge
)
output_type_per_edge
=
index_
type_per_edge
_for_sampled_edges
(
);
output_type_per_edge
=
Gather
(
*
type_per_edge
,
picked_eids
);
}
torch
::
optional
<
torch
::
Tensor
>
subgraph_reverse_edge_ids
=
torch
::
nullopt
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment