Unverified Commit 1f9ae668 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Rename `node_pairs` to `sampled_csc`. (#6836)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 437139f5
......@@ -368,7 +368,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
... edge_type_to_id=etypes)
>>> nodes = {"N0":torch.LongTensor([1]), "N1":torch.LongTensor([1, 2])}
>>> in_subgraph = graph.in_subgraph(nodes)
>>> print(in_subgraph.node_pairs)
>>> print(in_subgraph.sampled_csc)
defaultdict(<class 'list'>, {
'N0:R0:N0': (tensor([]), tensor([])),
'N0:R1:N1': (tensor([1, 0]), tensor([1, 2])),
......@@ -436,7 +436,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
if has_original_eids:
original_edge_ids = original_hetero_edge_ids
return FusedSampledSubgraphImpl(
node_pairs=node_pairs, original_edge_ids=original_edge_ids
sampled_csc=node_pairs, original_edge_ids=original_edge_ids
)
def _convert_to_homogeneous_nodes(self, nodes, timestamps=None):
......@@ -475,7 +475,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
]
if type_per_edge is None:
# The sampled graph is already a homogeneous graph.
node_pairs = CSCFormatBase(indptr=indptr, indices=indices)
sampled_csc = CSCFormatBase(indptr=indptr, indices=indices)
else:
# The sampled graph is a fused homogenized graph, which need to be
# converted to heterogeneous graphs.
......@@ -536,7 +536,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
l = end
if has_original_eids:
original_edge_ids = original_hetero_edge_ids
node_pairs = {
sampled_csc = {
etype: CSCFormatBase(
indptr=torch.tensor(subgraph_indptr[etype]),
indices=subgraph_indice[etype],
......@@ -544,7 +544,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
for etype in self.edge_type_to_id.keys()
}
return SampledSubgraphImpl(
node_pairs=node_pairs,
sampled_csc=sampled_csc,
original_edge_ids=original_edge_ids,
)
......@@ -619,7 +619,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
>>> fanouts = torch.tensor([1, 1])
>>> subgraph = graph.sample_neighbors(nodes, fanouts)
>>> print(subgraph.node_pairs)
>>> print(subgraph.sampled_csc)
defaultdict(<class 'list'>, {'n1:e1:n2': (tensor([0]),
tensor([0])), 'n2:e2:n1': (tensor([2]), tensor([0]))})
"""
......@@ -811,7 +811,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
>>> nodes = {'n1': torch.LongTensor([0]), 'n2': torch.LongTensor([0])}
>>> fanouts = torch.tensor([1, 1])
>>> subgraph = graph.sample_layer_neighbors(nodes, fanouts)
>>> print(subgraph.node_pairs)
>>> print(subgraph.sampled_csc)
defaultdict(<class 'list'>, {'n1:e1:n2': (tensor([1]),
tensor([0])), 'n2:e2:n1': (tensor([2]), tensor([0]))})
"""
......
......@@ -41,7 +41,7 @@ class InSubgraphSampler(SubgraphSampler):
>>> item_sampler = gb.ItemSampler(item_set, batch_size=2)
>>> insubgraph_sampler = gb.InSubgraphSampler(item_sampler, graph)
>>> for _, data in enumerate(insubgraph_sampler):
... print(data.sampled_subgraphs[0].node_pairs)
... print(data.sampled_subgraphs[0].sampled_csc)
... print(data.sampled_subgraphs[0].original_row_node_ids)
... print(data.sampled_subgraphs[0].original_column_node_ids)
CSCFormatBase(indptr=tensor([0, 3, 5]),
......@@ -79,9 +79,9 @@ class InSubgraphSampler(SubgraphSampler):
(
original_row_node_ids,
compacted_node_pairs,
) = unique_and_compact_node_pairs(subgraph.node_pairs, seeds)
) = unique_and_compact_node_pairs(subgraph.sampled_csc, seeds)
subgraph = FusedSampledSubgraphImpl(
node_pairs=compacted_node_pairs,
sampled_csc=compacted_node_pairs,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
......@@ -90,9 +90,9 @@ class InSubgraphSampler(SubgraphSampler):
(
original_row_node_ids,
compacted_csc_formats,
) = unique_and_compact_csc_formats(subgraph.node_pairs, seeds)
) = unique_and_compact_csc_formats(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl(
node_pairs=compacted_csc_formats,
sampled_csc=compacted_csc_formats,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
......
......@@ -143,10 +143,10 @@ class NeighborSampler(SubgraphSampler):
original_row_node_ids,
compacted_csc_format,
) = unique_and_compact_csc_formats(
subgraph.node_pairs, seeds
subgraph.sampled_csc, seeds
)
subgraph = SampledSubgraphImpl(
node_pairs=compacted_csc_format,
sampled_csc=compacted_csc_format,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
......@@ -156,10 +156,10 @@ class NeighborSampler(SubgraphSampler):
original_row_node_ids,
compacted_node_pairs,
) = unique_and_compact_node_pairs(
subgraph.node_pairs, seeds
subgraph.sampled_csc, seeds
)
subgraph = FusedSampledSubgraphImpl(
node_pairs=compacted_node_pairs,
sampled_csc=compacted_node_pairs,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
......@@ -168,13 +168,9 @@ class NeighborSampler(SubgraphSampler):
(
original_row_node_ids,
compacted_csc_format,
) = compact_csc_format(subgraph.node_pairs, seeds)
# [TODO] For node_pairs is defined in SampledSubgraph, which is
# SampledSubgraph's parent class, and it's still inherited by
# other classes, the name cannot be changed currently. This
# part will be cleaned up later.
) = compact_csc_format(subgraph.sampled_csc, seeds)
subgraph = SampledSubgraphImpl(
node_pairs=compacted_csc_format,
sampled_csc=compacted_csc_format,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
......
......@@ -24,12 +24,12 @@ class FusedSampledSubgraphImpl(SampledSubgraph):
>>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.FusedSampledSubgraphImpl(
... node_pairs=node_pairs,
... sampled_csc=node_pairs,
... original_column_node_ids=original_column_node_ids,
... original_row_node_ids=original_row_node_ids,
... original_edge_ids=original_edge_ids
... )
>>> print(subgraph.node_pairs)
>>> print(subgraph.sampled_csc)
{"A:relation:B": (tensor([0, 1, 2]), tensor([0, 1, 2]))}
>>> print(subgraph.original_column_node_ids)
{'B': tensor([10, 11, 12])}
......@@ -38,7 +38,7 @@ class FusedSampledSubgraphImpl(SampledSubgraph):
>>> print(subgraph.original_edge_ids)
{"A:relation:B": tensor([19, 20, 21])}
"""
node_pairs: Union[
sampled_csc: Union[
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor],
] = None
......@@ -49,8 +49,8 @@ class FusedSampledSubgraphImpl(SampledSubgraph):
original_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
def __post_init__(self):
if isinstance(self.node_pairs, dict):
for etype, pair in self.node_pairs.items():
if isinstance(self.sampled_csc, dict):
for etype, pair in self.sampled_csc.items():
assert (
isinstance(etype, str)
and len(etype_str_to_tuple(etype)) == 3
......@@ -63,10 +63,11 @@ class FusedSampledSubgraphImpl(SampledSubgraph):
), "Nodes in pairs should be of type torch.Tensor."
else:
assert (
isinstance(self.node_pairs, tuple) and len(self.node_pairs) == 2
isinstance(self.sampled_csc, tuple)
and len(self.sampled_csc) == 2
), "Node pair should be a source-destination tuple (u, v)."
assert all(
isinstance(item, torch.Tensor) for item in self.node_pairs
isinstance(item, torch.Tensor) for item in self.sampled_csc
), "Nodes in pairs should be of type torch.Tensor."
def __repr__(self) -> str:
......@@ -79,18 +80,18 @@ class SampledSubgraphImpl(SampledSubgraph):
Examples
--------
>>> node_pairs = {"A:relation:B": CSCFormatBase(indptr=torch.tensor([0, 1, 2, 3]),
>>> sampled_csc = {"A:relation:B": CSCFormatBase(indptr=torch.tensor([0, 1, 2, 3]),
... indices=torch.tensor([0, 1, 2]))}
>>> original_column_node_ids = {'B': torch.tensor([10, 11, 12])}
>>> original_row_node_ids = {'A': torch.tensor([13, 14, 15])}
>>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs,
... sampled_csc=sampled_csc,
... original_column_node_ids=original_column_node_ids,
... original_row_node_ids=original_row_node_ids,
... original_edge_ids=original_edge_ids
... )
>>> print(subgraph.node_pairs)
>>> print(subgraph.sampled_csc)
{"A:relation:B": CSCForamtBase(indptr=torch.tensor([0, 1, 2, 3]),
... indices=torch.tensor([0, 1, 2]))}
>>> print(subgraph.original_column_node_ids)
......@@ -100,10 +101,7 @@ class SampledSubgraphImpl(SampledSubgraph):
>>> print(subgraph.original_edge_ids)
{"A:relation:B": tensor([19, 20, 21])}
"""
node_pairs: Union[
CSCFormatBase,
Dict[str, CSCFormatBase],
] = None
sampled_csc: Union[CSCFormatBase, Dict[str, CSCFormatBase]] = None
original_column_node_ids: Union[
Dict[str, torch.Tensor], torch.Tensor
] = None
......@@ -111,8 +109,8 @@ class SampledSubgraphImpl(SampledSubgraph):
original_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
def __post_init__(self):
if isinstance(self.node_pairs, dict):
for etype, pair in self.node_pairs.items():
if isinstance(self.sampled_csc, dict):
for etype, pair in self.sampled_csc.items():
assert (
isinstance(etype, str)
and len(etype_str_to_tuple(etype)) == 3
......@@ -125,13 +123,13 @@ class SampledSubgraphImpl(SampledSubgraph):
), "Nodes in pairs should be of type torch.Tensor."
else:
assert (
self.node_pairs.indptr is not None
and self.node_pairs.indices is not None
self.sampled_csc.indptr is not None
and self.sampled_csc.indices is not None
), "Node pair should be have indptr and indice."
assert isinstance(
self.node_pairs.indptr, torch.Tensor
self.sampled_csc.indptr, torch.Tensor
) and isinstance(
self.node_pairs.indices, torch.Tensor
self.sampled_csc.indices, torch.Tensor
), "Nodes in pairs should be of type torch.Tensor."
def __repr__(self) -> str:
......
......@@ -180,7 +180,7 @@ class MiniBatch:
return None
is_heterogeneous = isinstance(
self.sampled_subgraphs[0].node_pairs, Dict
self.sampled_subgraphs[0].sampled_csc, Dict
)
blocks = []
......@@ -195,9 +195,9 @@ class MiniBatch:
), "Missing `original_column_node_ids` in sampled subgraph."
if is_heterogeneous:
if isinstance(
list(subgraph.node_pairs.values())[0], CSCFormatBase
list(subgraph.sampled_csc.values())[0], CSCFormatBase
):
node_pairs = {
sampled_csc = {
etype_str_to_tuple(etype): (
"csc",
(
......@@ -211,12 +211,12 @@ class MiniBatch:
),
),
)
for etype, v in subgraph.node_pairs.items()
for etype, v in subgraph.sampled_csc.items()
}
else:
node_pairs = {
sampled_csc = {
etype_str_to_tuple(etype): v
for etype, v in subgraph.node_pairs.items()
for etype, v in subgraph.sampled_csc.items()
}
num_src_nodes = {
ntype: nodes.size(0)
......@@ -227,18 +227,18 @@ class MiniBatch:
for ntype, nodes in original_column_node_ids.items()
}
else:
node_pairs = subgraph.node_pairs
if isinstance(subgraph.node_pairs, CSCFormatBase):
node_pairs = (
sampled_csc = subgraph.sampled_csc
if isinstance(subgraph.sampled_csc, CSCFormatBase):
sampled_csc = (
"csc",
(
node_pairs.indptr,
node_pairs.indices,
sampled_csc.indptr,
sampled_csc.indices,
torch.arange(
0,
node_pairs.indptr[-1],
device=node_pairs.indptr.device,
dtype=node_pairs.indptr.dtype,
sampled_csc.indptr[-1],
device=sampled_csc.indptr.device,
dtype=sampled_csc.indptr.dtype,
),
),
)
......@@ -246,7 +246,7 @@ class MiniBatch:
num_dst_nodes = original_column_node_ids.size(0)
blocks.append(
dgl.create_block(
node_pairs,
sampled_csc,
num_src_nodes=num_src_nodes,
num_dst_nodes=num_dst_nodes,
)
......
......@@ -20,14 +20,14 @@ class SampledSubgraph:
value type."""
@property
def node_pairs(
def sampled_csc(
self,
) -> Union[CSCFormatBase, Dict[str, CSCFormatBase],]:
"""Returns the node pairs representing edges in csc format.
- If `node_pairs` is a CSCFormatBase: It should be in the csc format.
- If `sampled_csc` is a CSCFormatBase: It should be in the csc format.
`indptr` stores the index in the data array where each column
starts. `indices` stores the row indices of the non-zero elements.
- If `node_pairs` is a dictionary: The keys should be edge type and
- If `sampled_csc` is a dictionary: The keys should be edge type and
the values should be corresponding node pairs. The ids inside
is heterogeneous ids."""
raise NotImplementedError
......@@ -45,7 +45,7 @@ class SampledSubgraph:
- If `original_column_node_ids` is a dictionary: The keys should be
node type and the values should be corresponding original
heterogeneous node ids.
If present, it means column IDs are compacted, and `node_pairs`
If present, it means column IDs are compacted, and `sampled_csc`
column IDs match these compacted ones.
"""
return None
......@@ -63,7 +63,7 @@ class SampledSubgraph:
- If `original_row_node_ids` is a dictionary: The keys should be
node type and the values should be corresponding original
heterogeneous node ids.
If present, it means row IDs are compacted, and `node_pairs`
If present, it means row IDs are compacted, and `sampled_csc`
row IDs match these compacted ones."""
return None
......@@ -119,14 +119,14 @@ class SampledSubgraph:
--------
>>> import dgl.graphbolt as gb
>>> import torch
>>> node_pairs = {"A:relation:B": gb.CSCFormatBase(
>>> sampled_csc = {"A:relation:B": gb.CSCFormatBase(
... indptr=torch.tensor([0, 1, 2, 3]),
... indices=torch.tensor([0, 1, 2]))}
>>> original_column_node_ids = {"B": torch.tensor([10, 11, 12])}
>>> original_row_node_ids = {"A": torch.tensor([13, 14, 15])}
>>> original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
>>> subgraph = gb.SampledSubgraphImpl(
... node_pairs=node_pairs,
... sampled_csc=sampled_csc,
... original_column_node_ids=original_column_node_ids,
... original_row_node_ids=original_row_node_ids,
... original_edge_ids=original_edge_ids
......@@ -134,7 +134,7 @@ class SampledSubgraph:
>>> edges_to_exclude = {"A:relation:B": (torch.tensor([14, 15]),
... torch.tensor([11, 12]))}
>>> result = subgraph.exclude_edges(edges_to_exclude)
>>> print(result.node_pairs)
>>> print(result.sampled_csc)
{'A:relation:B': CSCFormatBase(indptr=tensor([0, 1, 1, 1]),
indices=tensor([0]),
)}
......@@ -150,7 +150,7 @@ class SampledSubgraph:
assume_num_node_within_int32
), "Values > int32 are not supported yet."
assert (
isinstance(self.node_pairs, (CSCFormatBase, tuple))
isinstance(self.sampled_csc, (CSCFormatBase, tuple))
) == isinstance(edges, tuple), (
"The sampled subgraph and the edges to exclude should be both "
"homogeneous or both heterogeneous."
......@@ -162,9 +162,9 @@ class SampledSubgraph:
# 1. Convert the node pairs to the original ids if they are compacted.
# 2. Exclude the edges and get the index of the edges to keep.
# 3. Slice the subgraph according to the index.
if isinstance(self.node_pairs, tuple):
if isinstance(self.sampled_csc, tuple):
reverse_edges = _to_reverse_ids_node_pairs(
self.node_pairs,
self.sampled_csc,
self.original_row_node_ids,
self.original_column_node_ids,
)
......@@ -172,9 +172,9 @@ class SampledSubgraph:
reverse_edges, edges, assume_num_node_within_int32
)
return calling_class(*_slice_subgraph_node_pairs(self, index))
elif isinstance(self.node_pairs, CSCFormatBase):
elif isinstance(self.sampled_csc, CSCFormatBase):
reverse_edges = _to_reverse_ids(
self.node_pairs,
self.sampled_csc,
self.original_row_node_ids,
self.original_column_node_ids,
)
......@@ -185,7 +185,7 @@ class SampledSubgraph:
else:
index = {}
is_cscformat = 0
for etype, pair in self.node_pairs.items():
for etype, pair in self.sampled_csc.items():
if etype not in edges:
# No edges need to be excluded.
index[etype] = None
......@@ -314,7 +314,7 @@ def _slice_subgraph_node_pairs(subgraph: SampledSubgraph, index: torch.Tensor):
return ret
return (
_index_select(subgraph.node_pairs, index),
_index_select(subgraph.sampled_csc, index),
subgraph.original_column_node_ids,
subgraph.original_row_node_ids,
_index_select(subgraph.original_edge_ids, index),
......@@ -347,7 +347,7 @@ def _slice_subgraph(subgraph: SampledSubgraph, index: torch.Tensor):
return ret
return (
_index_select(subgraph.node_pairs, index),
_index_select(subgraph.sampled_csc, index),
subgraph.original_column_node_ids,
subgraph.original_row_node_ids,
_index_select(subgraph.original_edge_ids, index),
......
......@@ -710,10 +710,10 @@ def test_in_subgraph_node_pairs_homogeneous():
# Verify in subgraph.
assert torch.equal(
in_subgraph.node_pairs[0], torch.LongTensor([0, 3, 4, 2, 3, 1, 2])
in_subgraph.sampled_csc[0], torch.LongTensor([0, 3, 4, 2, 3, 1, 2])
)
assert torch.equal(
in_subgraph.node_pairs[1], torch.LongTensor([4, 4, 4, 1, 1, 3, 3])
in_subgraph.sampled_csc[1], torch.LongTensor([4, 4, 4, 1, 1, 3, 3])
)
assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None
......@@ -782,28 +782,28 @@ def test_in_subgraph_node_pairs_heterogeneous():
# Verify in subgraph.
assert torch.equal(
in_subgraph.node_pairs["N0:R0:N0"][0], torch.LongTensor([])
in_subgraph.sampled_csc["N0:R0:N0"][0], torch.LongTensor([])
)
assert torch.equal(
in_subgraph.node_pairs["N0:R0:N0"][1], torch.LongTensor([])
in_subgraph.sampled_csc["N0:R0:N0"][1], torch.LongTensor([])
)
assert torch.equal(
in_subgraph.node_pairs["N0:R1:N1"][0], torch.LongTensor([0, 1])
in_subgraph.sampled_csc["N0:R1:N1"][0], torch.LongTensor([0, 1])
)
assert torch.equal(
in_subgraph.node_pairs["N0:R1:N1"][1], torch.LongTensor([2, 1])
in_subgraph.sampled_csc["N0:R1:N1"][1], torch.LongTensor([2, 1])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R2:N0"][0], torch.LongTensor([0, 1])
in_subgraph.sampled_csc["N1:R2:N0"][0], torch.LongTensor([0, 1])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R2:N0"][1], torch.LongTensor([1, 1])
in_subgraph.sampled_csc["N1:R2:N0"][1], torch.LongTensor([1, 1])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R3:N1"][0], torch.LongTensor([1, 2, 0])
in_subgraph.sampled_csc["N1:R3:N1"][0], torch.LongTensor([1, 2, 0])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R3:N1"][1], torch.LongTensor([2, 2, 1])
in_subgraph.sampled_csc["N1:R3:N1"][1], torch.LongTensor([2, 2, 1])
)
assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None
......@@ -841,10 +841,10 @@ def test_in_subgraph_homo():
# Verify in subgraph.
assert torch.equal(
in_subgraph.node_pairs.indices, torch.LongTensor([0, 3, 4, 2, 3, 1, 2])
in_subgraph.sampled_csc.indices, torch.LongTensor([0, 3, 4, 2, 3, 1, 2])
)
assert torch.equal(
in_subgraph.node_pairs.indptr, torch.LongTensor([0, 3, 5, 7])
in_subgraph.sampled_csc.indptr, torch.LongTensor([0, 3, 5, 7])
)
assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None
......@@ -913,28 +913,28 @@ def test_in_subgraph_hetero():
# Verify in subgraph.
assert torch.equal(
in_subgraph.node_pairs["N0:R0:N0"].indices, torch.LongTensor([])
in_subgraph.sampled_csc["N0:R0:N0"].indices, torch.LongTensor([])
)
assert torch.equal(
in_subgraph.node_pairs["N0:R0:N0"].indptr, torch.LongTensor([0, 0])
in_subgraph.sampled_csc["N0:R0:N0"].indptr, torch.LongTensor([0, 0])
)
assert torch.equal(
in_subgraph.node_pairs["N0:R1:N1"].indices, torch.LongTensor([0, 1])
in_subgraph.sampled_csc["N0:R1:N1"].indices, torch.LongTensor([0, 1])
)
assert torch.equal(
in_subgraph.node_pairs["N0:R1:N1"].indptr, torch.LongTensor([0, 1, 2])
in_subgraph.sampled_csc["N0:R1:N1"].indptr, torch.LongTensor([0, 1, 2])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R2:N0"].indices, torch.LongTensor([0, 1])
in_subgraph.sampled_csc["N1:R2:N0"].indices, torch.LongTensor([0, 1])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R2:N0"].indptr, torch.LongTensor([0, 2])
in_subgraph.sampled_csc["N1:R2:N0"].indptr, torch.LongTensor([0, 2])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R3:N1"].indices, torch.LongTensor([1, 2, 0])
in_subgraph.sampled_csc["N1:R3:N1"].indices, torch.LongTensor([1, 2, 0])
)
assert torch.equal(
in_subgraph.node_pairs["N1:R3:N1"].indptr, torch.LongTensor([0, 2, 3])
in_subgraph.sampled_csc["N1:R3:N1"].indptr, torch.LongTensor([0, 2, 3])
)
assert in_subgraph.original_column_node_ids is None
assert in_subgraph.original_row_node_ids is None
......@@ -994,7 +994,7 @@ def test_sample_neighbors_homo(labor, indptr_dtype, indices_dtype):
subgraph = sampler(nodes, fanouts, output_cscformat=False)
# Verify in subgraph.
sampled_num = subgraph.node_pairs[0].size(0)
sampled_num = subgraph.sampled_csc[0].size(0)
assert sampled_num == 6
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
......@@ -1084,10 +1084,10 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
torch.LongTensor([0, 0]),
),
}
assert len(subgraph.node_pairs) == 2
assert len(subgraph.sampled_csc) == 2
for etype, pairs in expected_node_pairs.items():
assert torch.equal(subgraph.node_pairs[etype][0], pairs[0])
assert torch.equal(subgraph.node_pairs[etype][1], pairs[1])
assert torch.equal(subgraph.sampled_csc[etype][0], pairs[0])
assert torch.equal(subgraph.sampled_csc[etype][1], pairs[1])
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
assert subgraph.original_edge_ids is None
......@@ -1128,10 +1128,10 @@ def test_sample_neighbors_hetero(labor, indptr_dtype, indices_dtype):
torch.tensor([], dtype=indices_dtype),
),
}
assert len(subgraph.node_pairs) == 2
assert len(subgraph.sampled_csc) == 2
for etype, pairs in expected_node_pairs.items():
assert torch.equal(subgraph.node_pairs[etype][0], pairs[0])
assert torch.equal(subgraph.node_pairs[etype][1], pairs[1])
assert torch.equal(subgraph.sampled_csc[etype][0], pairs[0])
assert torch.equal(subgraph.sampled_csc[etype][1], pairs[1])
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
assert subgraph.original_edge_ids is None
......@@ -1199,11 +1199,11 @@ def test_sample_neighbors_fanouts(
# Verify in subgraph.
assert (
expected_sampled_num1 == 0
or subgraph.node_pairs["n1:e1:n2"][0].numel() == expected_sampled_num1
or subgraph.sampled_csc["n1:e1:n2"][0].numel() == expected_sampled_num1
)
assert (
expected_sampled_num2 == 0
or subgraph.node_pairs["n2:e2:n1"][0].numel() == expected_sampled_num2
or subgraph.sampled_csc["n2:e2:n1"][0].numel() == expected_sampled_num2
)
......@@ -1254,8 +1254,8 @@ def test_sample_neighbors_replace(
)
# Verify in subgraph.
assert subgraph.node_pairs["n1:e1:n2"][0].numel() == expected_sampled_num1
assert subgraph.node_pairs["n2:e2:n1"][0].numel() == expected_sampled_num2
assert subgraph.sampled_csc["n1:e1:n2"][0].numel() == expected_sampled_num1
assert subgraph.sampled_csc["n2:e2:n1"][0].numel() == expected_sampled_num2
@unittest.skipIf(
......@@ -1408,7 +1408,7 @@ def test_sample_neighbors_probs(replace, labor, probs_name):
)
# Verify in subgraph.
sampled_num = subgraph.node_pairs[0].size(0)
sampled_num = subgraph.sampled_csc[0].size(0)
if replace:
assert sampled_num == 6
else:
......@@ -1455,7 +1455,7 @@ def test_sample_neighbors_zero_probs(replace, labor, probs_or_mask):
)
# Verify in subgraph.
sampled_num = subgraph.node_pairs[0].size(0)
sampled_num = subgraph.sampled_csc[0].size(0)
assert sampled_num == 0
......@@ -1965,7 +1965,7 @@ def test_sample_neighbors_homo_pick_number(fanouts, replace, labor, probs_name):
probs_name=probs_name if probs_name != "none" else None,
output_cscformat=False,
)
sampled_num = subgraph.node_pairs[0].size(0)
sampled_num = subgraph.sampled_csc[0].size(0)
# Verify in subgraph.
if probs_name == "mask":
......@@ -2059,7 +2059,7 @@ def test_sample_neighbors_hetero_pick_number(
)
if probs_name == "none":
for etype, pairs in subgraph.node_pairs.items():
for etype, pairs in subgraph.sampled_csc.items():
fanout = fanouts[etypes[etype]]
if fanout == -1:
assert pairs[0].size(0) == 3
......@@ -2070,7 +2070,7 @@ def test_sample_neighbors_hetero_pick_number(
assert pairs[0].size(0) == min(fanout, 3)
else:
fanout = fanouts[0] # Here fanout is the same for all etypes.
for etype, pairs in subgraph.node_pairs.items():
for etype, pairs in subgraph.sampled_csc.items():
if etypes[etype] == 0:
# Etype 0: 2 valid neighbors.
if fanout == -1:
......@@ -2176,8 +2176,8 @@ def test_sample_neighbors_homo_csc_format():
)
# Verify in subgraph.
sampled_indptr_num = subgraph.node_pairs.indptr.size(0)
sampled_num = subgraph.node_pairs.indices.size(0)
sampled_indptr_num = subgraph.sampled_csc.indptr.size(0)
sampled_num = subgraph.sampled_csc.indices.size(0)
assert sampled_indptr_num == 4
assert sampled_num == 6
assert subgraph.original_column_node_ids is None
......@@ -2228,7 +2228,7 @@ def test_sample_neighbors_hetero_csc_format(labor):
subgraph = sampler(nodes, fanouts, output_cscformat=True)
# Verify in subgraph.
expected_node_pairs = {
expected_sampled_csc = {
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 2]),
indices=torch.LongTensor([0, 1]),
......@@ -2238,10 +2238,10 @@ def test_sample_neighbors_hetero_csc_format(labor):
indices=torch.LongTensor([0, 2]),
),
}
assert len(subgraph.node_pairs) == 2
for etype, pairs in expected_node_pairs.items():
assert torch.equal(subgraph.node_pairs[etype].indptr, pairs.indptr)
assert torch.equal(subgraph.node_pairs[etype].indices, pairs.indices)
assert len(subgraph.sampled_csc) == 2
for etype, pairs in expected_sampled_csc.items():
assert torch.equal(subgraph.sampled_csc[etype].indptr, pairs.indptr)
assert torch.equal(subgraph.sampled_csc[etype].indices, pairs.indices)
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
assert subgraph.original_edge_ids is None
......@@ -2253,7 +2253,7 @@ def test_sample_neighbors_hetero_csc_format(labor):
subgraph = sampler(nodes, fanouts, output_cscformat=True)
# Verify in subgraph.
expected_node_pairs = {
expected_sampled_csc = {
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.LongTensor([0]),
indices=torch.LongTensor([]),
......@@ -2263,10 +2263,10 @@ def test_sample_neighbors_hetero_csc_format(labor):
indices=torch.LongTensor([0, 2]),
),
}
assert len(subgraph.node_pairs) == 2
for etype, pairs in expected_node_pairs.items():
assert torch.equal(subgraph.node_pairs[etype].indptr, pairs.indptr)
assert torch.equal(subgraph.node_pairs[etype].indices, pairs.indices)
assert len(subgraph.sampled_csc) == 2
for etype, pairs in expected_sampled_csc.items():
assert torch.equal(subgraph.sampled_csc[etype].indptr, pairs.indptr)
assert torch.equal(subgraph.sampled_csc[etype].indices, pairs.indices)
assert subgraph.original_column_node_ids is None
assert subgraph.original_row_node_ids is None
assert subgraph.original_edge_ids is None
......@@ -2334,16 +2334,16 @@ def test_sample_neighbors_fanouts_csc_format(
# Verify in subgraph.
assert (
expected_sampled_num1 == 0
or subgraph.node_pairs["n1:e1:n2"].indices.numel()
or subgraph.sampled_csc["n1:e1:n2"].indices.numel()
== expected_sampled_num1
)
assert subgraph.node_pairs["n1:e1:n2"].indptr.size(0) == 2
assert subgraph.sampled_csc["n1:e1:n2"].indptr.size(0) == 2
assert (
expected_sampled_num2 == 0
or subgraph.node_pairs["n2:e2:n1"].indices.numel()
or subgraph.sampled_csc["n2:e2:n1"].indices.numel()
== expected_sampled_num2
)
assert subgraph.node_pairs["n2:e2:n1"].indptr.size(0) == 2
assert subgraph.sampled_csc["n2:e2:n1"].indptr.size(0) == 2
@unittest.skipIf(
......@@ -2394,13 +2394,15 @@ def test_sample_neighbors_replace_csc_format(
# Verify in subgraph.
assert (
subgraph.node_pairs["n1:e1:n2"].indices.numel() == expected_sampled_num1
subgraph.sampled_csc["n1:e1:n2"].indices.numel()
== expected_sampled_num1
)
assert subgraph.node_pairs["n1:e1:n2"].indptr.size(0) == 2
assert subgraph.sampled_csc["n1:e1:n2"].indptr.size(0) == 2
assert (
subgraph.node_pairs["n2:e2:n1"].indices.numel() == expected_sampled_num2
subgraph.sampled_csc["n2:e2:n1"].indices.numel()
== expected_sampled_num2
)
assert subgraph.node_pairs["n2:e2:n1"].indptr.size(0) == 2
assert subgraph.sampled_csc["n2:e2:n1"].indptr.size(0) == 2
@unittest.skipIf(
......@@ -2554,8 +2556,8 @@ def test_sample_neighbors_probs_csc_format(replace, labor, probs_name):
)
# Verify in subgraph.
sampled_num = subgraph.node_pairs.indices.size(0)
assert subgraph.node_pairs.indptr.size(0) == 4
sampled_num = subgraph.sampled_csc.indices.size(0)
assert subgraph.sampled_csc.indptr.size(0) == 4
if replace:
assert sampled_num == 6
else:
......@@ -2603,8 +2605,8 @@ def test_sample_neighbors_zero_probs_csc_format(replace, labor, probs_or_mask):
)
# Verify in subgraph.
sampled_num = subgraph.node_pairs.indices.size(0)
assert subgraph.node_pairs.indptr.size(0) == 4
sampled_num = subgraph.sampled_csc.indices.size(0)
assert subgraph.sampled_csc.indptr.size(0) == 4
assert sampled_num == 0
......@@ -2673,8 +2675,8 @@ def test_sample_neighbors_homo_pick_number_csc_format(
probs_name=probs_name if probs_name != "none" else None,
output_cscformat=True,
)
sampled_num = subgraph.node_pairs.indices.size(0)
assert subgraph.node_pairs.indptr.size(0) == 3
sampled_num = subgraph.sampled_csc.indices.size(0)
assert subgraph.sampled_csc.indptr.size(0) == 3
# Verify in subgraph.
if probs_name == "mask":
if fanouts[0] == -1:
......@@ -2767,7 +2769,7 @@ def test_sample_neighbors_hetero_pick_number_csc_format(
)
print(subgraph)
if probs_name == "none":
for etype, pairs in subgraph.node_pairs.items():
for etype, pairs in subgraph.sampled_csc.items():
assert pairs.indptr.size(0) == 2
sampled_num = pairs.indices.size(0)
fanout = fanouts[etypes[etype]]
......@@ -2780,7 +2782,7 @@ def test_sample_neighbors_hetero_pick_number_csc_format(
assert sampled_num == min(fanout, 3)
else:
fanout = fanouts[0] # Here fanout is the same for all etypes.
for etype, pairs in subgraph.node_pairs.items():
for etype, pairs in subgraph.sampled_csc.items():
assert pairs.indptr.size(0) == 2
sampled_num = pairs.indices.size(0)
if etypes[etype] == 0:
......
......@@ -90,11 +90,11 @@ def test_InSubgraphSampler_node_pairs_homo():
sampled_subgraph = minibatch.sampled_subgraphs[0]
_src = [
sampled_subgraph.original_row_node_ids[id]
for id in sampled_subgraph.node_pairs[0]
for id in sampled_subgraph.sampled_csc[0]
]
_dst = [
sampled_subgraph.original_column_node_ids[id]
for id in sampled_subgraph.node_pairs[1]
for id in sampled_subgraph.sampled_csc[1]
]
return _src, _dst
......@@ -172,7 +172,7 @@ def test_InSubgraphSampler_node_pairs_hetero():
"N1:R2:N0": (torch.LongTensor([0, 1]), torch.LongTensor([0, 1])),
"N1:R3:N1": (torch.LongTensor([]), torch.LongTensor([])),
}
for etype, pairs in mn.sampled_subgraphs[0].node_pairs.items():
for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items():
assert torch.equal(pairs[0], expected_node_pairs[etype][0])
assert torch.equal(pairs[1], expected_node_pairs[etype][1])
......@@ -187,7 +187,7 @@ def test_InSubgraphSampler_node_pairs_hetero():
"N1:R2:N0": (torch.LongTensor([1]), torch.LongTensor([0])),
"N1:R3:N1": (torch.LongTensor([]), torch.LongTensor([])),
}
for etype, pairs in mn.sampled_subgraphs[0].node_pairs.items():
for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items():
assert torch.equal(pairs[0], expected_node_pairs[etype][0])
assert torch.equal(pairs[1], expected_node_pairs[etype][1])
......@@ -199,7 +199,7 @@ def test_InSubgraphSampler_node_pairs_hetero():
"N1:R2:N0": (torch.LongTensor([]), torch.LongTensor([])),
"N1:R3:N1": (torch.LongTensor([1, 2, 0]), torch.LongTensor([0, 1, 1])),
}
for etype, pairs in mn.sampled_subgraphs[0].node_pairs.items():
for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items():
assert torch.equal(pairs[0], expected_node_pairs[etype][0])
assert torch.equal(pairs[1], expected_node_pairs[etype][1])
......@@ -231,28 +231,28 @@ def test_InSubgraphSampler_homo():
def original_indices(minibatch):
sampled_subgraph = minibatch.sampled_subgraphs[0]
_indices = sampled_subgraph.original_row_node_ids[
sampled_subgraph.node_pairs.indices
sampled_subgraph.sampled_csc.indices
]
return _indices
mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([0]))
assert torch.equal(
mn.sampled_subgraphs[0].node_pairs.indptr, torch.tensor([0, 3])
mn.sampled_subgraphs[0].sampled_csc.indptr, torch.tensor([0, 3])
)
assert torch.equal(original_indices(mn), torch.tensor([0, 1, 4]))
mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([5]))
assert torch.equal(
mn.sampled_subgraphs[0].node_pairs.indptr, torch.tensor([0, 2])
mn.sampled_subgraphs[0].sampled_csc.indptr, torch.tensor([0, 2])
)
assert torch.equal(original_indices(mn), torch.tensor([1, 4]))
mn = next(it)
assert torch.equal(mn.seed_nodes, torch.LongTensor([3]))
assert torch.equal(
mn.sampled_subgraphs[0].node_pairs.indptr, torch.tensor([0, 2])
mn.sampled_subgraphs[0].sampled_csc.indptr, torch.tensor([0, 2])
)
assert torch.equal(original_indices(mn), torch.tensor([1, 2]))
......@@ -312,7 +312,7 @@ def test_InSubgraphSampler_hetero():
mn = next(it)
assert torch.equal(mn.seed_nodes["N0"], torch.LongTensor([1, 0]))
expected_node_pairs = {
expected_sampled_csc = {
"N0:R0:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1, 3]),
indices=torch.LongTensor([2, 1, 0]),
......@@ -327,16 +327,16 @@ def test_InSubgraphSampler_hetero():
indptr=torch.LongTensor([0]), indices=torch.LongTensor([])
),
}
for etype, pairs in mn.sampled_subgraphs[0].node_pairs.items():
assert torch.equal(pairs.indices, expected_node_pairs[etype].indices)
assert torch.equal(pairs.indptr, expected_node_pairs[etype].indptr)
for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items():
assert torch.equal(pairs.indices, expected_sampled_csc[etype].indices)
assert torch.equal(pairs.indptr, expected_sampled_csc[etype].indptr)
mn = next(it)
assert mn.seed_nodes == {
"N0": torch.LongTensor([2]),
"N1": torch.LongTensor([0]),
}
expected_node_pairs = {
expected_sampled_csc = {
"N0:R0:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0, 1]), indices=torch.LongTensor([1])
),
......@@ -350,13 +350,13 @@ def test_InSubgraphSampler_hetero():
indptr=torch.LongTensor([0, 0]), indices=torch.LongTensor([])
),
}
for etype, pairs in mn.sampled_subgraphs[0].node_pairs.items():
assert torch.equal(pairs.indices, expected_node_pairs[etype].indices)
assert torch.equal(pairs.indptr, expected_node_pairs[etype].indptr)
for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items():
assert torch.equal(pairs.indices, expected_sampled_csc[etype].indices)
assert torch.equal(pairs.indptr, expected_sampled_csc[etype].indptr)
mn = next(it)
assert torch.equal(mn.seed_nodes["N1"], torch.LongTensor([2, 1]))
expected_node_pairs = {
expected_sampled_csc = {
"N0:R0:N0": gb.CSCFormatBase(
indptr=torch.LongTensor([0]), indices=torch.LongTensor([])
),
......@@ -371,6 +371,6 @@ def test_InSubgraphSampler_hetero():
indices=torch.LongTensor([1, 2, 0]),
),
}
for etype, pairs in mn.sampled_subgraphs[0].node_pairs.items():
assert torch.equal(pairs.indices, expected_node_pairs[etype].indices)
assert torch.equal(pairs.indptr, expected_node_pairs[etype].indptr)
for etype, pairs in mn.sampled_subgraphs[0].sampled_csc.items():
assert torch.equal(pairs.indices, expected_sampled_csc[etype].indices)
assert torch.equal(pairs.indptr, expected_sampled_csc[etype].indptr)
......@@ -40,7 +40,7 @@ def create_homo_minibatch():
for i in range(2):
subgraphs.append(
gb.FusedSampledSubgraphImpl(
node_pairs=node_pairs[i],
sampled_csc=node_pairs[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i],
......@@ -94,7 +94,7 @@ def create_hetero_minibatch():
for i in range(2):
subgraphs.append(
gb.FusedSampledSubgraphImpl(
node_pairs=node_pairs[i],
sampled_csc=node_pairs[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i],
......@@ -143,7 +143,7 @@ def test_minibatch_representation_homo():
for i in range(2):
subgraphs.append(
gb.SampledSubgraphImpl(
node_pairs=csc_formats[i],
sampled_csc=csc_formats[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i],
......@@ -197,19 +197,19 @@ def test_minibatch_representation_homo():
)
expect_result = str(
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([10, 11, 12, 13]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6]),
indices=tensor([0, 1, 2, 2, 1, 2]),
),
original_row_node_ids=tensor([10, 11, 12, 13]),
original_edge_ids=tensor([19, 20, 21, 22, 25, 30]),
original_column_node_ids=tensor([10, 11, 12, 13]),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6]),
indices=tensor([0, 1, 2, 2, 1, 2]),
),
),
SampledSubgraphImpl(original_row_node_ids=tensor([10, 11, 12]),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 2, 3]),
indices=tensor([1, 2, 0]),
),
original_row_node_ids=tensor([10, 11, 12]),
original_edge_ids=tensor([10, 15, 17]),
original_column_node_ids=tensor([10, 11]),
node_pairs=CSCFormatBase(indptr=tensor([0, 2, 3]),
indices=tensor([1, 2, 0]),
),
)],
positive_node_pairs=CSCFormatBase(indptr=tensor([0, 2, 3]),
indices=tensor([3, 4, 5]),
......@@ -304,7 +304,7 @@ def test_minibatch_representation_hetero():
for i in range(2):
subgraphs.append(
gb.SampledSubgraphImpl(
node_pairs=csc_formats[i],
sampled_csc=csc_formats[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i],
......@@ -342,21 +342,21 @@ def test_minibatch_representation_hetero():
)
expect_result = str(
"""MiniBatch(seed_nodes={'B': tensor([10, 15])},
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids={'A': tensor([ 5, 7, 9, 11]), 'B': tensor([10, 11, 12])},
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([0, 1, 1]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
indices=tensor([1, 0]),
)},
original_row_node_ids={'A': tensor([ 5, 7, 9, 11]), 'B': tensor([10, 11, 12])},
original_edge_ids={'A:r:B': tensor([19, 20, 21]), 'B:rr:A': tensor([23, 26])},
original_column_node_ids={'B': tensor([10, 11, 12]), 'A': tensor([ 5, 7, 9, 11])},
node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([0, 1, 1]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
indices=tensor([1, 0]),
)},
),
SampledSubgraphImpl(original_row_node_ids={'A': tensor([5, 7]), 'B': tensor([10, 11])},
SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([1, 0]),
)},
original_row_node_ids={'A': tensor([5, 7]), 'B': tensor([10, 11])},
original_edge_ids={'A:r:B': tensor([10, 12])},
original_column_node_ids={'B': tensor([10, 11])},
node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([1, 0]),
)},
)],
positive_node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([3, 4, 5]),
......@@ -446,7 +446,7 @@ def test_get_dgl_blocks_homo():
for i in range(2):
subgraphs.append(
gb.FusedSampledSubgraphImpl(
node_pairs=node_pairs[i],
sampled_csc=node_pairs[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i],
......@@ -521,7 +521,7 @@ def test_get_dgl_blocks_hetero():
for i in range(2):
subgraphs.append(
gb.FusedSampledSubgraphImpl(
node_pairs=node_pairs[i],
sampled_csc=node_pairs[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i],
......@@ -613,7 +613,7 @@ def test_minibatch_node_pairs_with_labels(mode):
def check_dgl_blocks_hetero(minibatch, blocks):
etype = gb.etype_str_to_tuple(relation)
node_pairs = [
subgraph.node_pairs for subgraph in minibatch.sampled_subgraphs
subgraph.sampled_csc for subgraph in minibatch.sampled_subgraphs
]
original_edge_ids = [
subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs
......@@ -643,7 +643,7 @@ def check_dgl_blocks_hetero(minibatch, blocks):
def check_dgl_blocks_homo(minibatch, blocks):
node_pairs = [
subgraph.node_pairs for subgraph in minibatch.sampled_subgraphs
subgraph.sampled_csc for subgraph in minibatch.sampled_subgraphs
]
original_edge_ids = [
subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs
......@@ -862,7 +862,7 @@ def create_homo_minibatch_csc_format():
for i in range(2):
subgraphs.append(
gb.SampledSubgraphImpl(
node_pairs=csc_formats[i],
sampled_csc=csc_formats[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i],
......@@ -877,7 +877,7 @@ def create_homo_minibatch_csc_format():
def create_hetero_minibatch_csc_format():
node_pairs = [
sampled_csc = [
{
relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2, 3]),
......@@ -926,7 +926,7 @@ def create_hetero_minibatch_csc_format():
for i in range(2):
subgraphs.append(
gb.SampledSubgraphImpl(
node_pairs=node_pairs[i],
sampled_csc=sampled_csc[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i],
......@@ -945,8 +945,8 @@ def create_hetero_minibatch_csc_format():
def check_dgl_blocks_hetero_csc_format(minibatch, blocks):
etype = gb.etype_str_to_tuple(relation)
node_pairs = [
subgraph.node_pairs for subgraph in minibatch.sampled_subgraphs
sampled_csc = [
subgraph.sampled_csc for subgraph in minibatch.sampled_subgraphs
]
original_edge_ids = [
subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs
......@@ -959,24 +959,24 @@ def check_dgl_blocks_hetero_csc_format(minibatch, blocks):
for i, block in enumerate(blocks):
edges = block.edges(etype=etype)
dst_ndoes = torch.arange(
0, len(node_pairs[i][relation].indptr) - 1
0, len(sampled_csc[i][relation].indptr) - 1
).repeat_interleave(
node_pairs[i][relation].indptr[1:]
- node_pairs[i][relation].indptr[:-1]
sampled_csc[i][relation].indptr[1:]
- sampled_csc[i][relation].indptr[:-1]
)
assert torch.equal(edges[0], node_pairs[i][relation].indices)
assert torch.equal(edges[0], sampled_csc[i][relation].indices)
assert torch.equal(edges[1], dst_ndoes)
assert torch.equal(
block.edges[etype].data[dgl.EID], original_edge_ids[i][relation]
)
edges = blocks[0].edges(etype=gb.etype_str_to_tuple(reverse_relation))
dst_ndoes = torch.arange(
0, len(node_pairs[0][reverse_relation].indptr) - 1
0, len(sampled_csc[0][reverse_relation].indptr) - 1
).repeat_interleave(
node_pairs[0][reverse_relation].indptr[1:]
- node_pairs[0][reverse_relation].indptr[:-1]
sampled_csc[0][reverse_relation].indptr[1:]
- sampled_csc[0][reverse_relation].indptr[:-1]
)
assert torch.equal(edges[0], node_pairs[0][reverse_relation].indices)
assert torch.equal(edges[0], sampled_csc[0][reverse_relation].indices)
assert torch.equal(edges[1], dst_ndoes)
assert torch.equal(
blocks[0].srcdata[dgl.NID]["A"], original_row_node_ids[0]["A"]
......@@ -987,8 +987,8 @@ def check_dgl_blocks_hetero_csc_format(minibatch, blocks):
def check_dgl_blocks_homo_csc_format(minibatch, blocks):
node_pairs = [
subgraph.node_pairs for subgraph in minibatch.sampled_subgraphs
sampled_csc = [
subgraph.sampled_csc for subgraph in minibatch.sampled_subgraphs
]
original_edge_ids = [
subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs
......@@ -999,11 +999,11 @@ def check_dgl_blocks_homo_csc_format(minibatch, blocks):
]
for i, block in enumerate(blocks):
dst_ndoes = torch.arange(
0, len(node_pairs[i].indptr) - 1
0, len(sampled_csc[i].indptr) - 1
).repeat_interleave(
node_pairs[i].indptr[1:] - node_pairs[i].indptr[:-1]
sampled_csc[i].indptr[1:] - sampled_csc[i].indptr[:-1]
)
assert torch.equal(block.edges()[0], node_pairs[i].indices), print(
assert torch.equal(block.edges()[0], sampled_csc[i].indices), print(
block.edges()
)
assert torch.equal(block.edges()[1], dst_ndoes), print(block.edges())
......
......@@ -1145,7 +1145,7 @@ def test_OnDiskDataset_preprocess_homogeneous():
torch.tensor([fanout]),
output_cscformat=False,
)
assert len(subgraph.node_pairs[0]) <= num_samples
assert len(subgraph.sampled_csc[0]) <= num_samples
with tempfile.TemporaryDirectory() as test_dir:
# All metadata fields are specified.
......
......@@ -71,7 +71,7 @@ def test_exclude_edges_homo_node_pairs(reverse_row, reverse_column):
expected_column_node_ids = None
expected_edge_ids = torch.Tensor([5, 10])
_assert_container_equal(result.node_pairs, expected_node_pairs)
_assert_container_equal(result.sampled_csc, expected_node_pairs)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
......@@ -106,7 +106,7 @@ def test_exclude_edges_hetero_node_pairs(reverse_row, reverse_column):
dst_to_exclude = torch.tensor([0, 2])
original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
subgraph = FusedSampledSubgraphImpl(
node_pairs=node_pairs,
sampled_csc=node_pairs,
original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids,
original_edge_ids=original_edge_ids,
......@@ -139,7 +139,7 @@ def test_exclude_edges_hetero_node_pairs(reverse_row, reverse_column):
expected_column_node_ids = None
expected_edge_ids = {"A:relation:B": torch.tensor([20])}
_assert_container_equal(result.node_pairs, expected_node_pairs)
_assert_container_equal(result.sampled_csc, expected_node_pairs)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
......@@ -188,7 +188,7 @@ def test_exclude_edges_homo_deduplicated(reverse_row, reverse_column):
expected_column_node_ids = None
expected_edge_ids = torch.Tensor([5, 9])
_assert_container_equal(result.node_pairs, expected_csc_formats)
_assert_container_equal(result.sampled_csc, expected_csc_formats)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
......@@ -237,7 +237,7 @@ def test_exclude_edges_homo_duplicated(reverse_row, reverse_column):
else:
expected_column_node_ids = None
expected_edge_ids = torch.Tensor([5, 10, 10])
_assert_container_equal(result.node_pairs, expected_csc_formats)
_assert_container_equal(result.sampled_csc, expected_csc_formats)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
......@@ -272,7 +272,7 @@ def test_exclude_edges_hetero_deduplicated(reverse_row, reverse_column):
dst_to_exclude = torch.tensor([0, 2])
original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
subgraph = SampledSubgraphImpl(
node_pairs=csc_formats,
sampled_csc=csc_formats,
original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids,
original_edge_ids=original_edge_ids,
......@@ -305,7 +305,7 @@ def test_exclude_edges_hetero_deduplicated(reverse_row, reverse_column):
expected_column_node_ids = None
expected_edge_ids = {"A:relation:B": torch.tensor([20])}
_assert_container_equal(result.node_pairs, expected_csc_formats)
_assert_container_equal(result.sampled_csc, expected_csc_formats)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
......@@ -340,7 +340,7 @@ def test_exclude_edges_hetero_duplicated(reverse_row, reverse_column):
dst_to_exclude = torch.tensor([0, 2])
original_edge_ids = {"A:relation:B": torch.tensor([19, 19, 20, 20, 21])}
subgraph = SampledSubgraphImpl(
node_pairs=csc_formats,
sampled_csc=csc_formats,
original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids,
original_edge_ids=original_edge_ids,
......@@ -373,7 +373,7 @@ def test_exclude_edges_hetero_duplicated(reverse_row, reverse_column):
expected_column_node_ids = None
expected_edge_ids = {"A:relation:B": torch.tensor([20, 20])}
_assert_container_equal(result.node_pairs, expected_csc_formats)
_assert_container_equal(result.sampled_csc, expected_csc_formats)
_assert_container_equal(
result.original_column_node_ids, expected_column_node_ids
)
......@@ -403,7 +403,7 @@ def test_sampled_subgraph_to_device():
dst_to_exclude = torch.tensor([10, 12])
original_edge_ids = {"A:relation:B": torch.tensor([19, 20, 21])}
subgraph = FusedSampledSubgraphImpl(
node_pairs=node_pairs,
sampled_csc=node_pairs,
original_column_node_ids=original_column_node_ids,
original_row_node_ids=original_row_node_ids,
original_edge_ids=original_edge_ids,
......@@ -420,9 +420,9 @@ def test_sampled_subgraph_to_device():
graph = graph.to("cuda")
# Check.
for key in graph.node_pairs:
assert graph.node_pairs[key][0].device.type == "cuda"
assert graph.node_pairs[key][1].device.type == "cuda"
for key in graph.sampled_csc:
assert graph.sampled_csc[key][0].device.type == "cuda"
assert graph.sampled_csc[key][1].device.type == "cuda"
for key in graph.original_column_node_ids:
assert graph.original_column_node_ids[key].device.type == "cuda"
for key in graph.original_row_node_ids:
......@@ -433,7 +433,7 @@ def test_sampled_subgraph_to_device():
def test_sampled_subgraph_impl_representation_homo():
sampled_subgraph_impl = SampledSubgraphImpl(
node_pairs=gb.CSCFormatBase(
sampled_csc=gb.CSCFormatBase(
indptr=torch.arange(0, 101, 10),
indices=torch.arange(10, 110),
),
......@@ -442,7 +442,17 @@ def test_sampled_subgraph_impl_representation_homo():
original_edge_ids=None,
)
expected_result = str(
"""SampledSubgraphImpl(original_row_node_ids=tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
"""SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]),
indices=tensor([ 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37,
38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65,
66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107,
108, 109]),
),
original_row_node_ids=tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
......@@ -452,16 +462,6 @@ def test_sampled_subgraph_impl_representation_homo():
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109]),
original_edge_ids=None,
original_column_node_ids=tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
node_pairs=CSCFormatBase(indptr=tensor([ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]),
indices=tensor([ 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37,
38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65,
66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107,
108, 109]),
),
)"""
)
assert str(sampled_subgraph_impl) == expected_result, print(
......@@ -471,7 +471,7 @@ def test_sampled_subgraph_impl_representation_homo():
def test_sampled_subgraph_impl_representation_hetero():
sampled_subgraph_impl = SampledSubgraphImpl(
node_pairs={
sampled_csc={
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4]),
indices=torch.tensor([4, 5, 6, 7]),
......@@ -492,14 +492,16 @@ def test_sampled_subgraph_impl_representation_hetero():
original_edge_ids=None,
)
expected_result = str(
"""SampledSubgraphImpl(original_row_node_ids={'n1': tensor([1, 0, 0, 1, 1, 0, 0, 1]), 'n2': tensor([1, 2, 0, 1, 0, 2, 0, 2, 0, 1])},
"""SampledSubgraphImpl(sampled_csc={'n1:e1:n2': CSCFormatBase(indptr=tensor([0, 2, 4]),
indices=tensor([4, 5, 6, 7]),
), 'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 2, 4, 6, 8]),
indices=tensor([2, 3, 4, 5, 6, 7, 8, 9]),
)},
original_row_node_ids={'n1': tensor([1, 0, 0, 1, 1, 0, 0, 1]), 'n2': tensor([1, 2, 0, 1, 0, 2, 0, 2, 0, 1])},
original_edge_ids=None,
original_column_node_ids={'n1': tensor([1, 0, 0, 1]), 'n2': tensor([1, 2])},
node_pairs={'n1:e1:n2': CSCFormatBase(indptr=tensor([0, 2, 4]),
indices=tensor([4, 5, 6, 7]),
), 'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 2, 4, 6, 8]),
indices=tensor([2, 3, 4, 5, 6, 7, 8, 9]),
)},
)"""
)
assert str(sampled_subgraph_impl) == expected_result, print(expected_result)
assert str(sampled_subgraph_impl) == expected_result, print(
sampled_subgraph_impl
)
......@@ -83,7 +83,7 @@ def test_FeatureFetcher_with_edges_homo():
range_tensor = torch.arange(10)
subgraphs.append(
gb.FusedSampledSubgraphImpl(
node_pairs=(range_tensor, range_tensor),
sampled_csc=(range_tensor, range_tensor),
original_column_node_ids=range_tensor,
original_row_node_ids=range_tensor,
original_edge_ids=torch.randint(
......@@ -184,7 +184,7 @@ def test_FeatureFetcher_with_edges_hetero():
for _ in range(3):
subgraphs.append(
gb.FusedSampledSubgraphImpl(
node_pairs={
sampled_csc={
"n1:e1:n2": (
torch.arange(10),
torch.arange(10),
......
......@@ -61,19 +61,19 @@ def test_integration_link_prediction():
expected = [
str(
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2]),
indices=tensor([5, 4]),
),
original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2]),
indices=tensor([5, 4]),
),
),
SampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1]),
indices=tensor([5]),
),
original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2, 0]),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1]),
indices=tensor([5]),
),
)],
positive_node_pairs=(tensor([0, 1, 1, 1]),
tensor([2, 3, 3, 1])),
......@@ -117,19 +117,19 @@ def test_integration_link_prediction():
),
str(
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0, 5, 1]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2]),
indices=tensor([1, 3]),
),
original_row_node_ids=tensor([3, 4, 0, 5, 1]),
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0, 5, 1]),
node_pairs=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2]),
indices=tensor([1, 3]),
),
),
SampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0, 5, 1]),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2]),
indices=tensor([1, 3]),
),
original_row_node_ids=tensor([3, 4, 0, 5, 1]),
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0, 5, 1]),
node_pairs=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2]),
indices=tensor([1, 3]),
),
)],
positive_node_pairs=(tensor([0, 1, 1, 2]),
tensor([0, 0, 1, 1])),
......@@ -172,19 +172,19 @@ def test_integration_link_prediction():
),
str(
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([5, 4]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1]),
indices=tensor([1]),
),
original_row_node_ids=tensor([5, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 4]),
node_pairs=CSCFormatBase(indptr=tensor([0, 0, 1]),
indices=tensor([1]),
),
),
SampledSubgraphImpl(original_row_node_ids=tensor([5, 4]),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1]),
indices=tensor([1]),
),
original_row_node_ids=tensor([5, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 4]),
node_pairs=CSCFormatBase(indptr=tensor([0, 0, 1]),
indices=tensor([1]),
),
)],
positive_node_pairs=(tensor([0, 1]),
tensor([0, 0])),
......@@ -274,19 +274,19 @@ def test_integration_node_classification():
expected = [
str(
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 4]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]),
indices=tensor([4, 1, 0, 1]),
),
original_row_node_ids=tensor([5, 3, 1, 2, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2]),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]),
indices=tensor([4, 1, 0, 1]),
),
),
SampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2]),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]),
indices=tensor([0, 1, 0, 1]),
),
original_row_node_ids=tensor([5, 3, 1, 2]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2]),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]),
indices=tensor([0, 1, 0, 1]),
),
)],
positive_node_pairs=(tensor([0, 1, 1, 1]),
tensor([2, 3, 3, 1])),
......@@ -315,19 +315,19 @@ def test_integration_node_classification():
),
str(
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2]),
indices=tensor([0, 2]),
),
original_row_node_ids=tensor([3, 4, 0]),
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0]),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2, 2]),
indices=tensor([0, 2]),
),
),
SampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0]),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2]),
indices=tensor([0, 2]),
),
original_row_node_ids=tensor([3, 4, 0]),
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0]),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2, 2]),
indices=tensor([0, 2]),
),
)],
positive_node_pairs=(tensor([0, 1, 1, 2]),
tensor([0, 0, 1, 1])),
......@@ -354,19 +354,19 @@ def test_integration_node_classification():
),
str(
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([5, 4, 0]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([0, 2]),
),
original_row_node_ids=tensor([5, 4, 0]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 4]),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([0, 2]),
),
),
SampledSubgraphImpl(original_row_node_ids=tensor([5, 4]),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([1, 1]),
),
original_row_node_ids=tensor([5, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 4]),
node_pairs=CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([1, 1]),
),
)],
positive_node_pairs=(tensor([0, 1]),
tensor([0, 0])),
......
......@@ -246,7 +246,7 @@ def test_SubgraphSampler_Random_Hetero_Graph(labor):
for data in sampler_dp:
for sampledsubgraph in data.sampled_subgraphs:
for _, value in sampledsubgraph.node_pairs.items():
for _, value in sampledsubgraph.sampled_csc.items():
assert torch.equal(
torch.ge(value.indices, torch.zeros(len(value.indices))),
torch.ones(len(value.indices)),
......@@ -297,9 +297,11 @@ def test_SubgraphSampler_without_dedpulication_Homo(labor):
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
assert len(sampled_subgraph.original_row_node_ids) == length[step]
assert torch.equal(
sampled_subgraph.node_pairs.indices, compacted_indices[step]
sampled_subgraph.sampled_csc.indices, compacted_indices[step]
)
assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step]
)
assert torch.equal(sampled_subgraph.node_pairs.indptr, indptr[step])
assert torch.equal(
sampled_subgraph.original_column_node_ids, seeds[step]
)
......@@ -372,11 +374,11 @@ def test_SubgraphSampler_without_dedpulication_Hetero(labor):
)
for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal(
sampled_subgraph.node_pairs[etype].indices,
sampled_subgraph.sampled_csc[etype].indices,
csc_formats[step][etype].indices,
)
assert torch.equal(
sampled_subgraph.node_pairs[etype].indptr,
sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr,
)
......@@ -423,9 +425,11 @@ def test_SubgraphSampler_unique_csc_format_Homo(labor):
original_row_node_ids[step],
)
assert torch.equal(
sampled_subgraph.node_pairs.indices, compacted_indices[step]
sampled_subgraph.sampled_csc.indices, compacted_indices[step]
)
assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step]
)
assert torch.equal(sampled_subgraph.node_pairs.indptr, indptr[step])
assert torch.equal(
sampled_subgraph.original_column_node_ids, seeds[step]
)
......@@ -504,10 +508,10 @@ def test_SubgraphSampler_unique_csc_format_Hetero(labor):
)
for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal(
sampled_subgraph.node_pairs[etype].indices,
sampled_subgraph.sampled_csc[etype].indices,
csc_formats[step][etype].indices,
)
assert torch.equal(
sampled_subgraph.node_pairs[etype].indptr,
sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment