Unverified Commit ed3840fc authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Replace reverse with original (#6371)

parent 7e49ccef
...@@ -206,8 +206,8 @@ def to_dgl_blocks(sampled_subgraphs: gb.SampledSubgraphImpl): ...@@ -206,8 +206,8 @@ def to_dgl_blocks(sampled_subgraphs: gb.SampledSubgraphImpl):
blocks = [ blocks = [
dgl.create_block( dgl.create_block(
sampled_subgraph.node_pairs, sampled_subgraph.node_pairs,
num_src_nodes=sampled_subgraph.reverse_row_node_ids.shape[0], num_src_nodes=sampled_subgraph.original_row_node_ids.shape[0],
num_dst_nodes=sampled_subgraph.reverse_column_node_ids.shape[0], num_dst_nodes=sampled_subgraph.original_column_node_ids.shape[0],
) )
for sampled_subgraph in sampled_subgraphs for sampled_subgraph in sampled_subgraphs
] ]
......
...@@ -22,17 +22,17 @@ namespace sampling { ...@@ -22,17 +22,17 @@ namespace sampling {
* ``` * ```
* auto indptr = torch::tensor({0, 2, 3, 4}, {torch::kInt64}); * auto indptr = torch::tensor({0, 2, 3, 4}, {torch::kInt64});
* auto indices = torch::tensor({55, 101, 3, 3}, {torch::kInt64}); * auto indices = torch::tensor({55, 101, 3, 3}, {torch::kInt64});
* auto reverse_column_node_ids = torch::tensor({3, 3, 101}, {torch::kInt64}); * auto original_column_node_ids = torch::tensor({3, 3, 101}, {torch::kInt64});
* *
* SampledSubgraph sampledSubgraph(indptr, indices, reverse_column_node_ids); * SampledSubgraph sampledSubgraph(indptr, indices, original_column_node_ids);
* ``` * ```
* *
* The `reverse_column_node_ids` indicates that nodes `[3, 3, 101]` in the * The `original_column_node_ids` indicates that nodes `[3, 3, 101]` in the
* original graph are mapped to `[0, 1, 2]` in this subgraph, and because * original graph are mapped to `[0, 1, 2]` in this subgraph, and because
* `reverse_row_node_ids` is `Null`, `{55, 101, 3, 3}` in `indices` is just * `original_row_node_ids` is `Null`, `{55, 101, 3, 3}` in `indices` is just
* the original node ids without compaction. * the original node ids without compaction.
* *
* If `reverse_row_node_ids = torch::tensor({55, 101, 3}, {torch::kInt64})`, * If `original_row_node_ids = torch::tensor({55, 101, 3}, {torch::kInt64})`,
* it would indicate a different mapping for the row nodes. Note this is * it would indicate a different mapping for the row nodes. Note this is
* inconsistent with column, which is legal, as `3` is mapped to `0` and `1` in * inconsistent with column, which is legal, as `3` is mapped to `0` and `1` in
* the column while `2` in the row. * the column while `2` in the row.
...@@ -44,24 +44,24 @@ struct SampledSubgraph : torch::CustomClassHolder { ...@@ -44,24 +44,24 @@ struct SampledSubgraph : torch::CustomClassHolder {
* *
* @param indptr CSC format index pointer array. * @param indptr CSC format index pointer array.
* @param indices CSC format index array. * @param indices CSC format index array.
* @param reverse_column_node_ids Row's reverse node ids in the original * @param original_column_node_ids Row's reverse node ids in the original
* graph. * graph.
* @param reverse_row_node_ids Column's reverse node ids in the original * @param original_row_node_ids Column's reverse node ids in the original
* graph. * graph.
* @param reverse_edge_ids Reverse edge ids in the original graph. * @param original_edge_ids Reverse edge ids in the original graph.
* @param type_per_edge Type id of each edge. * @param type_per_edge Type id of each edge.
*/ */
SampledSubgraph( SampledSubgraph(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor indptr, torch::Tensor indices,
torch::Tensor reverse_column_node_ids, torch::Tensor original_column_node_ids,
torch::optional<torch::Tensor> reverse_row_node_ids = torch::nullopt, torch::optional<torch::Tensor> original_row_node_ids = torch::nullopt,
torch::optional<torch::Tensor> reverse_edge_ids = torch::nullopt, torch::optional<torch::Tensor> original_edge_ids = torch::nullopt,
torch::optional<torch::Tensor> type_per_edge = torch::nullopt) torch::optional<torch::Tensor> type_per_edge = torch::nullopt)
: indptr(indptr), : indptr(indptr),
indices(indices), indices(indices),
reverse_column_node_ids(reverse_column_node_ids), original_column_node_ids(original_column_node_ids),
reverse_row_node_ids(reverse_row_node_ids), original_row_node_ids(original_row_node_ids),
reverse_edge_ids(reverse_edge_ids), original_edge_ids(original_edge_ids),
type_per_edge(type_per_edge) {} type_per_edge(type_per_edge) {}
SampledSubgraph() = default; SampledSubgraph() = default;
...@@ -69,14 +69,14 @@ struct SampledSubgraph : torch::CustomClassHolder { ...@@ -69,14 +69,14 @@ struct SampledSubgraph : torch::CustomClassHolder {
/** /**
* @brief CSC format index pointer array, where the implicit node ids are * @brief CSC format index pointer array, where the implicit node ids are
* already compacted. And the original ids are stored in the * already compacted. And the original ids are stored in the
* `reverse_column_node_ids` field. * `original_column_node_ids` field.
*/ */
torch::Tensor indptr; torch::Tensor indptr;
/** /**
* @brief CSC format index array, where the node ids can be compacted ids or * @brief CSC format index array, where the node ids can be compacted ids or
* original ids. If compacted, the original ids are stored in the * original ids. If compacted, the original ids are stored in the
* `reverse_row_node_ids` field. * `original_row_node_ids` field.
*/ */
torch::Tensor indices; torch::Tensor indices;
...@@ -88,7 +88,7 @@ struct SampledSubgraph : torch::CustomClassHolder { ...@@ -88,7 +88,7 @@ struct SampledSubgraph : torch::CustomClassHolder {
* @note This is required and the mapping relations can be inconsistent with * @note This is required and the mapping relations can be inconsistent with
* column's. * column's.
*/ */
torch::Tensor reverse_column_node_ids; torch::Tensor original_column_node_ids;
/** /**
* @brief Row's reverse node ids in the original graph. A graph structure * @brief Row's reverse node ids in the original graph. A graph structure
...@@ -98,14 +98,14 @@ struct SampledSubgraph : torch::CustomClassHolder { ...@@ -98,14 +98,14 @@ struct SampledSubgraph : torch::CustomClassHolder {
* @note This is optional and the mapping relations can be inconsistent with * @note This is optional and the mapping relations can be inconsistent with
* row's. * row's.
*/ */
torch::optional<torch::Tensor> reverse_row_node_ids; torch::optional<torch::Tensor> original_row_node_ids;
/** /**
* @brief Reverse edge ids in the original graph, the edge with id * @brief Reverse edge ids in the original graph, the edge with id
* `reverse_edge_ids[i]` in the original graph is mapped to `i` in this * `original_edge_ids[i]` in the original graph is mapped to `i` in this
* subgraph. This is useful when edge features are needed. * subgraph. This is useful when edge features are needed.
*/ */
torch::optional<torch::Tensor> reverse_edge_ids; torch::optional<torch::Tensor> original_edge_ids;
/** /**
* @brief Type id of each edge, where type id is the corresponding index of * @brief Type id of each edge, where type id is the corresponding index of
......
...@@ -17,10 +17,11 @@ TORCH_LIBRARY(graphbolt, m) { ...@@ -17,10 +17,11 @@ TORCH_LIBRARY(graphbolt, m) {
.def_readwrite("indptr", &SampledSubgraph::indptr) .def_readwrite("indptr", &SampledSubgraph::indptr)
.def_readwrite("indices", &SampledSubgraph::indices) .def_readwrite("indices", &SampledSubgraph::indices)
.def_readwrite( .def_readwrite(
"reverse_row_node_ids", &SampledSubgraph::reverse_row_node_ids) "original_row_node_ids", &SampledSubgraph::original_row_node_ids)
.def_readwrite( .def_readwrite(
"reverse_column_node_ids", &SampledSubgraph::reverse_column_node_ids) "original_column_node_ids",
.def_readwrite("reverse_edge_ids", &SampledSubgraph::reverse_edge_ids) &SampledSubgraph::original_column_node_ids)
.def_readwrite("original_edge_ids", &SampledSubgraph::original_edge_ids)
.def_readwrite("type_per_edge", &SampledSubgraph::type_per_edge); .def_readwrite("type_per_edge", &SampledSubgraph::type_per_edge);
m.class_<CSCSamplingGraph>("CSCSamplingGraph") m.class_<CSCSamplingGraph>("CSCSamplingGraph")
.def("num_nodes", &CSCSamplingGraph::NumNodes) .def("num_nodes", &CSCSamplingGraph::NumNodes)
......
...@@ -95,14 +95,14 @@ class FeatureFetcher(MiniBatchTransformer): ...@@ -95,14 +95,14 @@ class FeatureFetcher(MiniBatchTransformer):
# Read Edge features. # Read Edge features.
if self.edge_feature_keys and data.sampled_subgraphs: if self.edge_feature_keys and data.sampled_subgraphs:
for i, subgraph in enumerate(data.sampled_subgraphs): for i, subgraph in enumerate(data.sampled_subgraphs):
if subgraph.reverse_edge_ids is None: if subgraph.original_edge_ids is None:
continue continue
if is_heterogeneous: if is_heterogeneous:
for ( for (
type_name, type_name,
feature_names, feature_names,
) in self.edge_feature_keys.items(): ) in self.edge_feature_keys.items():
edges = subgraph.reverse_edge_ids.get(type_name, None) edges = subgraph.original_edge_ids.get(type_name, None)
if edges is None: if edges is None:
continue continue
for feature_name in feature_names: for feature_name in feature_names:
...@@ -119,6 +119,6 @@ class FeatureFetcher(MiniBatchTransformer): ...@@ -119,6 +119,6 @@ class FeatureFetcher(MiniBatchTransformer):
"edge", "edge",
None, None,
feature_name, feature_name,
subgraph.reverse_edge_ids, subgraph.original_edge_ids,
) )
return data return data
...@@ -225,7 +225,7 @@ class CSCSamplingGraph: ...@@ -225,7 +225,7 @@ class CSCSamplingGraph:
column_num = ( column_num = (
C_sampled_subgraph.indptr[1:] - C_sampled_subgraph.indptr[:-1] C_sampled_subgraph.indptr[1:] - C_sampled_subgraph.indptr[:-1]
) )
column = C_sampled_subgraph.reverse_column_node_ids.repeat_interleave( column = C_sampled_subgraph.original_column_node_ids.repeat_interleave(
column_num column_num
) )
row = C_sampled_subgraph.indices row = C_sampled_subgraph.indices
......
...@@ -106,14 +106,14 @@ class NeighborSampler(SubgraphSampler): ...@@ -106,14 +106,14 @@ class NeighborSampler(SubgraphSampler):
self.replace, self.replace,
self.prob_name, self.prob_name,
) )
reverse_column_node_ids = seeds original_column_node_ids = seeds
seeds, compacted_node_pairs = unique_and_compact_node_pairs( seeds, compacted_node_pairs = unique_and_compact_node_pairs(
subgraph.node_pairs, seeds subgraph.node_pairs, seeds
) )
subgraph = SampledSubgraphImpl( subgraph = SampledSubgraphImpl(
node_pairs=compacted_node_pairs, node_pairs=compacted_node_pairs,
reverse_column_node_ids=reverse_column_node_ids, original_column_node_ids=original_column_node_ids,
reverse_row_node_ids=seeds, original_row_node_ids=seeds,
) )
subgraphs.insert(0, subgraph) subgraphs.insert(0, subgraph)
return seeds, subgraphs return seeds, subgraphs
......
...@@ -17,31 +17,33 @@ class SampledSubgraphImpl(SampledSubgraph): ...@@ -17,31 +17,33 @@ class SampledSubgraphImpl(SampledSubgraph):
-------- --------
>>> node_pairs = {"A:relation:B"): (torch.tensor([0, 1, 2]), >>> node_pairs = {"A:relation:B"): (torch.tensor([0, 1, 2]),
... torch.tensor([0, 1, 2]))} ... torch.tensor([0, 1, 2]))}
>>> reverse_column_node_ids = {'B': torch.tensor([10, 11, 12])} >>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> reverse_row_node_ids = {'A': torch.tensor([13, 14, 15])} >>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> reverse_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])} >>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl( >>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs, ... node_pairs=node_pairs,
... reverse_column_node_ids=reverse_column_node_ids, ... original_column_node_ids=original_column_node_ids,
... reverse_row_node_ids=reverse_row_node_ids, ... original_row_node_ids=original_row_node_ids,
... reverse_edge_ids=reverse_edge_ids ... original_edge_ids=original_edge_ids
... ) ... )
>>> print(subgraph.node_pairs) >>> print(subgraph.node_pairs)
{"A:relation:B": (tensor([0, 1, 2]), tensor([0, 1, 2]))} {"A:relation:B": (tensor([0, 1, 2]), tensor([0, 1, 2]))}
>>> print(subgraph.reverse_column_node_ids) >>> print(subgraph.original_column_node_ids)
{'B': tensor([10, 11, 12])} {'B': tensor([10, 11, 12])}
>>> print(subgraph.reverse_row_node_ids) >>> print(subgraph.original_row_node_ids)
{'A': tensor([13, 14, 15])} {'A': tensor([13, 14, 15])}
>>> print(subgraph.reverse_edge_ids) >>> print(subgraph.original_edge_ids)
{"A:relation:B": tensor([19, 20, 21])} {"A:relation:B": tensor([19, 20, 21])}
""" """
node_pairs: Union[ node_pairs: Union[
Dict[str, Tuple[torch.Tensor, torch.Tensor]], Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor],
] = None ] = None
reverse_column_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None original_column_node_ids: Union[
reverse_row_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None Dict[str, torch.Tensor], torch.Tensor
reverse_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None ] = None
original_row_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
original_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
def __post_init__(self): def __post_init__(self):
if isinstance(self.node_pairs, dict): if isinstance(self.node_pairs, dict):
......
...@@ -140,14 +140,14 @@ class MiniBatch: ...@@ -140,14 +140,14 @@ class MiniBatch:
blocks = [] blocks = []
for subgraph in self.sampled_subgraphs: for subgraph in self.sampled_subgraphs:
reverse_row_node_ids = subgraph.reverse_row_node_ids original_row_node_ids = subgraph.original_row_node_ids
assert ( assert (
reverse_row_node_ids is not None original_row_node_ids is not None
), "Missing `reverse_row_node_ids` in sampled subgraph." ), "Missing `original_row_node_ids` in sampled subgraph."
reverse_column_node_ids = subgraph.reverse_column_node_ids original_column_node_ids = subgraph.original_column_node_ids
assert ( assert (
reverse_column_node_ids is not None original_column_node_ids is not None
), "Missing `reverse_column_node_ids` in sampled subgraph." ), "Missing `original_column_node_ids` in sampled subgraph."
if is_heterogeneous: if is_heterogeneous:
node_pairs = { node_pairs = {
etype_str_to_tuple(etype): v etype_str_to_tuple(etype): v
...@@ -155,16 +155,16 @@ class MiniBatch: ...@@ -155,16 +155,16 @@ class MiniBatch:
} }
num_src_nodes = { num_src_nodes = {
ntype: nodes.size(0) ntype: nodes.size(0)
for ntype, nodes in reverse_row_node_ids.items() for ntype, nodes in original_row_node_ids.items()
} }
num_dst_nodes = { num_dst_nodes = {
ntype: nodes.size(0) ntype: nodes.size(0)
for ntype, nodes in reverse_column_node_ids.items() for ntype, nodes in original_column_node_ids.items()
} }
else: else:
node_pairs = subgraph.node_pairs node_pairs = subgraph.node_pairs
num_src_nodes = reverse_row_node_ids.size(0) num_src_nodes = original_row_node_ids.size(0)
num_dst_nodes = reverse_column_node_ids.size(0) num_dst_nodes = original_column_node_ids.size(0)
blocks.append( blocks.append(
dgl.create_block( dgl.create_block(
node_pairs, node_pairs,
...@@ -194,15 +194,15 @@ class MiniBatch: ...@@ -194,15 +194,15 @@ class MiniBatch:
# Assign reverse node ids to the outermost layer's source nodes. # Assign reverse node ids to the outermost layer's source nodes.
for node_type, reverse_ids in self.sampled_subgraphs[ for node_type, reverse_ids in self.sampled_subgraphs[
0 0
].reverse_row_node_ids.items(): ].original_row_node_ids.items():
blocks[0].srcnodes[node_type].data[dgl.NID] = reverse_ids blocks[0].srcnodes[node_type].data[dgl.NID] = reverse_ids
# Assign reverse edges ids. # Assign reverse edges ids.
for block, subgraph in zip(blocks, self.sampled_subgraphs): for block, subgraph in zip(blocks, self.sampled_subgraphs):
if subgraph.reverse_edge_ids: if subgraph.original_edge_ids:
for ( for (
edge_type, edge_type,
reverse_ids, reverse_ids,
) in subgraph.reverse_edge_ids.items(): ) in subgraph.original_edge_ids.items():
block.edges[etype_str_to_tuple(edge_type)].data[ block.edges[etype_str_to_tuple(edge_type)].data[
dgl.EID dgl.EID
] = reverse_ids ] = reverse_ids
...@@ -218,11 +218,11 @@ class MiniBatch: ...@@ -218,11 +218,11 @@ class MiniBatch:
block.edata[feature_name] = feature block.edata[feature_name] = feature
blocks[0].srcdata[dgl.NID] = self.sampled_subgraphs[ blocks[0].srcdata[dgl.NID] = self.sampled_subgraphs[
0 0
].reverse_row_node_ids ].original_row_node_ids
# Assign reverse edges ids. # Assign reverse edges ids.
for block, subgraph in zip(blocks, self.sampled_subgraphs): for block, subgraph in zip(blocks, self.sampled_subgraphs):
if subgraph.reverse_edge_ids is not None: if subgraph.original_edge_ids is not None:
block.edata[dgl.EID] = subgraph.reverse_edge_ids block.edata[dgl.EID] = subgraph.original_edge_ids
return blocks return blocks
......
...@@ -29,16 +29,16 @@ class SampledSubgraph: ...@@ -29,16 +29,16 @@ class SampledSubgraph:
raise NotImplementedError raise NotImplementedError
@property @property
def reverse_column_node_ids( def original_column_node_ids(
self, self,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""Returns corresponding reverse column node ids the original graph. """Returns corresponding reverse column node ids the original graph.
Column's reverse node ids in the original graph. A graph structure Column's reverse node ids in the original graph. A graph structure
can be treated as a coordinated row and column pair, and this is can be treated as a coordinated row and column pair, and this is
the mapped ids of the column. the mapped ids of the column.
- If `reverse_column_node_ids` is a tensor: It represents the - If `original_column_node_ids` is a tensor: It represents the
original node ids. original node ids.
- If `reverse_column_node_ids` is a dictionary: The keys should be - If `original_column_node_ids` is a dictionary: The keys should be
node type and the values should be corresponding original node type and the values should be corresponding original
heterogeneous node ids. heterogeneous node ids.
If present, it means column IDs are compacted, and `node_pairs` If present, it means column IDs are compacted, and `node_pairs`
...@@ -47,16 +47,16 @@ class SampledSubgraph: ...@@ -47,16 +47,16 @@ class SampledSubgraph:
return None return None
@property @property
def reverse_row_node_ids( def original_row_node_ids(
self, self,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""Returns corresponding reverse row node ids the original graph. """Returns corresponding reverse row node ids the original graph.
Row's reverse node ids in the original graph. A graph structure Row's reverse node ids in the original graph. A graph structure
can be treated as a coordinated row and column pair, and this is can be treated as a coordinated row and column pair, and this is
the mapped ids of the row. the mapped ids of the row.
- If `reverse_row_node_ids` is a tensor: It represents the - If `original_row_node_ids` is a tensor: It represents the
original node ids. original node ids.
- If `reverse_row_node_ids` is a dictionary: The keys should be - If `original_row_node_ids` is a dictionary: The keys should be
node type and the values should be corresponding original node type and the values should be corresponding original
heterogeneous node ids. heterogeneous node ids.
If present, it means row IDs are compacted, and `node_pairs` If present, it means row IDs are compacted, and `node_pairs`
...@@ -64,13 +64,13 @@ class SampledSubgraph: ...@@ -64,13 +64,13 @@ class SampledSubgraph:
return None return None
@property @property
def reverse_edge_ids(self) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: def original_edge_ids(self) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""Returns corresponding reverse edge ids the original graph. """Returns corresponding reverse edge ids the original graph.
Reverse edge ids in the original graph. This is useful when edge Reverse edge ids in the original graph. This is useful when edge
features are needed. features are needed.
- If `reverse_edge_ids` is a tensor: It represents the - If `original_edge_ids` is a tensor: It represents the
original edge ids. original edge ids.
- If `reverse_edge_ids` is a dictionary: The keys should be - If `original_edge_ids` is a dictionary: The keys should be
edge type and the values should be corresponding original edge type and the values should be corresponding original
heterogeneous edge ids. heterogeneous edge ids.
""" """
...@@ -110,24 +110,24 @@ class SampledSubgraph: ...@@ -110,24 +110,24 @@ class SampledSubgraph:
-------- --------
>>> node_pairs = {"A:relation:B": (torch.tensor([0, 1, 2]), >>> node_pairs = {"A:relation:B": (torch.tensor([0, 1, 2]),
... torch.tensor([0, 1, 2]))} ... torch.tensor([0, 1, 2]))}
>>> reverse_column_node_ids = {'B': torch.tensor([10, 11, 12])} >>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> reverse_row_node_ids = {'A': torch.tensor([13, 14, 15])} >>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> reverse_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])} >>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl( >>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs, ... node_pairs=node_pairs,
... reverse_column_node_ids=reverse_column_node_ids, ... original_column_node_ids=original_column_node_ids,
... reverse_row_node_ids=reverse_row_node_ids, ... original_row_node_ids=original_row_node_ids,
... reverse_edge_ids=reverse_edge_ids ... original_edge_ids=original_edge_ids
... ) ... )
>>> edges_to_exclude = (torch.tensor([14, 15]), torch.tensor([11, 12])) >>> edges_to_exclude = (torch.tensor([14, 15]), torch.tensor([11, 12]))
>>> result = subgraph.exclude_edges(edges_to_exclude) >>> result = subgraph.exclude_edges(edges_to_exclude)
>>> print(result.node_pairs) >>> print(result.node_pairs)
{"A:relation:B": (tensor([0]), tensor([0]))} {"A:relation:B": (tensor([0]), tensor([0]))}
>>> print(result.reverse_column_node_ids) >>> print(result.original_column_node_ids)
{'B': tensor([10, 11, 12])} {'B': tensor([10, 11, 12])}
>>> print(result.reverse_row_node_ids) >>> print(result.original_row_node_ids)
{'A': tensor([13, 14, 15])} {'A': tensor([13, 14, 15])}
>>> print(result.reverse_edge_ids) >>> print(result.original_edge_ids)
{"A:relation:B": tensor([19])} {"A:relation:B": tensor([19])}
""" """
assert isinstance(self.node_pairs, tuple) == isinstance(edges, tuple), ( assert isinstance(self.node_pairs, tuple) == isinstance(edges, tuple), (
...@@ -144,8 +144,8 @@ class SampledSubgraph: ...@@ -144,8 +144,8 @@ class SampledSubgraph:
if isinstance(self.node_pairs, tuple): if isinstance(self.node_pairs, tuple):
reverse_edges = _to_reverse_ids( reverse_edges = _to_reverse_ids(
self.node_pairs, self.node_pairs,
self.reverse_row_node_ids, self.original_row_node_ids,
self.reverse_column_node_ids, self.original_column_node_ids,
) )
index = _exclude_homo_edges(reverse_edges, edges) index = _exclude_homo_edges(reverse_edges, edges)
return calling_class(*_slice_subgraph(self, index)) return calling_class(*_slice_subgraph(self, index))
...@@ -153,20 +153,20 @@ class SampledSubgraph: ...@@ -153,20 +153,20 @@ class SampledSubgraph:
index = {} index = {}
for etype, pair in self.node_pairs.items(): for etype, pair in self.node_pairs.items():
src_type, _, dst_type = etype_str_to_tuple(etype) src_type, _, dst_type = etype_str_to_tuple(etype)
reverse_row_node_ids = ( original_row_node_ids = (
None None
if self.reverse_row_node_ids is None if self.original_row_node_ids is None
else self.reverse_row_node_ids.get(src_type) else self.original_row_node_ids.get(src_type)
) )
reverse_column_node_ids = ( original_column_node_ids = (
None None
if self.reverse_column_node_ids is None if self.original_column_node_ids is None
else self.reverse_column_node_ids.get(dst_type) else self.original_column_node_ids.get(dst_type)
) )
reverse_edges = _to_reverse_ids( reverse_edges = _to_reverse_ids(
pair, pair,
reverse_row_node_ids, original_row_node_ids,
reverse_column_node_ids, original_column_node_ids,
) )
index[etype] = _exclude_homo_edges( index[etype] = _exclude_homo_edges(
reverse_edges, edges.get(etype) reverse_edges, edges.get(etype)
...@@ -174,12 +174,12 @@ class SampledSubgraph: ...@@ -174,12 +174,12 @@ class SampledSubgraph:
return calling_class(*_slice_subgraph(self, index)) return calling_class(*_slice_subgraph(self, index))
def _to_reverse_ids(node_pair, reverse_row_node_ids, reverse_column_node_ids): def _to_reverse_ids(node_pair, original_row_node_ids, original_column_node_ids):
u, v = node_pair u, v = node_pair
if reverse_row_node_ids is not None: if original_row_node_ids is not None:
u = reverse_row_node_ids[u] u = original_row_node_ids[u]
if reverse_column_node_ids is not None: if original_column_node_ids is not None:
v = reverse_column_node_ids[v] v = original_column_node_ids[v]
return (u, v) return (u, v)
...@@ -224,7 +224,7 @@ def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor): ...@@ -224,7 +224,7 @@ def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor):
return ( return (
_index_select(subgraph.node_pairs, index), _index_select(subgraph.node_pairs, index),
subgraph.reverse_column_node_ids, subgraph.original_column_node_ids,
subgraph.reverse_row_node_ids, subgraph.original_row_node_ids,
_index_select(subgraph.reverse_edge_ids, index), _index_select(subgraph.original_edge_ids, index),
) )
...@@ -397,12 +397,12 @@ def test_in_subgraph_homogeneous(): ...@@ -397,12 +397,12 @@ def test_in_subgraph_homogeneous():
assert torch.equal( assert torch.equal(
in_subgraph.indices, torch.LongTensor([2, 3, 1, 2, 0, 3, 4]) in_subgraph.indices, torch.LongTensor([2, 3, 1, 2, 0, 3, 4])
) )
assert torch.equal(in_subgraph.reverse_column_node_ids, nodes) assert torch.equal(in_subgraph.original_column_node_ids, nodes)
assert torch.equal( assert torch.equal(
in_subgraph.reverse_row_node_ids, torch.arange(0, num_nodes) in_subgraph.original_row_node_ids, torch.arange(0, num_nodes)
) )
assert torch.equal( assert torch.equal(
in_subgraph.reverse_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11]) in_subgraph.original_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11])
) )
assert in_subgraph.type_per_edge is None assert in_subgraph.type_per_edge is None
...@@ -463,12 +463,12 @@ def test_in_subgraph_heterogeneous(): ...@@ -463,12 +463,12 @@ def test_in_subgraph_heterogeneous():
assert torch.equal( assert torch.equal(
in_subgraph.indices, torch.LongTensor([2, 3, 1, 2, 0, 3, 4]) in_subgraph.indices, torch.LongTensor([2, 3, 1, 2, 0, 3, 4])
) )
assert torch.equal(in_subgraph.reverse_column_node_ids, nodes) assert torch.equal(in_subgraph.original_column_node_ids, nodes)
assert torch.equal( assert torch.equal(
in_subgraph.reverse_row_node_ids, torch.arange(0, num_nodes) in_subgraph.original_row_node_ids, torch.arange(0, num_nodes)
) )
assert torch.equal( assert torch.equal(
in_subgraph.reverse_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11]) in_subgraph.original_edge_ids, torch.LongTensor([3, 4, 7, 8, 9, 10, 11])
) )
assert torch.equal( assert torch.equal(
in_subgraph.type_per_edge, torch.LongTensor([2, 2, 1, 3, 1, 3, 3]) in_subgraph.type_per_edge, torch.LongTensor([2, 2, 1, 3, 1, 3, 3])
...@@ -505,9 +505,9 @@ def test_sample_neighbors_homo(): ...@@ -505,9 +505,9 @@ def test_sample_neighbors_homo():
# Verify in subgraph. # Verify in subgraph.
sampled_num = subgraph.node_pairs[0].size(0) sampled_num = subgraph.node_pairs[0].size(0)
assert sampled_num == 6 assert sampled_num == 6
assert subgraph.reverse_column_node_ids is None assert subgraph.original_column_node_ids is None
assert subgraph.reverse_row_node_ids is None assert subgraph.original_row_node_ids is None
assert subgraph.reverse_edge_ids is None assert subgraph.original_edge_ids is None
@unittest.skipIf( @unittest.skipIf(
...@@ -568,9 +568,9 @@ def test_sample_neighbors_hetero(labor): ...@@ -568,9 +568,9 @@ def test_sample_neighbors_hetero(labor):
for etype, pairs in expected_node_pairs.items(): for etype, pairs in expected_node_pairs.items():
assert torch.equal(subgraph.node_pairs[etype][0], pairs[0]) assert torch.equal(subgraph.node_pairs[etype][0], pairs[0])
assert torch.equal(subgraph.node_pairs[etype][1], pairs[1]) assert torch.equal(subgraph.node_pairs[etype][1], pairs[1])
assert subgraph.reverse_column_node_ids is None assert subgraph.original_column_node_ids is None
assert subgraph.reverse_row_node_ids is None assert subgraph.original_row_node_ids is None
assert subgraph.reverse_edge_ids is None assert subgraph.original_edge_ids is None
# Sample on single node type. # Sample on single node type.
nodes = {"n1": torch.LongTensor([0])} nodes = {"n1": torch.LongTensor([0])}
...@@ -593,9 +593,9 @@ def test_sample_neighbors_hetero(labor): ...@@ -593,9 +593,9 @@ def test_sample_neighbors_hetero(labor):
for etype, pairs in expected_node_pairs.items(): for etype, pairs in expected_node_pairs.items():
assert torch.equal(subgraph.node_pairs[etype][0], pairs[0]) assert torch.equal(subgraph.node_pairs[etype][0], pairs[0])
assert torch.equal(subgraph.node_pairs[etype][1], pairs[1]) assert torch.equal(subgraph.node_pairs[etype][1], pairs[1])
assert subgraph.reverse_column_node_ids is None assert subgraph.original_column_node_ids is None
assert subgraph.reverse_row_node_ids is None assert subgraph.original_row_node_ids is None
assert subgraph.reverse_edge_ids is None assert subgraph.original_edge_ids is None
@unittest.skipIf( @unittest.skipIf(
......
...@@ -13,11 +13,11 @@ def test_to_dgl_blocks_hetero(): ...@@ -13,11 +13,11 @@ def test_to_dgl_blocks_hetero():
}, },
{relation: (torch.tensor([0, 1]), torch.tensor([1, 0]))}, {relation: (torch.tensor([0, 1]), torch.tensor([1, 0]))},
] ]
reverse_column_node_ids = [ original_column_node_ids = [
{"B": torch.tensor([10, 11, 12]), "A": torch.tensor([5, 7, 9, 11])}, {"B": torch.tensor([10, 11, 12]), "A": torch.tensor([5, 7, 9, 11])},
{"B": torch.tensor([10, 11])}, {"B": torch.tensor([10, 11])},
] ]
reverse_row_node_ids = [ original_row_node_ids = [
{ {
"A": torch.tensor([5, 7, 9, 11]), "A": torch.tensor([5, 7, 9, 11]),
"B": torch.tensor([10, 11, 12]), "B": torch.tensor([10, 11, 12]),
...@@ -27,7 +27,7 @@ def test_to_dgl_blocks_hetero(): ...@@ -27,7 +27,7 @@ def test_to_dgl_blocks_hetero():
"B": torch.tensor([10, 11]), "B": torch.tensor([10, 11]),
}, },
] ]
reverse_edge_ids = [ original_edge_ids = [
{ {
relation: torch.tensor([19, 20, 21]), relation: torch.tensor([19, 20, 21]),
reverse_relation: torch.tensor([23, 26]), reverse_relation: torch.tensor([23, 26]),
...@@ -46,9 +46,9 @@ def test_to_dgl_blocks_hetero(): ...@@ -46,9 +46,9 @@ def test_to_dgl_blocks_hetero():
subgraphs.append( subgraphs.append(
gb.SampledSubgraphImpl( gb.SampledSubgraphImpl(
node_pairs=node_pairs[i], node_pairs=node_pairs[i],
reverse_column_node_ids=reverse_column_node_ids[i], original_column_node_ids=original_column_node_ids[i],
reverse_row_node_ids=reverse_row_node_ids[i], original_row_node_ids=original_row_node_ids[i],
reverse_edge_ids=reverse_edge_ids[i], original_edge_ids=original_edge_ids[i],
) )
) )
blocks = gb.MiniBatch( blocks = gb.MiniBatch(
...@@ -63,7 +63,7 @@ def test_to_dgl_blocks_hetero(): ...@@ -63,7 +63,7 @@ def test_to_dgl_blocks_hetero():
assert torch.equal(edges[0], node_pairs[i][relation][0]) assert torch.equal(edges[0], node_pairs[i][relation][0])
assert torch.equal(edges[1], node_pairs[i][relation][1]) assert torch.equal(edges[1], node_pairs[i][relation][1])
assert torch.equal( assert torch.equal(
block.edges[etype].data[dgl.EID], reverse_edge_ids[i][relation] block.edges[etype].data[dgl.EID], original_edge_ids[i][relation]
) )
assert torch.equal( assert torch.equal(
block.edges[etype].data["x"], block.edges[etype].data["x"],
...@@ -73,10 +73,10 @@ def test_to_dgl_blocks_hetero(): ...@@ -73,10 +73,10 @@ def test_to_dgl_blocks_hetero():
assert torch.equal(edges[0], node_pairs[0][reverse_relation][0]) assert torch.equal(edges[0], node_pairs[0][reverse_relation][0])
assert torch.equal(edges[1], node_pairs[0][reverse_relation][1]) assert torch.equal(edges[1], node_pairs[0][reverse_relation][1])
assert torch.equal( assert torch.equal(
blocks[0].srcdata[dgl.NID]["A"], reverse_row_node_ids[0]["A"] blocks[0].srcdata[dgl.NID]["A"], original_row_node_ids[0]["A"]
) )
assert torch.equal( assert torch.equal(
blocks[0].srcdata[dgl.NID]["B"], reverse_row_node_ids[0]["B"] blocks[0].srcdata[dgl.NID]["B"], original_row_node_ids[0]["B"]
) )
assert torch.equal( assert torch.equal(
blocks[0].srcnodes["A"].data["x"], node_features[("A", "x")] blocks[0].srcnodes["A"].data["x"], node_features[("A", "x")]
...@@ -97,15 +97,15 @@ def test_to_dgl_blocks_homo(): ...@@ -97,15 +97,15 @@ def test_to_dgl_blocks_homo():
torch.tensor([1, 0, 0]), torch.tensor([1, 0, 0]),
), ),
] ]
reverse_column_node_ids = [ original_column_node_ids = [
torch.tensor([10, 11, 12, 13]), torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11]), torch.tensor([10, 11]),
] ]
reverse_row_node_ids = [ original_row_node_ids = [
torch.tensor([10, 11, 12, 13]), torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11, 12]), torch.tensor([10, 11, 12]),
] ]
reverse_edge_ids = [ original_edge_ids = [
torch.tensor([19, 20, 21, 22, 25, 30]), torch.tensor([19, 20, 21, 22, 25, 30]),
torch.tensor([10, 15, 17]), torch.tensor([10, 15, 17]),
] ]
...@@ -119,9 +119,9 @@ def test_to_dgl_blocks_homo(): ...@@ -119,9 +119,9 @@ def test_to_dgl_blocks_homo():
subgraphs.append( subgraphs.append(
gb.SampledSubgraphImpl( gb.SampledSubgraphImpl(
node_pairs=node_pairs[i], node_pairs=node_pairs[i],
reverse_column_node_ids=reverse_column_node_ids[i], original_column_node_ids=original_column_node_ids[i],
reverse_row_node_ids=reverse_row_node_ids[i], original_row_node_ids=original_row_node_ids[i],
reverse_edge_ids=reverse_edge_ids[i], original_edge_ids=original_edge_ids[i],
) )
) )
blocks = gb.MiniBatch( blocks = gb.MiniBatch(
...@@ -133,9 +133,9 @@ def test_to_dgl_blocks_homo(): ...@@ -133,9 +133,9 @@ def test_to_dgl_blocks_homo():
for i, block in enumerate(blocks): for i, block in enumerate(blocks):
assert torch.equal(block.edges()[0], node_pairs[i][0]) assert torch.equal(block.edges()[0], node_pairs[i][0])
assert torch.equal(block.edges()[1], node_pairs[i][1]) assert torch.equal(block.edges()[1], node_pairs[i][1])
assert torch.equal(block.edata[dgl.EID], reverse_edge_ids[i]) assert torch.equal(block.edata[dgl.EID], original_edge_ids[i])
assert torch.equal(block.edata["x"], edge_features[i]["x"]) assert torch.equal(block.edata["x"], edge_features[i]["x"])
assert torch.equal(blocks[0].srcdata[dgl.NID], reverse_row_node_ids[0]) assert torch.equal(blocks[0].srcdata[dgl.NID], original_row_node_ids[0])
assert torch.equal(blocks[0].srcdata["x"], node_features["x"]) assert torch.equal(blocks[0].srcdata["x"], node_features["x"])
...@@ -150,15 +150,15 @@ def test_representation(): ...@@ -150,15 +150,15 @@ def test_representation():
torch.tensor([1, 0, 0]), torch.tensor([1, 0, 0]),
), ),
] ]
reverse_column_node_ids = [ original_column_node_ids = [
torch.tensor([10, 11, 12, 13]), torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11]), torch.tensor([10, 11]),
] ]
reverse_row_node_ids = [ original_row_node_ids = [
torch.tensor([10, 11, 12, 13]), torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11, 12]), torch.tensor([10, 11, 12]),
] ]
reverse_edge_ids = [ original_edge_ids = [
torch.tensor([19, 20, 21, 22, 25, 30]), torch.tensor([19, 20, 21, 22, 25, 30]),
torch.tensor([10, 15, 17]), torch.tensor([10, 15, 17]),
] ]
...@@ -172,9 +172,9 @@ def test_representation(): ...@@ -172,9 +172,9 @@ def test_representation():
subgraphs.append( subgraphs.append(
gb.SampledSubgraphImpl( gb.SampledSubgraphImpl(
node_pairs=node_pairs[i], node_pairs=node_pairs[i],
reverse_column_node_ids=reverse_column_node_ids[i], original_column_node_ids=original_column_node_ids[i],
reverse_row_node_ids=reverse_row_node_ids[i], original_row_node_ids=original_row_node_ids[i],
reverse_edge_ids=reverse_edge_ids[i], original_edge_ids=original_edge_ids[i],
) )
) )
negative_srcs = torch.tensor([[8], [1], [6]]) negative_srcs = torch.tensor([[8], [1], [6]])
...@@ -220,13 +220,13 @@ def test_representation(): ...@@ -220,13 +220,13 @@ def test_representation():
expect_result = str( expect_result = str(
"""MiniBatch(seed_nodes=None, """MiniBatch(seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(node_pairs=(tensor([0, 1, 2, 2, 2, 1]), tensor([0, 1, 1, 2, 3, 2])), sampled_subgraphs=[SampledSubgraphImpl(node_pairs=(tensor([0, 1, 2, 2, 2, 1]), tensor([0, 1, 1, 2, 3, 2])),
reverse_column_node_ids=tensor([10, 11, 12, 13]), original_column_node_ids=tensor([10, 11, 12, 13]),
reverse_edge_ids=tensor([19, 20, 21, 22, 25, 30]), original_edge_ids=tensor([19, 20, 21, 22, 25, 30]),
reverse_row_node_ids=tensor([10, 11, 12, 13]),), original_row_node_ids=tensor([10, 11, 12, 13]),),
SampledSubgraphImpl(node_pairs=(tensor([0, 1, 2]), tensor([1, 0, 0])), SampledSubgraphImpl(node_pairs=(tensor([0, 1, 2]), tensor([1, 0, 0])),
reverse_column_node_ids=tensor([10, 11]), original_column_node_ids=tensor([10, 11]),
reverse_edge_ids=tensor([10, 15, 17]), original_edge_ids=tensor([10, 15, 17]),
reverse_row_node_ids=tensor([10, 11, 12]),)], original_row_node_ids=tensor([10, 11, 12]),)],
node_pairs=[(tensor([0, 1, 2, 2, 2, 1]), tensor([0, 1, 1, 2, 3, 2])), node_pairs=[(tensor([0, 1, 2, 2, 2, 1]), tensor([0, 1, 1, 2, 3, 2])),
(tensor([0, 1, 2]), tensor([1, 0, 0]))], (tensor([0, 1, 2]), tensor([1, 0, 0]))],
node_features={'x': tensor([7, 6, 2, 2])}, node_features={'x': tensor([7, 6, 2, 2])},
......
...@@ -26,24 +26,24 @@ def _assert_container_equal(lhs, rhs): ...@@ -26,24 +26,24 @@ def _assert_container_equal(lhs, rhs):
def test_exclude_edges_homo(reverse_row, reverse_column): def test_exclude_edges_homo(reverse_row, reverse_column):
node_pairs = (torch.tensor([0, 2, 3]), torch.tensor([1, 4, 2])) node_pairs = (torch.tensor([0, 2, 3]), torch.tensor([1, 4, 2]))
if reverse_row: if reverse_row:
reverse_row_node_ids = torch.tensor([10, 15, 11, 24, 9]) original_row_node_ids = torch.tensor([10, 15, 11, 24, 9])
src_to_exclude = torch.tensor([11]) src_to_exclude = torch.tensor([11])
else: else:
reverse_row_node_ids = None original_row_node_ids = None
src_to_exclude = torch.tensor([2]) src_to_exclude = torch.tensor([2])
if reverse_column: if reverse_column:
reverse_column_node_ids = torch.tensor([10, 15, 11, 24, 9]) original_column_node_ids = torch.tensor([10, 15, 11, 24, 9])
dst_to_exclude = torch.tensor([9]) dst_to_exclude = torch.tensor([9])
else: else:
reverse_column_node_ids = None original_column_node_ids = None
dst_to_exclude = torch.tensor([4]) dst_to_exclude = torch.tensor([4])
reverse_edge_ids = torch.Tensor([5, 9, 10]) original_edge_ids = torch.Tensor([5, 9, 10])
subgraph = SampledSubgraphImpl( subgraph = SampledSubgraphImpl(
node_pairs, node_pairs,
reverse_column_node_ids, original_column_node_ids,
reverse_row_node_ids, original_row_node_ids,
reverse_edge_ids, original_edge_ids,
) )
edges_to_exclude = (src_to_exclude, dst_to_exclude) edges_to_exclude = (src_to_exclude, dst_to_exclude)
result = subgraph.exclude_edges(edges_to_exclude) result = subgraph.exclude_edges(edges_to_exclude)
...@@ -60,10 +60,10 @@ def test_exclude_edges_homo(reverse_row, reverse_column): ...@@ -60,10 +60,10 @@ def test_exclude_edges_homo(reverse_row, reverse_column):
_assert_container_equal(result.node_pairs, expected_node_pairs) _assert_container_equal(result.node_pairs, expected_node_pairs)
_assert_container_equal( _assert_container_equal(
result.reverse_column_node_ids, expected_column_node_ids result.original_column_node_ids, expected_column_node_ids
) )
_assert_container_equal(result.reverse_row_node_ids, expected_row_node_ids) _assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.reverse_edge_ids, expected_edge_ids) _assert_container_equal(result.original_edge_ids, expected_edge_ids)
@pytest.mark.parametrize("reverse_row", [True, False]) @pytest.mark.parametrize("reverse_row", [True, False])
...@@ -76,27 +76,27 @@ def test_exclude_edges_hetero(reverse_row, reverse_column): ...@@ -76,27 +76,27 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
) )
} }
if reverse_row: if reverse_row:
reverse_row_node_ids = { original_row_node_ids = {
"A": torch.tensor([13, 14, 15]), "A": torch.tensor([13, 14, 15]),
} }
src_to_exclude = torch.tensor([15, 13]) src_to_exclude = torch.tensor([15, 13])
else: else:
reverse_row_node_ids = None original_row_node_ids = None
src_to_exclude = torch.tensor([2, 0]) src_to_exclude = torch.tensor([2, 0])
if reverse_column: if reverse_column:
reverse_column_node_ids = { original_column_node_ids = {
"B": torch.tensor([10, 11, 12]), "B": torch.tensor([10, 11, 12]),
} }
dst_to_exclude = torch.tensor([10, 12]) dst_to_exclude = torch.tensor([10, 12])
else: else:
reverse_column_node_ids = None original_column_node_ids = None
dst_to_exclude = torch.tensor([0, 2]) dst_to_exclude = torch.tensor([0, 2])
reverse_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])} original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
subgraph = SampledSubgraphImpl( subgraph = SampledSubgraphImpl(
node_pairs=node_pairs, node_pairs=node_pairs,
reverse_column_node_ids=reverse_column_node_ids, original_column_node_ids=original_column_node_ids,
reverse_row_node_ids=reverse_row_node_ids, original_row_node_ids=original_row_node_ids,
reverse_edge_ids=reverse_edge_ids, original_edge_ids=original_edge_ids,
) )
edges_to_exclude = { edges_to_exclude = {
...@@ -128,7 +128,7 @@ def test_exclude_edges_hetero(reverse_row, reverse_column): ...@@ -128,7 +128,7 @@ def test_exclude_edges_hetero(reverse_row, reverse_column):
_assert_container_equal(result.node_pairs, expected_node_pairs) _assert_container_equal(result.node_pairs, expected_node_pairs)
_assert_container_equal( _assert_container_equal(
result.reverse_column_node_ids, expected_column_node_ids result.original_column_node_ids, expected_column_node_ids
) )
_assert_container_equal(result.reverse_row_node_ids, expected_row_node_ids) _assert_container_equal(result.original_row_node_ids, expected_row_node_ids)
_assert_container_equal(result.reverse_edge_ids, expected_edge_ids) _assert_container_equal(result.original_edge_ids, expected_edge_ids)
...@@ -65,7 +65,7 @@ def test_FeatureFetcher_with_edges_homo(): ...@@ -65,7 +65,7 @@ def test_FeatureFetcher_with_edges_homo():
subgraphs.append( subgraphs.append(
gb.SampledSubgraphImpl( gb.SampledSubgraphImpl(
node_pairs=(torch.tensor([]), torch.tensor([])), node_pairs=(torch.tensor([]), torch.tensor([])),
reverse_edge_ids=torch.randint(0, graph.num_edges, (10,)), original_edge_ids=torch.randint(0, graph.num_edges, (10,)),
) )
) )
data = gb.MiniBatch(input_nodes=seeds, sampled_subgraphs=subgraphs) data = gb.MiniBatch(input_nodes=seeds, sampled_subgraphs=subgraphs)
...@@ -146,7 +146,7 @@ def test_FeatureFetcher_with_edges_hetero(): ...@@ -146,7 +146,7 @@ def test_FeatureFetcher_with_edges_hetero():
def add_node_and_edge_ids(seeds): def add_node_and_edge_ids(seeds):
subgraphs = [] subgraphs = []
reverse_edge_ids = { original_edge_ids = {
"n1:e1:n2": torch.randint(0, 50, (10,)), "n1:e1:n2": torch.randint(0, 50, (10,)),
"n2:e2:n1": torch.randint(0, 50, (10,)), "n2:e2:n1": torch.randint(0, 50, (10,)),
} }
...@@ -154,7 +154,7 @@ def test_FeatureFetcher_with_edges_hetero(): ...@@ -154,7 +154,7 @@ def test_FeatureFetcher_with_edges_hetero():
subgraphs.append( subgraphs.append(
gb.SampledSubgraphImpl( gb.SampledSubgraphImpl(
node_pairs=(torch.tensor([]), torch.tensor([])), node_pairs=(torch.tensor([]), torch.tensor([])),
reverse_edge_ids=reverse_edge_ids, original_edge_ids=original_edge_ids,
) )
) )
data = gb.MiniBatch(input_nodes=seeds, sampled_subgraphs=subgraphs) data = gb.MiniBatch(input_nodes=seeds, sampled_subgraphs=subgraphs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment