Unverified Commit 1f9ae668 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Rename `node_pairs` to `sampled_csc`. (#6836)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 437139f5
...@@ -368,7 +368,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -368,7 +368,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
... edge_type_to_id=etypes) ... edge_type_to_id=etypes)
>>> nodes = {"N0":torch.LongTensor([1]), "N1":torch.LongTensor([1, 2])} >>> nodes = {"N0":torch.LongTensor([1]), "N1":torch.LongTensor([1, 2])}
>>> in_subgraph = graph.in_subgraph(nodes) >>> in_subgraph = graph.in_subgraph(nodes)
>>> print(in_subgraph.node_pairs) >>> print(in_subgraph.sampled_csc)
defaultdict(<class 'list'>, { defaultdict(<class 'list'>, {
'N0:R0:N0': (tensor([]), tensor([])), 'N0:R0:N0': (tensor([]), tensor([])),
'N0:R1:N1': (tensor([1, 0]), tensor([1, 2])), 'N0:R1:N1': (tensor([1, 0]), tensor([1, 2])),
...@@ -436,7 +436,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -436,7 +436,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
if has_original_eids: if has_original_eids:
original_edge_ids = original_hetero_edge_ids original_edge_ids = original_hetero_edge_ids
return FusedSampledSubgraphImpl( return FusedSampledSubgraphImpl(
node_pairs=node_pairs, original_edge_ids=original_edge_ids sampled_csc=node_pairs, original_edge_ids=original_edge_ids
) )
def _convert_to_homogeneous_nodes(self, nodes, timestamps=None): def _convert_to_homogeneous_nodes(self, nodes, timestamps=None):
...@@ -475,7 +475,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -475,7 +475,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
] ]
if type_per_edge is None: if type_per_edge is None:
# The sampled graph is already a homogeneous graph. # The sampled graph is already a homogeneous graph.
node_pairs = CSCFormatBase(indptr=indptr, indices=indices) sampled_csc = CSCFormatBase(indptr=indptr, indices=indices)
else: else:
# The sampled graph is a fused homogenized graph, which need to be # The sampled graph is a fused homogenized graph, which need to be
# converted to heterogeneous graphs. # converted to heterogeneous graphs.
...@@ -536,7 +536,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -536,7 +536,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
l = end l = end
if has_original_eids: if has_original_eids:
original_edge_ids = original_hetero_edge_ids original_edge_ids = original_hetero_edge_ids
node_pairs = { sampled_csc = {
etype: CSCFormatBase( etype: CSCFormatBase(
indptr=torch.tensor(subgraph_indptr[etype]), indptr=torch.tensor(subgraph_indptr[etype]),
indices=subgraph_indice[etype], indices=subgraph_indice[etype],
...@@ -544,7 +544,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -544,7 +544,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
for etype in self.edge_type_to_id.keys() for etype in self.edge_type_to_id.keys()
} }
return SampledSubgraphImpl( return SampledSubgraphImpl(
node_pairs=node_pairs, sampled_csc=sampled_csc,
original_edge_ids=original_edge_ids, original_edge_ids=original_edge_ids,
) )
...@@ -619,7 +619,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -619,7 +619,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])} >>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
>>> fanouts = torch.tensor([1, 1]) >>> fanouts = torch.tensor([1, 1])
>>> subgraph = graph.sample_neighbors(nodes, fanouts) >>> subgraph = graph.sample_neighbors(nodes, fanouts)
>>> print(subgraph.node_pairs) >>> print(subgraph.sampled_csc)
defaultdict(<class 'list'>, {'n1:e1:n2': (tensor([0]), defaultdict(<class 'list'>, {'n1:e1:n2': (tensor([0]),
tensor([0])), 'n2:e2:n1': (tensor([2]), tensor([0]))}) tensor([0])), 'n2:e2:n1': (tensor([2]), tensor([0]))})
""" """
...@@ -811,7 +811,7 @@ class FusedCSCSamplingGraph(SamplingGraph): ...@@ -811,7 +811,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])} >>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
>>> fanouts = torch.tensor([1, 1]) >>> fanouts = torch.tensor([1, 1])
>>> subgraph = graph.sample_layer_neighbors(nodes, fanouts) >>> subgraph = graph.sample_layer_neighbors(nodes, fanouts)
>>> print(subgraph.node_pairs) >>> print(subgraph.sampled_csc)
defaultdict(<class 'list'>, {'n1:e1:n2': (tensor([1]), defaultdict(<class 'list'>, {'n1:e1:n2': (tensor([1]),
tensor([0])), 'n2:e2:n1': (tensor([2]), tensor([0]))}) tensor([0])), 'n2:e2:n1': (tensor([2]), tensor([0]))})
""" """
......
...@@ -41,7 +41,7 @@ class InSubgraphSampler(SubgraphSampler): ...@@ -41,7 +41,7 @@ class InSubgraphSampler(SubgraphSampler):
>>> item_sampler = gb.ItemSampler(item_set, batch_size=2) >>> item_sampler = gb.ItemSampler(item_set, batch_size=2)
>>> insubgraph_sampler = gb.InSubgraphSampler(item_sampler, graph) >>> insubgraph_sampler = gb.InSubgraphSampler(item_sampler, graph)
>>> for _, data in enumerate(insubgraph_sampler): >>> for _, data in enumerate(insubgraph_sampler):
... print(data.sampled_subgraphs[0].node_pairs) ... print(data.sampled_subgraphs[0].sampled_csc)
... print(data.sampled_subgraphs[0].original_row_node_ids) ... print(data.sampled_subgraphs[0].original_row_node_ids)
... print(data.sampled_subgraphs[0].original_column_node_ids) ... print(data.sampled_subgraphs[0].original_column_node_ids)
CSCFormatBase(indptr=tensor([0, 3, 5]), CSCFormatBase(indptr=tensor([0, 3, 5]),
...@@ -79,9 +79,9 @@ class InSubgraphSampler(SubgraphSampler): ...@@ -79,9 +79,9 @@ class InSubgraphSampler(SubgraphSampler):
( (
original_row_node_ids, original_row_node_ids,
compacted_node_pairs, compacted_node_pairs,
) = unique_and_compact_node_pairs(subgraph.node_pairs, seeds) ) = unique_and_compact_node_pairs(subgraph.sampled_csc, seeds)
subgraph = FusedSampledSubgraphImpl( subgraph = FusedSampledSubgraphImpl(
node_pairs=compacted_node_pairs, sampled_csc=compacted_node_pairs,
original_column_node_ids=seeds, original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids, original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids, original_edge_ids=subgraph.original_edge_ids,
...@@ -90,9 +90,9 @@ class InSubgraphSampler(SubgraphSampler): ...@@ -90,9 +90,9 @@ class InSubgraphSampler(SubgraphSampler):
( (
original_row_node_ids, original_row_node_ids,
compacted_csc_formats, compacted_csc_formats,
) = unique_and_compact_csc_formats(subgraph.node_pairs, seeds) ) = unique_and_compact_csc_formats(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl( subgraph = SampledSubgraphImpl(
node_pairs=compacted_csc_formats, sampled_csc=compacted_csc_formats,
original_column_node_ids=seeds, original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids, original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids, original_edge_ids=subgraph.original_edge_ids,
......
...@@ -143,10 +143,10 @@ class NeighborSampler(SubgraphSampler): ...@@ -143,10 +143,10 @@ class NeighborSampler(SubgraphSampler):
original_row_node_ids, original_row_node_ids,
compacted_csc_format, compacted_csc_format,
) = unique_and_compact_csc_formats( ) = unique_and_compact_csc_formats(
subgraph.node_pairs, seeds subgraph.sampled_csc, seeds
) )
subgraph = SampledSubgraphImpl( subgraph = SampledSubgraphImpl(
node_pairs=compacted_csc_format, sampled_csc=compacted_csc_format,
original_column_node_ids=seeds, original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids, original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids, original_edge_ids=subgraph.original_edge_ids,
...@@ -156,10 +156,10 @@ class NeighborSampler(SubgraphSampler): ...@@ -156,10 +156,10 @@ class NeighborSampler(SubgraphSampler):
original_row_node_ids, original_row_node_ids,
compacted_node_pairs, compacted_node_pairs,
) = unique_and_compact_node_pairs( ) = unique_and_compact_node_pairs(
subgraph.node_pairs, seeds subgraph.sampled_csc, seeds
) )
subgraph = FusedSampledSubgraphImpl( subgraph = FusedSampledSubgraphImpl(
node_pairs=compacted_node_pairs, sampled_csc=compacted_node_pairs,
original_column_node_ids=seeds, original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids, original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids, original_edge_ids=subgraph.original_edge_ids,
...@@ -168,13 +168,9 @@ class NeighborSampler(SubgraphSampler): ...@@ -168,13 +168,9 @@ class NeighborSampler(SubgraphSampler):
( (
original_row_node_ids, original_row_node_ids,
compacted_csc_format, compacted_csc_format,
) = compact_csc_format(subgraph.node_pairs, seeds) ) = compact_csc_format(subgraph.sampled_csc, seeds)
# [TODO] For node_pairs is defined in SampledSubgraph, which is
# SampledSubgraph's parent class, and it's still inherited by
# other classes, the name cannot be changed currently. This
# part will be cleaned up later.
subgraph = SampledSubgraphImpl( subgraph = SampledSubgraphImpl(
node_pairs=compacted_csc_format, sampled_csc=compacted_csc_format,
original_column_node_ids=seeds, original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids, original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids, original_edge_ids=subgraph.original_edge_ids,
......
...@@ -24,12 +24,12 @@ class FusedSampledSubgraphImpl(SampledSubgraph): ...@@ -24,12 +24,12 @@ class FusedSampledSubgraphImpl(SampledSubgraph):
>>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])} >>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])} >>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.FusedSampledSubgraphImpl( >>> subgraph = gb.FusedSampledSubgraphImpl(
... node_pairs=node_pairs, ... sampled_csc=node_pairs,
... original_column_node_ids=original_column_node_ids, ... original_column_node_ids=original_column_node_ids,
... original_row_node_ids=original_row_node_ids, ... original_row_node_ids=original_row_node_ids,
... original_edge_ids=original_edge_ids ... original_edge_ids=original_edge_ids
... ) ... )
>>> print(subgraph.node_pairs) >>> print(subgraph.sampled_csc)
{"A:relation:B": (tensor([0, 1, 2]), tensor([0, 1, 2]))} {"A:relation:B": (tensor([0, 1, 2]), tensor([0, 1, 2]))}
>>> print(subgraph.original_column_node_ids) >>> print(subgraph.original_column_node_ids)
{'B': tensor([10, 11, 12])} {'B': tensor([10, 11, 12])}
...@@ -38,7 +38,7 @@ class FusedSampledSubgraphImpl(SampledSubgraph): ...@@ -38,7 +38,7 @@ class FusedSampledSubgraphImpl(SampledSubgraph):
>>> print(subgraph.original_edge_ids) >>> print(subgraph.original_edge_ids)
{"A:relation:B": tensor([19, 20, 21])} {"A:relation:B": tensor([19, 20, 21])}
""" """
node_pairs: Union[ sampled_csc: Union[
Dict[str, Tuple[torch.Tensor, torch.Tensor]], Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor],
] = None ] = None
...@@ -49,8 +49,8 @@ class FusedSampledSubgraphImpl(SampledSubgraph): ...@@ -49,8 +49,8 @@ class FusedSampledSubgraphImpl(SampledSubgraph):
original_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None original_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
def __post_init__(self): def __post_init__(self):
if isinstance(self.node_pairs, dict): if isinstance(self.sampled_csc, dict):
for etype, pair in self.node_pairs.items(): for etype, pair in self.sampled_csc.items():
assert ( assert (
isinstance(etype, str) isinstance(etype, str)
and len(etype_str_to_tuple(etype)) == 3 and len(etype_str_to_tuple(etype)) == 3
...@@ -63,10 +63,11 @@ class FusedSampledSubgraphImpl(SampledSubgraph): ...@@ -63,10 +63,11 @@ class FusedSampledSubgraphImpl(SampledSubgraph):
), "Nodes in pairs should be of type torch.Tensor." ), "Nodes in pairs should be of type torch.Tensor."
else: else:
assert ( assert (
isinstance(self.node_pairs, tuple) and len(self.node_pairs) == 2 isinstance(self.sampled_csc, tuple)
and len(self.sampled_csc) == 2
), "Node pair should be a source-destination tuple (u, v)." ), "Node pair should be a source-destination tuple (u, v)."
assert all( assert all(
isinstance(item, torch.Tensor) for item in self.node_pairs isinstance(item, torch.Tensor) for item in self.sampled_csc
), "Nodes in pairs should be of type torch.Tensor." ), "Nodes in pairs should be of type torch.Tensor."
def __repr__(self) -> str: def __repr__(self) -> str:
...@@ -79,18 +80,18 @@ class SampledSubgraphImpl(SampledSubgraph): ...@@ -79,18 +80,18 @@ class SampledSubgraphImpl(SampledSubgraph):
Examples Examples
-------- --------
>>> node_pairs = {"A:relation:B": CSCFormatBase(indptr=torch.tensor([0, 1, 2, 3]), >>> sampled_csc = {"A:relation:B": CSCFormatBase(indptr=torch.tensor([0, 1, 2, 3]),
... indices=torch.tensor([0, 1, 2]))} ... indices=torch.tensor([0, 1, 2]))}
>>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])} >>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])} >>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])} >>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl( >>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs, ... sampled_csc=sampled_csc,
... original_column_node_ids=original_column_node_ids, ... original_column_node_ids=original_column_node_ids,
... original_row_node_ids=original_row_node_ids, ... original_row_node_ids=original_row_node_ids,
... original_edge_ids=original_edge_ids ... original_edge_ids=original_edge_ids
... ) ... )
>>> print(subgraph.node_pairs) >>> print(subgraph.sampled_csc)
{"A:relation:B": CSCForamtBase(indptr=torch.tensor([0, 1, 2, 3]), {"A:relation:B": CSCForamtBase(indptr=torch.tensor([0, 1, 2, 3]),
... indices=torch.tensor([0, 1, 2]))} ... indices=torch.tensor([0, 1, 2]))}
>>> print(subgraph.original_column_node_ids) >>> print(subgraph.original_column_node_ids)
...@@ -100,10 +101,7 @@ class SampledSubgraphImpl(SampledSubgraph): ...@@ -100,10 +101,7 @@ class SampledSubgraphImpl(SampledSubgraph):
>>> print(subgraph.original_edge_ids) >>> print(subgraph.original_edge_ids)
{"A:relation:B": tensor([19, 20, 21])} {"A:relation:B": tensor([19, 20, 21])}
""" """
node_pairs: Union[ sampled_csc: Union[CSCFormatBase, Dict[str, CSCFormatBase]] = None
CSCFormatBase,
Dict[str, CSCFormatBase],
] = None
original_column_node_ids: Union[ original_column_node_ids: Union[
Dict[str, torch.Tensor], torch.Tensor Dict[str, torch.Tensor], torch.Tensor
] = None ] = None
...@@ -111,8 +109,8 @@ class SampledSubgraphImpl(SampledSubgraph): ...@@ -111,8 +109,8 @@ class SampledSubgraphImpl(SampledSubgraph):
original_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None original_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
def __post_init__(self): def __post_init__(self):
if isinstance(self.node_pairs, dict): if isinstance(self.sampled_csc, dict):
for etype, pair in self.node_pairs.items(): for etype, pair in self.sampled_csc.items():
assert ( assert (
isinstance(etype, str) isinstance(etype, str)
and len(etype_str_to_tuple(etype)) == 3 and len(etype_str_to_tuple(etype)) == 3
...@@ -125,13 +123,13 @@ class SampledSubgraphImpl(SampledSubgraph): ...@@ -125,13 +123,13 @@ class SampledSubgraphImpl(SampledSubgraph):
), "Nodes in pairs should be of type torch.Tensor." ), "Nodes in pairs should be of type torch.Tensor."
else: else:
assert ( assert (
self.node_pairs.indptr is not None self.sampled_csc.indptr is not None
and self.node_pairs.indices is not None and self.sampled_csc.indices is not None
), "Node pair should be have indptr and indice." ), "Node pair should be have indptr and indice."
assert isinstance( assert isinstance(
self.node_pairs.indptr, torch.Tensor self.sampled_csc.indptr, torch.Tensor
) and isinstance( ) and isinstance(
self.node_pairs.indices, torch.Tensor self.sampled_csc.indices, torch.Tensor
), "Nodes in pairs should be of type torch.Tensor." ), "Nodes in pairs should be of type torch.Tensor."
def __repr__(self) -> str: def __repr__(self) -> str:
......
...@@ -180,7 +180,7 @@ class MiniBatch: ...@@ -180,7 +180,7 @@ class MiniBatch:
return None return None
is_heterogeneous = isinstance( is_heterogeneous = isinstance(
self.sampled_subgraphs[0].node_pairs, Dict self.sampled_subgraphs[0].sampled_csc, Dict
) )
blocks = [] blocks = []
...@@ -195,9 +195,9 @@ class MiniBatch: ...@@ -195,9 +195,9 @@ class MiniBatch:
), "Missing `original_column_node_ids` in sampled subgraph." ), "Missing `original_column_node_ids` in sampled subgraph."
if is_heterogeneous: if is_heterogeneous:
if isinstance( if isinstance(
list(subgraph.node_pairs.values())[0], CSCFormatBase list(subgraph.sampled_csc.values())[0], CSCFormatBase
): ):
node_pairs = { sampled_csc = {
etype_str_to_tuple(etype): ( etype_str_to_tuple(etype): (
"csc", "csc",
( (
...@@ -211,12 +211,12 @@ class MiniBatch: ...@@ -211,12 +211,12 @@ class MiniBatch:
), ),
), ),
) )
for etype, v in subgraph.node_pairs.items() for etype, v in subgraph.sampled_csc.items()
} }
else: else:
node_pairs = { sampled_csc = {
etype_str_to_tuple(etype): v etype_str_to_tuple(etype): v
for etype, v in subgraph.node_pairs.items() for etype, v in subgraph.sampled_csc.items()
} }
num_src_nodes = { num_src_nodes = {
ntype: nodes.size(0) ntype: nodes.size(0)
...@@ -227,18 +227,18 @@ class MiniBatch: ...@@ -227,18 +227,18 @@ class MiniBatch:
for ntype, nodes in original_column_node_ids.items() for ntype, nodes in original_column_node_ids.items()
} }
else: else:
node_pairs = subgraph.node_pairs sampled_csc = subgraph.sampled_csc
if isinstance(subgraph.node_pairs, CSCFormatBase): if isinstance(subgraph.sampled_csc, CSCFormatBase):
node_pairs = ( sampled_csc = (
"csc", "csc",
( (
node_pairs.indptr, sampled_csc.indptr,
node_pairs.indices, sampled_csc.indices,
torch.arange( torch.arange(
0, 0,
node_pairs.indptr[-1], sampled_csc.indptr[-1],
device=node_pairs.indptr.device, device=sampled_csc.indptr.device,
dtype=node_pairs.indptr.dtype, dtype=sampled_csc.indptr.dtype,
), ),
), ),
) )
...@@ -246,7 +246,7 @@ class MiniBatch: ...@@ -246,7 +246,7 @@ class MiniBatch:
num_dst_nodes = original_column_node_ids.size(0) num_dst_nodes = original_column_node_ids.size(0)
blocks.append( blocks.append(
dgl.create_block( dgl.create_block(
node_pairs, sampled_csc,
num_src_nodes=num_src_nodes, num_src_nodes=num_src_nodes,
num_dst_nodes=num_dst_nodes, num_dst_nodes=num_dst_nodes,
) )
......
...@@ -20,14 +20,14 @@ class SampledSubgraph: ...@@ -20,14 +20,14 @@ class SampledSubgraph:
value type.""" value type."""
@property @property
def node_pairs( def sampled_csc(
self, self,
) -> Union[CSCFormatBase, Dict[str, CSCFormatBase],]: ) -> Union[CSCFormatBase, Dict[str, CSCFormatBase],]:
"""Returns the node pairs representing edges in csc format. """Returns the node pairs representing edges in csc format.
- If `node_pairs` is a CSCFormatBase: It should be in the csc format. - If `sampled_csc` is a CSCFormatBase: It should be in the csc format.
`indptr` stores the index in the data array where each column `indptr` stores the index in the data array where each column
starts. `indices` stores the row indices of the non-zero elements. starts. `indices` stores the row indices of the non-zero elements.
- If `node_pairs` is a dictionary: The keys should be edge type and - If `sampled_csc` is a dictionary: The keys should be edge type and
the values should be corresponding node pairs. The ids inside the values should be corresponding node pairs. The ids inside
is heterogeneous ids.""" is heterogeneous ids."""
raise NotImplementedError raise NotImplementedError
...@@ -45,7 +45,7 @@ class SampledSubgraph: ...@@ -45,7 +45,7 @@ class SampledSubgraph:
- If `original_column_node_ids` is a dictionary: The keys should be - If `original_column_node_ids` is a dictionary: The keys should be
node type and the values should be corresponding original node type and the values should be corresponding original
heterogeneous node ids. heterogeneous node ids.
If present, it means column IDs are compacted, and `node_pairs` If present, it means column IDs are compacted, and `sampled_csc`
column IDs match these compacted ones. column IDs match these compacted ones.
""" """
return None return None
...@@ -63,7 +63,7 @@ class SampledSubgraph: ...@@ -63,7 +63,7 @@ class SampledSubgraph:
- If `original_row_node_ids` is a dictionary: The keys should be - If `original_row_node_ids` is a dictionary: The keys should be
node type and the values should be corresponding original node type and the values should be corresponding original
heterogeneous node ids. heterogeneous node ids.
If present, it means row IDs are compacted, and `node_pairs` If present, it means row IDs are compacted, and `sampled_csc`
row IDs match these compacted ones.""" row IDs match these compacted ones."""
return None return None
...@@ -119,14 +119,14 @@ class SampledSubgraph: ...@@ -119,14 +119,14 @@ class SampledSubgraph:
-------- --------
>>> import dgl.graphbolt as gb >>> import dgl.graphbolt as gb
>>> import torch >>> import torch
>>> node_pairs = {"A:relation:B": gb.CSCFormatBase( >>> sampled_csc = {"A:relation:B": gb.CSCFormatBase(
... indptr=torch.tensor([0, 1, 2, 3]), ... indptr=torch.tensor([0, 1, 2, 3]),
... indices=torch.tensor([0, 1, 2]))} ... indices=torch.tensor([0, 1, 2]))}
>>> original_column_node_ids = {"B": torch.tensor([10, 11, 12])} >>> original_column_node_ids = {"B": torch.tensor([10, 11, 12])}
>>> original_row_node_ids = {"A": torch.tensor([13, 14, 15])} >>> original_row_node_ids = {"A": torch.tensor([13, 14, 15])}
>>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])} >>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl( >>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs, ... sampled_csc=sampled_csc,
... original_column_node_ids=original_column_node_ids, ... original_column_node_ids=original_column_node_ids,
... original_row_node_ids=original_row_node_ids, ... original_row_node_ids=original_row_node_ids,
... original_edge_ids=original_edge_ids ... original_edge_ids=original_edge_ids
...@@ -134,7 +134,7 @@ class SampledSubgraph: ...@@ -134,7 +134,7 @@ class SampledSubgraph:
>>> edges_to_exclude = {"A:relation:B": (torch.tensor([14, 15]), >>> edges_to_exclude = {"A:relation:B": (torch.tensor([14, 15]),
... torch.tensor([11, 12]))} ... torch.tensor([11, 12]))}
>>> result = subgraph.exclude_edges(edges_to_exclude) >>> result = subgraph.exclude_edges(edges_to_exclude)
>>> print(result.node_pairs) >>> print(result.sampled_csc)
{'A:relation:B': CSCFormatBase(indptr=tensor([0, 1, 1, 1]), {'A:relation:B': CSCFormatBase(indptr=tensor([0, 1, 1, 1]),
indices=tensor([0]), indices=tensor([0]),
)} )}
...@@ -150,7 +150,7 @@ class SampledSubgraph: ...@@ -150,7 +150,7 @@ class SampledSubgraph:
assume_num_node_within_int32 assume_num_node_within_int32
), "Values > int32 are not supported yet." ), "Values > int32 are not supported yet."
assert ( assert (
isinstance(self.node_pairs, (CSCFormatBase, tuple)) isinstance(self.sampled_csc, (CSCFormatBase, tuple))
) == isinstance(edges, tuple), ( ) == isinstance(edges, tuple), (
"The sampled subgraph and the edges to exclude should be both " "The sampled subgraph and the edges to exclude should be both "
"homogeneous or both heterogeneous." "homogeneous or both heterogeneous."
...@@ -162,9 +162,9 @@ class SampledSubgraph: ...@@ -162,9 +162,9 @@ class SampledSubgraph:
# 1. Convert the node pairs to the original ids if they are compacted. # 1. Convert the node pairs to the original ids if they are compacted.
# 2. Exclude the edges and get the index of the edges to keep. # 2. Exclude the edges and get the index of the edges to keep.
# 3. Slice the subgraph according to the index. # 3. Slice the subgraph according to the index.
if isinstance(self.node_pairs, tuple): if isinstance(self.sampled_csc, tuple):
reverse_edges = _to_reverse_ids_node_pairs( reverse_edges = _to_reverse_ids_node_pairs(
self.node_pairs, self.sampled_csc,
self.original_row_node_ids, self.original_row_node_ids,
self.original_column_node_ids, self.original_column_node_ids,
) )
...@@ -172,9 +172,9 @@ class SampledSubgraph: ...@@ -172,9 +172,9 @@ class SampledSubgraph:
reverse_edges, edges, assume_num_node_within_int32 reverse_edges, edges, assume_num_node_within_int32
) )
return calling_class(*_slice_subgraph_node_pairs(self, index)) return calling_class(*_slice_subgraph_node_pairs(self, index))
elif isinstance(self.node_pairs, CSCFormatBase): elif isinstance(self.sampled_csc, CSCFormatBase):
reverse_edges = _to_reverse_ids( reverse_edges = _to_reverse_ids(
self.node_pairs, self.sampled_csc,
self.original_row_node_ids, self.original_row_node_ids,
self.original_column_node_ids, self.original_column_node_ids,
) )
...@@ -185,7 +185,7 @@ class SampledSubgraph: ...@@ -185,7 +185,7 @@ class SampledSubgraph:
else: else:
index = {} index = {}
is_cscformat = 0 is_cscformat = 0
for etype, pair in self.node_pairs.items(): for etype, pair in self.sampled_csc.items():
if etype not in edges: if etype not in edges:
# No edges need to be excluded. # No edges need to be excluded.
index[etype] = None index[etype] = None
...@@ -314,7 +314,7 @@ def _slice_subgraph_node_pairs(subgraph: SampledSubgraph, index: torch.Tensor): ...@@ -314,7 +314,7 @@ def _slice_subgraph_node_pairs(subgraph: SampledSubgraph, index: torch.Tensor):
return ret return ret
return ( return (
_index_select(subgraph.node_pairs, index), _index_select(subgraph.sampled_csc, index),
subgraph.original_column_node_ids, subgraph.original_column_node_ids,
subgraph.original_row_node_ids, subgraph.original_row_node_ids,
_index_select(subgraph.original_edge_ids, index), _index_select(subgraph.original_edge_ids, index),
...@@ -347,7 +347,7 @@ def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor): ...@@ -347,7 +347,7 @@ def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor):
return ret return ret
return ( return (
_index_select(subgraph.node_pairs, index), _index_select(subgraph.sampled_csc, index),
subgraph.original_column_node_ids, subgraph.original_column_node_ids,
subgraph.original_row_node_ids, subgraph.original_row_node_ids,
_index_select(subgraph.original_edge_ids, index), _index_select(subgraph.original_edge_ids, index),
......
...@@ -710,10 +710,10 @@ def test_in_subgraph_node_pairs_homogeneous(): ...@@ -710,10 +710,10 @@ def test_in_subgraph_node_pairs_homogeneous():
# Verify in subgraph. # Verify in subgraph.
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs[0], torch.LongTensor([0, 3, 4, 2, 3, 1, 2]) in_subgraph.sampled_csc[0], torch.LongTensor([0, 3, 4, 2, 3, 1, 2])
) )
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs[1], torch.LongTensor([4, 4, 4, 1, 1, 3, 3]) in_subgraph.sampled_csc[1], torch.LongTensor([4, 4, 4, 1, 1, 3, 3])
) )
assert in_subgraph.original_column_node_ids is None assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None assert in_subgraph.original_row_node_ids is None
...@@ -782,28 +782,28 @@ def test_in_subgraph_node_pairs_heterogeneous(): ...@@ -782,28 +782,28 @@ def test_in_subgraph_node_pairs_heterogeneous():
# Verify in subgraph. # Verify in subgraph.
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs["N0:R0:N0"][0], torch.LongTensor([]) in_subgraph.sampled_csc["N0:R0:N0"][0], torch.LongTensor([])
) )
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs["N0:R0:N0"][1], torch.LongTensor([]) in_subgraph.sampled_csc["N0:R0:N0"][1], torch.LongTensor([])
) )
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs["N0:R1:N1"][0], torch.LongTensor([0, 1]) in_subgraph.sampled_csc["N0:R1:N1"][0], torch.LongTensor([0, 1])
) )
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs["N0:R1:N1"][1], torch.LongTensor([2, 1]) in_subgraph.sampled_csc["N0:R1:N1"][1], torch.LongTensor([2, 1])
) )
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs["N1:R2:N0"][0], torch.LongTensor([0, 1]) in_subgraph.sampled_csc["N1:R2:N0"][0], torch.LongTensor([0, 1])
) )
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs["N1:R2:N0"][1], torch.LongTensor([1, 1]) in_subgraph.sampled_csc["N1:R2:N0"][1], torch.LongTensor([1, 1])
) )
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs["N1:R3:N1"][0], torch.LongTensor([1, 2, 0]) in_subgraph.sampled_csc["N1:R3:N1"][0], torch.LongTensor([1, 2, 0])
) )
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs["N1:R3:N1"][1], torch.LongTensor([2, 2, 1]) in_subgraph.sampled_csc["N1:R3:N1"][1], torch.LongTensor([2, 2, 1])
) )
assert in_subgraph.original_column_node_ids is None assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None assert in_subgraph.original_row_node_ids is None
...@@ -841,10 +841,10 @@ def test_in_subgraph_homo(): ...@@ -841,10 +841,10 @@ def test_in_subgraph_homo():
# Verify in subgraph. # Verify in subgraph.
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs.indices, torch.LongTensor([0, 3, 4, 2, 3, 1, 2]) in_subgraph.sampled_csc.indices, torch.LongTensor([0, 3, 4, 2, 3, 1, 2])
) )
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs.indptr, torch.LongTensor([0, 3, 5, 7]) in_subgraph.sampled_csc.indptr, torch.LongTensor([0, 3, 5, 7])
) )
assert in_subgraph.original_column_node_ids is None assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None assert in_subgraph.original_row_node_ids is None
...@@ -913,28 +913,28 @@ def test_in_subgraph_hetero(): ...@@ -913,28 +913,28 @@ def test_in_subgraph_hetero():
# Verify in subgraph. # Verify in subgraph.
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs["N0:R0:N0"].indices, torch.LongTensor([]) in_subgraph.sampled_csc["N0:R0:N0"].indices, torch.LongTensor([])
) )
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs["N0:R0:N0"].indptr, torch.LongTensor([0, 0]) in_subgraph.sampled_csc["N0:R0:N0"].indptr, torch.LongTensor([0, 0])
) )
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs["N0:R1:N1"].indices, torch.LongTensor([0, 1]) in_subgraph.sampled_csc["N0:R1:N1"].indices, torch.LongTensor([0, 1])
) )
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs["N0:R1:N1"].indptr, torch.LongTensor([0, 1, 2]) in_subgraph.sampled_csc["N0:R1:N1"].indptr, torch.LongTensor([0, 1, 2])
) )
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs["N1:R2:N0"].indices, torch.LongTensor([0, 1]) in_subgraph.sampled_csc["N1:R2:N0"].indices, torch.LongTensor([0, 1])
) )
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs["N1:R2:N0"].indptr, torch.LongTensor([0, 2]) in_subgraph.sampled_csc["N1:R2:N0"].indptr, torch.LongTensor([0, 2])
) )
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs["N1:R3:N1"].indices, torch.LongTensor([1, 2, 0]) in_subgraph.sampled_csc["N1:R3:N1"].indices, torch.LongTensor([1, 2, 0])
) )
assert torch.equal( assert torch.equal(
in_subgraph.node_pairs["N1:R3:N1"].indptr, torch.LongTensor([0, 2, 3]) in_subgraph.sampled_csc["N1:R3:N1"].indptr, torch.LongTensor([0, 2, 3])
) )
assert in_subgraph.original_column_node_ids is None assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None assert in_subgraph.original_row_node_ids is None
...@@ -994,7 +994,7 @@ def test_sample_neighbors_homo(labor, indptr_dtype, indices_dtype): ...@@ -994,7 +994,7 @@ def test_sample_neighbors_homo(labor, indptr_dtype, indices_dtype):
subgraph = sampler(nodes, fanouts, output_cscformat=False) subgraph = sampler(nodes, fanouts, output_cscformat=False)
# Verify in subgraph. # Verify in subgraph.
sampled_num = subgraph.node_pairs[0].size(0) sampled_num = subgraph.sampled_csc[0].size(0)
assert sampled_num == 6 assert sampled_num == 6
assert subgraph.original_column_node_ids is None assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None assert subgraph.original_row_node_ids is None
...@@ -1084,10 +1084,10 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype): ...@@ -1084,10 +1084,10 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
torch.LongTensor([0, 0]), torch.LongTensor([0, 0]),
), ),
} }
assert len(subgraph.node_pairs) == 2 assert len(subgraph.sampled_csc) == 2
for etype, pairs in expected_node_pairs.items(): for etype, pairs in expected_node_pairs.items():
assert torch.equal(subgraph.node_pairs[etype][0], pairs[0]) assert torch.equal(subgraph.sampled_csc[etype][0], pairs[0])
assert torch.equal(subgraph.node_pairs[etype][1], pairs[1]) assert torch.equal(subgraph.sampled_csc[etype][1], pairs[1])
assert subgraph.original_column_node_ids is None assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None assert subgraph.original_row_node_ids is None
assert subgraph.original_edge_ids is None assert subgraph.original_edge_ids is None
...@@ -1128,10 +1128,10 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype): ...@@ -1128,10 +1128,10 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
torch.tensor([], dtype=indices_dtype), torch.tensor([], dtype=indices_dtype),
), ),
} }
assert len(subgraph.node_pairs) == 2 assert len(subgraph.sampled_csc) == 2
for etype, pairs in expected_node_pairs.items(): for etype, pairs in expected_node_pairs.items():
assert torch.equal(subgraph.node_pairs[etype][0], pairs[0]) assert torch.equal(subgraph.sampled_csc[etype][0], pairs[0])
assert torch.equal(subgraph.node_pairs[etype][1], pairs[1]) assert torch.equal(subgraph.sampled_csc[etype][1], pairs[1])
assert subgraph.original_column_node_ids is None assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None assert subgraph.original_row_node_ids is None
assert subgraph.original_edge_ids is None assert subgraph.original_edge_ids is None
...@@ -1199,11 +1199,11 @@ def test_sample_neighbors_fanouts( ...@@ -1199,11 +1199,11 @@ def test_sample_neighbors_fanouts(
# Verify in subgraph. # Verify in subgraph.
assert ( assert (
expected_sampled_num1 == 0 expected_sampled_num1 == 0
or subgraph.node_pairs["n1:e1:n2"][0].numel() == expected_sampled_num1 or subgraph.sampled_csc["n1:e1:n2"][0].numel() == expected_sampled_num1
) )
assert ( assert (
expected_sampled_num2 == 0 expected_sampled_num2 == 0
or subgraph.node_pairs["n2:e2:n1"][0].numel() == expected_sampled_num2 or subgraph.sampled_csc["n2:e2:n1"][0].numel() == expected_sampled_num2
) )
...@@ -1254,8 +1254,8 @@ def test_sample_neighbors_replace( ...@@ -1254,8 +1254,8 @@ def test_sample_neighbors_replace(
) )
# Verify in subgraph. # Verify in subgraph.
assert subgraph.node_pairs["n1:e1:n2"][0].numel() == expected_sampled_num1 assert subgraph.sampled_csc["n1:e1:n2"][0].numel() == expected_sampled_num1
assert subgraph.node_pairs["n2:e2:n1"][0].numel() == expected_sampled_num2 assert subgraph.sampled_csc["n2:e2:n1"][0].numel() == expected_sampled_num2
@unittest.skipIf( @unittest.skipIf(
...@@ -1408,7 +1408,7 @@ def test_sample_neighbors_probs(replace, labor, probs_name): ...@@ -1408,7 +1408,7 @@ def test_sample_neighbors_probs(replace, labor, probs_name):
) )
# Verify in subgraph. # Verify in subgraph.
sampled_num = subgraph.node_pairs[0].size(0) sampled_num = subgraph.sampled_csc[0].size(0)
if replace: if replace:
assert sampled_num == 6 assert sampled_num == 6
else: else:
...@@ -1455,7 +1455,7 @@ def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask): ...@@ -1455,7 +1455,7 @@ def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask):
) )
# Verify in subgraph. # Verify in subgraph.
sampled_num = subgraph.node_pairs[0].size(0) sampled_num = subgraph.sampled_csc[0].size(0)
assert sampled_num == 0 assert sampled_num == 0
...@@ -1965,7 +1965,7 @@ def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name): ...@@ -1965,7 +1965,7 @@ def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name):
probs_name=probs_name if probs_name != "none" else None, probs_name=probs_name if probs_name != "none" else None,
output_cscformat=False, output_cscformat=False,
) )
sampled_num = subgraph.node_pairs[0].size(0) sampled_num = subgraph.sampled_csc[0].size(0)
# Verify in subgraph. # Verify in subgraph.
if probs_name == "mask": if probs_name == "mask":
...@@ -2059,7 +2059,7 @@ def test_sample_neighbors_hetero_pick_number( ...@@ -2059,7 +2059,7 @@ def test_sample_neighbors_hetero_pick_number(
) )
if probs_name == "none": if probs_name == "none":
for etype, pairs in subgraph.node_pairs.items(): for etype, pairs in subgraph.sampled_csc.items():
fanout = fanouts[etypes[etype]] fanout = fanouts[etypes[etype]]
if fanout == -1: if fanout == -1:
assert pairs[0].size(0) == 3 assert pairs[0].size(0) == 3
...@@ -2070,7 +2070,7 @@ def test_sample_neighbors_hetero_pick_number( ...@@ -2070,7 +2070,7 @@ def test_sample_neighbors_hetero_pick_number(
assert pairs[0].size(0) == min(fanout, 3) assert pairs[0].size(0) == min(fanout, 3)
else: else:
fanout = fanouts[0] # Here fanout is the same for all etypes. fanout = fanouts[0] # Here fanout is the same for all etypes.
for etype, pairs in subgraph.node_pairs.items(): for etype, pairs in subgraph.sampled_csc.items():
if etypes[etype] == 0: if etypes[etype] == 0:
# Etype 0: 2 valid neighbors. # Etype 0: 2 valid neighbors.
if fanout == -1: if fanout == -1:
...@@ -2176,8 +2176,8 @@ def test_sample_neighbors_homo_csc_format(): ...@@ -2176,8 +2176,8 @@ def test_sample_neighbors_homo_csc_format():
) )
# Verify in subgraph. # Verify in subgraph.
sampled_indptr_num = subgraph.node_pairs.indptr.size(0) sampled_indptr_num = subgraph.sampled_csc.indptr.size(0)
sampled_num = subgraph.node_pairs.indices.size(0) sampled_num = subgraph.sampled_csc.indices.size(0)
assert sampled_indptr_num == 4 assert sampled_indptr_num == 4
assert sampled_num == 6 assert sampled_num == 6
assert subgraph.original_column_node_ids is None assert subgraph.original_column_node_ids is None
...@@ -2228,7 +2228,7 @@ def test_sample_neighbors_hetero_csc_format(labor): ...@@ -2228,7 +2228,7 @@ def test_sample_neighbors_hetero_csc_format(labor):
subgraph = sampler(nodes, fanouts, output_cscformat=True) subgraph = sampler(nodes, fanouts, output_cscformat=True)
# Verify in subgraph. # Verify in subgraph.
expected_node_pairs = { expected_sampled_csc = {
"n1:e1:n2": gb.CSCFormatBase( "n1:e1:n2": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 2]), indptr=torch.LongTensor([0, 2]),
indices=torch.LongTensor([0, 1]), indices=torch.LongTensor([0, 1]),
...@@ -2238,10 +2238,10 @@ def test_sample_neighbors_hetero_csc_format(labor): ...@@ -2238,10 +2238,10 @@ def test_sample_neighbors_hetero_csc_format(labor):
indices=torch.LongTensor([0, 2]), indices=torch.LongTensor([0, 2]),
), ),
} }
assert len(subgraph.node_pairs) == 2 assert len(subgraph.sampled_csc) == 2
for etype, pairs in expected_node_pairs.items(): for etype, pairs in expected_sampled_csc.items():
assert torch.equal(subgraph.node_pairs[etype].indptr, pairs.indptr) assert torch.equal(subgraph.sampled_csc[etype].indptr, pairs.indptr)
assert torch.equal(subgraph.node_pairs[etype].indices, pairs.indices) assert torch.equal(subgraph.sampled_csc[etype].indices, pairs.indices)
assert subgraph.original_column_node_ids is None assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None assert subgraph.original_row_node_ids is None
assert subgraph.original_edge_ids is None assert subgraph.original_edge_ids is None
...@@ -2253,7 +2253,7 @@ def test_sample_neighbors_hetero_csc_format(labor): ...@@ -2253,7 +2253,7 @@ def test_sample_neighbors_hetero_csc_format(labor):
subgraph = sampler(nodes, fanouts, output_cscformat=True) subgraph = sampler(nodes, fanouts, output_cscformat=True)
# Verify in subgraph. # Verify in subgraph.
expected_node_pairs = { expected_sampled_csc = {
"n1:e1:n2": gb.CSCFormatBase( "n1:e1:n2": gb.CSCFormatBase(
indptr=torch.LongTensor([0]), indptr=torch.LongTensor([0]),
indices=torch.LongTensor([]), indices=torch.LongTensor([]),
...@@ -2263,10 +2263,10 @@ def test_sample_neighbors_hetero_csc_format(labor): ...@@ -2263,10 +2263,10 @@ def test_sample_neighbors_hetero_csc_format(labor):
indices=torch.LongTensor([0, 2]), indices=torch.LongTensor([0, 2]),
), ),
} }
assert len(subgraph.node_pairs) == 2 assert len(subgraph.sampled_csc) == 2
for etype, pairs in expected_node_pairs.items(): for etype, pairs in expected_sampled_csc.items():
assert torch.equal(subgraph.node_pairs[etype].indptr, pairs.indptr) assert torch.equal(subgraph.sampled_csc[etype].indptr, pairs.indptr)
assert torch.equal(subgraph.node_pairs[etype].indices, pairs.indices) assert torch.equal(subgraph.sampled_csc[etype].indices, pairs.indices)
assert subgraph.original_column_node_ids is None assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None assert subgraph.original_row_node_ids is None
assert subgraph.original_edge_ids is None assert subgraph.original_edge_ids is None
...@@ -2334,16 +2334,16 @@ def test_sample_neighbors_fanouts_csc_format( ...@@ -2334,16 +2334,16 @@ def test_sample_neighbors_fanouts_csc_format(
# Verify in subgraph. # Verify in subgraph.
assert ( assert (
expected_sampled_num1 == 0 expected_sampled_num1 == 0
or subgraph.node_pairs["n1:e1:n2"].indices.numel() or subgraph.sampled_csc["n1:e1:n2"].indices.numel()
== expected_sampled_num1 == expected_sampled_num1
) )
assert subgraph.node_pairs["n1:e1:n2"].indptr.size(0) == 2 assert subgraph.sampled_csc["n1:e1:n2"].indptr.size(0) == 2
assert ( assert (
expected_sampled_num2 == 0 expected_sampled_num2 == 0
or subgraph.node_pairs["n2:e2:n1"].indices.numel() or subgraph.sampled_csc["n2:e2:n1"].indices.numel()
== expected_sampled_num2 == expected_sampled_num2
) )
assert subgraph.node_pairs["n2:e2:n1"].indptr.size(0) == 2 assert subgraph.sampled_csc["n2:e2:n1"].indptr.size(0) == 2
@unittest.skipIf( @unittest.skipIf(
...@@ -2394,13 +2394,15 @@ def test_sample_neighbors_replace_csc_format( ...@@ -2394,13 +2394,15 @@ def test_sample_neighbors_replace_csc_format(
# Verify in subgraph. # Verify in subgraph.
assert ( assert (
subgraph.node_pairs["n1:e1:n2"].indices.numel() == expected_sampled_num1 subgraph.sampled_csc["n1:e1:n2"].indices.numel()
== expected_sampled_num1
) )
assert subgraph.node_pairs["n1:e1:n2"].indptr.size(0) == 2 assert subgraph.sampled_csc["n1:e1:n2"].indptr.size(0) == 2
assert ( assert (
subgraph.node_pairs["n2:e2:n1"].indices.numel() == expected_sampled_num2 subgraph.sampled_csc["n2:e2:n1"].indices.numel()
== expected_sampled_num2
) )
assert subgraph.node_pairs["n2:e2:n1"].indptr.size(0) == 2 assert subgraph.sampled_csc["n2:e2:n1"].indptr.size(0) == 2
@unittest.skipIf( @unittest.skipIf(
...@@ -2554,8 +2556,8 @@ def test_sample_neighbors_probs_csc_format(replace, labor, probs_name): ...@@ -2554,8 +2556,8 @@ def test_sample_neighbors_probs_csc_format(replace, labor, probs_name):
) )
# Verify in subgraph. # Verify in subgraph.
sampled_num = subgraph.node_pairs.indices.size(0) sampled_num = subgraph.sampled_csc.indices.size(0)
assert subgraph.node_pairs.indptr.size(0) == 4 assert subgraph.sampled_csc.indptr.size(0) == 4
if replace: if replace:
assert sampled_num == 6 assert sampled_num == 6
else: else:
...@@ -2603,8 +2605,8 @@ def test_sample_neighbors_zero_probs_csc_format(replace, labor, probs_or_mask): ...@@ -2603,8 +2605,8 @@ def test_sample_neighbors_zero_probs_csc_format(replace, labor, probs_or_mask):
) )
# Verify in subgraph. # Verify in subgraph.
sampled_num = subgraph.node_pairs.indices.size(0) sampled_num = subgraph.sampled_csc.indices.size(0)
assert subgraph.node_pairs.indptr.size(0) == 4 assert subgraph.sampled_csc.indptr.size(0) == 4
assert sampled_num == 0 assert sampled_num == 0
...@@ -2673,8 +2675,8 @@ def test_sample_neighbors_homo_pick_number_csc_format( ...@@ -2673,8 +2675,8 @@ def test_sample_neighbors_homo_pick_number_csc_format(
probs_name=probs_name if probs_name != "none" else None, probs_name=probs_name if probs_name != "none" else None,
output_cscformat=True, output_cscformat=True,
) )
sampled_num = subgraph.node_pairs.indices.size(0) sampled_num = subgraph.sampled_csc.indices.size(0)
assert subgraph.node_pairs.indptr.size(0) == 3 assert subgraph.sampled_csc.indptr.size(0) == 3
# Verify in subgraph. # Verify in subgraph.
if probs_name == "mask": if probs_name == "mask":
if fanouts[0] == -1: if fanouts[0] == -1:
...@@ -2767,7 +2769,7 @@ def test_sample_neighbors_hetero_pick_number_csc_format( ...@@ -2767,7 +2769,7 @@ def test_sample_neighbors_hetero_pick_number_csc_format(
) )
print(subgraph) print(subgraph)
if probs_name == "none": if probs_name == "none":
for etype, pairs in subgraph.node_pairs.items(): for etype, pairs in subgraph.sampled_csc.items():
assert pairs.indptr.size(0) == 2 assert pairs.indptr.size(0) == 2
sampled_num = pairs.indices.size(0) sampled_num = pairs.indices.size(0)
fanout = fanouts[etypes[etype]] fanout = fanouts[etypes[etype]]
...@@ -2780,7 +2782,7 @@ def test_sample_neighbors_hetero_pick_number_csc_format( ...@@ -2780,7 +2782,7 @@ def test_sample_neighbors_hetero_pick_number_csc_format(
assert sampled_num == min(fanout, 3) assert sampled_num == min(fanout, 3)
else: else:
fanout = fanouts[0] # Here fanout is the same for all etypes. fanout = fanouts[0] # Here fanout is the same for all etypes.
for etype, pairs in subgraph.node_pairs.items(): for etype, pairs in subgraph.sampled_csc.items():
assert pairs.indptr.size(0) == 2 assert pairs.indptr.size(0) == 2
sampled_num = pairs.indices.size(0) sampled_num = pairs.indices.size(0)
if etypes[etype] == 0: if etypes[etype] == 0:
......
...@@ -90,11 +90,11 @@ def test_InSubgraphSampler_node_pairs_homo(): ...@@ -90,11 +90,11 @@ def test_InSubgraphSampler_node_pairs_homo():
sampled_subgraph = minibatch.sampled_subgraphs[0] sampled_subgraph = minibatch.sampled_subgraphs[0]
_src = [ _src = [
sampled_subgraph.original_row_node_ids[id] sampled_subgraph.original_row_node_ids[id]
for id in sampled_subgraph.node_pairs[0] for id in sampled_subgraph.sampled_csc[0]
] ]
_dst = [ _dst = [
sampled_subgraph.original_column_node_ids[id] sampled_subgraph.original_column_node_ids[id]
for id in sampled_subgraph.node_pairs[1] for id in sampled_subgraph.sampled_csc[1]
] ]
return _src, _dst return _src, _dst
...@@ -172,7 +172,7 @@ def test_InSubgraphSampler_node_pairs_hetero(): ...@@ -172,7 +172,7 @@ def test_InSubgraphSampler_node_pairs_hetero():
"N1:R2:N0": (torch.LongTensor([0, 1]), torch.LongTensor([0, 1])), "N1:R2:N0": (torch.LongTensor([0, 1]), torch.LongTensor([0, 1])),
"N1:R3:N1": (torch.LongTensor([]), torch.LongTensor([])), "N1:R3:N1": (torch.LongTensor([]), torch.LongTensor([])),
} }
for etype, pairs in mn.sampled_subgraphs[0].node_pairs.items(): for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items():
assert torch.equal(pairs[0], expected_node_pairs[etype][0]) assert torch.equal(pairs[0], expected_node_pairs[etype][0])
assert torch.equal(pairs[1], expected_node_pairs[etype][1]) assert torch.equal(pairs[1], expected_node_pairs[etype][1])
...@@ -187,7 +187,7 @@ def test_InSubgraphSampler_node_pairs_hetero(): ...@@ -187,7 +187,7 @@ def test_InSubgraphSampler_node_pairs_hetero():
"N1:R2:N0": (torch.LongTensor([1]), torch.LongTensor([0])), "N1:R2:N0": (torch.LongTensor([1]), torch.LongTensor([0])),
"N1:R3:N1": (torch.LongTensor([]), torch.LongTensor([])), "N1:R3:N1": (torch.LongTensor([]), torch.LongTensor([])),
} }
for etype, pairs in mn.sampled_subgraphs[0].node_pairs.items(): for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items():
assert torch.equal(pairs[0], expected_node_pairs[etype][0]) assert torch.equal(pairs[0], expected_node_pairs[etype][0])
assert torch.equal(pairs[1], expected_node_pairs[etype][1]) assert torch.equal(pairs[1], expected_node_pairs[etype][1])
...@@ -199,7 +199,7 @@ def test_InSubgraphSampler_node_pairs_hetero(): ...@@ -199,7 +199,7 @@ def test_InSubgraphSampler_node_pairs_hetero():
"N1:R2:N0": (torch.LongTensor([]), torch.LongTensor([])), "N1:R2:N0": (torch.LongTensor([]), torch.LongTensor([])),
"N1:R3:N1": (torch.LongTensor([1, 2, 0]), torch.LongTensor([0, 1, 1])), "N1:R3:N1": (torch.LongTensor([1, 2, 0]), torch.LongTensor([0, 1, 1])),
} }
for etype, pairs in mn.sampled_subgraphs[0].node_pairs.items(): for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items():
assert torch.equal(pairs[0], expected_node_pairs[etype][0]) assert torch.equal(pairs[0], expected_node_pairs[etype][0])
assert torch.equal(pairs[1], expected_node_pairs[etype][1]) assert torch.equal(pairs[1], expected_node_pairs[etype][1])
...@@ -231,28 +231,28 @@ def test_InSubgraphSampler_homo(): ...@@ -231,28 +231,28 @@ def test_InSubgraphSampler_homo():
def original_indices(minibatch): def original_indices(minibatch):
sampled_subgraph = minibatch.sampled_subgraphs[0] sampled_subgraph = minibatch.sampled_subgraphs[0]
_indices = sampled_subgraph.original_row_node_ids[ _indices = sampled_subgraph.original_row_node_ids[
sampled_subgraph.node_pairs.indices sampled_subgraph.sampled_csc.indices
] ]
return _indices return _indices
mn = next(it) mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([0])) assert torch.equal(mn.seed_nodes, torch.LongTensor([0]))
assert torch.equal( assert torch.equal(
mn.sampled_subgraphs[0].node_pairs.indptr, torch.tensor([0, 3]) mn.sampled_subgraphs[0].sampled_csc.indptr, torch.tensor([0, 3])
) )
assert torch.equal(original_indices(mn), torch.tensor([0, 1, 4])) assert torch.equal(original_indices(mn), torch.tensor([0, 1, 4]))
mn = next(it) mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([5])) assert torch.equal(mn.seed_nodes, torch.LongTensor([5]))
assert torch.equal( assert torch.equal(
mn.sampled_subgraphs[0].node_pairs.indptr, torch.tensor([0, 2]) mn.sampled_subgraphs[0].sampled_csc.indptr, torch.tensor([0, 2])
) )
assert torch.equal(original_indices(mn), torch.tensor([1, 4])) assert torch.equal(original_indices(mn), torch.tensor([1, 4]))
mn = next(it) mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([3])) assert torch.equal(mn.seed_nodes, torch.LongTensor([3]))
assert torch.equal( assert torch.equal(
mn.sampled_subgraphs[0].node_pairs.indptr, torch.tensor([0, 2]) mn.sampled_subgraphs[0].sampled_csc.indptr, torch.tensor([0, 2])
) )
assert torch.equal(original_indices(mn), torch.tensor([1, 2])) assert torch.equal(original_indices(mn), torch.tensor([1, 2]))
...@@ -312,7 +312,7 @@ def test_InSubgraphSampler_hetero(): ...@@ -312,7 +312,7 @@ def test_InSubgraphSampler_hetero():
mn = next(it) mn = next(it)
assert torch.equal(mn.seed_nodes["N0"], torch.LongTensor([1, 0])) assert torch.equal(mn.seed_nodes["N0"], torch.LongTensor([1, 0]))
expected_node_pairs = { expected_sampled_csc = {
"N0:R0:N0": gb.CSCFormatBase( "N0:R0:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1, 3]), indptr=torch.LongTensor([0, 1, 3]),
indices=torch.LongTensor([2, 1, 0]), indices=torch.LongTensor([2, 1, 0]),
...@@ -327,16 +327,16 @@ def test_InSubgraphSampler_hetero(): ...@@ -327,16 +327,16 @@ def test_InSubgraphSampler_hetero():
indptr=torch.LongTensor([0]), indices=torch.LongTensor([]) indptr=torch.LongTensor([0]), indices=torch.LongTensor([])
), ),
} }
for etype, pairs in mn.sampled_subgraphs[0].node_pairs.items(): for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items():
assert torch.equal(pairs.indices, expected_node_pairs[etype].indices) assert torch.equal(pairs.indices, expected_sampled_csc[etype].indices)
assert torch.equal(pairs.indptr, expected_node_pairs[etype].indptr) assert torch.equal(pairs.indptr, expected_sampled_csc[etype].indptr)
mn = next(it) mn = next(it)
assert mn.seed_nodes == { assert mn.seed_nodes == {
"N0": torch.LongTensor([2]), "N0": torch.LongTensor([2]),
"N1": torch.LongTensor([0]), "N1": torch.LongTensor([0]),
} }
expected_node_pairs = { expected_sampled_csc = {
"N0:R0:N0": gb.CSCFormatBase( "N0:R0:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1]), indices=torch.LongTensor([1]) indptr=torch.LongTensor([0, 1]), indices=torch.LongTensor([1])
), ),
...@@ -350,13 +350,13 @@ def test_InSubgraphSampler_hetero(): ...@@ -350,13 +350,13 @@ def test_InSubgraphSampler_hetero():
indptr=torch.LongTensor([0, 0]), indices=torch.LongTensor([]) indptr=torch.LongTensor([0, 0]), indices=torch.LongTensor([])
), ),
} }
for etype, pairs in mn.sampled_subgraphs[0].node_pairs.items(): for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items():
assert torch.equal(pairs.indices, expected_node_pairs[etype].indices) assert torch.equal(pairs.indices, expected_sampled_csc[etype].indices)
assert torch.equal(pairs.indptr, expected_node_pairs[etype].indptr) assert torch.equal(pairs.indptr, expected_sampled_csc[etype].indptr)
mn = next(it) mn = next(it)
assert torch.equal(mn.seed_nodes["N1"], torch.LongTensor([2, 1])) assert torch.equal(mn.seed_nodes["N1"], torch.LongTensor([2, 1]))
expected_node_pairs = { expected_sampled_csc = {
"N0:R0:N0": gb.CSCFormatBase( "N0:R0:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0]), indices=torch.LongTensor([]) indptr=torch.LongTensor([0]), indices=torch.LongTensor([])
), ),
...@@ -371,6 +371,6 @@ def test_InSubgraphSampler_hetero(): ...@@ -371,6 +371,6 @@ def test_InSubgraphSampler_hetero():
indices=torch.LongTensor([1, 2, 0]), indices=torch.LongTensor([1, 2, 0]),
), ),
} }
for etype, pairs in mn.sampled_subgraphs[0].node_pairs.items(): for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items():
assert torch.equal(pairs.indices, expected_node_pairs[etype].indices) assert torch.equal(pairs.indices, expected_sampled_csc[etype].indices)
assert torch.equal(pairs.indptr, expected_node_pairs[etype].indptr) assert torch.equal(pairs.indptr, expected_sampled_csc[etype].indptr)
...@@ -40,7 +40,7 @@ def create_homo_minibatch(): ...@@ -40,7 +40,7 @@ def create_homo_minibatch():
for i in range(2): for i in range(2):
subgraphs.append( subgraphs.append(
gb.FusedSampledSubgraphImpl( gb.FusedSampledSubgraphImpl(
node_pairs=node_pairs[i], sampled_csc=node_pairs[i],
original_column_node_ids=original_column_node_ids[i], original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i], original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i], original_edge_ids=original_edge_ids[i],
...@@ -94,7 +94,7 @@ def create_hetero_minibatch(): ...@@ -94,7 +94,7 @@ def create_hetero_minibatch():
for i in range(2): for i in range(2):
subgraphs.append( subgraphs.append(
gb.FusedSampledSubgraphImpl( gb.FusedSampledSubgraphImpl(
node_pairs=node_pairs[i], sampled_csc=node_pairs[i],
original_column_node_ids=original_column_node_ids[i], original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i], original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i], original_edge_ids=original_edge_ids[i],
...@@ -143,7 +143,7 @@ def test_minibatch_representation_homo(): ...@@ -143,7 +143,7 @@ def test_minibatch_representation_homo():
for i in range(2): for i in range(2):
subgraphs.append( subgraphs.append(
gb.SampledSubgraphImpl( gb.SampledSubgraphImpl(
node_pairs=csc_formats[i], sampled_csc=csc_formats[i],
original_column_node_ids=original_column_node_ids[i], original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i], original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i], original_edge_ids=original_edge_ids[i],
...@@ -197,19 +197,19 @@ def test_minibatch_representation_homo(): ...@@ -197,19 +197,19 @@ def test_minibatch_representation_homo():
) )
expect_result = str( expect_result = str(
"""MiniBatch(seed_nodes=None, """MiniBatch(seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([10, 11, 12, 13]), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6]),
indices=tensor([0, 1, 2, 2, 1, 2]),
),
original_row_node_ids=tensor([10, 11, 12, 13]),
original_edge_ids=tensor([19, 20, 21, 22, 25, 30]), original_edge_ids=tensor([19, 20, 21, 22, 25, 30]),
original_column_node_ids=tensor([10, 11, 12, 13]), original_column_node_ids=tensor([10, 11, 12, 13]),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6]),
indices=tensor([0, 1, 2, 2, 1, 2]),
),
), ),
SampledSubgraphImpl(original_row_node_ids=tensor([10, 11, 12]), SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 2, 3]),
indices=tensor([1, 2, 0]),
),
original_row_node_ids=tensor([10, 11, 12]),
original_edge_ids=tensor([10, 15, 17]), original_edge_ids=tensor([10, 15, 17]),
original_column_node_ids=tensor([10, 11]), original_column_node_ids=tensor([10, 11]),
node_pairs=CSCFormatBase(indptr=tensor([0, 2, 3]),
indices=tensor([1, 2, 0]),
),
)], )],
positive_node_pairs=CSCFormatBase(indptr=tensor([0, 2, 3]), positive_node_pairs=CSCFormatBase(indptr=tensor([0, 2, 3]),
indices=tensor([3, 4, 5]), indices=tensor([3, 4, 5]),
...@@ -304,7 +304,7 @@ def test_minibatch_representation_hetero(): ...@@ -304,7 +304,7 @@ def test_minibatch_representation_hetero():
for i in range(2): for i in range(2):
subgraphs.append( subgraphs.append(
gb.SampledSubgraphImpl( gb.SampledSubgraphImpl(
node_pairs=csc_formats[i], sampled_csc=csc_formats[i],
original_column_node_ids=original_column_node_ids[i], original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i], original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i], original_edge_ids=original_edge_ids[i],
...@@ -342,21 +342,21 @@ def test_minibatch_representation_hetero(): ...@@ -342,21 +342,21 @@ def test_minibatch_representation_hetero():
) )
expect_result = str( expect_result = str(
"""MiniBatch(seed_nodes={'B': tensor([10, 15])}, """MiniBatch(seed_nodes={'B': tensor([10, 15])},
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids={'A': tensor([ 5, 7, 9, 11]), 'B': tensor([10, 11, 12])}, sampled_subgraphs=[SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([0, 1, 1]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
indices=tensor([1, 0]),
)},
original_row_node_ids={'A': tensor([ 5, 7, 9, 11]), 'B': tensor([10, 11, 12])},
original_edge_ids={'A:r:B': tensor([19, 20, 21]), 'B:rr:A': tensor([23, 26])}, original_edge_ids={'A:r:B': tensor([19, 20, 21]), 'B:rr:A': tensor([23, 26])},
original_column_node_ids={'B': tensor([10, 11, 12]), 'A': tensor([ 5, 7, 9, 11])}, original_column_node_ids={'B': tensor([10, 11, 12]), 'A': tensor([ 5, 7, 9, 11])},
node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([0, 1, 1]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
indices=tensor([1, 0]),
)},
), ),
SampledSubgraphImpl(original_row_node_ids={'A': tensor([5, 7]), 'B': tensor([10, 11])}, SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([1, 0]),
)},
original_row_node_ids={'A': tensor([5, 7]), 'B': tensor([10, 11])},
original_edge_ids={'A:r:B': tensor([10, 12])}, original_edge_ids={'A:r:B': tensor([10, 12])},
original_column_node_ids={'B': tensor([10, 11])}, original_column_node_ids={'B': tensor([10, 11])},
node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([1, 0]),
)},
)], )],
positive_node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]), positive_node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([3, 4, 5]), indices=tensor([3, 4, 5]),
...@@ -446,7 +446,7 @@ def test_get_dgl_blocks_homo(): ...@@ -446,7 +446,7 @@ def test_get_dgl_blocks_homo():
for i in range(2): for i in range(2):
subgraphs.append( subgraphs.append(
gb.FusedSampledSubgraphImpl( gb.FusedSampledSubgraphImpl(
node_pairs=node_pairs[i], sampled_csc=node_pairs[i],
original_column_node_ids=original_column_node_ids[i], original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i], original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i], original_edge_ids=original_edge_ids[i],
...@@ -521,7 +521,7 @@ def test_get_dgl_blocks_hetero(): ...@@ -521,7 +521,7 @@ def test_get_dgl_blocks_hetero():
for i in range(2): for i in range(2):
subgraphs.append( subgraphs.append(
gb.FusedSampledSubgraphImpl( gb.FusedSampledSubgraphImpl(
node_pairs=node_pairs[i], sampled_csc=node_pairs[i],
original_column_node_ids=original_column_node_ids[i], original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i], original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i], original_edge_ids=original_edge_ids[i],
...@@ -613,7 +613,7 @@ def test_minibatch_node_pairs_with_labels(mode): ...@@ -613,7 +613,7 @@ def test_minibatch_node_pairs_with_labels(mode):
def check_dgl_blocks_hetero(minibatch, blocks): def check_dgl_blocks_hetero(minibatch, blocks):
etype = gb.etype_str_to_tuple(relation) etype = gb.etype_str_to_tuple(relation)
node_pairs = [ node_pairs = [
subgraph.node_pairs for subgraph in minibatch.sampled_subgraphs subgraph.sampled_csc for subgraph in minibatch.sampled_subgraphs
] ]
original_edge_ids = [ original_edge_ids = [
subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs
...@@ -643,7 +643,7 @@ def check_dgl_blocks_hetero(minibatch, blocks): ...@@ -643,7 +643,7 @@ def check_dgl_blocks_hetero(minibatch, blocks):
def check_dgl_blocks_homo(minibatch, blocks): def check_dgl_blocks_homo(minibatch, blocks):
node_pairs = [ node_pairs = [
subgraph.node_pairs for subgraph in minibatch.sampled_subgraphs subgraph.sampled_csc for subgraph in minibatch.sampled_subgraphs
] ]
original_edge_ids = [ original_edge_ids = [
subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs
...@@ -862,7 +862,7 @@ def create_homo_minibatch_csc_format(): ...@@ -862,7 +862,7 @@ def create_homo_minibatch_csc_format():
for i in range(2): for i in range(2):
subgraphs.append( subgraphs.append(
gb.SampledSubgraphImpl( gb.SampledSubgraphImpl(
node_pairs=csc_formats[i], sampled_csc=csc_formats[i],
original_column_node_ids=original_column_node_ids[i], original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i], original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i], original_edge_ids=original_edge_ids[i],
...@@ -877,7 +877,7 @@ def create_homo_minibatch_csc_format(): ...@@ -877,7 +877,7 @@ def create_homo_minibatch_csc_format():
def create_hetero_minibatch_csc_format(): def create_hetero_minibatch_csc_format():
node_pairs = [ sampled_csc = [
{ {
relation: gb.CSCFormatBase( relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2, 3]), indptr=torch.tensor([0, 1, 2, 3]),
...@@ -926,7 +926,7 @@ def create_hetero_minibatch_csc_format(): ...@@ -926,7 +926,7 @@ def create_hetero_minibatch_csc_format():
for i in range(2): for i in range(2):
subgraphs.append( subgraphs.append(
gb.SampledSubgraphImpl( gb.SampledSubgraphImpl(
node_pairs=node_pairs[i], sampled_csc=sampled_csc[i],
original_column_node_ids=original_column_node_ids[i], original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i], original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i], original_edge_ids=original_edge_ids[i],
...@@ -945,8 +945,8 @@ def create_hetero_minibatch_csc_format(): ...@@ -945,8 +945,8 @@ def create_hetero_minibatch_csc_format():
def check_dgl_blocks_hetero_csc_format(minibatch, blocks): def check_dgl_blocks_hetero_csc_format(minibatch, blocks):
etype = gb.etype_str_to_tuple(relation) etype = gb.etype_str_to_tuple(relation)
node_pairs = [ sampled_csc = [
subgraph.node_pairs for subgraph in minibatch.sampled_subgraphs subgraph.sampled_csc for subgraph in minibatch.sampled_subgraphs
] ]
original_edge_ids = [ original_edge_ids = [
subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs
...@@ -959,24 +959,24 @@ def check_dgl_blocks_hetero_csc_format(minibatch, blocks): ...@@ -959,24 +959,24 @@ def check_dgl_blocks_hetero_csc_format(minibatch, blocks):
for i, block in enumerate(blocks): for i, block in enumerate(blocks):
edges = block.edges(etype=etype) edges = block.edges(etype=etype)
dst_ndoes = torch.arange( dst_ndoes = torch.arange(
0, len(node_pairs[i][relation].indptr) - 1 0, len(sampled_csc[i][relation].indptr) - 1
).repeat_interleave( ).repeat_interleave(
node_pairs[i][relation].indptr[1:] sampled_csc[i][relation].indptr[1:]
- node_pairs[i][relation].indptr[:-1] - sampled_csc[i][relation].indptr[:-1]
) )
assert torch.equal(edges[0], node_pairs[i][relation].indices) assert torch.equal(edges[0], sampled_csc[i][relation].indices)
assert torch.equal(edges[1], dst_ndoes) assert torch.equal(edges[1], dst_ndoes)
assert torch.equal( assert torch.equal(
block.edges[etype].data[dgl.EID], original_edge_ids[i][relation] block.edges[etype].data[dgl.EID], original_edge_ids[i][relation]
) )
edges = blocks[0].edges(etype=gb.etype_str_to_tuple(reverse_relation)) edges = blocks[0].edges(etype=gb.etype_str_to_tuple(reverse_relation))
dst_ndoes = torch.arange( dst_ndoes = torch.arange(
0, len(node_pairs[0][reverse_relation].indptr) - 1 0, len(sampled_csc[0][reverse_relation].indptr) - 1
).repeat_interleave( ).repeat_interleave(
node_pairs[0][reverse_relation].indptr[1:] sampled_csc[0][reverse_relation].indptr[1:]
- node_pairs[0][reverse_relation].indptr[:-1] - sampled_csc[0][reverse_relation].indptr[:-1]
) )
assert torch.equal(edges[0], node_pairs[0][reverse_relation].indices) assert torch.equal(edges[0], sampled_csc[0][reverse_relation].indices)
assert torch.equal(edges[1], dst_ndoes) assert torch.equal(edges[1], dst_ndoes)
assert torch.equal( assert torch.equal(
blocks[0].srcdata[dgl.NID]["A"], original_row_node_ids[0]["A"] blocks[0].srcdata[dgl.NID]["A"], original_row_node_ids[0]["A"]
...@@ -987,8 +987,8 @@ def check_dgl_blocks_hetero_csc_format(minibatch, blocks): ...@@ -987,8 +987,8 @@ def check_dgl_blocks_hetero_csc_format(minibatch, blocks):
def check_dgl_blocks_homo_csc_format(minibatch, blocks): def check_dgl_blocks_homo_csc_format(minibatch, blocks):
node_pairs = [ sampled_csc = [
subgraph.node_pairs for subgraph in minibatch.sampled_subgraphs subgraph.sampled_csc for subgraph in minibatch.sampled_subgraphs
] ]
original_edge_ids = [ original_edge_ids = [
subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs
...@@ -999,11 +999,11 @@ def check_dgl_blocks_homo_csc_format(minibatch, blocks): ...@@ -999,11 +999,11 @@ def check_dgl_blocks_homo_csc_format(minibatch, blocks):
] ]
for i, block in enumerate(blocks): for i, block in enumerate(blocks):
dst_ndoes = torch.arange( dst_ndoes = torch.arange(
0, len(node_pairs[i].indptr) - 1 0, len(sampled_csc[i].indptr) - 1
).repeat_interleave( ).repeat_interleave(
node_pairs[i].indptr[1:] - node_pairs[i].indptr[:-1] sampled_csc[i].indptr[1:] - sampled_csc[i].indptr[:-1]
) )
assert torch.equal(block.edges()[0], node_pairs[i].indices), print( assert torch.equal(block.edges()[0], sampled_csc[i].indices), print(
block.edges() block.edges()
) )
assert torch.equal(block.edges()[1], dst_ndoes), print(block.edges()) assert torch.equal(block.edges()[1], dst_ndoes), print(block.edges())
......
...@@ -1145,7 +1145,7 @@ def test_OnDiskDataset_preprocess_homogeneous(): ...@@ -1145,7 +1145,7 @@ def test_OnDiskDataset_preprocess_homogeneous():
torch.tensor([fanout]), torch.tensor([fanout]),
output_cscformat=False, output_cscformat=False,
) )
assert len(subgraph.node_pairs[0]) <= num_samples assert len(subgraph.sampled_csc[0]) <= num_samples
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified. # All metadata fields are specified.
......
...@@ -71,7 +71,7 @@ def test_exclude_edges_homo_node_pairs(reverse_row, reverse_column): ...@@ -71,7 +71,7 @@ def test_exclude_edges_homo_node_pairs(reverse_row, reverse_column):
expected_column_node_ids = None expected_column_node_ids = None
expected_edge_ids = torch.Tensor([5, 10]) expected_edge_ids = torch.Tensor([5, 10])
_assert_container_equal(result.node_pairs, expected_node_pairs) _assert_container_equal(result.sampled_csc, expected_node_pairs)
_assert_container_equal( _assert_container_equal(
result.original_column_node_ids, expected_column_node_ids result.original_column_node_ids, expected_column_node_ids
) )
...@@ -106,7 +106,7 @@ def test_exclude_edges_hetero_node_pairs(reverse_row, reverse_column): ...@@ -106,7 +106,7 @@ def test_exclude_edges_hetero_node_pairs(reverse_row, reverse_column):
dst_to_exclude = torch.tensor([0, 2]) dst_to_exclude = torch.tensor([0, 2])
original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])} original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
subgraph = FusedSampledSubgraphImpl( subgraph = FusedSampledSubgraphImpl(
node_pairs=node_pairs, sampled_csc=node_pairs,
original_column_node_ids=original_column_node_ids, original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids, original_row_node_ids=original_row_node_ids,
original_edge_ids=original_edge_ids, original_edge_ids=original_edge_ids,
...@@ -139,7 +139,7 @@ def test_exclude_edges_hetero_node_pairs(reverse_row, reverse_column): ...@@ -139,7 +139,7 @@ def test_exclude_edges_hetero_node_pairs(reverse_row, reverse_column):
expected_column_node_ids = None expected_column_node_ids = None
expected_edge_ids = {"A:relation:B": torch.tensor([20])} expected_edge_ids = {"A:relation:B": torch.tensor([20])}
_assert_container_equal(result.node_pairs, expected_node_pairs) _assert_container_equal(result.sampled_csc, expected_node_pairs)
_assert_container_equal( _assert_container_equal(
result.original_column_node_ids, expected_column_node_ids result.original_column_node_ids, expected_column_node_ids
) )
...@@ -188,7 +188,7 @@ def test_exclude_edges_homo_deduplicated(reverse_row, reverse_column): ...@@ -188,7 +188,7 @@ def test_exclude_edges_homo_deduplicated(reverse_row, reverse_column):
expected_column_node_ids = None expected_column_node_ids = None
expected_edge_ids = torch.Tensor([5, 9]) expected_edge_ids = torch.Tensor([5, 9])
_assert_container_equal(result.node_pairs, expected_csc_formats) _assert_container_equal(result.sampled_csc, expected_csc_formats)
_assert_container_equal( _assert_container_equal(
result.original_column_node_ids, expected_column_node_ids result.original_column_node_ids, expected_column_node_ids
) )
...@@ -237,7 +237,7 @@ def test_exclude_edges_homo_duplicated(reverse_row, reverse_column): ...@@ -237,7 +237,7 @@ def test_exclude_edges_homo_duplicated(reverse_row, reverse_column):
else: else:
expected_column_node_ids = None expected_column_node_ids = None
expected_edge_ids = torch.Tensor([5, 10, 10]) expected_edge_ids = torch.Tensor([5, 10, 10])
_assert_container_equal(result.node_pairs, expected_csc_formats) _assert_container_equal(result.sampled_csc, expected_csc_formats)
_assert_container_equal( _assert_container_equal(
result.original_column_node_ids, expected_column_node_ids result.original_column_node_ids, expected_column_node_ids
) )
...@@ -272,7 +272,7 @@ def test_exclude_edges_hetero_deduplicated(reverse_row, reverse_column): ...@@ -272,7 +272,7 @@ def test_exclude_edges_hetero_deduplicated(reverse_row, reverse_column):
dst_to_exclude = torch.tensor([0, 2]) dst_to_exclude = torch.tensor([0, 2])
original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])} original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
subgraph = SampledSubgraphImpl( subgraph = SampledSubgraphImpl(
node_pairs=csc_formats, sampled_csc=csc_formats,
original_column_node_ids=original_column_node_ids, original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids, original_row_node_ids=original_row_node_ids,
original_edge_ids=original_edge_ids, original_edge_ids=original_edge_ids,
...@@ -305,7 +305,7 @@ def test_exclude_edges_hetero_deduplicated(reverse_row, reverse_column): ...@@ -305,7 +305,7 @@ def test_exclude_edges_hetero_deduplicated(reverse_row, reverse_column):
expected_column_node_ids = None expected_column_node_ids = None
expected_edge_ids = {"A:relation:B": torch.tensor([20])} expected_edge_ids = {"A:relation:B": torch.tensor([20])}
_assert_container_equal(result.node_pairs, expected_csc_formats) _assert_container_equal(result.sampled_csc, expected_csc_formats)
_assert_container_equal( _assert_container_equal(
result.original_column_node_ids, expected_column_node_ids result.original_column_node_ids, expected_column_node_ids
) )
...@@ -340,7 +340,7 @@ def test_exclude_edges_hetero_duplicated(reverse_row, reverse_column): ...@@ -340,7 +340,7 @@ def test_exclude_edges_hetero_duplicated(reverse_row, reverse_column):
dst_to_exclude = torch.tensor([0, 2]) dst_to_exclude = torch.tensor([0, 2])
original_edge_ids = {"A:relation:B": torch.tensor([19, 19, 20, 20, 21])} original_edge_ids = {"A:relation:B": torch.tensor([19, 19, 20, 20, 21])}
subgraph = SampledSubgraphImpl( subgraph = SampledSubgraphImpl(
node_pairs=csc_formats, sampled_csc=csc_formats,
original_column_node_ids=original_column_node_ids, original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids, original_row_node_ids=original_row_node_ids,
original_edge_ids=original_edge_ids, original_edge_ids=original_edge_ids,
...@@ -373,7 +373,7 @@ def test_exclude_edges_hetero_duplicated(reverse_row, reverse_column): ...@@ -373,7 +373,7 @@ def test_exclude_edges_hetero_duplicated(reverse_row, reverse_column):
expected_column_node_ids = None expected_column_node_ids = None
expected_edge_ids = {"A:relation:B": torch.tensor([20, 20])} expected_edge_ids = {"A:relation:B": torch.tensor([20, 20])}
_assert_container_equal(result.node_pairs, expected_csc_formats) _assert_container_equal(result.sampled_csc, expected_csc_formats)
_assert_container_equal( _assert_container_equal(
result.original_column_node_ids, expected_column_node_ids result.original_column_node_ids, expected_column_node_ids
) )
...@@ -403,7 +403,7 @@ def test_sampled_subgraph_to_device(): ...@@ -403,7 +403,7 @@ def test_sampled_subgraph_to_device():
dst_to_exclude = torch.tensor([10, 12]) dst_to_exclude = torch.tensor([10, 12])
original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])} original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
subgraph = FusedSampledSubgraphImpl( subgraph = FusedSampledSubgraphImpl(
node_pairs=node_pairs, sampled_csc=node_pairs,
original_column_node_ids=original_column_node_ids, original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids, original_row_node_ids=original_row_node_ids,
original_edge_ids=original_edge_ids, original_edge_ids=original_edge_ids,
...@@ -420,9 +420,9 @@ def test_sampled_subgraph_to_device(): ...@@ -420,9 +420,9 @@ def test_sampled_subgraph_to_device():
graph = graph.to("cuda") graph = graph.to("cuda")
# Check. # Check.
for key in graph.node_pairs: for key in graph.sampled_csc:
assert graph.node_pairs[key][0].device.type == "cuda" assert graph.sampled_csc[key][0].device.type == "cuda"
assert graph.node_pairs[key][1].device.type == "cuda" assert graph.sampled_csc[key][1].device.type == "cuda"
for key in graph.original_column_node_ids: for key in graph.original_column_node_ids:
assert graph.original_column_node_ids[key].device.type == "cuda" assert graph.original_column_node_ids[key].device.type == "cuda"
for key in graph.original_row_node_ids: for key in graph.original_row_node_ids:
...@@ -433,7 +433,7 @@ def test_sampled_subgraph_to_device(): ...@@ -433,7 +433,7 @@ def test_sampled_subgraph_to_device():
def test_sampled_subgraph_impl_representation_homo(): def test_sampled_subgraph_impl_representation_homo():
sampled_subgraph_impl = SampledSubgraphImpl( sampled_subgraph_impl = SampledSubgraphImpl(
node_pairs=gb.CSCFormatBase( sampled_csc=gb.CSCFormatBase(
indptr=torch.arange(0, 101, 10), indptr=torch.arange(0, 101, 10),
indices=torch.arange(10, 110), indices=torch.arange(10, 110),
), ),
...@@ -442,7 +442,17 @@ def test_sampled_subgraph_impl_representation_homo(): ...@@ -442,7 +442,17 @@ def test_sampled_subgraph_impl_representation_homo():
original_edge_ids=None, original_edge_ids=None,
) )
expected_result = str( expected_result = str(
"""SampledSubgraphImpl(original_row_node_ids=tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, """SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]),
indices=tensor([ 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37,
38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65,
66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107,
108, 109]),
),
original_row_node_ids=tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
...@@ -452,16 +462,6 @@ def test_sampled_subgraph_impl_representation_homo(): ...@@ -452,16 +462,6 @@ def test_sampled_subgraph_impl_representation_homo():
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109]), 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), original_column_node_ids=tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
node_pairs=CSCFormatBase(indptr=tensor([ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]),
indices=tensor([ 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37,
38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65,
66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107,
108, 109]),
),
)""" )"""
) )
assert str(sampled_subgraph_impl) == expected_result, print( assert str(sampled_subgraph_impl) == expected_result, print(
...@@ -471,7 +471,7 @@ def test_sampled_subgraph_impl_representation_homo(): ...@@ -471,7 +471,7 @@ def test_sampled_subgraph_impl_representation_homo():
def test_sampled_subgraph_impl_representation_hetero(): def test_sampled_subgraph_impl_representation_hetero():
sampled_subgraph_impl = SampledSubgraphImpl( sampled_subgraph_impl = SampledSubgraphImpl(
node_pairs={ sampled_csc={
"n1:e1:n2": gb.CSCFormatBase( "n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4]), indptr=torch.tensor([0, 2, 4]),
indices=torch.tensor([4, 5, 6, 7]), indices=torch.tensor([4, 5, 6, 7]),
...@@ -492,14 +492,16 @@ def test_sampled_subgraph_impl_representation_hetero(): ...@@ -492,14 +492,16 @@ def test_sampled_subgraph_impl_representation_hetero():
original_edge_ids=None, original_edge_ids=None,
) )
expected_result = str( expected_result = str(
"""SampledSubgraphImpl(original_row_node_ids={'n1': tensor([1, 0, 0, 1, 1, 0, 0, 1]), 'n2': tensor([1, 2, 0, 1, 0, 2, 0, 2, 0, 1])}, """SampledSubgraphImpl(sampled_csc={'n1:e1:n2': CSCFormatBase(indptr=tensor([0, 2, 4]),
indices=tensor([4, 5, 6, 7]),
), 'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 2, 4, 6, 8]),
indices=tensor([2, 3, 4, 5, 6, 7, 8, 9]),
)},
original_row_node_ids={'n1': tensor([1, 0, 0, 1, 1, 0, 0, 1]), 'n2': tensor([1, 2, 0, 1, 0, 2, 0, 2, 0, 1])},
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids={'n1': tensor([1, 0, 0, 1]), 'n2': tensor([1, 2])}, original_column_node_ids={'n1': tensor([1, 0, 0, 1]), 'n2': tensor([1, 2])},
node_pairs={'n1:e1:n2': CSCFormatBase(indptr=tensor([0, 2, 4]),
indices=tensor([4, 5, 6, 7]),
), 'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 2, 4, 6, 8]),
indices=tensor([2, 3, 4, 5, 6, 7, 8, 9]),
)},
)""" )"""
) )
assert str(sampled_subgraph_impl) == expected_result, print(expected_result) assert str(sampled_subgraph_impl) == expected_result, print(
sampled_subgraph_impl
)
...@@ -83,7 +83,7 @@ def test_FeatureFetcher_with_edges_homo(): ...@@ -83,7 +83,7 @@ def test_FeatureFetcher_with_edges_homo():
range_tensor = torch.arange(10) range_tensor = torch.arange(10)
subgraphs.append( subgraphs.append(
gb.FusedSampledSubgraphImpl( gb.FusedSampledSubgraphImpl(
node_pairs=(range_tensor, range_tensor), sampled_csc=(range_tensor, range_tensor),
original_column_node_ids=range_tensor, original_column_node_ids=range_tensor,
original_row_node_ids=range_tensor, original_row_node_ids=range_tensor,
original_edge_ids=torch.randint( original_edge_ids=torch.randint(
...@@ -184,7 +184,7 @@ def test_FeatureFetcher_with_edges_hetero(): ...@@ -184,7 +184,7 @@ def test_FeatureFetcher_with_edges_hetero():
for _ in range(3): for _ in range(3):
subgraphs.append( subgraphs.append(
gb.FusedSampledSubgraphImpl( gb.FusedSampledSubgraphImpl(
node_pairs={ sampled_csc={
"n1:e1:n2": ( "n1:e1:n2": (
torch.arange(10), torch.arange(10),
torch.arange(10), torch.arange(10),
......
...@@ -61,19 +61,19 @@ def test_integration_link_prediction(): ...@@ -61,19 +61,19 @@ def test_integration_link_prediction():
expected = [ expected = [
str( str(
"""MiniBatch(seed_nodes=None, """MiniBatch(seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2]),
indices=tensor([5, 4]),
),
original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]), original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2]),
indices=tensor([5, 4]),
),
), ),
SampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]), SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1]),
indices=tensor([5]),
),
original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2, 0]), original_column_node_ids=tensor([5, 3, 1, 2, 0]),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1]),
indices=tensor([5]),
),
)], )],
positive_node_pairs=(tensor([0, 1, 1, 1]), positive_node_pairs=(tensor([0, 1, 1, 1]),
tensor([2, 3, 3, 1])), tensor([2, 3, 3, 1])),
...@@ -117,19 +117,19 @@ def test_integration_link_prediction(): ...@@ -117,19 +117,19 @@ def test_integration_link_prediction():
), ),
str( str(
"""MiniBatch(seed_nodes=None, """MiniBatch(seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0, 5, 1]), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2]),
indices=tensor([1, 3]),
),
original_row_node_ids=tensor([3, 4, 0, 5, 1]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0, 5, 1]), original_column_node_ids=tensor([3, 4, 0, 5, 1]),
node_pairs=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2]),
indices=tensor([1, 3]),
),
), ),
SampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0, 5, 1]), SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2]),
indices=tensor([1, 3]),
),
original_row_node_ids=tensor([3, 4, 0, 5, 1]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0, 5, 1]), original_column_node_ids=tensor([3, 4, 0, 5, 1]),
node_pairs=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2]),
indices=tensor([1, 3]),
),
)], )],
positive_node_pairs=(tensor([0, 1, 1, 2]), positive_node_pairs=(tensor([0, 1, 1, 2]),
tensor([0, 0, 1, 1])), tensor([0, 0, 1, 1])),
...@@ -172,19 +172,19 @@ def test_integration_link_prediction(): ...@@ -172,19 +172,19 @@ def test_integration_link_prediction():
), ),
str( str(
"""MiniBatch(seed_nodes=None, """MiniBatch(seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([5, 4]), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1]),
indices=tensor([1]),
),
original_row_node_ids=tensor([5, 4]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 4]), original_column_node_ids=tensor([5, 4]),
node_pairs=CSCFormatBase(indptr=tensor([0, 0, 1]),
indices=tensor([1]),
),
), ),
SampledSubgraphImpl(original_row_node_ids=tensor([5, 4]), SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1]),
indices=tensor([1]),
),
original_row_node_ids=tensor([5, 4]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 4]), original_column_node_ids=tensor([5, 4]),
node_pairs=CSCFormatBase(indptr=tensor([0, 0, 1]),
indices=tensor([1]),
),
)], )],
positive_node_pairs=(tensor([0, 1]), positive_node_pairs=(tensor([0, 1]),
tensor([0, 0])), tensor([0, 0])),
...@@ -274,19 +274,19 @@ def test_integration_node_classification(): ...@@ -274,19 +274,19 @@ def test_integration_node_classification():
expected = [ expected = [
str( str(
"""MiniBatch(seed_nodes=None, """MiniBatch(seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 4]), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]),
indices=tensor([4, 1, 0, 1]),
),
original_row_node_ids=tensor([5, 3, 1, 2, 4]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2]), original_column_node_ids=tensor([5, 3, 1, 2]),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]),
indices=tensor([4, 1, 0, 1]),
),
), ),
SampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2]), SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]),
indices=tensor([0, 1, 0, 1]),
),
original_row_node_ids=tensor([5, 3, 1, 2]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2]), original_column_node_ids=tensor([5, 3, 1, 2]),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]),
indices=tensor([0, 1, 0, 1]),
),
)], )],
positive_node_pairs=(tensor([0, 1, 1, 1]), positive_node_pairs=(tensor([0, 1, 1, 1]),
tensor([2, 3, 3, 1])), tensor([2, 3, 3, 1])),
...@@ -315,19 +315,19 @@ def test_integration_node_classification(): ...@@ -315,19 +315,19 @@ def test_integration_node_classification():
), ),
str( str(
"""MiniBatch(seed_nodes=None, """MiniBatch(seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0]), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2]),
indices=tensor([0, 2]),
),
original_row_node_ids=tensor([3, 4, 0]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0]), original_column_node_ids=tensor([3, 4, 0]),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2, 2]),
indices=tensor([0, 2]),
),
), ),
SampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0]), SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2]),
indices=tensor([0, 2]),
),
original_row_node_ids=tensor([3, 4, 0]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0]), original_column_node_ids=tensor([3, 4, 0]),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2, 2]),
indices=tensor([0, 2]),
),
)], )],
positive_node_pairs=(tensor([0, 1, 1, 2]), positive_node_pairs=(tensor([0, 1, 1, 2]),
tensor([0, 0, 1, 1])), tensor([0, 0, 1, 1])),
...@@ -354,19 +354,19 @@ def test_integration_node_classification(): ...@@ -354,19 +354,19 @@ def test_integration_node_classification():
), ),
str( str(
"""MiniBatch(seed_nodes=None, """MiniBatch(seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([5, 4, 0]), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([0, 2]),
),
original_row_node_ids=tensor([5, 4, 0]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 4]), original_column_node_ids=tensor([5, 4]),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([0, 2]),
),
), ),
SampledSubgraphImpl(original_row_node_ids=tensor([5, 4]), SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([1, 1]),
),
original_row_node_ids=tensor([5, 4]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 4]), original_column_node_ids=tensor([5, 4]),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([1, 1]),
),
)], )],
positive_node_pairs=(tensor([0, 1]), positive_node_pairs=(tensor([0, 1]),
tensor([0, 0])), tensor([0, 0])),
......
...@@ -246,7 +246,7 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor): ...@@ -246,7 +246,7 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
for data in sampler_dp: for data in sampler_dp:
for sampledsubgraph in data.sampled_subgraphs: for sampledsubgraph in data.sampled_subgraphs:
for _, value in sampledsubgraph.node_pairs.items(): for _, value in sampledsubgraph.sampled_csc.items():
assert torch.equal( assert torch.equal(
torch.ge(value.indices, torch.zeros(len(value.indices))), torch.ge(value.indices, torch.zeros(len(value.indices))),
torch.ones(len(value.indices)), torch.ones(len(value.indices)),
...@@ -297,9 +297,11 @@ def test_SubgraphSampler_without_dedpulication_Homo(labor): ...@@ -297,9 +297,11 @@ def test_SubgraphSampler_without_dedpulication_Homo(labor):
for step, sampled_subgraph in enumerate(data.sampled_subgraphs): for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
assert len(sampled_subgraph.original_row_node_ids) == length[step] assert len(sampled_subgraph.original_row_node_ids) == length[step]
assert torch.equal( assert torch.equal(
sampled_subgraph.node_pairs.indices, compacted_indices[step] sampled_subgraph.sampled_csc.indices, compacted_indices[step]
)
assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step]
) )
assert torch.equal(sampled_subgraph.node_pairs.indptr, indptr[step])
assert torch.equal( assert torch.equal(
sampled_subgraph.original_column_node_ids, seeds[step] sampled_subgraph.original_column_node_ids, seeds[step]
) )
...@@ -372,11 +374,11 @@ def test_SubgraphSampler_without_dedpulication_Hetero(labor): ...@@ -372,11 +374,11 @@ def test_SubgraphSampler_without_dedpulication_Hetero(labor):
) )
for etype in ["n1:e1:n2", "n2:e2:n1"]: for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal( assert torch.equal(
sampled_subgraph.node_pairs[etype].indices, sampled_subgraph.sampled_csc[etype].indices,
csc_formats[step][etype].indices, csc_formats[step][etype].indices,
) )
assert torch.equal( assert torch.equal(
sampled_subgraph.node_pairs[etype].indptr, sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr, csc_formats[step][etype].indptr,
) )
...@@ -423,9 +425,11 @@ def test_SubgraphSampler_unique_csc_format_Homo(labor): ...@@ -423,9 +425,11 @@ def test_SubgraphSampler_unique_csc_format_Homo(labor):
original_row_node_ids[step], original_row_node_ids[step],
) )
assert torch.equal( assert torch.equal(
sampled_subgraph.node_pairs.indices, compacted_indices[step] sampled_subgraph.sampled_csc.indices, compacted_indices[step]
)
assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step]
) )
assert torch.equal(sampled_subgraph.node_pairs.indptr, indptr[step])
assert torch.equal( assert torch.equal(
sampled_subgraph.original_column_node_ids, seeds[step] sampled_subgraph.original_column_node_ids, seeds[step]
) )
...@@ -504,10 +508,10 @@ def test_SubgraphSampler_unique_csc_format_Hetero(labor): ...@@ -504,10 +508,10 @@ def test_SubgraphSampler_unique_csc_format_Hetero(labor):
) )
for etype in ["n1:e1:n2", "n2:e2:n1"]: for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal( assert torch.equal(
sampled_subgraph.node_pairs[etype].indices, sampled_subgraph.sampled_csc[etype].indices,
csc_formats[step][etype].indices, csc_formats[step][etype].indices,
) )
assert torch.equal( assert torch.equal(
sampled_subgraph.node_pairs[etype].indptr, sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr, csc_formats[step][etype].indptr,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment