Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
d3483fe1
Unverified
Commit
d3483fe1
authored
Apr 03, 2024
by
Muhammed Fatih BALIN
Committed by
GitHub
Apr 03, 2024
Browse files
[GraphBolt][CUDA] Optimize hetero sampling. (#7223)
parent
2a00cd3d
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
372 additions
and
153 deletions
+372
-153
graphbolt/include/graphbolt/cuda_sampling_ops.h
graphbolt/include/graphbolt/cuda_sampling_ops.h
+17
-3
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
+7
-4
graphbolt/include/graphbolt/fused_sampled_subgraph.h
graphbolt/include/graphbolt/fused_sampled_subgraph.h
+30
-9
graphbolt/src/cuda/neighbor_sampler.cu
graphbolt/src/cuda/neighbor_sampler.cu
+155
-34
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+15
-13
graphbolt/src/python_binding.cc
graphbolt/src/python_binding.cc
+2
-1
graphbolt/src/utils.h
graphbolt/src/utils.h
+11
-0
python/dgl/distributed/graph_services.py
python/dgl/distributed/graph_services.py
+1
-1
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+112
-52
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+16
-30
tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py
tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py
+6
-6
No files found.
graphbolt/include/graphbolt/cuda_sampling_ops.h
View file @
d3483fe1
...
@@ -19,8 +19,10 @@ namespace ops {
...
@@ -19,8 +19,10 @@ namespace ops {
*
*
* @param indptr Index pointer array of the CSC.
* @param indptr Index pointer array of the CSC.
* @param indices Indices array of the CSC.
* @param indices Indices array of the CSC.
* @param
node
s The nodes from which to sample neighbors. If not provided,
* @param
seed
s The nodes from which to sample neighbors. If not provided,
* assumed to be equal to torch.arange(indptr.size(0) - 1).
* assumed to be equal to torch.arange(indptr.size(0) - 1).
* @param seed_offsets The offsets of the given seeds,
* seeds[seed_offsets[i]: seed_offsets[i + 1]] has node type i.
* @param fanouts The number of edges to be sampled for each node with or
* @param fanouts The number of edges to be sampled for each node with or
* without considering edge types.
* without considering edge types.
* - When the length is 1, it indicates that the fanout applies to all
* - When the length is 1, it indicates that the fanout applies to all
...
@@ -45,6 +47,12 @@ namespace ops {
...
@@ -45,6 +47,12 @@ namespace ops {
* @param probs_or_mask An optional tensor with (unnormalized) probabilities
* @param probs_or_mask An optional tensor with (unnormalized) probabilities
* corresponding to each neighboring edge of a node. It must be
* corresponding to each neighboring edge of a node. It must be
* a 1D tensor, with the number of elements equaling the total number of edges.
* a 1D tensor, with the number of elements equaling the total number of edges.
* @param node_type_to_id A dictionary mapping node type names to type IDs. The
* length of it is equal to the number of node types. The key is the node type
* name, and the value is the corresponding type ID.
* @param edge_type_to_id A dictionary mapping edge type names to type IDs. The
* length of it is equal to the number of edge types. The key is the edge type
* name, and the value is the corresponding type ID.
* @param random_seed The random seed for the sampler for layer=True.
* @param random_seed The random seed for the sampler for layer=True.
* @param seed2_contribution The contribution of the second random seed, [0, 1)
* @param seed2_contribution The contribution of the second random seed, [0, 1)
* for layer=True.
* for layer=True.
...
@@ -54,10 +62,16 @@ namespace ops {
...
@@ -54,10 +62,16 @@ namespace ops {
*/
*/
c10
::
intrusive_ptr
<
sampling
::
FusedSampledSubgraph
>
SampleNeighbors
(
c10
::
intrusive_ptr
<
sampling
::
FusedSampledSubgraph
>
SampleNeighbors
(
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
optional
<
torch
::
Tensor
>
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
torch
::
optional
<
torch
::
Tensor
>
seeds
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
std
::
vector
<
int64_t
>>
seed_offsets
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
probs_or_mask
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
probs_or_mask
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Dict
<
std
::
string
,
int64_t
>>
node_type_to_id
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Dict
<
std
::
string
,
int64_t
>>
edge_type_to_id
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
random_seed
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
random_seed
=
torch
::
nullopt
,
float
seed2_contribution
=
.0
f
);
float
seed2_contribution
=
.0
f
);
...
...
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
View file @
d3483fe1
...
@@ -298,8 +298,10 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -298,8 +298,10 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @brief Sample neighboring edges of the given nodes and return the induced
* @brief Sample neighboring edges of the given nodes and return the induced
* subgraph.
* subgraph.
*
*
* @param
node
s The nodes from which to sample neighbors. If not provided,
* @param
seed
s The nodes from which to sample neighbors. If not provided,
* assumed to be equal to torch.arange(NumNodes()).
* assumed to be equal to torch.arange(NumNodes()).
* @param seed_offsets The offsets of the given seeds,
* seeds[seed_offsets[i]: seed_offsets[i + 1]] has node type id i.
* @param fanouts The number of edges to be sampled for each node with or
* @param fanouts The number of edges to be sampled for each node with or
* without considering edge types.
* without considering edge types.
* - When the length is 1, it indicates that the fanout applies to all
* - When the length is 1, it indicates that the fanout applies to all
...
@@ -333,9 +335,10 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
...
@@ -333,9 +335,10 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* the sampled graph's information.
* the sampled graph's information.
*/
*/
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
SampleNeighbors
(
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
SampleNeighbors
(
torch
::
optional
<
torch
::
Tensor
>
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
torch
::
optional
<
torch
::
Tensor
>
seeds
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
std
::
vector
<
int64_t
>>
seed_offsets
,
torch
::
optional
<
std
::
string
>
probs_name
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
std
::
string
>
probs_name
,
torch
::
optional
<
torch
::
Tensor
>
random_seed
,
torch
::
optional
<
torch
::
Tensor
>
random_seed
,
double
seed2_contribution
)
const
;
double
seed2_contribution
)
const
;
...
...
graphbolt/include/graphbolt/fused_sampled_subgraph.h
View file @
d3483fe1
...
@@ -51,33 +51,39 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
...
@@ -51,33 +51,39 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
* graph.
* graph.
* @param original_edge_ids Reverse edge ids in the original graph.
* @param original_edge_ids Reverse edge ids in the original graph.
* @param type_per_edge Type id of each edge.
* @param type_per_edge Type id of each edge.
* @param etype_offsets Edge offsets for the sampled edges for the sampled
* edges that are sorted w.r.t. edge types.
*/
*/
FusedSampledSubgraph
(
FusedSampledSubgraph
(
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
original_column_node_ids
,
torch
::
optional
<
torch
::
Tensor
>
original_column_node_ids
,
torch
::
optional
<
torch
::
Tensor
>
original_row_node_ids
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
original_row_node_ids
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
original_edge_ids
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
original_edge_ids
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
=
torch
::
nullopt
)
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
=
torch
::
nullopt
,
torch
::
optional
<
torch
::
Tensor
>
etype_offsets
=
torch
::
nullopt
)
:
indptr
(
indptr
),
:
indptr
(
indptr
),
indices
(
indices
),
indices
(
indices
),
original_column_node_ids
(
original_column_node_ids
),
original_column_node_ids
(
original_column_node_ids
),
original_row_node_ids
(
original_row_node_ids
),
original_row_node_ids
(
original_row_node_ids
),
original_edge_ids
(
original_edge_ids
),
original_edge_ids
(
original_edge_ids
),
type_per_edge
(
type_per_edge
)
{}
type_per_edge
(
type_per_edge
),
etype_offsets
(
etype_offsets
)
{}
FusedSampledSubgraph
()
=
default
;
FusedSampledSubgraph
()
=
default
;
/**
/**
* @brief CSC format index pointer array, where the implicit node ids are
* @brief CSC format index pointer array, where the implicit node ids are
* already compacted. And the original ids are stored in the
* already compacted. And the original ids are stored in the
* `original_column_node_ids` field.
* `original_column_node_ids` field. Its length is equal to:
* 1 + \sum_{etype} #seeds with dst_node_type(etype)
*/
*/
torch
::
Tensor
indptr
;
torch
::
Tensor
indptr
;
/**
/**
* @brief CSC format index array, where the node ids can be compacted ids or
* @brief CSC format index array, where the node ids can be compacted ids or
* original ids. If compacted, the original ids are stored in the
* original ids. If compacted, the original ids are stored in the
* `original_row_node_ids` field.
* `original_row_node_ids` field. The indices are sorted w.r.t. their edge
* types for the heterogenous case.
*/
*/
torch
::
Tensor
indices
;
torch
::
Tensor
indices
;
...
@@ -86,10 +92,11 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
...
@@ -86,10 +92,11 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
* can be treated as a coordinated row and column pair, and this is the the
* can be treated as a coordinated row and column pair, and this is the the
* mapped ids of the column.
* mapped ids of the column.
*
*
* @note This is required and the mapping relations can be inconsistent with
* @note This is optional and the mapping relations can be inconsistent with
* column's.
* column's. It can be missing when the sampling algorithm is called via a
* sliced sampled subgraph with missing seeds argument.
*/
*/
torch
::
Tensor
original_column_node_ids
;
torch
::
optional
<
torch
::
Tensor
>
original_column_node_ids
;
/**
/**
* @brief Row's reverse node ids in the original graph. A graph structure
* @brief Row's reverse node ids in the original graph. A graph structure
...
@@ -104,7 +111,8 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
...
@@ -104,7 +111,8 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
/**
/**
* @brief Reverse edge ids in the original graph, the edge with id
* @brief Reverse edge ids in the original graph, the edge with id
* `original_edge_ids[i]` in the original graph is mapped to `i` in this
* `original_edge_ids[i]` in the original graph is mapped to `i` in this
* subgraph. This is useful when edge features are needed.
* subgraph. This is useful when edge features are needed. The edges are
* sorted w.r.t. their edge types for the heterogenous case.
*/
*/
torch
::
optional
<
torch
::
Tensor
>
original_edge_ids
;
torch
::
optional
<
torch
::
Tensor
>
original_edge_ids
;
...
@@ -112,8 +120,21 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
...
@@ -112,8 +120,21 @@ struct FusedSampledSubgraph : torch::CustomClassHolder {
* @brief Type id of each edge, where type id is the corresponding index of
* @brief Type id of each edge, where type id is the corresponding index of
* edge types. The length of it is equal to the number of edges in the
* edge types. The length of it is equal to the number of edges in the
* subgraph.
* subgraph.
*
* @note This output is not created by the CUDA implementation as the edges
* are sorted w.r.t edge types, one has to use etype_offsets to infer the edge
* type information. This field is going to be deprecated. It can be generated
* when needed by computing gb.expand_indptr(etype_offsets).
*/
*/
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
;
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
;
/**
* @brief Offsets of each etype,
* type_per_edge[etype_offsets[i]: etype_offsets[i + 1]] == i
* It has length equal to (1 + #etype), and the edges are guaranteed to be
* sorted w.r.t. their edge types.
*/
torch
::
optional
<
torch
::
Tensor
>
etype_offsets
;
};
};
}
// namespace sampling
}
// namespace sampling
...
...
graphbolt/src/cuda/neighbor_sampler.cu
View file @
d3483fe1
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#include <type_traits>
#include <type_traits>
#include "../random.h"
#include "../random.h"
#include "../utils.h"
#include "./common.h"
#include "./common.h"
#include "./utils.h"
#include "./utils.h"
...
@@ -183,19 +184,26 @@ struct SegmentEndFunc {
...
@@ -183,19 +184,26 @@ struct SegmentEndFunc {
c10
::
intrusive_ptr
<
sampling
::
FusedSampledSubgraph
>
SampleNeighbors
(
c10
::
intrusive_ptr
<
sampling
::
FusedSampledSubgraph
>
SampleNeighbors
(
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
optional
<
torch
::
Tensor
>
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
torch
::
optional
<
torch
::
Tensor
>
seeds
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
std
::
vector
<
int64_t
>>
seed_offsets
,
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
torch
::
Tensor
>
type_per_edge
,
torch
::
optional
<
torch
::
Tensor
>
probs_or_mask
,
torch
::
optional
<
torch
::
Tensor
>
probs_or_mask
,
torch
::
optional
<
torch
::
Dict
<
std
::
string
,
int64_t
>>
node_type_to_id
,
torch
::
optional
<
torch
::
Dict
<
std
::
string
,
int64_t
>>
edge_type_to_id
,
torch
::
optional
<
torch
::
Tensor
>
random_seed_tensor
,
torch
::
optional
<
torch
::
Tensor
>
random_seed_tensor
,
float
seed2_contribution
)
{
float
seed2_contribution
)
{
// When seed_offsets.has_value() in the hetero case, we compute the output of
// sample_neighbors _convert_to_sampled_subgraph in a fused manner so that
// _convert_to_sampled_subgraph only has to perform slices over the returned
// indptr and indices tensors to form CSC outputs for each edge type.
TORCH_CHECK
(
!
replace
,
"Sampling with replacement is not supported yet!"
);
TORCH_CHECK
(
!
replace
,
"Sampling with replacement is not supported yet!"
);
// Assume that indptr, indices,
node
s, type_per_edge and probs_or_mask
// Assume that indptr, indices,
seed
s, type_per_edge and probs_or_mask
// are all resident on the GPU. If not, it is better to first extract them
// are all resident on the GPU. If not, it is better to first extract them
// before calling this function.
// before calling this function.
auto
allocator
=
cuda
::
GetAllocator
();
auto
allocator
=
cuda
::
GetAllocator
();
auto
num_rows
=
auto
num_rows
=
node
s
.
has_value
()
?
node
s
.
value
().
size
(
0
)
:
indptr
.
size
(
0
)
-
1
;
seed
s
.
has_value
()
?
seed
s
.
value
().
size
(
0
)
:
indptr
.
size
(
0
)
-
1
;
auto
fanouts_pinned
=
torch
::
empty
(
auto
fanouts_pinned
=
torch
::
empty
(
fanouts
.
size
(),
fanouts
.
size
(),
c10
::
TensorOptions
().
dtype
(
torch
::
kLong
).
pinned_memory
(
true
));
c10
::
TensorOptions
().
dtype
(
torch
::
kLong
).
pinned_memory
(
true
));
...
@@ -210,7 +218,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
...
@@ -210,7 +218,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
fanouts_device
.
get
(),
fanouts_pinned_ptr
,
fanouts_device
.
get
(),
fanouts_pinned_ptr
,
sizeof
(
int64_t
)
*
fanouts
.
size
(),
cudaMemcpyHostToDevice
,
sizeof
(
int64_t
)
*
fanouts
.
size
(),
cudaMemcpyHostToDevice
,
cuda
::
GetCurrentStream
()));
cuda
::
GetCurrentStream
()));
auto
in_degree_and_sliced_indptr
=
SliceCSCIndptr
(
indptr
,
node
s
);
auto
in_degree_and_sliced_indptr
=
SliceCSCIndptr
(
indptr
,
seed
s
);
auto
in_degree
=
std
::
get
<
0
>
(
in_degree_and_sliced_indptr
);
auto
in_degree
=
std
::
get
<
0
>
(
in_degree_and_sliced_indptr
);
auto
sliced_indptr
=
std
::
get
<
1
>
(
in_degree_and_sliced_indptr
);
auto
sliced_indptr
=
std
::
get
<
1
>
(
in_degree_and_sliced_indptr
);
auto
max_in_degree
=
torch
::
empty
(
auto
max_in_degree
=
torch
::
empty
(
...
@@ -227,16 +235,16 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
...
@@ -227,16 +235,16 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
max_in_degree_event
.
record
();
max_in_degree_event
.
record
();
torch
::
optional
<
int64_t
>
num_edges
;
torch
::
optional
<
int64_t
>
num_edges
;
torch
::
Tensor
sub_indptr
;
torch
::
Tensor
sub_indptr
;
if
(
!
node
s
.
has_value
())
{
if
(
!
seed
s
.
has_value
())
{
num_edges
=
indices
.
size
(
0
);
num_edges
=
indices
.
size
(
0
);
sub_indptr
=
indptr
;
sub_indptr
=
indptr
;
}
}
torch
::
optional
<
torch
::
Tensor
>
sliced_probs_or_mask
;
torch
::
optional
<
torch
::
Tensor
>
sliced_probs_or_mask
;
if
(
probs_or_mask
.
has_value
())
{
if
(
probs_or_mask
.
has_value
())
{
if
(
node
s
.
has_value
())
{
if
(
seed
s
.
has_value
())
{
torch
::
Tensor
sliced_probs_or_mask_tensor
;
torch
::
Tensor
sliced_probs_or_mask_tensor
;
std
::
tie
(
sub_indptr
,
sliced_probs_or_mask_tensor
)
=
IndexSelectCSCImpl
(
std
::
tie
(
sub_indptr
,
sliced_probs_or_mask_tensor
)
=
IndexSelectCSCImpl
(
in_degree
,
sliced_indptr
,
probs_or_mask
.
value
(),
node
s
.
value
(),
in_degree
,
sliced_indptr
,
probs_or_mask
.
value
(),
seed
s
.
value
(),
indptr
.
size
(
0
)
-
2
,
num_edges
);
indptr
.
size
(
0
)
-
2
,
num_edges
);
sliced_probs_or_mask
=
sliced_probs_or_mask_tensor
;
sliced_probs_or_mask
=
sliced_probs_or_mask_tensor
;
num_edges
=
sliced_probs_or_mask_tensor
.
size
(
0
);
num_edges
=
sliced_probs_or_mask_tensor
.
size
(
0
);
...
@@ -246,9 +254,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
...
@@ -246,9 +254,9 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
}
}
if
(
fanouts
.
size
()
>
1
)
{
if
(
fanouts
.
size
()
>
1
)
{
torch
::
Tensor
sliced_type_per_edge
;
torch
::
Tensor
sliced_type_per_edge
;
if
(
node
s
.
has_value
())
{
if
(
seed
s
.
has_value
())
{
std
::
tie
(
sub_indptr
,
sliced_type_per_edge
)
=
IndexSelectCSCImpl
(
std
::
tie
(
sub_indptr
,
sliced_type_per_edge
)
=
IndexSelectCSCImpl
(
in_degree
,
sliced_indptr
,
type_per_edge
.
value
(),
node
s
.
value
(),
in_degree
,
sliced_indptr
,
type_per_edge
.
value
(),
seed
s
.
value
(),
indptr
.
size
(
0
)
-
2
,
num_edges
);
indptr
.
size
(
0
)
-
2
,
num_edges
);
}
else
{
}
else
{
sliced_type_per_edge
=
type_per_edge
.
value
();
sliced_type_per_edge
=
type_per_edge
.
value
();
...
@@ -259,7 +267,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
...
@@ -259,7 +267,7 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
num_edges
=
sliced_type_per_edge
.
size
(
0
);
num_edges
=
sliced_type_per_edge
.
size
(
0
);
}
}
// If sub_indptr was not computed in the two code blocks above:
// If sub_indptr was not computed in the two code blocks above:
if
(
node
s
.
has_value
()
&&
!
probs_or_mask
.
has_value
()
&&
fanouts
.
size
()
<=
1
)
{
if
(
seed
s
.
has_value
()
&&
!
probs_or_mask
.
has_value
()
&&
fanouts
.
size
()
<=
1
)
{
sub_indptr
=
ExclusiveCumSum
(
in_degree
);
sub_indptr
=
ExclusiveCumSum
(
in_degree
);
}
}
auto
coo_rows
=
ExpandIndptrImpl
(
auto
coo_rows
=
ExpandIndptrImpl
(
...
@@ -276,7 +284,6 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
...
@@ -276,7 +284,6 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
auto
output_indptr
=
torch
::
empty_like
(
sub_indptr
);
auto
output_indptr
=
torch
::
empty_like
(
sub_indptr
);
torch
::
Tensor
picked_eids
;
torch
::
Tensor
picked_eids
;
torch
::
Tensor
output_indices
;
torch
::
Tensor
output_indices
;
torch
::
optional
<
torch
::
Tensor
>
output_type_per_edge
;
AT_DISPATCH_INDEX_TYPES
(
AT_DISPATCH_INDEX_TYPES
(
indptr
.
scalar_type
(),
"SampleNeighborsIndptr"
,
([
&
]
{
indptr
.
scalar_type
(),
"SampleNeighborsIndptr"
,
([
&
]
{
...
@@ -507,39 +514,153 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
...
@@ -507,39 +514,153 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> SampleNeighbors(
indices
.
data_ptr
<
indices_t
>
(),
indices
.
data_ptr
<
indices_t
>
(),
output_indices
.
data_ptr
<
indices_t
>
());
output_indices
.
data_ptr
<
indices_t
>
());
}));
}));
}));
if
(
type_per_edge
)
{
auto
index_type_per_edge_for_sampled_edges
=
[
&
]
{
// output_type_per_edge = type_per_edge.gather(0, picked_eids);
// The code behaves same as:
// The commented out torch equivalent above does not work when
// output_type_per_edge = type_per_edge.gather(0, picked_eids);
// type_per_edge is on pinned memory. That is why, we have to
// The reimplementation is required due to the torch equivalent does
// reimplement it, similar to the indices gather operation above.
// not work when type_per_edge is on pinned memory
auto
types
=
type_per_edge
.
value
();
auto
types
=
type_per_edge
.
value
();
output_type_per_edge
=
torch
::
empty
(
auto
output
=
torch
::
empty
(
picked_eids
.
size
(
0
),
picked_eids
.
size
(
0
),
picked_eids
.
options
().
dtype
(
types
.
scalar_type
()));
picked_eids
.
options
().
dtype
(
types
.
scalar_type
()));
AT_DISPATCH_INDEX_TYPES
(
indptr
.
scalar_type
(),
"SampleNeighborsIndptr"
,
([
&
]
{
using
indptr_t
=
index_t
;
AT_DISPATCH_INTEGRAL_TYPES
(
AT_DISPATCH_INTEGRAL_TYPES
(
types
.
scalar_type
(),
"SampleNeighborsOutputTypePerEdge"
,
([
&
]
{
types
.
scalar_type
(),
"SampleNeighborsOutputTypePerEdge"
,
([
&
]
{
THRUST_CALL
(
THRUST_CALL
(
gather
,
picked_eids
.
data_ptr
<
indptr_t
>
(),
gather
,
picked_eids
.
data_ptr
<
indptr_t
>
(),
picked_eids
.
data_ptr
<
indptr_t
>
()
+
picked_eids
.
size
(
0
),
picked_eids
.
data_ptr
<
indptr_t
>
()
+
picked_eids
.
size
(
0
),
types
.
data_ptr
<
scalar_t
>
(),
types
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
());
output_type_per_edge
.
value
().
data_ptr
<
scalar_t
>
());
}));
}));
}
}));
}));
return
output
;
};
torch
::
optional
<
torch
::
Tensor
>
output_type_per_edge
;
torch
::
optional
<
torch
::
Tensor
>
edge_offsets
;
if
(
type_per_edge
&&
seed_offsets
)
{
const
int64_t
num_etypes
=
edge_type_to_id
.
has_value
()
?
edge_type_to_id
->
size
()
:
1
;
// If we performed homogenous sampling on hetero graph, we have to look at
// type_per_edge of sampled edges and determine the offsets of different
// sampled etypes and convert to fused hetero indptr representation.
if
(
fanouts
.
size
()
==
1
)
{
output_type_per_edge
=
index_type_per_edge_for_sampled_edges
();
torch
::
Tensor
output_in_degree
,
sliced_output_indptr
;
sliced_output_indptr
=
output_indptr
.
slice
(
0
,
0
,
output_indptr
.
size
(
0
)
-
1
);
std
::
tie
(
output_indptr
,
output_in_degree
,
sliced_output_indptr
)
=
SliceCSCIndptrHetero
(
output_indptr
,
output_type_per_edge
.
value
(),
sliced_output_indptr
,
num_etypes
);
// We use num_rows to hold num_seeds * num_etypes. So, it needs to be
// updated when sampling with a single fanout value when the graph is
// heterogenous.
num_rows
=
sliced_output_indptr
.
size
(
0
);
}
// Here, we check what are the dst node types for the given seeds so that
// we can compute the output indptr space later.
std
::
vector
<
int64_t
>
etype_id_to_dst_ntype_id
(
num_etypes
);
for
(
auto
&
etype_and_id
:
edge_type_to_id
.
value
())
{
auto
etype
=
etype_and_id
.
key
();
auto
id
=
etype_and_id
.
value
();
auto
dst_type
=
utils
::
parse_dst_ntype_from_etype
(
etype
);
etype_id_to_dst_ntype_id
[
id
]
=
node_type_to_id
->
at
(
dst_type
);
}
// For each edge type, we compute the start and end offsets to index into
// indptr to form the final output_indptr.
auto
indptr_offsets
=
torch
::
empty
(
num_etypes
*
2
,
c10
::
TensorOptions
().
dtype
(
torch
::
kLong
).
pinned_memory
(
true
));
auto
indptr_offsets_ptr
=
indptr_offsets
.
data_ptr
<
int64_t
>
();
// We compute the indptr offsets here, right now, output_indptr is of size
// # seeds * num_etypes + 1. We can simply take slices to get correct output
// indptr. The final output_indptr is same as current indptr except that
// some intermediate values are removed to change the node ids space from
// all of the seed vertices to the node id space of the dst node type of
// each edge type.
for
(
int
i
=
0
;
i
<
num_etypes
;
i
++
)
{
indptr_offsets_ptr
[
2
*
i
]
=
num_rows
/
num_etypes
*
i
+
seed_offsets
->
at
(
etype_id_to_dst_ntype_id
[
i
]);
indptr_offsets_ptr
[
2
*
i
+
1
]
=
num_rows
/
num_etypes
*
i
+
seed_offsets
->
at
(
etype_id_to_dst_ntype_id
[
i
]
+
1
);
}
auto
permutation
=
torch
::
arange
(
0
,
num_rows
*
num_etypes
,
num_etypes
,
output_indptr
.
options
());
permutation
=
permutation
.
remainder
(
num_rows
)
+
permutation
.
div
(
num_rows
,
"floor"
);
// This permutation, when applied sorts the sampled edges with respect to
// edge types.
auto
[
output_in_degree
,
sliced_output_indptr
]
=
SliceCSCIndptr
(
output_indptr
,
permutation
);
std
::
tie
(
output_indptr
,
picked_eids
)
=
IndexSelectCSCImpl
(
output_in_degree
,
sliced_output_indptr
,
picked_eids
,
permutation
,
num_rows
-
1
,
picked_eids
.
size
(
0
));
edge_offsets
=
torch
::
empty
(
num_etypes
*
2
,
c10
::
TensorOptions
()
.
dtype
(
output_indptr
.
scalar_type
())
.
pinned_memory
(
true
));
at
::
cuda
::
CUDAEvent
edge_offsets_event
;
AT_DISPATCH_INDEX_TYPES
(
indptr
.
scalar_type
(),
"SampleNeighborsEdgeOffsets"
,
([
&
]
{
THRUST_CALL
(
gather
,
indptr_offsets_ptr
,
indptr_offsets_ptr
+
indptr_offsets
.
size
(
0
),
output_indptr
.
data_ptr
<
index_t
>
(),
edge_offsets
->
data_ptr
<
index_t
>
());
}));
edge_offsets_event
.
record
();
// The output_indices is permuted here.
std
::
tie
(
output_indptr
,
output_indices
)
=
IndexSelectCSCImpl
(
output_in_degree
,
sliced_output_indptr
,
output_indices
,
permutation
,
num_rows
-
1
,
output_indices
.
size
(
0
));
std
::
vector
<
torch
::
Tensor
>
indptr_list
;
for
(
int
i
=
0
;
i
<
num_etypes
;
i
++
)
{
indptr_list
.
push_back
(
output_indptr
.
slice
(
0
,
indptr_offsets_ptr
[
2
*
i
],
indptr_offsets_ptr
[
2
*
i
+
1
]
+
(
i
==
num_etypes
-
1
)));
}
// We form the final output indptr by concatenating pieces for different
// edge types.
output_indptr
=
torch
::
cat
(
indptr_list
);
edge_offsets_event
.
synchronize
();
// We read the edge_offsets here, they are in pairs but we don't need it to
// be in pairs. So we remove the duplicate information from it and turn it
// into a real offsets array.
AT_DISPATCH_INDEX_TYPES
(
indptr
.
scalar_type
(),
"SampleNeighborsEdgeOffsetsCheck"
,
([
&
]
{
auto
edge_offsets_ptr
=
edge_offsets
->
data_ptr
<
index_t
>
();
TORCH_CHECK
(
edge_offsets_ptr
[
0
]
==
0
,
"edge_offsets is incorrect."
);
for
(
int
i
=
1
;
i
<
num_etypes
;
i
++
)
{
TORCH_CHECK
(
edge_offsets_ptr
[
2
*
i
-
1
]
==
edge_offsets_ptr
[
2
*
i
],
"edge_offsets is incorrect."
);
}
TORCH_CHECK
(
edge_offsets_ptr
[
2
*
num_etypes
-
1
]
==
picked_eids
.
size
(
0
),
"edge_offsets is incorrect."
);
for
(
int
i
=
0
;
i
<
num_etypes
;
i
++
)
{
edge_offsets_ptr
[
i
+
1
]
=
edge_offsets_ptr
[
2
*
i
+
1
];
}
}));
edge_offsets
=
edge_offsets
->
slice
(
0
,
0
,
num_etypes
+
1
);
}
else
{
// Convert output_indptr back to homo by discarding intermediate offsets.
output_indptr
=
output_indptr
.
slice
(
0
,
0
,
output_indptr
.
size
(
0
),
fanouts
.
size
());
if
(
type_per_edge
)
output_type_per_edge
=
index_type_per_edge_for_sampled_edges
();
}
// Convert output_indptr back to homo by discarding intermediate offsets.
output_indptr
=
output_indptr
.
slice
(
0
,
0
,
output_indptr
.
size
(
0
),
fanouts
.
size
());
torch
::
optional
<
torch
::
Tensor
>
subgraph_reverse_edge_ids
=
torch
::
nullopt
;
torch
::
optional
<
torch
::
Tensor
>
subgraph_reverse_edge_ids
=
torch
::
nullopt
;
if
(
return_eids
)
subgraph_reverse_edge_ids
=
std
::
move
(
picked_eids
);
if
(
return_eids
)
subgraph_reverse_edge_ids
=
std
::
move
(
picked_eids
);
if
(
!
nodes
.
has_value
())
{
nodes
=
torch
::
arange
(
indptr
.
size
(
0
)
-
1
,
indices
.
options
());
}
return
c10
::
make_intrusive
<
sampling
::
FusedSampledSubgraph
>
(
return
c10
::
make_intrusive
<
sampling
::
FusedSampledSubgraph
>
(
output_indptr
,
output_indices
,
nodes
.
value
()
,
torch
::
nullopt
,
output_indptr
,
output_indices
,
seeds
,
torch
::
nullopt
,
subgraph_reverse_edge_ids
,
output_type_per_edge
);
subgraph_reverse_edge_ids
,
output_type_per_edge
,
edge_offsets
);
}
}
}
// namespace ops
}
// namespace ops
...
...
graphbolt/src/fused_csc_sampling_graph.cc
View file @
d3483fe1
...
@@ -617,23 +617,24 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
...
@@ -617,23 +617,24 @@ FusedCSCSamplingGraph::SampleNeighborsImpl(
}
}
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
FusedCSCSamplingGraph
::
SampleNeighbors
(
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
FusedCSCSamplingGraph
::
SampleNeighbors
(
torch
::
optional
<
torch
::
Tensor
>
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
torch
::
optional
<
torch
::
Tensor
>
seeds
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
std
::
vector
<
int64_t
>>
seed_offsets
,
torch
::
optional
<
std
::
string
>
probs_name
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
std
::
string
>
probs_name
,
torch
::
optional
<
torch
::
Tensor
>
random_seed
,
torch
::
optional
<
torch
::
Tensor
>
random_seed
,
double
seed2_contribution
)
const
{
double
seed2_contribution
)
const
{
auto
probs_or_mask
=
this
->
EdgeAttribute
(
probs_name
);
auto
probs_or_mask
=
this
->
EdgeAttribute
(
probs_name
);
// If
node
s does not have a value, then we expect all arguments to be resident
// If
seed
s does not have a value, then we expect all arguments to be resident
// on the GPU. If
node
s has a value, then we expect them to be accessible from
// on the GPU. If
seed
s has a value, then we expect them to be accessible from
// GPU. This is required for the dispatch to work when CUDA is not available.
// GPU. This is required for the dispatch to work when CUDA is not available.
if
(((
!
node
s
.
has_value
()
&&
utils
::
is_on_gpu
(
indptr_
)
&&
if
(((
!
seed
s
.
has_value
()
&&
utils
::
is_on_gpu
(
indptr_
)
&&
utils
::
is_on_gpu
(
indices_
)
&&
utils
::
is_on_gpu
(
indices_
)
&&
(
!
probs_or_mask
.
has_value
()
||
(
!
probs_or_mask
.
has_value
()
||
utils
::
is_on_gpu
(
probs_or_mask
.
value
()))
&&
utils
::
is_on_gpu
(
probs_or_mask
.
value
()))
&&
(
!
type_per_edge_
.
has_value
()
||
(
!
type_per_edge_
.
has_value
()
||
utils
::
is_on_gpu
(
type_per_edge_
.
value
())))
||
utils
::
is_on_gpu
(
type_per_edge_
.
value
())))
||
(
node
s
.
has_value
()
&&
utils
::
is_on_gpu
(
node
s
.
value
())
&&
(
seed
s
.
has_value
()
&&
utils
::
is_on_gpu
(
seed
s
.
value
())
&&
utils
::
is_accessible_from_gpu
(
indptr_
)
&&
utils
::
is_accessible_from_gpu
(
indptr_
)
&&
utils
::
is_accessible_from_gpu
(
indices_
)
&&
utils
::
is_accessible_from_gpu
(
indices_
)
&&
(
!
probs_or_mask
.
has_value
()
||
(
!
probs_or_mask
.
has_value
()
||
...
@@ -644,11 +645,12 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
...
@@ -644,11 +645,12 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE
(
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE
(
c10
::
DeviceType
::
CUDA
,
"SampleNeighbors"
,
{
c10
::
DeviceType
::
CUDA
,
"SampleNeighbors"
,
{
return
ops
::
SampleNeighbors
(
return
ops
::
SampleNeighbors
(
indptr_
,
indices_
,
nodes
,
fanouts
,
replace
,
layer
,
return_eids
,
indptr_
,
indices_
,
seeds
,
seed_offsets
,
fanouts
,
replace
,
layer
,
type_per_edge_
,
probs_or_mask
,
random_seed
,
seed2_contribution
);
return_eids
,
type_per_edge_
,
probs_or_mask
,
node_type_to_id_
,
edge_type_to_id_
,
random_seed
,
seed2_contribution
);
});
});
}
}
TORCH_CHECK
(
node
s
.
has_value
(),
"Nodes can not be None on the CPU."
);
TORCH_CHECK
(
seed
s
.
has_value
(),
"Nodes can not be None on the CPU."
);
if
(
probs_or_mask
.
has_value
())
{
if
(
probs_or_mask
.
has_value
())
{
// Note probs will be passed as input for 'torch.multinomial' in deeper
// Note probs will be passed as input for 'torch.multinomial' in deeper
...
@@ -667,7 +669,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
...
@@ -667,7 +669,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
{
random_seed
.
value
(),
static_cast
<
float
>
(
seed2_contribution
)},
{
random_seed
.
value
(),
static_cast
<
float
>
(
seed2_contribution
)},
NumNodes
()};
NumNodes
()};
return
SampleNeighborsImpl
(
return
SampleNeighborsImpl
(
node
s
.
value
(),
return_eids
,
seed
s
.
value
(),
return_eids
,
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
),
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
),
GetPickFn
(
GetPickFn
(
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
,
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
,
...
@@ -686,7 +688,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
...
@@ -686,7 +688,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
}
}
}();
}();
return
SampleNeighborsImpl
(
return
SampleNeighborsImpl
(
node
s
.
value
(),
return_eids
,
seed
s
.
value
(),
return_eids
,
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
),
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
),
GetPickFn
(
GetPickFn
(
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
,
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
,
...
@@ -695,7 +697,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
...
@@ -695,7 +697,7 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
}
else
{
}
else
{
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
;
SamplerArgs
<
SamplerType
::
NEIGHBOR
>
args
;
return
SampleNeighborsImpl
(
return
SampleNeighborsImpl
(
node
s
.
value
(),
return_eids
,
seed
s
.
value
(),
return_eids
,
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
),
GetNumPickFn
(
fanouts
,
replace
,
type_per_edge_
,
probs_or_mask
),
GetPickFn
(
GetPickFn
(
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
,
probs_or_mask
,
fanouts
,
replace
,
indptr_
.
options
(),
type_per_edge_
,
probs_or_mask
,
...
...
graphbolt/src/python_binding.cc
View file @
d3483fe1
...
@@ -36,7 +36,8 @@ TORCH_LIBRARY(graphbolt, m) {
...
@@ -36,7 +36,8 @@ TORCH_LIBRARY(graphbolt, m) {
&
FusedSampledSubgraph
::
original_column_node_ids
)
&
FusedSampledSubgraph
::
original_column_node_ids
)
.
def_readwrite
(
.
def_readwrite
(
"original_edge_ids"
,
&
FusedSampledSubgraph
::
original_edge_ids
)
"original_edge_ids"
,
&
FusedSampledSubgraph
::
original_edge_ids
)
.
def_readwrite
(
"type_per_edge"
,
&
FusedSampledSubgraph
::
type_per_edge
);
.
def_readwrite
(
"type_per_edge"
,
&
FusedSampledSubgraph
::
type_per_edge
)
.
def_readwrite
(
"etype_offsets"
,
&
FusedSampledSubgraph
::
etype_offsets
);
m
.
class_
<
storage
::
OnDiskNpyArray
>
(
"OnDiskNpyArray"
)
m
.
class_
<
storage
::
OnDiskNpyArray
>
(
"OnDiskNpyArray"
)
.
def
(
"index_select"
,
&
storage
::
OnDiskNpyArray
::
IndexSelect
);
.
def
(
"index_select"
,
&
storage
::
OnDiskNpyArray
::
IndexSelect
);
m
.
class_
<
FusedCSCSamplingGraph
>
(
"FusedCSCSamplingGraph"
)
m
.
class_
<
FusedCSCSamplingGraph
>
(
"FusedCSCSamplingGraph"
)
...
...
graphbolt/src/utils.h
View file @
d3483fe1
...
@@ -26,6 +26,17 @@ inline bool is_accessible_from_gpu(torch::Tensor tensor) {
...
@@ -26,6 +26,17 @@ inline bool is_accessible_from_gpu(torch::Tensor tensor) {
return
is_on_gpu
(
tensor
)
||
tensor
.
is_pinned
();
return
is_on_gpu
(
tensor
)
||
tensor
.
is_pinned
();
}
}
/**
* @brief Parses the destination node type from a given edge type triple
* seperated with ":".
*/
inline
std
::
string
parse_dst_ntype_from_etype
(
std
::
string
etype
)
{
auto
first_seperator_it
=
std
::
find
(
etype
.
begin
(),
etype
.
end
(),
':'
);
auto
second_seperator_pos
=
std
::
find
(
first_seperator_it
+
1
,
etype
.
end
(),
':'
)
-
etype
.
begin
();
return
etype
.
substr
(
second_seperator_pos
+
1
);
}
/**
/**
* @brief Retrieves the value of the tensor at the given index.
* @brief Retrieves the value of the tensor at the given index.
*
*
...
...
python/dgl/distributed/graph_services.py
View file @
d3483fe1
...
@@ -146,7 +146,7 @@ def _sample_neighbors_graphbolt(
...
@@ -146,7 +146,7 @@ def _sample_neighbors_graphbolt(
return_eids
=
g
.
edge_attributes
is
not
None
and
EID
in
g
.
edge_attributes
return_eids
=
g
.
edge_attributes
is
not
None
and
EID
in
g
.
edge_attributes
subgraph
=
g
.
_sample_neighbors
(
subgraph
=
g
.
_sample_neighbors
(
nodes
,
fanout
,
replace
=
replace
,
return_eids
=
return_eids
nodes
,
None
,
fanout
,
replace
=
replace
,
return_eids
=
return_eids
)
)
# 3. Map local node IDs to global node IDs.
# 3. Map local node IDs to global node IDs.
...
...
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
d3483fe1
...
@@ -444,7 +444,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -444,7 +444,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
)}
)}
"""
"""
if
isinstance
(
nodes
,
dict
):
if
isinstance
(
nodes
,
dict
):
nodes
=
self
.
_convert_to_homogeneous_nodes
(
nodes
)
nodes
,
_
=
self
.
_convert_to_homogeneous_nodes
(
nodes
)
# Ensure nodes is 1-D tensor.
# Ensure nodes is 1-D tensor.
assert
nodes
.
dim
()
==
1
,
"Nodes should be 1-D tensor."
assert
nodes
.
dim
()
==
1
,
"Nodes should be 1-D tensor."
...
@@ -453,22 +453,28 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -453,22 +453,28 @@ class FusedCSCSamplingGraph(SamplingGraph):
def
_convert_to_homogeneous_nodes
(
self
,
nodes
,
timestamps
=
None
):
def
_convert_to_homogeneous_nodes
(
self
,
nodes
,
timestamps
=
None
):
homogeneous_nodes
=
[]
homogeneous_nodes
=
[]
homogeneous_node_offsets
=
[
0
]
homogeneous_timestamps
=
[]
homogeneous_timestamps
=
[]
offset
=
self
.
_node_type_offset_list
offset
=
self
.
_node_type_offset_list
for
ntype
,
ids
in
nodes
.
items
():
for
ntype
,
ntype_id
in
self
.
node_type_to_id
.
items
():
ntype_id
=
self
.
node_type_to_id
[
ntype
]
ids
=
nodes
.
get
(
ntype
,
[])
homogeneous_nodes
.
append
(
ids
+
offset
[
ntype_id
])
if
len
(
ids
)
>
0
:
if
timestamps
is
not
None
:
homogeneous_nodes
.
append
(
ids
+
offset
[
ntype_id
])
homogeneous_timestamps
.
append
(
timestamps
[
ntype
])
if
timestamps
is
not
None
:
homogeneous_timestamps
.
append
(
timestamps
[
ntype
])
homogeneous_node_offsets
.
append
(
homogeneous_node_offsets
[
-
1
]
+
len
(
ids
)
)
if
timestamps
is
not
None
:
if
timestamps
is
not
None
:
return
torch
.
cat
(
homogeneous_nodes
),
torch
.
cat
(
return
torch
.
cat
(
homogeneous_nodes
),
torch
.
cat
(
homogeneous_timestamps
homogeneous_timestamps
)
)
return
torch
.
cat
(
homogeneous_nodes
)
return
torch
.
cat
(
homogeneous_nodes
)
,
homogeneous_node_offsets
def
_convert_to_sampled_subgraph
(
def
_convert_to_sampled_subgraph
(
self
,
self
,
C_sampled_subgraph
:
torch
.
ScriptObject
,
C_sampled_subgraph
:
torch
.
ScriptObject
,
seed_offsets
:
Optional
[
list
]
=
None
,
)
->
SampledSubgraphImpl
:
)
->
SampledSubgraphImpl
:
"""An internal function used to convert a fused homogeneous sampled
"""An internal function used to convert a fused homogeneous sampled
subgraph to general struct 'SampledSubgraphImpl'."""
subgraph to general struct 'SampledSubgraphImpl'."""
...
@@ -477,6 +483,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -477,6 +483,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
type_per_edge
=
C_sampled_subgraph
.
type_per_edge
type_per_edge
=
C_sampled_subgraph
.
type_per_edge
column
=
C_sampled_subgraph
.
original_column_node_ids
column
=
C_sampled_subgraph
.
original_column_node_ids
original_edge_ids
=
C_sampled_subgraph
.
original_edge_ids
original_edge_ids
=
C_sampled_subgraph
.
original_edge_ids
etype_offsets
=
C_sampled_subgraph
.
etype_offsets
if
etype_offsets
is
not
None
:
etype_offsets
=
etype_offsets
.
tolist
()
has_original_eids
=
(
has_original_eids
=
(
self
.
edge_attributes
is
not
None
self
.
edge_attributes
is
not
None
...
@@ -486,45 +495,78 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -486,45 +495,78 @@ class FusedCSCSamplingGraph(SamplingGraph):
original_edge_ids
=
torch
.
ops
.
graphbolt
.
index_select
(
original_edge_ids
=
torch
.
ops
.
graphbolt
.
index_select
(
self
.
edge_attributes
[
ORIGINAL_EDGE_ID
],
original_edge_ids
self
.
edge_attributes
[
ORIGINAL_EDGE_ID
],
original_edge_ids
)
)
if
type_per_edge
is
None
:
if
type_per_edge
is
None
and
etype_offsets
is
None
:
# The sampled graph is already a homogeneous graph.
# The sampled graph is already a homogeneous graph.
sampled_csc
=
CSCFormatBase
(
indptr
=
indptr
,
indices
=
indices
)
sampled_csc
=
CSCFormatBase
(
indptr
=
indptr
,
indices
=
indices
)
else
:
else
:
# UVA sampling requires us to move node_type_offset to GPU.
offset
=
self
.
_node_type_offset_list
self
.
node_type_offset
=
self
.
node_type_offset
.
to
(
column
.
device
)
# 1. Find node types for each nodes in column.
node_types
=
(
torch
.
searchsorted
(
self
.
node_type_offset
,
column
,
right
=
True
)
-
1
)
original_hetero_edge_ids
=
{}
original_hetero_edge_ids
=
{}
sub_indices
=
{}
sub_indices
=
{}
sub_indptr
=
{}
sub_indptr
=
{}
offset
=
self
.
_node_type_offset_list
if
etype_offsets
is
None
:
# 2. For loop each node type.
# UVA sampling requires us to move node_type_offset to GPU.
for
ntype
,
ntype_id
in
self
.
node_type_to_id
.
items
():
self
.
node_type_offset
=
self
.
node_type_offset
.
to
(
column
.
device
)
# Get all nodes of a specific node type in column.
# 1. Find node types for each nodes in column.
nids
=
torch
.
nonzero
(
node_types
==
ntype_id
).
view
(
-
1
)
node_types
=
(
nids_original_indptr
=
indptr
[
nids
+
1
]
torch
.
searchsorted
(
self
.
node_type_offset
,
column
,
right
=
True
)
-
1
)
for
ntype
,
ntype_id
in
self
.
node_type_to_id
.
items
():
# Get all nodes of a specific node type in column.
nids
=
torch
.
nonzero
(
node_types
==
ntype_id
).
view
(
-
1
)
nids_original_indptr
=
indptr
[
nids
+
1
]
for
etype
,
etype_id
in
self
.
edge_type_to_id
.
items
():
src_ntype
,
_
,
dst_ntype
=
etype_str_to_tuple
(
etype
)
if
dst_ntype
!=
ntype
:
continue
# Get all edge ids of a specific edge type.
eids
=
torch
.
nonzero
(
type_per_edge
==
etype_id
).
view
(
-
1
)
src_ntype_id
=
self
.
node_type_to_id
[
src_ntype
]
sub_indices
[
etype
]
=
(
indices
[
eids
]
-
offset
[
src_ntype_id
]
)
cum_edges
=
torch
.
searchsorted
(
eids
,
nids_original_indptr
,
right
=
False
)
sub_indptr
[
etype
]
=
torch
.
cat
(
(
torch
.
tensor
([
0
],
device
=
indptr
.
device
),
cum_edges
)
)
if
has_original_eids
:
original_hetero_edge_ids
[
etype
]
=
original_edge_ids
[
eids
]
else
:
edge_offsets
=
[
0
]
for
etype
,
etype_id
in
self
.
edge_type_to_id
.
items
():
for
etype
,
etype_id
in
self
.
edge_type_to_id
.
items
():
src_ntype
,
_
,
dst_ntype
=
etype_str_to_tuple
(
etype
)
src_ntype
,
_
,
dst_ntype
=
etype_str_to_tuple
(
etype
)
if
dst_ntype
!=
ntype
:
ntype_id
=
self
.
node_type_to_id
[
dst_ntype
]
continue
edge_offsets
.
append
(
# Get all edge ids of a specific edge type.
edge_offsets
[
-
1
]
eids
=
torch
.
nonzero
(
type_per_edge
==
etype_id
).
view
(
-
1
)
+
seed_offsets
[
ntype_id
+
1
]
src_ntype_id
=
self
.
node_type_to_id
[
src_ntype
]
-
seed_offsets
[
ntype_id
]
sub_indices
[
etype
]
=
indices
[
eids
]
-
offset
[
src_ntype_id
]
cum_edges
=
torch
.
searchsorted
(
eids
,
nids_original_indptr
,
right
=
False
)
sub_indptr
[
etype
]
=
torch
.
cat
(
(
torch
.
tensor
([
0
],
device
=
indptr
.
device
),
cum_edges
)
)
)
for
etype
,
etype_id
in
self
.
edge_type_to_id
.
items
():
src_ntype
,
_
,
dst_ntype
=
etype_str_to_tuple
(
etype
)
ntype_id
=
self
.
node_type_to_id
[
dst_ntype
]
sub_indptr_
=
indptr
[
edge_offsets
[
etype_id
]
:
edge_offsets
[
etype_id
+
1
]
+
1
]
sub_indptr
[
etype
]
=
sub_indptr_
-
sub_indptr_
[
0
]
sub_indices
[
etype
]
=
indices
[
etype_offsets
[
etype_id
]
:
etype_offsets
[
etype_id
+
1
]
]
if
has_original_eids
:
if
has_original_eids
:
original_hetero_edge_ids
[
etype
]
=
original_edge_ids
[
original_hetero_edge_ids
[
etype
]
=
original_edge_ids
[
eids
etype_offsets
[
etype_id
]
:
etype_offsets
[
etype_id
+
1
]
]
]
src_ntype_id
=
self
.
node_type_to_id
[
src_ntype
]
sub_indices
[
etype
]
-=
offset
[
src_ntype_id
]
if
has_original_eids
:
if
has_original_eids
:
original_edge_ids
=
original_hetero_edge_ids
original_edge_ids
=
original_hetero_edge_ids
sampled_csc
=
{
sampled_csc
=
{
...
@@ -541,7 +583,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -541,7 +583,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
def
sample_neighbors
(
def
sample_neighbors
(
self
,
self
,
node
s
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]],
seed
s
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]],
fanouts
:
torch
.
Tensor
,
fanouts
:
torch
.
Tensor
,
replace
:
bool
=
False
,
replace
:
bool
=
False
,
probs_name
:
Optional
[
str
]
=
None
,
probs_name
:
Optional
[
str
]
=
None
,
...
@@ -551,7 +593,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -551,7 +593,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Parameters
Parameters
----------
----------
node
s: torch.Tensor or Dict[str, torch.Tensor]
seed
s: torch.Tensor or Dict[str, torch.Tensor]
IDs of the given seed nodes.
IDs of the given seed nodes.
- If `nodes` is a tensor: It means the graph is homogeneous
- If `nodes` is a tensor: It means the graph is homogeneous
graph, and ids inside are homogeneous ids.
graph, and ids inside are homogeneous ids.
...
@@ -615,21 +657,27 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -615,21 +657,27 @@ class FusedCSCSamplingGraph(SamplingGraph):
indices=tensor([2]),
indices=tensor([2]),
)}
)}
"""
"""
if
isinstance
(
nodes
,
dict
):
nodes
=
self
.
_convert_to_homogeneous_nodes
(
nodes
)
return_eids
=
(
return_eids
=
(
self
.
edge_attributes
is
not
None
self
.
edge_attributes
is
not
None
and
ORIGINAL_EDGE_ID
in
self
.
edge_attributes
and
ORIGINAL_EDGE_ID
in
self
.
edge_attributes
)
)
seed_offsets
=
None
if
isinstance
(
seeds
,
dict
):
seeds
,
seed_offsets
=
self
.
_convert_to_homogeneous_nodes
(
seeds
)
elif
seeds
is
None
and
hasattr
(
self
,
"_seed_offset_list"
):
seed_offsets
=
self
.
_seed_offset_list
# pylint: disable=no-member
C_sampled_subgraph
=
self
.
_sample_neighbors
(
C_sampled_subgraph
=
self
.
_sample_neighbors
(
nodes
,
seeds
,
seed_offsets
,
fanouts
,
fanouts
,
replace
=
replace
,
replace
=
replace
,
probs_name
=
probs_name
,
probs_name
=
probs_name
,
return_eids
=
return_eids
,
return_eids
=
return_eids
,
)
)
return
self
.
_convert_to_sampled_subgraph
(
C_sampled_subgraph
)
return
self
.
_convert_to_sampled_subgraph
(
C_sampled_subgraph
,
seed_offsets
)
def
_check_sampler_arguments
(
self
,
nodes
,
fanouts
,
probs_name
):
def
_check_sampler_arguments
(
self
,
nodes
,
fanouts
,
probs_name
):
if
nodes
is
not
None
:
if
nodes
is
not
None
:
...
@@ -676,7 +724,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -676,7 +724,8 @@ class FusedCSCSamplingGraph(SamplingGraph):
def
_sample_neighbors
(
def
_sample_neighbors
(
self
,
self
,
nodes
:
torch
.
Tensor
,
seeds
:
torch
.
Tensor
,
seed_offsets
:
Optional
[
list
],
fanouts
:
torch
.
Tensor
,
fanouts
:
torch
.
Tensor
,
replace
:
bool
=
False
,
replace
:
bool
=
False
,
probs_name
:
Optional
[
str
]
=
None
,
probs_name
:
Optional
[
str
]
=
None
,
...
@@ -687,8 +736,11 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -687,8 +736,11 @@ class FusedCSCSamplingGraph(SamplingGraph):
Parameters
Parameters
----------
----------
node
s: torch.Tensor
seed
s: torch.Tensor
IDs of the given seed nodes.
IDs of the given seed nodes.
seeds_offsets: list, optional
The offsets of the given seeds,
seeds[seed_offsets[i]: seed_offsets[i + 1]] has node type i.
fanouts: torch.Tensor
fanouts: torch.Tensor
The number of edges to be sampled for each node with or without
The number of edges to be sampled for each node with or without
considering edge types.
considering edge types.
...
@@ -726,9 +778,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -726,9 +778,10 @@ class FusedCSCSamplingGraph(SamplingGraph):
The sampled C subgraph.
The sampled C subgraph.
"""
"""
# Ensure nodes is 1-D tensor.
# Ensure nodes is 1-D tensor.
self
.
_check_sampler_arguments
(
node
s
,
fanouts
,
probs_name
)
self
.
_check_sampler_arguments
(
seed
s
,
fanouts
,
probs_name
)
return
self
.
_c_csc_graph
.
sample_neighbors
(
return
self
.
_c_csc_graph
.
sample_neighbors
(
nodes
,
seeds
,
seed_offsets
,
fanouts
.
tolist
(),
fanouts
.
tolist
(),
replace
,
replace
,
False
,
# is_labor
False
,
# is_labor
...
@@ -740,7 +793,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -740,7 +793,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
def
sample_layer_neighbors
(
def
sample_layer_neighbors
(
self
,
self
,
node
s
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]],
seed
s
:
Union
[
torch
.
Tensor
,
Dict
[
str
,
torch
.
Tensor
]],
fanouts
:
torch
.
Tensor
,
fanouts
:
torch
.
Tensor
,
replace
:
bool
=
False
,
replace
:
bool
=
False
,
probs_name
:
Optional
[
str
]
=
None
,
probs_name
:
Optional
[
str
]
=
None
,
...
@@ -754,7 +807,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -754,7 +807,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
Parameters
Parameters
----------
----------
node
s: torch.Tensor or Dict[str, torch.Tensor]
seed
s: torch.Tensor or Dict[str, torch.Tensor]
IDs of the given seed nodes.
IDs of the given seed nodes.
- If `nodes` is a tensor: It means the graph is homogeneous
- If `nodes` is a tensor: It means the graph is homogeneous
graph, and ids inside are homogeneous ids.
graph, and ids inside are homogeneous ids.
...
@@ -844,10 +897,6 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -844,10 +897,6 @@ class FusedCSCSamplingGraph(SamplingGraph):
indices=tensor([2]),
indices=tensor([2]),
)}
)}
"""
"""
if
isinstance
(
nodes
,
dict
):
nodes
=
self
.
_convert_to_homogeneous_nodes
(
nodes
)
self
.
_check_sampler_arguments
(
nodes
,
fanouts
,
probs_name
)
if
random_seed
is
not
None
:
if
random_seed
is
not
None
:
assert
(
assert
(
1
<=
len
(
random_seed
)
<=
2
1
<=
len
(
random_seed
)
<=
2
...
@@ -856,12 +905,21 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -856,12 +905,21 @@ class FusedCSCSamplingGraph(SamplingGraph):
assert
(
assert
(
0
<=
seed2_contribution
<=
1
0
<=
seed2_contribution
<=
1
),
"seed2_contribution should be in [0, 1]."
),
"seed2_contribution should be in [0, 1]."
has_original_eids
=
(
has_original_eids
=
(
self
.
edge_attributes
is
not
None
self
.
edge_attributes
is
not
None
and
ORIGINAL_EDGE_ID
in
self
.
edge_attributes
and
ORIGINAL_EDGE_ID
in
self
.
edge_attributes
)
)
seed_offsets
=
None
if
isinstance
(
seeds
,
dict
):
seeds
,
seed_offsets
=
self
.
_convert_to_homogeneous_nodes
(
seeds
)
elif
seeds
is
None
and
hasattr
(
self
,
"_seed_offset_list"
):
seed_offsets
=
self
.
_seed_offset_list
# pylint: disable=no-member
self
.
_check_sampler_arguments
(
seeds
,
fanouts
,
probs_name
)
C_sampled_subgraph
=
self
.
_c_csc_graph
.
sample_neighbors
(
C_sampled_subgraph
=
self
.
_c_csc_graph
.
sample_neighbors
(
nodes
,
seeds
,
seed_offsets
,
fanouts
.
tolist
(),
fanouts
.
tolist
(),
replace
,
replace
,
True
,
True
,
...
@@ -870,7 +928,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
...
@@ -870,7 +928,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
random_seed
,
random_seed
,
seed2_contribution
,
seed2_contribution
,
)
)
return
self
.
_convert_to_sampled_subgraph
(
C_sampled_subgraph
)
return
self
.
_convert_to_sampled_subgraph
(
C_sampled_subgraph
,
seed_offsets
)
def
temporal_sample_neighbors
(
def
temporal_sample_neighbors
(
self
,
self
,
...
...
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
d3483fe1
...
@@ -46,44 +46,37 @@ class FetchInsubgraphData(Mapper):
...
@@ -46,44 +46,37 @@ class FetchInsubgraphData(Mapper):
def
_fetch_per_layer_impl
(
self
,
minibatch
,
stream
):
def
_fetch_per_layer_impl
(
self
,
minibatch
,
stream
):
with
torch
.
cuda
.
stream
(
self
.
stream
):
with
torch
.
cuda
.
stream
(
self
.
stream
):
index
=
minibatch
.
_seed_nodes
seeds
=
minibatch
.
_seed_nodes
if
isinstance
(
index
,
dict
):
is_hetero
=
isinstance
(
seeds
,
dict
)
for
idx
in
index
.
values
():
if
is_hetero
:
for
idx
in
seeds
.
values
():
idx
.
record_stream
(
torch
.
cuda
.
current_stream
())
idx
.
record_stream
(
torch
.
cuda
.
current_stream
())
index
=
self
.
graph
.
_convert_to_homogeneous_nodes
(
index
)
(
seeds
,
seed_offsets
,
)
=
self
.
graph
.
_convert_to_homogeneous_nodes
(
seeds
)
else
:
else
:
index
.
record_stream
(
torch
.
cuda
.
current_stream
())
seeds
.
record_stream
(
torch
.
cuda
.
current_stream
())
seed_offsets
=
None
def
record_stream
(
tensor
):
def
record_stream
(
tensor
):
if
stream
is
not
None
and
tensor
.
is_cuda
:
if
stream
is
not
None
and
tensor
.
is_cuda
:
tensor
.
record_stream
(
stream
)
tensor
.
record_stream
(
stream
)
return
tensor
return
tensor
if
self
.
graph
.
node_type_offset
is
None
:
# sorting not needed.
minibatch
.
_subgraph_seed_nodes
=
None
else
:
index
,
original_positions
=
index
.
sort
()
if
(
original_positions
.
diff
()
==
1
).
all
().
item
():
# already sorted.
minibatch
.
_subgraph_seed_nodes
=
None
else
:
minibatch
.
_subgraph_seed_nodes
=
record_stream
(
original_positions
.
sort
()[
1
]
)
index_select_csc_with_indptr
=
partial
(
index_select_csc_with_indptr
=
partial
(
torch
.
ops
.
graphbolt
.
index_select_csc
,
self
.
graph
.
csc_indptr
torch
.
ops
.
graphbolt
.
index_select_csc
,
self
.
graph
.
csc_indptr
)
)
indptr
,
indices
=
index_select_csc_with_indptr
(
indptr
,
indices
=
index_select_csc_with_indptr
(
self
.
graph
.
indices
,
index
,
None
self
.
graph
.
indices
,
seeds
,
None
)
)
record_stream
(
indptr
)
record_stream
(
indptr
)
record_stream
(
indices
)
record_stream
(
indices
)
output_size
=
len
(
indices
)
output_size
=
len
(
indices
)
if
self
.
graph
.
type_per_edge
is
not
None
:
if
self
.
graph
.
type_per_edge
is
not
None
:
_
,
type_per_edge
=
index_select_csc_with_indptr
(
_
,
type_per_edge
=
index_select_csc_with_indptr
(
self
.
graph
.
type_per_edge
,
index
,
output_size
self
.
graph
.
type_per_edge
,
seeds
,
output_size
)
)
record_stream
(
type_per_edge
)
record_stream
(
type_per_edge
)
else
:
else
:
...
@@ -94,27 +87,22 @@ class FetchInsubgraphData(Mapper):
...
@@ -94,27 +87,22 @@ class FetchInsubgraphData(Mapper):
)
)
if
probs_or_mask
is
not
None
:
if
probs_or_mask
is
not
None
:
_
,
probs_or_mask
=
index_select_csc_with_indptr
(
_
,
probs_or_mask
=
index_select_csc_with_indptr
(
probs_or_mask
,
index
,
output_size
probs_or_mask
,
seeds
,
output_size
)
)
record_stream
(
probs_or_mask
)
record_stream
(
probs_or_mask
)
else
:
else
:
probs_or_mask
=
None
probs_or_mask
=
None
if
self
.
graph
.
node_type_offset
is
not
None
:
node_type_offset
=
torch
.
searchsorted
(
index
,
self
.
graph
.
node_type_offset
)
else
:
node_type_offset
=
None
subgraph
=
fused_csc_sampling_graph
(
subgraph
=
fused_csc_sampling_graph
(
indptr
,
indptr
,
indices
,
indices
,
node_type_offset
=
node_type_offset
,
node_type_offset
=
self
.
graph
.
node_type_offset
,
type_per_edge
=
type_per_edge
,
type_per_edge
=
type_per_edge
,
node_type_to_id
=
self
.
graph
.
node_type_to_id
,
node_type_to_id
=
self
.
graph
.
node_type_to_id
,
edge_type_to_id
=
self
.
graph
.
edge_type_to_id
,
edge_type_to_id
=
self
.
graph
.
edge_type_to_id
,
)
)
if
self
.
prob_name
is
not
None
and
probs_or_mask
is
not
None
:
if
self
.
prob_name
is
not
None
and
probs_or_mask
is
not
None
:
subgraph
.
edge_attributes
=
{
self
.
prob_name
:
probs_or_mask
}
subgraph
.
edge_attributes
=
{
self
.
prob_name
:
probs_or_mask
}
subgraph
.
_seed_offset_list
=
seed_offsets
minibatch
.
sampled_subgraphs
.
insert
(
0
,
subgraph
)
minibatch
.
sampled_subgraphs
.
insert
(
0
,
subgraph
)
...
@@ -152,14 +140,12 @@ class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer):
...
@@ -152,14 +140,12 @@ class SamplePerLayerFromFetchedSubgraph(MiniBatchTransformer):
if
hasattr
(
minibatch
,
key
)
if
hasattr
(
minibatch
,
key
)
}
}
sampled_subgraph
=
getattr
(
subgraph
,
self
.
sampler_name
)(
sampled_subgraph
=
getattr
(
subgraph
,
self
.
sampler_name
)(
minibatch
.
_subgraph_seed_nodes
,
None
,
self
.
fanout
,
self
.
fanout
,
self
.
replace
,
self
.
replace
,
self
.
prob_name
,
self
.
prob_name
,
**
kwargs
,
**
kwargs
,
)
)
delattr
(
minibatch
,
"_subgraph_seed_nodes"
)
sampled_subgraph
.
original_column_node_ids
=
minibatch
.
_seed_nodes
minibatch
.
sampled_subgraphs
[
0
]
=
sampled_subgraph
minibatch
.
sampled_subgraphs
[
0
]
=
sampled_subgraph
return
minibatch
return
minibatch
...
...
tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py
View file @
d3483fe1
...
@@ -15,10 +15,10 @@ def get_hetero_graph():
...
@@ -15,10 +15,10 @@ def get_hetero_graph():
# [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]
# [2, 4, 2, 3, 0, 1, 1, 0, 0, 1]
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# [1, 1, 1, 1, 0, 0, 0, 0, 0] - > edge type.
# num_nodes = 5, num_n1 = 2, num_n2 = 3
# num_nodes = 5, num_n1 = 2, num_n2 = 3
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
}
ntypes
=
{
"n1"
:
0
,
"n2"
:
1
,
"n3"
:
2
}
etypes
=
{
"n
1
:e1:n
2
"
:
0
,
"n
2
:e2:n
1
"
:
1
}
etypes
=
{
"n
2
:e1:n
3
"
:
0
,
"n
3
:e2:n
2
"
:
1
}
indptr
=
torch
.
LongTensor
([
0
,
2
,
4
,
6
,
8
,
10
])
indptr
=
torch
.
LongTensor
([
0
,
0
,
2
,
4
,
6
,
8
,
10
])
indices
=
torch
.
LongTensor
([
2
,
4
,
2
,
3
,
0
,
1
,
1
,
0
,
0
,
1
])
indices
=
torch
.
LongTensor
([
3
,
5
,
3
,
4
,
1
,
2
,
2
,
1
,
1
,
2
])
type_per_edge
=
torch
.
LongTensor
([
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
])
type_per_edge
=
torch
.
LongTensor
([
1
,
1
,
1
,
1
,
0
,
0
,
0
,
0
,
0
,
0
])
edge_attributes
=
{
edge_attributes
=
{
"weight"
:
torch
.
FloatTensor
(
"weight"
:
torch
.
FloatTensor
(
...
@@ -26,7 +26,7 @@ def get_hetero_graph():
...
@@ -26,7 +26,7 @@ def get_hetero_graph():
),
),
"mask"
:
torch
.
BoolTensor
([
1
,
0
,
1
,
0
,
1
,
1
,
1
,
0
,
1
,
1
]),
"mask"
:
torch
.
BoolTensor
([
1
,
0
,
1
,
0
,
1
,
1
,
1
,
0
,
1
,
1
]),
}
}
node_type_offset
=
torch
.
LongTensor
([
0
,
2
,
5
])
node_type_offset
=
torch
.
LongTensor
([
0
,
1
,
3
,
6
])
return
gb
.
fused_csc_sampling_graph
(
return
gb
.
fused_csc_sampling_graph
(
indptr
,
indptr
,
indices
,
indices
,
...
@@ -51,7 +51,7 @@ def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted):
...
@@ -51,7 +51,7 @@ def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted):
itemset
=
gb
.
ItemSet
(
items
,
names
=
names
)
itemset
=
gb
.
ItemSet
(
items
,
names
=
names
)
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
graph
=
get_hetero_graph
().
to
(
F
.
ctx
())
if
hetero
:
if
hetero
:
itemset
=
gb
.
ItemSetDict
({
"n
2
"
:
itemset
})
itemset
=
gb
.
ItemSetDict
({
"n
3
"
:
itemset
})
else
:
else
:
graph
.
type_per_edge
=
None
graph
.
type_per_edge
=
None
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
item_sampler
=
gb
.
ItemSampler
(
itemset
,
batch_size
=
2
).
copy_to
(
F
.
ctx
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment