Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
382a2de7
Unverified
Commit
382a2de7
authored
Nov 08, 2023
by
LastWhisper
Committed by
GitHub
Nov 08, 2023
Browse files
[GraphBolt] Refactor SampledSubgraph and update the corresponding method. (#6533)
parent
a24a38bc
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
61 additions
and
58 deletions
+61
-58
docs/source/api/python/dgl.graphbolt.rst
docs/source/api/python/dgl.graphbolt.rst
+1
-1
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
+7
-7
graphbolt/include/graphbolt/fused_sampled_subgraph.h
graphbolt/include/graphbolt/fused_sampled_subgraph.h
+10
-9
graphbolt/src/fused_csc_sampling_graph.cc
graphbolt/src/fused_csc_sampling_graph.cc
+6
-5
graphbolt/src/python_binding.cc
graphbolt/src/python_binding.cc
+8
-7
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
+8
-8
python/dgl/graphbolt/impl/neighbor_sampler.py
python/dgl/graphbolt/impl/neighbor_sampler.py
+2
-2
python/dgl/graphbolt/impl/sampled_subgraph_impl.py
python/dgl/graphbolt/impl/sampled_subgraph_impl.py
+3
-3
python/dgl/graphbolt/minibatch.py
python/dgl/graphbolt/minibatch.py
+3
-3
python/dgl/graphbolt/sampled_subgraph.py
python/dgl/graphbolt/sampled_subgraph.py
+1
-1
tests/python/pytorch/graphbolt/impl/test_minibatch.py
tests/python/pytorch/graphbolt/impl/test_minibatch.py
+6
-6
tests/python/pytorch/graphbolt/impl/test_sampled_subgraph_impl.py
...thon/pytorch/graphbolt/impl/test_sampled_subgraph_impl.py
+4
-4
tests/python/pytorch/graphbolt/test_feature_fetcher.py
tests/python/pytorch/graphbolt/test_feature_fetcher.py
+2
-2
No files found.
docs/source/api/python/dgl.graphbolt.rst
View file @
382a2de7
...
...
@@ -60,7 +60,7 @@ Standard Implementations
UniformNegativeSampler
NeighborSampler
LayerNeighborSampler
SampledSubgraphImpl
Fused
SampledSubgraphImpl
BasicFeatureStore
TorchBasedFeature
TorchBasedFeatureStore
...
...
graphbolt/include/graphbolt/fused_csc_sampling_graph.h
View file @
382a2de7
...
...
@@ -6,7 +6,7 @@
#ifndef GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
#define GRAPHBOLT_CSC_SAMPLING_GRAPH_H_
#include <graphbolt/sampled_subgraph.h>
#include <graphbolt/
fused_
sampled_subgraph.h>
#include <graphbolt/shared_memory.h>
#include <torch/torch.h>
...
...
@@ -172,9 +172,9 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* @brief Return the subgraph induced on the inbound edges of the given nodes.
* @param nodes Type agnostic node IDs to form the subgraph.
*
* @return SampledSubgraph.
* @return
Fused
SampledSubgraph.
*/
c10
::
intrusive_ptr
<
SampledSubgraph
>
InSubgraph
(
c10
::
intrusive_ptr
<
Fused
SampledSubgraph
>
InSubgraph
(
const
torch
::
Tensor
&
nodes
)
const
;
/**
...
...
@@ -208,10 +208,10 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
* a 1D floating-point or boolean tensor, with the number of elements
* equalling the total number of edges.
*
* @return An intrusive pointer to a SampledSubgraph object containing
the
* sampled graph's information.
* @return An intrusive pointer to a
Fused
SampledSubgraph object containing
*
the
sampled graph's information.
*/
c10
::
intrusive_ptr
<
SampledSubgraph
>
SampleNeighbors
(
c10
::
intrusive_ptr
<
Fused
SampledSubgraph
>
SampleNeighbors
(
const
torch
::
Tensor
&
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
std
::
string
>
probs_name
)
const
;
...
...
@@ -276,7 +276,7 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
private:
template
<
typename
NumPickFn
,
typename
PickFn
>
c10
::
intrusive_ptr
<
SampledSubgraph
>
SampleNeighborsImpl
(
c10
::
intrusive_ptr
<
Fused
SampledSubgraph
>
SampleNeighborsImpl
(
const
torch
::
Tensor
&
nodes
,
bool
return_eids
,
NumPickFn
num_pick_fn
,
PickFn
pick_fn
)
const
;
...
...
graphbolt/include/graphbolt/sampled_subgraph.h
→
graphbolt/include/graphbolt/
fused_
sampled_subgraph.h
View file @
382a2de7
/**
* Copyright (c) 2023 by Contributors
* @file graphbolt/sampled_subgraph.h
* @file graphbolt/
fused_
sampled_subgraph.h
* @brief Header file of sampled sub graph.
*/
#ifndef GRAPHBOLT_SAMPLED_SUBGRAPH_H_
#define GRAPHBOLT_SAMPLED_SUBGRAPH_H_
#ifndef GRAPHBOLT_
FUSED_
SAMPLED_SUBGRAPH_H_
#define GRAPHBOLT_
FUSED_
SAMPLED_SUBGRAPH_H_
#include <torch/custom_class.h>
#include <torch/torch.h>
...
...
@@ -24,7 +24,8 @@ namespace sampling {
* auto indices = torch::tensor({55, 101, 3, 3}, {torch::kInt64});
* auto original_column_node_ids = torch::tensor({3, 3, 101}, {torch::kInt64});
*
* SampledSubgraph sampledSubgraph(indptr, indices, original_column_node_ids);
* FusedSampledSubgraph sampledSubgraph(indptr, indices,
* original_column_node_ids);
* ```
*
* The `original_column_node_ids` indicates that nodes `[3, 3, 101]` in the
...
...
@@ -37,10 +38,10 @@ namespace sampling {
* inconsistent with column, which is legal, as `3` is mapped to `0` and `1` in
* the column while `2` in the row.
*/
struct
SampledSubgraph
:
torch
::
CustomClassHolder
{
struct
Fused
SampledSubgraph
:
torch
::
CustomClassHolder
{
public:
/**
* @brief Constructor for the SampledSubgraph struct.
* @brief Constructor for the
Fused
SampledSubgraph struct.
*
* @param indptr CSC format index pointer array.
* @param indices CSC format index array.
...
...
@@ -51,7 +52,7 @@ struct SampledSubgraph : torch::CustomClassHolder {
* @param original_edge_ids Reverse edge ids in the original graph.
* @param type_per_edge Type id of each edge.
*/
SampledSubgraph
(
Fused
SampledSubgraph
(
torch
::
Tensor
indptr
,
torch
::
Tensor
indices
,
torch
::
Tensor
original_column_node_ids
,
torch
::
optional
<
torch
::
Tensor
>
original_row_node_ids
=
torch
::
nullopt
,
...
...
@@ -64,7 +65,7 @@ struct SampledSubgraph : torch::CustomClassHolder {
original_edge_ids
(
original_edge_ids
),
type_per_edge
(
type_per_edge
)
{}
SampledSubgraph
()
=
default
;
Fused
SampledSubgraph
()
=
default
;
/**
* @brief CSC format index pointer array, where the implicit node ids are
...
...
@@ -118,4 +119,4 @@ struct SampledSubgraph : torch::CustomClassHolder {
}
// namespace sampling
}
// namespace graphbolt
#endif // GRAPHBOLT_SAMPLED_SUBGRAPH_H_
#endif // GRAPHBOLT_
FUSED_
SAMPLED_SUBGRAPH_H_
graphbolt/src/fused_csc_sampling_graph.cc
View file @
382a2de7
...
...
@@ -182,7 +182,7 @@ FusedCSCSamplingGraph::GetState() const {
return
state
;
}
c10
::
intrusive_ptr
<
SampledSubgraph
>
FusedCSCSamplingGraph
::
InSubgraph
(
c10
::
intrusive_ptr
<
Fused
SampledSubgraph
>
FusedCSCSamplingGraph
::
InSubgraph
(
const
torch
::
Tensor
&
nodes
)
const
{
using
namespace
torch
::
indexing
;
const
int32_t
kDefaultGrainSize
=
100
;
...
...
@@ -211,7 +211,7 @@ c10::intrusive_ptr<SampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
torch
::
Tensor
compact_indptr
=
torch
::
zeros
({
nonzero_idx
.
size
(
0
)
+
1
},
indptr_
.
dtype
());
compact_indptr
.
index_put_
({
Slice
(
1
,
None
)},
indptr
.
index
({
nonzero_idx
}));
return
c10
::
make_intrusive
<
SampledSubgraph
>
(
return
c10
::
make_intrusive
<
Fused
SampledSubgraph
>
(
compact_indptr
.
cumsum
(
0
),
torch
::
cat
(
indices_arr
),
nonzero_idx
-
1
,
torch
::
arange
(
0
,
NumNodes
()),
torch
::
cat
(
edge_ids_arr
),
type_per_edge_
...
...
@@ -305,7 +305,8 @@ auto GetPickFn(
}
template
<
typename
NumPickFn
,
typename
PickFn
>
c10
::
intrusive_ptr
<
SampledSubgraph
>
FusedCSCSamplingGraph
::
SampleNeighborsImpl
(
c10
::
intrusive_ptr
<
FusedSampledSubgraph
>
FusedCSCSamplingGraph
::
SampleNeighborsImpl
(
const
torch
::
Tensor
&
nodes
,
bool
return_eids
,
NumPickFn
num_pick_fn
,
PickFn
pick_fn
)
const
{
const
int64_t
num_nodes
=
nodes
.
size
(
0
);
...
...
@@ -417,12 +418,12 @@ c10::intrusive_ptr<SampledSubgraph> FusedCSCSamplingGraph::SampleNeighborsImpl(
torch
::
optional
<
torch
::
Tensor
>
subgraph_reverse_edge_ids
=
torch
::
nullopt
;
if
(
return_eids
)
subgraph_reverse_edge_ids
=
std
::
move
(
picked_eids
);
return
c10
::
make_intrusive
<
SampledSubgraph
>
(
return
c10
::
make_intrusive
<
Fused
SampledSubgraph
>
(
subgraph_indptr
,
subgraph_indices
,
nodes
,
torch
::
nullopt
,
subgraph_reverse_edge_ids
,
subgraph_type_per_edge
);
}
c10
::
intrusive_ptr
<
SampledSubgraph
>
FusedCSCSamplingGraph
::
SampleNeighbors
(
c10
::
intrusive_ptr
<
Fused
SampledSubgraph
>
FusedCSCSamplingGraph
::
SampleNeighbors
(
const
torch
::
Tensor
&
nodes
,
const
std
::
vector
<
int64_t
>&
fanouts
,
bool
replace
,
bool
layer
,
bool
return_eids
,
torch
::
optional
<
std
::
string
>
probs_name
)
const
{
...
...
graphbolt/src/python_binding.cc
View file @
382a2de7
...
...
@@ -15,17 +15,18 @@ namespace graphbolt {
namespace
sampling
{
TORCH_LIBRARY
(
graphbolt
,
m
)
{
m
.
class_
<
SampledSubgraph
>
(
"SampledSubgraph"
)
m
.
class_
<
Fused
SampledSubgraph
>
(
"
Fused
SampledSubgraph"
)
.
def
(
torch
::
init
<>
())
.
def_readwrite
(
"indptr"
,
&
SampledSubgraph
::
indptr
)
.
def_readwrite
(
"indices"
,
&
SampledSubgraph
::
indices
)
.
def_readwrite
(
"indptr"
,
&
Fused
SampledSubgraph
::
indptr
)
.
def_readwrite
(
"indices"
,
&
Fused
SampledSubgraph
::
indices
)
.
def_readwrite
(
"original_row_node_ids"
,
&
SampledSubgraph
::
original_row_node_ids
)
"original_row_node_ids"
,
&
Fused
SampledSubgraph
::
original_row_node_ids
)
.
def_readwrite
(
"original_column_node_ids"
,
&
SampledSubgraph
::
original_column_node_ids
)
.
def_readwrite
(
"original_edge_ids"
,
&
SampledSubgraph
::
original_edge_ids
)
.
def_readwrite
(
"type_per_edge"
,
&
SampledSubgraph
::
type_per_edge
);
&
FusedSampledSubgraph
::
original_column_node_ids
)
.
def_readwrite
(
"original_edge_ids"
,
&
FusedSampledSubgraph
::
original_edge_ids
)
.
def_readwrite
(
"type_per_edge"
,
&
FusedSampledSubgraph
::
type_per_edge
);
m
.
class_
<
FusedCSCSamplingGraph
>
(
"FusedCSCSamplingGraph"
)
.
def
(
"num_nodes"
,
&
FusedCSCSamplingGraph
::
NumNodes
)
.
def
(
"num_edges"
,
&
FusedCSCSamplingGraph
::
NumEdges
)
...
...
python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
View file @
382a2de7
...
...
@@ -15,7 +15,7 @@ from ...convert import to_homogeneous
from
...heterograph
import
DGLGraph
from
..base
import
etype_str_to_tuple
,
etype_tuple_to_str
,
ORIGINAL_EDGE_ID
from
..sampling_graph
import
SamplingGraph
from
.sampled_subgraph_impl
import
SampledSubgraphImpl
from
.sampled_subgraph_impl
import
Fused
SampledSubgraphImpl
__all__
=
[
...
...
@@ -305,7 +305,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
assert
len
(
torch
.
unique
(
nodes
))
==
len
(
nodes
),
"Nodes cannot have duplicate values."
# TODO: change the result to 'SampledSubgraphImpl'.
# TODO: change the result to '
Fused
SampledSubgraphImpl'.
return
self
.
_c_csc_graph
.
in_subgraph
(
nodes
)
def
_convert_to_sampled_subgraph
(
...
...
@@ -313,7 +313,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
C_sampled_subgraph
:
torch
.
ScriptObject
,
):
"""An internal function used to convert a fused homogeneous sampled
subgraph to general struct 'SampledSubgraphImpl'."""
subgraph to general struct '
Fused
SampledSubgraphImpl'."""
column_num
=
(
C_sampled_subgraph
.
indptr
[
1
:]
-
C_sampled_subgraph
.
indptr
[:
-
1
]
)
...
...
@@ -353,7 +353,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
original_hetero_edge_ids
[
etype
]
=
original_edge_ids
[
mask
]
if
has_original_eids
:
original_edge_ids
=
original_hetero_edge_ids
return
SampledSubgraphImpl
(
return
Fused
SampledSubgraphImpl
(
node_pairs
=
node_pairs
,
original_edge_ids
=
original_edge_ids
)
...
...
@@ -370,7 +370,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts
:
torch
.
Tensor
,
replace
:
bool
=
False
,
probs_name
:
Optional
[
str
]
=
None
,
)
->
SampledSubgraphImpl
:
)
->
Fused
SampledSubgraphImpl
:
"""Sample neighboring edges of the given nodes and return the induced
subgraph.
...
...
@@ -411,7 +411,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
equalling the total number of edges.
Returns
-------
SampledSubgraphImpl
Fused
SampledSubgraphImpl
The sampled subgraph.
Examples
...
...
@@ -548,7 +548,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
fanouts
:
torch
.
Tensor
,
replace
:
bool
=
False
,
probs_name
:
Optional
[
str
]
=
None
,
)
->
SampledSubgraphImpl
:
)
->
Fused
SampledSubgraphImpl
:
"""Sample neighboring edges of the given nodes and return the induced
subgraph via layer-neighbor sampling from the NeurIPS 2023 paper
`Layer-Neighbor Sampling -- Defusing Neighborhood Explosion in GNNs
...
...
@@ -591,7 +591,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
equalling the total number of edges.
Returns
-------
SampledSubgraphImpl
Fused
SampledSubgraphImpl
The sampled subgraph.
Examples
...
...
python/dgl/graphbolt/impl/neighbor_sampler.py
View file @
382a2de7
...
...
@@ -5,7 +5,7 @@ from torch.utils.data import functional_datapipe
from
..subgraph_sampler
import
SubgraphSampler
from
..utils
import
unique_and_compact_node_pairs
from
.sampled_subgraph_impl
import
SampledSubgraphImpl
from
.sampled_subgraph_impl
import
Fused
SampledSubgraphImpl
__all__
=
[
"NeighborSampler"
,
"LayerNeighborSampler"
]
...
...
@@ -124,7 +124,7 @@ class NeighborSampler(SubgraphSampler):
)
=
unique_and_compact_node_pairs
(
subgraph
.
node_pairs
,
seeds
)
else
:
raise
RuntimeError
(
"Not implemented yet."
)
subgraph
=
SampledSubgraphImpl
(
subgraph
=
Fused
SampledSubgraphImpl
(
node_pairs
=
compacted_node_pairs
,
original_column_node_ids
=
seeds
,
original_row_node_ids
=
original_row_node_ids
,
...
...
python/dgl/graphbolt/impl/sampled_subgraph_impl.py
View file @
382a2de7
...
...
@@ -8,11 +8,11 @@ import torch
from
..base
import
etype_str_to_tuple
from
..sampled_subgraph
import
SampledSubgraph
__all__
=
[
"SampledSubgraphImpl"
]
__all__
=
[
"
Fused
SampledSubgraphImpl"
]
@
dataclass
class
SampledSubgraphImpl
(
SampledSubgraph
):
class
Fused
SampledSubgraphImpl
(
SampledSubgraph
):
r
"""Sampled subgraph of FusedCSCSamplingGraph.
Examples
...
...
@@ -22,7 +22,7 @@ class SampledSubgraphImpl(SampledSubgraph):
>>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl(
>>> subgraph = gb.
Fused
SampledSubgraphImpl(
... node_pairs=node_pairs,
... original_column_node_ids=original_column_node_ids,
... original_row_node_ids=original_row_node_ids,
...
...
python/dgl/graphbolt/minibatch.py
View file @
382a2de7
...
...
@@ -453,16 +453,16 @@ def _minibatch_str(minibatch: MiniBatch) -> str:
if
isinstance
(
val
,
list
):
if
len
(
val
)
==
0
:
val
=
"[]"
# Special handling of SampledSubgraphImpl data. Each element of
# Special handling of
Fused
SampledSubgraphImpl data. Each element of
# the data occupies one row and is further structured.
elif
isinstance
(
val
[
0
],
dgl
.
graphbolt
.
impl
.
sampled_subgraph_impl
.
SampledSubgraphImpl
,
dgl
.
graphbolt
.
impl
.
sampled_subgraph_impl
.
Fused
SampledSubgraphImpl
,
):
sampledsubgraph_strs
=
[]
for
sampledsubgraph
in
val
:
ss_attributes
=
_get_attributes
(
sampledsubgraph
)
sampledsubgraph_str
=
"SampledSubgraphImpl("
sampledsubgraph_str
=
"
Fused
SampledSubgraphImpl("
for
ss_name
in
ss_attributes
:
ss_val
=
str
(
getattr
(
sampledsubgraph
,
ss_name
))
sampledsubgraph_str
=
(
...
...
python/dgl/graphbolt/sampled_subgraph.py
View file @
382a2de7
...
...
@@ -123,7 +123,7 @@ class SampledSubgraph:
>>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl(
>>> subgraph = gb.
Fused
SampledSubgraphImpl(
... node_pairs=node_pairs,
... original_column_node_ids=original_column_node_ids,
... original_row_node_ids=original_row_node_ids,
...
...
tests/python/pytorch/graphbolt/impl/test_minibatch.py
View file @
382a2de7
...
...
@@ -39,7 +39,7 @@ def create_homo_minibatch():
subgraphs
=
[]
for
i
in
range
(
2
):
subgraphs
.
append
(
gb
.
SampledSubgraphImpl
(
gb
.
Fused
SampledSubgraphImpl
(
node_pairs
=
node_pairs
[
i
],
original_column_node_ids
=
original_column_node_ids
[
i
],
original_row_node_ids
=
original_row_node_ids
[
i
],
...
...
@@ -93,7 +93,7 @@ def create_hetero_minibatch():
subgraphs
=
[]
for
i
in
range
(
2
):
subgraphs
.
append
(
gb
.
SampledSubgraphImpl
(
gb
.
Fused
SampledSubgraphImpl
(
node_pairs
=
node_pairs
[
i
],
original_column_node_ids
=
original_column_node_ids
[
i
],
original_row_node_ids
=
original_row_node_ids
[
i
],
...
...
@@ -142,7 +142,7 @@ def test_minibatch_representation():
subgraphs
=
[]
for
i
in
range
(
2
):
subgraphs
.
append
(
gb
.
SampledSubgraphImpl
(
gb
.
Fused
SampledSubgraphImpl
(
node_pairs
=
node_pairs
[
i
],
original_column_node_ids
=
original_column_node_ids
[
i
],
original_row_node_ids
=
original_row_node_ids
[
i
],
...
...
@@ -191,11 +191,11 @@ def test_minibatch_representation():
)
expect_result
=
str
(
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(node_pairs=(tensor([0, 1, 2, 2, 2, 1]), tensor([0, 1, 1, 2, 3, 2])),
sampled_subgraphs=[
Fused
SampledSubgraphImpl(node_pairs=(tensor([0, 1, 2, 2, 2, 1]), tensor([0, 1, 1, 2, 3, 2])),
original_column_node_ids=tensor([10, 11, 12, 13]),
original_edge_ids=tensor([19, 20, 21, 22, 25, 30]),
original_row_node_ids=tensor([10, 11, 12, 13]),),
SampledSubgraphImpl(node_pairs=(tensor([0, 1, 2]), tensor([1, 0, 0])),
Fused
SampledSubgraphImpl(node_pairs=(tensor([0, 1, 2]), tensor([1, 0, 0])),
original_column_node_ids=tensor([10, 11]),
original_edge_ids=tensor([10, 15, 17]),
original_row_node_ids=tensor([10, 11, 12]),)],
...
...
@@ -260,7 +260,7 @@ def test_dgl_minibatch_representation():
subgraphs
=
[]
for
i
in
range
(
2
):
subgraphs
.
append
(
gb
.
SampledSubgraphImpl
(
gb
.
Fused
SampledSubgraphImpl
(
node_pairs
=
node_pairs
[
i
],
original_column_node_ids
=
original_column_node_ids
[
i
],
original_row_node_ids
=
original_row_node_ids
[
i
],
...
...
tests/python/pytorch/graphbolt/impl/test_sampled_subgraph_impl.py
View file @
382a2de7
...
...
@@ -4,7 +4,7 @@ import backend as F
import
pytest
import
torch
from
dgl.graphbolt.impl.sampled_subgraph_impl
import
SampledSubgraphImpl
from
dgl.graphbolt.impl.sampled_subgraph_impl
import
Fused
SampledSubgraphImpl
def
_assert_container_equal
(
lhs
,
rhs
):
...
...
@@ -42,7 +42,7 @@ def test_exclude_edges_homo(reverse_row, reverse_column):
original_column_node_ids
=
None
dst_to_exclude
=
torch
.
tensor
([
4
])
original_edge_ids
=
torch
.
Tensor
([
5
,
9
,
10
])
subgraph
=
SampledSubgraphImpl
(
subgraph
=
Fused
SampledSubgraphImpl
(
node_pairs
,
original_column_node_ids
,
original_row_node_ids
,
...
...
@@ -95,7 +95,7 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
original_column_node_ids
=
None
dst_to_exclude
=
torch
.
tensor
([
0
,
2
])
original_edge_ids
=
{
"A:relation:B"
:
torch
.
tensor
([
19
,
20
,
21
])}
subgraph
=
SampledSubgraphImpl
(
subgraph
=
Fused
SampledSubgraphImpl
(
node_pairs
=
node_pairs
,
original_column_node_ids
=
original_column_node_ids
,
original_row_node_ids
=
original_row_node_ids
,
...
...
@@ -158,7 +158,7 @@ def test_sampled_subgraph_to_device():
}
dst_to_exclude
=
torch
.
tensor
([
10
,
12
])
original_edge_ids
=
{
"A:relation:B"
:
torch
.
tensor
([
19
,
20
,
21
])}
subgraph
=
SampledSubgraphImpl
(
subgraph
=
Fused
SampledSubgraphImpl
(
node_pairs
=
node_pairs
,
original_column_node_ids
=
original_column_node_ids
,
original_row_node_ids
=
original_row_node_ids
,
...
...
tests/python/pytorch/graphbolt/test_feature_fetcher.py
View file @
382a2de7
...
...
@@ -77,7 +77,7 @@ def test_FeatureFetcher_with_edges_homo():
subgraphs
=
[]
for
_
in
range
(
3
):
subgraphs
.
append
(
gb
.
SampledSubgraphImpl
(
gb
.
Fused
SampledSubgraphImpl
(
node_pairs
=
(
torch
.
tensor
([]),
torch
.
tensor
([])),
original_edge_ids
=
torch
.
randint
(
0
,
graph
.
total_num_edges
,
(
10
,)
...
...
@@ -168,7 +168,7 @@ def test_FeatureFetcher_with_edges_hetero():
}
for
_
in
range
(
3
):
subgraphs
.
append
(
gb
.
SampledSubgraphImpl
(
gb
.
Fused
SampledSubgraphImpl
(
node_pairs
=
(
torch
.
tensor
([]),
torch
.
tensor
([])),
original_edge_ids
=
original_edge_ids
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment