Unverified Commit 697a9977 authored by Xinyu Yao's avatar Xinyu Yao Committed by GitHub
Browse files

[GraphBolt] Remove unused attributes in minibatch. (#7337)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent c997434c
...@@ -24,27 +24,6 @@ class MiniBatch: ...@@ -24,27 +24,6 @@ class MiniBatch:
representation of input and output data across different stages, ensuring representation of input and output data across different stages, ensuring
consistency and ease of use throughout the loading process.""" consistency and ease of use throughout the loading process."""
seed_nodes: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of seed nodes used for sampling in the graph.
- If `seed_nodes` is a tensor: It indicates the graph is homogeneous.
- If `seed_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node ids.
"""
node_pairs: Union[
Tuple[torch.Tensor, torch.Tensor],
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
] = None
"""
Representation of seed node pairs utilized in link prediction tasks.
- If `node_pairs` is a tuple: It indicates a homogeneous graph where each
tuple contains two tensors representing source-destination node pairs.
- If `node_pairs` is a dictionary: The keys should be edge type, and the
value should be a tuple of tensors representing node pairs of the given
type.
"""
labels: Union[torch.Tensor, Dict[str, torch.Tensor]] = None labels: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
""" """
Labels associated with seed nodes / node pairs in the graph. Labels associated with seed nodes / node pairs in the graph.
...@@ -93,26 +72,6 @@ class MiniBatch: ...@@ -93,26 +72,6 @@ class MiniBatch:
key, indexes are consecutive integers starting from zero. key, indexes are consecutive integers starting from zero.
""" """
negative_srcs: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of negative samples for the head nodes in the link
prediction task.
- If `negative_srcs` is a tensor: It indicates a homogeneous graph.
- If `negative_srcs` is a dictionary: The key should be edge type, and the
value should correspond to the negative samples for head nodes of the
given type.
"""
negative_dsts: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of negative samples for the tail nodes in the link
prediction task.
- If `negative_dsts` is a tensor: It indicates a homogeneous graph.
- If `negative_dsts` is a dictionary: The key should be edge type, and the
value should correspond to the negative samples for head nodes of the
given type.
"""
sampled_subgraphs: List[SampledSubgraph] = None sampled_subgraphs: List[SampledSubgraph] = None
"""A list of 'SampledSubgraph's, each one corresponding to one layer, """A list of 'SampledSubgraph's, each one corresponding to one layer,
representing a subset of a larger graph structure. representing a subset of a larger graph structure.
...@@ -147,15 +106,6 @@ class MiniBatch: ...@@ -147,15 +106,6 @@ class MiniBatch:
string of format 'str:str:str'. string of format 'str:str:str'.
""" """
compacted_node_pairs: Union[
Tuple[torch.Tensor, torch.Tensor],
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
] = None
"""
Representation of compacted node pairs corresponding to 'node_pairs', where
all node ids inside are compacted.
"""
compacted_seeds: Union[ compacted_seeds: Union[
torch.Tensor, torch.Tensor,
Dict[str, torch.Tensor], Dict[str, torch.Tensor],
...@@ -165,18 +115,6 @@ class MiniBatch: ...@@ -165,18 +115,6 @@ class MiniBatch:
all node ids inside are compacted. all node ids inside are compacted.
""" """
compacted_negative_srcs: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of compacted nodes corresponding to 'negative_srcs', where
all node ids inside are compacted.
"""
compacted_negative_dsts: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of compacted nodes corresponding to 'negative_dsts', where
all node ids inside are compacted.
"""
def __repr__(self) -> str: def __repr__(self) -> str:
return _minibatch_str(self) return _minibatch_str(self)
...@@ -333,163 +271,6 @@ class MiniBatch: ...@@ -333,163 +271,6 @@ class MiniBatch:
block.edata[dgl.EID] = subgraph.original_edge_ids block.edata[dgl.EID] = subgraph.original_edge_ids
return blocks return blocks
@property
def positive_node_pairs(self):
"""`positive_node_pairs` is a representation of positive graphs used for
evaluating or computing loss in link prediction tasks.
- If `positive_node_pairs` is a tuple: It indicates a homogeneous graph
containing two tensors representing source-destination node pairs.
- If `positive_node_pairs` is a dictionary: The keys should be edge type,
and the value should be a tuple of tensors representing node pairs of the
given type.
"""
return self.compacted_node_pairs
@property
def negative_node_pairs(self):
"""`negative_node_pairs` is a representation of negative graphs used for
evaluating or computing loss in link prediction tasks.
- If `negative_node_pairs` is a tuple: It indicates a homogeneous graph
containing two tensors representing source-destination node pairs.
- If `negative_node_pairs` is a dictionary: The keys should be edge type,
and the value should be a tuple of tensors representing node pairs of the
given type.
"""
# Build negative graph.
if (
self.compacted_negative_srcs is not None
and self.compacted_negative_dsts is not None
):
# For homogeneous graph.
if isinstance(self.compacted_negative_srcs, torch.Tensor):
negative_node_pairs = (
self.compacted_negative_srcs,
self.compacted_negative_dsts,
)
# For heterogeneous graph.
else:
negative_node_pairs = {
etype: (
neg_src,
self.compacted_negative_dsts[etype],
)
for etype, neg_src in self.compacted_negative_srcs.items()
}
elif (
self.compacted_negative_srcs is not None
and self.compacted_node_pairs is not None
):
# For homogeneous graph.
if isinstance(self.compacted_negative_srcs, torch.Tensor):
negative_ratio = self.compacted_negative_srcs.size(1)
negative_node_pairs = (
self.compacted_negative_srcs,
self.compacted_node_pairs[1]
.repeat_interleave(negative_ratio)
.view(-1, negative_ratio),
)
# For heterogeneous graph.
else:
negative_ratio = list(self.compacted_negative_srcs.values())[
0
].size(1)
negative_node_pairs = {
etype: (
neg_src,
self.compacted_node_pairs[etype][1]
.repeat_interleave(negative_ratio)
.view(-1, negative_ratio),
)
for etype, neg_src in self.compacted_negative_srcs.items()
}
elif (
self.compacted_negative_dsts is not None
and self.compacted_node_pairs is not None
):
# For homogeneous graph.
if isinstance(self.compacted_negative_dsts, torch.Tensor):
negative_ratio = self.compacted_negative_dsts.size(1)
negative_node_pairs = (
self.compacted_node_pairs[0]
.repeat_interleave(negative_ratio)
.view(-1, negative_ratio),
self.compacted_negative_dsts,
)
# For heterogeneous graph.
else:
negative_ratio = list(self.compacted_negative_dsts.values())[
0
].size(1)
negative_node_pairs = {
etype: (
self.compacted_node_pairs[etype][0]
.repeat_interleave(negative_ratio)
.view(-1, negative_ratio),
neg_dst,
)
for etype, neg_dst in self.compacted_negative_dsts.items()
}
else:
negative_node_pairs = None
return negative_node_pairs
@property
def node_pairs_with_labels(self):
"""Get a node pair tensor and a label tensor from MiniBatch. They are
used for evaluating or computing loss. For homogeneous graph, it will
return `(node_pairs, labels)` as result; for heterogeneous graph, the
`node_pairs` and `labels` will both be a dict with etype as the key.
- If it's a link prediction task, `node_pairs` will contain both
negative and positive node pairs and `labels` will consist of 0 and 1,
indicating whether the corresponding node pair is negative or positive.
- If it's an edge classification task, this function will directly
return `compacted_node_pairs` for each etype and the corresponding
`labels`.
- Otherwise it will return None.
"""
if self.labels is None:
# Link prediction.
positive_node_pairs = self.positive_node_pairs
negative_node_pairs = self.negative_node_pairs
if positive_node_pairs is None or negative_node_pairs is None:
return None
if isinstance(positive_node_pairs, Dict):
# Heterogeneous graph.
node_pairs_by_etype = {}
labels_by_etype = {}
for etype in positive_node_pairs:
pos_src, pos_dst = positive_node_pairs[etype]
neg_src, neg_dst = negative_node_pairs[etype]
neg_src, neg_dst = neg_src.view(-1), neg_dst.view(-1)
node_pairs_by_etype[etype] = (
torch.cat((pos_src, neg_src), dim=0),
torch.cat((pos_dst, neg_dst), dim=0),
)
pos_label = torch.ones_like(pos_src)
neg_label = torch.zeros_like(neg_src)
labels_by_etype[etype] = torch.cat(
[pos_label, neg_label], dim=0
)
return (node_pairs_by_etype, labels_by_etype)
else:
# Homogeneous graph.
pos_src, pos_dst = positive_node_pairs
neg_src, neg_dst = negative_node_pairs
neg_src, neg_dst = neg_src.view(-1), neg_dst.view(-1)
node_pairs = (
torch.cat((pos_src, neg_src), dim=0),
torch.cat((pos_dst, neg_dst), dim=0),
)
pos_label = torch.ones_like(pos_src)
neg_label = torch.zeros_like(neg_src)
labels = torch.cat([pos_label, neg_label], dim=0)
return (node_pairs, labels.float())
elif self.compacted_node_pairs is not None:
# Edge classification.
return (self.compacted_node_pairs, self.labels)
else:
return None
def to_pyg_data(self): def to_pyg_data(self):
"""Construct a PyG Data from `MiniBatch`. This function only supports """Construct a PyG Data from `MiniBatch`. This function only supports
node classification task on a homogeneous graph and the number of node classification task on a homogeneous graph and the number of
...@@ -527,17 +308,7 @@ class MiniBatch: ...@@ -527,17 +308,7 @@ class MiniBatch:
), "`to_pyg_data` only supports single feature homogeneous graph." ), "`to_pyg_data` only supports single feature homogeneous graph."
node_features = next(iter(self.node_features.values())) node_features = next(iter(self.node_features.values()))
if self.seed_nodes is not None: if self.seeds is not None:
if isinstance(self.seed_nodes, Dict):
batch_size = len(next(iter(self.seed_nodes.values())))
else:
batch_size = len(self.seed_nodes)
elif self.node_pairs is not None:
if isinstance(self.node_pairs, Dict):
batch_size = len(next(iter(self.node_pairs.values()))[0])
else:
batch_size = len(self.node_pairs[0])
elif self.seeds is not None:
if isinstance(self.seeds, Dict): if isinstance(self.seeds, Dict):
batch_size = len(next(iter(self.seeds.values()))) batch_size = len(next(iter(self.seeds.values())))
else: else:
......
...@@ -11,7 +11,7 @@ def test_integration_link_prediction(): ...@@ -11,7 +11,7 @@ def test_integration_link_prediction():
indices = torch.tensor([5, 3, 3, 3, 3, 4, 4, 0, 5, 4]) indices = torch.tensor([5, 3, 3, 3, 3, 4, 4, 0, 5, 4])
matrix_a = dglsp.from_csc(indptr, indices) matrix_a = dglsp.from_csc(indptr, indices)
node_pairs = torch.t(torch.stack(matrix_a.coo())) seeds = torch.t(torch.stack(matrix_a.coo()))
node_feature_data = torch.tensor( node_feature_data = torch.tensor(
[ [
[0.9634, 0.2294], [0.9634, 0.2294],
...@@ -37,7 +37,7 @@ def test_integration_link_prediction(): ...@@ -37,7 +37,7 @@ def test_integration_link_prediction():
] ]
) )
item_set = gb.ItemSet(node_pairs, names="node_pairs") item_set = gb.ItemSet(seeds, names="seeds")
graph = gb.fused_csc_sampling_graph(indptr, indices) graph = gb.fused_csc_sampling_graph(indptr, indices)
node_feature = gb.TorchBasedFeature(node_feature_data) node_feature = gb.TorchBasedFeature(node_feature_data)
...@@ -72,7 +72,6 @@ def test_integration_link_prediction(): ...@@ -72,7 +72,6 @@ def test_integration_link_prediction():
[3, 3], [3, 3],
[3, 3], [3, 3],
[3, 4]]), [3, 4]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 2, 2, 2, 3], dtype=torch.int32), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 2, 2, 2, 3], dtype=torch.int32),
indices=tensor([0, 5, 4], dtype=torch.int32), indices=tensor([0, 5, 4], dtype=torch.int32),
), ),
...@@ -87,18 +86,12 @@ def test_integration_link_prediction(): ...@@ -87,18 +86,12 @@ def test_integration_link_prediction():
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 1, 3, 2, 0, 4]), original_column_node_ids=tensor([5, 1, 3, 2, 0, 4]),
)], )],
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=None,
node_features={'feat': tensor([[0.5160, 0.2486], node_features={'feat': tensor([[0.5160, 0.2486],
[0.6172, 0.7865], [0.6172, 0.7865],
[0.8672, 0.2276], [0.8672, 0.2276],
[0.2109, 0.1089], [0.2109, 0.1089],
[0.9634, 0.2294], [0.9634, 0.2294],
[0.5503, 0.8223]])}, [0.5503, 0.8223]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.]), labels=tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),
input_nodes=tensor([5, 1, 3, 2, 0, 4]), input_nodes=tensor([5, 1, 3, 2, 0, 4]),
indexes=tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3]), indexes=tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3]),
...@@ -116,9 +109,6 @@ def test_integration_link_prediction(): ...@@ -116,9 +109,6 @@ def test_integration_link_prediction():
[2, 2], [2, 2],
[2, 2], [2, 2],
[2, 5]]), [2, 5]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3), blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3),
Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2)], Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2)],
)""" )"""
...@@ -136,7 +126,6 @@ def test_integration_link_prediction(): ...@@ -136,7 +126,6 @@ def test_integration_link_prediction():
[4, 3], [4, 3],
[0, 1], [0, 1],
[0, 5]]), [0, 5]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 1, 2], dtype=torch.int32), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 1, 2], dtype=torch.int32),
indices=tensor([4, 0], dtype=torch.int32), indices=tensor([4, 0], dtype=torch.int32),
), ),
...@@ -151,18 +140,12 @@ def test_integration_link_prediction(): ...@@ -151,18 +140,12 @@ def test_integration_link_prediction():
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]), original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]),
)], )],
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=None,
node_features={'feat': tensor([[0.8672, 0.2276], node_features={'feat': tensor([[0.8672, 0.2276],
[0.5503, 0.8223], [0.5503, 0.8223],
[0.9634, 0.2294], [0.9634, 0.2294],
[0.6172, 0.7865], [0.6172, 0.7865],
[0.5160, 0.2486], [0.5160, 0.2486],
[0.2109, 0.1089]])}, [0.2109, 0.1089]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.]), labels=tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),
input_nodes=tensor([3, 4, 0, 1, 5, 2]), input_nodes=tensor([3, 4, 0, 1, 5, 2]),
indexes=tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3]), indexes=tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3]),
...@@ -180,9 +163,6 @@ def test_integration_link_prediction(): ...@@ -180,9 +163,6 @@ def test_integration_link_prediction():
[1, 0], [1, 0],
[2, 3], [2, 3],
[2, 4]]), [2, 4]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2), blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2),
Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3)], Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3)],
)""" )"""
...@@ -194,7 +174,6 @@ def test_integration_link_prediction(): ...@@ -194,7 +174,6 @@ def test_integration_link_prediction():
[5, 4], [5, 4],
[4, 0], [4, 0],
[4, 1]]), [4, 1]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2], dtype=torch.int32), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2], dtype=torch.int32),
indices=tensor([1, 0], dtype=torch.int32), indices=tensor([1, 0], dtype=torch.int32),
), ),
...@@ -209,16 +188,10 @@ def test_integration_link_prediction(): ...@@ -209,16 +188,10 @@ def test_integration_link_prediction():
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 4, 0, 1]), original_column_node_ids=tensor([5, 4, 0, 1]),
)], )],
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=None,
node_features={'feat': tensor([[0.5160, 0.2486], node_features={'feat': tensor([[0.5160, 0.2486],
[0.5503, 0.8223], [0.5503, 0.8223],
[0.9634, 0.2294], [0.9634, 0.2294],
[0.6172, 0.7865]])}, [0.6172, 0.7865]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=tensor([1., 1., 0., 0., 0., 0.]), labels=tensor([1., 1., 0., 0., 0., 0.]),
input_nodes=tensor([5, 4, 0, 1]), input_nodes=tensor([5, 4, 0, 1]),
indexes=tensor([0, 1, 0, 0, 1, 1]), indexes=tensor([0, 1, 0, 0, 1, 1]),
...@@ -230,9 +203,6 @@ def test_integration_link_prediction(): ...@@ -230,9 +203,6 @@ def test_integration_link_prediction():
[0, 1], [0, 1],
[1, 2], [1, 2],
[1, 3]]), [1, 3]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2), blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2),
Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2)], Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2)],
)""" )"""
...@@ -248,8 +218,7 @@ def test_integration_node_classification(): ...@@ -248,8 +218,7 @@ def test_integration_node_classification():
indptr = torch.tensor([0, 0, 1, 3, 6, 8, 10]) indptr = torch.tensor([0, 0, 1, 3, 6, 8, 10])
indices = torch.tensor([5, 3, 3, 3, 3, 4, 4, 0, 5, 4]) indices = torch.tensor([5, 3, 3, 3, 3, 4, 4, 0, 5, 4])
matrix_a = dglsp.from_csc(indptr, indices) seeds = torch.tensor([5, 1, 2, 4, 3, 0])
node_pairs = torch.t(torch.stack(matrix_a.coo()))
node_feature_data = torch.tensor( node_feature_data = torch.tensor(
[ [
[0.9634, 0.2294], [0.9634, 0.2294],
...@@ -275,7 +244,7 @@ def test_integration_node_classification(): ...@@ -275,7 +244,7 @@ def test_integration_node_classification():
] ]
) )
item_set = gb.ItemSet(node_pairs, names="node_pairs") item_set = gb.ItemSet(seeds, names="seeds")
graph = gb.fused_csc_sampling_graph(indptr, indices) graph = gb.fused_csc_sampling_graph(indptr, indices)
node_feature = gb.TorchBasedFeature(node_feature_data) node_feature = gb.TorchBasedFeature(node_feature_data)
...@@ -285,7 +254,7 @@ def test_integration_node_classification(): ...@@ -285,7 +254,7 @@ def test_integration_node_classification():
("edge", None, "feat"): edge_feature, ("edge", None, "feat"): edge_feature,
} }
feature_store = gb.BasicFeatureStore(features) feature_store = gb.BasicFeatureStore(features)
datapipe = gb.ItemSampler(item_set, batch_size=4) datapipe = gb.ItemSampler(item_set, batch_size=2)
fanouts = torch.LongTensor([1]) fanouts = torch.LongTensor([1])
datapipe = datapipe.sample_neighbor(graph, [fanouts, fanouts], replace=True) datapipe = datapipe.sample_neighbor(graph, [fanouts, fanouts], replace=True)
datapipe = datapipe.fetch_feature( datapipe = datapipe.fetch_feature(
...@@ -296,138 +265,92 @@ def test_integration_node_classification(): ...@@ -296,138 +265,92 @@ def test_integration_node_classification():
) )
expected = [ expected = [
str( str(
"""MiniBatch(seeds=tensor([[5, 1], """MiniBatch(seeds=tensor([5, 1]),
[3, 2], sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
[3, 2], indices=tensor([2, 0], dtype=torch.int32),
[3, 3]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4], dtype=torch.int32),
indices=tensor([4, 0, 2, 2], dtype=torch.int32),
), ),
original_row_node_ids=tensor([5, 1, 3, 2, 4]), original_row_node_ids=tensor([5, 1, 4]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 1, 3, 2]), original_column_node_ids=tensor([5, 1]),
), ),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4], dtype=torch.int32), SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([0, 0, 2, 2], dtype=torch.int32), indices=tensor([0, 0], dtype=torch.int32),
), ),
original_row_node_ids=tensor([5, 1, 3, 2]), original_row_node_ids=tensor([5, 1]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 1, 3, 2]), original_column_node_ids=tensor([5, 1]),
)], )],
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=None,
node_features={'feat': tensor([[0.5160, 0.2486], node_features={'feat': tensor([[0.5160, 0.2486],
[0.6172, 0.7865], [0.6172, 0.7865],
[0.8672, 0.2276],
[0.2109, 0.1089],
[0.5503, 0.8223]])}, [0.5503, 0.8223]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=None, labels=None,
input_nodes=tensor([5, 1, 3, 2, 4]), input_nodes=tensor([5, 1, 4]),
indexes=None, indexes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
compacted_seeds=tensor([[0, 1], compacted_seeds=None,
[2, 3], blocks=[Block(num_src_nodes=3, num_dst_nodes=2, num_edges=2),
[2, 3], Block(num_src_nodes=2, num_dst_nodes=2, num_edges=2)],
[2, 2]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=5, num_dst_nodes=4, num_edges=4),
Block(num_src_nodes=4, num_dst_nodes=4, num_edges=4)],
)""" )"""
), ),
str( str(
"""MiniBatch(seeds=tensor([[3, 3], """MiniBatch(seeds=tensor([2, 4]),
[4, 3], sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 3], dtype=torch.int32),
[4, 4], indices=tensor([2, 1, 2], dtype=torch.int32),
[0, 4]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2], dtype=torch.int32),
indices=tensor([0, 2], dtype=torch.int32),
), ),
original_row_node_ids=tensor([3, 4, 0]), original_row_node_ids=tensor([2, 4, 3, 0]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0]), original_column_node_ids=tensor([2, 4, 3, 0]),
), ),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2], dtype=torch.int32), SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([0, 2], dtype=torch.int32), indices=tensor([2, 3], dtype=torch.int32),
), ),
original_row_node_ids=tensor([3, 4, 0]), original_row_node_ids=tensor([2, 4, 3, 0]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0]), original_column_node_ids=tensor([2, 4]),
)], )],
positive_node_pairs=None, node_features={'feat': tensor([[0.2109, 0.1089],
node_pairs_with_labels=None,
node_pairs=None,
node_features={'feat': tensor([[0.8672, 0.2276],
[0.5503, 0.8223], [0.5503, 0.8223],
[0.8672, 0.2276],
[0.9634, 0.2294]])}, [0.9634, 0.2294]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=None, labels=None,
input_nodes=tensor([3, 4, 0]), input_nodes=tensor([2, 4, 3, 0]),
indexes=None, indexes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
compacted_seeds=tensor([[0, 0], compacted_seeds=None,
[1, 0], blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=3),
[1, 1], Block(num_src_nodes=4, num_dst_nodes=2, num_edges=2)],
[2, 1]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2),
Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2)],
)""" )"""
), ),
str( str(
"""MiniBatch(seeds=tensor([[5, 5], """MiniBatch(seeds=tensor([3, 0]),
[4, 5]]), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1], dtype=torch.int32),
seed_nodes=None, indices=tensor([0], dtype=torch.int32),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([0, 2], dtype=torch.int32),
), ),
original_row_node_ids=tensor([5, 4, 0]), original_row_node_ids=tensor([3, 0]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 4]), original_column_node_ids=tensor([3, 0]),
), ),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32), SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1], dtype=torch.int32),
indices=tensor([1, 1], dtype=torch.int32), indices=tensor([0], dtype=torch.int32),
), ),
original_row_node_ids=tensor([5, 4]), original_row_node_ids=tensor([3, 0]),
original_edge_ids=None, original_edge_ids=None,
original_column_node_ids=tensor([5, 4]), original_column_node_ids=tensor([3, 0]),
)], )],
positive_node_pairs=None, node_features={'feat': tensor([[0.8672, 0.2276],
node_pairs_with_labels=None,
node_pairs=None,
node_features={'feat': tensor([[0.5160, 0.2486],
[0.5503, 0.8223],
[0.9634, 0.2294]])}, [0.9634, 0.2294]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=None, labels=None,
input_nodes=tensor([5, 4, 0]), input_nodes=tensor([3, 0]),
indexes=None, indexes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
compacted_seeds=tensor([[0, 0], compacted_seeds=None,
[1, 0]]), blocks=[Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1),
compacted_node_pairs=None, Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1)],
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=3, num_dst_nodes=2, num_edges=2),
Block(num_src_nodes=2, num_dst_nodes=2, num_edges=2)],
)""" )"""
), ),
] ]
for step, data in enumerate(dataloader): for step, data in enumerate(dataloader):
assert expected[step] == str(data), print(data) assert expected[step] == str(data), print(step, data)
import os import os
import re import re
import unittest import unittest
from collections import defaultdict
from sys import platform from sys import platform
import backend as F import backend as F
...@@ -47,7 +48,7 @@ def test_ItemSampler_minibatcher(): ...@@ -47,7 +48,7 @@ def test_ItemSampler_minibatcher():
# Default minibatcher is used if not specified. # Default minibatcher is used if not specified.
# `MiniBatch` is returned if expected names are specified. # `MiniBatch` is returned if expected names are specified.
item_set = gb.ItemSet(torch.arange(0, 10), names="seed_nodes") item_set = gb.ItemSet(torch.arange(0, 10), names="seeds")
item_sampler = gb.ItemSampler(item_set, batch_size=4) item_sampler = gb.ItemSampler(item_set, batch_size=4)
minibatch = next(iter(item_sampler)) minibatch = next(iter(item_sampler))
assert isinstance(minibatch, gb.MiniBatch) assert isinstance(minibatch, gb.MiniBatch)
...@@ -78,7 +79,7 @@ def test_ItemSet_Iterable_Only(batch_size, shuffle, drop_last): ...@@ -78,7 +79,7 @@ def test_ItemSet_Iterable_Only(batch_size, shuffle, drop_last):
return iter(torch.arange(0, num_ids)) return iter(torch.arange(0, num_ids))
seed_nodes = gb.ItemSet(InvalidLength()) seed_nodes = gb.ItemSet(InvalidLength())
item_set = gb.ItemSet(seed_nodes, names="seed_nodes") item_set = gb.ItemSet(seed_nodes, names="seeds")
item_sampler = gb.ItemSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
...@@ -106,7 +107,7 @@ def test_ItemSet_Iterable_Only(batch_size, shuffle, drop_last): ...@@ -106,7 +107,7 @@ def test_ItemSet_Iterable_Only(batch_size, shuffle, drop_last):
def test_ItemSet_integer(batch_size, shuffle, drop_last): def test_ItemSet_integer(batch_size, shuffle, drop_last):
# Node IDs. # Node IDs.
num_ids = 103 num_ids = 103
item_set = gb.ItemSet(num_ids, names="seed_nodes") item_set = gb.ItemSet(num_ids, names="seeds")
item_sampler = gb.ItemSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
...@@ -135,7 +136,7 @@ def test_ItemSet_seed_nodes(batch_size, shuffle, drop_last): ...@@ -135,7 +136,7 @@ def test_ItemSet_seed_nodes(batch_size, shuffle, drop_last):
# Node IDs. # Node IDs.
num_ids = 103 num_ids = 103
seed_nodes = torch.arange(0, num_ids) seed_nodes = torch.arange(0, num_ids)
item_set = gb.ItemSet(seed_nodes, names="seed_nodes") item_set = gb.ItemSet(seed_nodes, names="seeds")
item_sampler = gb.ItemSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
...@@ -165,7 +166,7 @@ def test_ItemSet_seed_nodes_labels(batch_size, shuffle, drop_last): ...@@ -165,7 +166,7 @@ def test_ItemSet_seed_nodes_labels(batch_size, shuffle, drop_last):
num_ids = 103 num_ids = 103
seed_nodes = torch.arange(0, num_ids) seed_nodes = torch.arange(0, num_ids)
labels = torch.arange(0, num_ids) labels = torch.arange(0, num_ids)
item_set = gb.ItemSet((seed_nodes, labels), names=("seed_nodes", "labels")) item_set = gb.ItemSet((seed_nodes, labels), names=("seeds", "labels"))
item_sampler = gb.ItemSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
...@@ -249,7 +250,7 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last): ...@@ -249,7 +250,7 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
# Node pairs. # Node pairs.
num_ids = 103 num_ids = 103
node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2) node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)
item_set = gb.ItemSet(node_pairs, names="node_pairs") item_set = gb.ItemSet(node_pairs, names="seeds")
item_sampler = gb.ItemSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
...@@ -289,7 +290,7 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last): ...@@ -289,7 +290,7 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
num_ids = 103 num_ids = 103
node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2) node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)
labels = node_pairs[:, 0] labels = node_pairs[:, 0]
item_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels")) item_set = gb.ItemSet((node_pairs, labels), names=("seeds", "labels"))
item_sampler = gb.ItemSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
) )
...@@ -333,16 +334,26 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last): ...@@ -333,16 +334,26 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
@pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False]) @pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last): def test_ItemSet_node_pairs_labels_indexes(batch_size, shuffle, drop_last):
# Node pairs and negative destinations. # Node pairs and negative destinations.
num_ids = 103 num_ids = 103
num_negs = 2 num_negs = 2
node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2) node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)
neg_dsts = torch.arange( neg_srcs = node_pairs[:, 0].repeat_interleave(num_negs)
2 * num_ids, 2 * num_ids + num_ids * num_negs neg_dsts = torch.arange(2 * num_ids, 2 * num_ids + num_ids * num_negs)
).reshape(-1, num_negs) neg_node_pairs = torch.cat((neg_srcs, neg_dsts)).reshape(2, -1).T
labels = torch.empty(num_ids * 3)
labels[:num_ids] = 1
labels[num_ids:] = 0
indexes = torch.cat(
(
torch.arange(0, num_ids),
torch.arange(0, num_ids).repeat_interleave(num_negs),
)
)
node_pairs = torch.cat((node_pairs, neg_node_pairs))
item_set = gb.ItemSet( item_set = gb.ItemSet(
(node_pairs, neg_dsts), names=("node_pairs", "negative_dsts") (node_pairs, labels, indexes), names=("seeds", "labels", "indexes")
) )
item_sampler = gb.ItemSampler( item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
...@@ -350,6 +361,8 @@ def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last): ...@@ -350,6 +361,8 @@ def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
src_ids = [] src_ids = []
dst_ids = [] dst_ids = []
negs_ids = [] negs_ids = []
final_labels = []
final_indexes = []
for i, minibatch in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
assert minibatch.seeds is not None assert minibatch.seeds is not None
assert isinstance(minibatch.seeds, torch.Tensor) assert isinstance(minibatch.seeds, torch.Tensor)
...@@ -358,46 +371,43 @@ def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last): ...@@ -358,46 +371,43 @@ def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
src, dst = minibatch.seeds.T src, dst = minibatch.seeds.T
negs_src = src[~minibatch.labels.to(bool)] negs_src = src[~minibatch.labels.to(bool)]
negs_dst = dst[~minibatch.labels.to(bool)] negs_dst = dst[~minibatch.labels.to(bool)]
is_last = (i + 1) * batch_size >= num_ids is_last = (i + 1) * batch_size >= num_ids * 3
if not is_last or num_ids % batch_size == 0: if not is_last or num_ids * 3 % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
else: else:
if not drop_last: if not drop_last:
expected_batch_size = num_ids % batch_size expected_batch_size = num_ids * 3 % batch_size
else: else:
assert False assert False
assert len(src) == expected_batch_size * 3 assert len(src) == expected_batch_size
assert len(dst) == expected_batch_size * 3 assert len(dst) == expected_batch_size
assert negs_src.dim() == 1 assert negs_src.dim() == 1
assert negs_dst.dim() == 1 assert negs_dst.dim() == 1
assert len(negs_src) == expected_batch_size * 2
assert len(negs_dst) == expected_batch_size * 2
expected_indexes = torch.arange(expected_batch_size)
expected_indexes = torch.cat(
(expected_indexes, expected_indexes.repeat_interleave(2))
)
assert torch.equal(minibatch.indexes, expected_indexes)
# Verify node pairs and negative destinations.
assert torch.equal(
src[minibatch.labels.to(bool)] + 1, dst[minibatch.labels.to(bool)]
)
assert torch.equal((negs_dst - 2 * num_ids) // 2 * 2, negs_src) assert torch.equal((negs_dst - 2 * num_ids) // 2 * 2, negs_src)
# Archive batch. # Archive batch.
src_ids.append(src) src_ids.append(src)
dst_ids.append(dst) dst_ids.append(dst)
negs_ids.append(negs_dst) negs_ids.append(negs_dst)
final_labels.append(minibatch.labels)
final_indexes.append(minibatch.indexes)
src_ids = torch.cat(src_ids) src_ids = torch.cat(src_ids)
dst_ids = torch.cat(dst_ids) dst_ids = torch.cat(dst_ids)
negs_ids = torch.cat(negs_ids) negs_ids = torch.cat(negs_ids)
final_labels = torch.cat(final_labels)
final_indexes = torch.cat(final_indexes)
assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
assert torch.all(negs_ids[:-1] <= negs_ids[1:]) is not shuffle assert torch.all(negs_ids[:-1] <= negs_ids[1:]) is not shuffle
assert torch.all(final_labels[:-1] >= final_labels[1:]) is not shuffle
if not drop_last:
assert final_labels.sum() == num_ids
assert torch.equal(final_indexes, indexes) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False]) @pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_seeds(batch_size, shuffle, drop_last): def test_ItemSet_hyperlink(batch_size, shuffle, drop_last):
# Node pairs. # Node pairs.
num_ids = 103 num_ids = 103
seeds = torch.arange(0, 3 * num_ids).reshape(-1, 3) seeds = torch.arange(0, 3 * num_ids).reshape(-1, 3)
...@@ -477,7 +487,7 @@ def test_ItemSet_seeds_labels(batch_size, shuffle, drop_last): ...@@ -477,7 +487,7 @@ def test_ItemSet_seeds_labels(batch_size, shuffle, drop_last):
def test_append_with_other_datapipes(): def test_append_with_other_datapipes():
num_ids = 100 num_ids = 100
batch_size = 4 batch_size = 4
item_set = gb.ItemSet(torch.arange(0, num_ids), names="seed_nodes") item_set = gb.ItemSet(torch.arange(0, num_ids), names="seeds")
data_pipe = gb.ItemSampler(item_set, batch_size) data_pipe = gb.ItemSampler(item_set, batch_size)
# torchdata.datapipes.iter.Enumerator # torchdata.datapipes.iter.Enumerator
data_pipe = data_pipe.enumerate() data_pipe = data_pipe.enumerate()
...@@ -500,8 +510,8 @@ def test_ItemSetDict_iterable_only(batch_size, shuffle, drop_last): ...@@ -500,8 +510,8 @@ def test_ItemSetDict_iterable_only(batch_size, shuffle, drop_last):
num_ids = 205 num_ids = 205
ids = { ids = {
"user": gb.ItemSet(IterableOnly(0, 99), names="seed_nodes"), "user": gb.ItemSet(IterableOnly(0, 99), names="seeds"),
"item": gb.ItemSet(IterableOnly(99, num_ids), names="seed_nodes"), "item": gb.ItemSet(IterableOnly(99, num_ids), names="seeds"),
} }
chained_ids = [] chained_ids = []
for key, value in ids.items(): for key, value in ids.items():
...@@ -539,8 +549,8 @@ def test_ItemSetDict_seed_nodes(batch_size, shuffle, drop_last): ...@@ -539,8 +549,8 @@ def test_ItemSetDict_seed_nodes(batch_size, shuffle, drop_last):
# Node IDs. # Node IDs.
num_ids = 205 num_ids = 205
ids = { ids = {
"user": gb.ItemSet(torch.arange(0, 99), names="seed_nodes"), "user": gb.ItemSet(torch.arange(0, 99), names="seeds"),
"item": gb.ItemSet(torch.arange(99, num_ids), names="seed_nodes"), "item": gb.ItemSet(torch.arange(99, num_ids), names="seeds"),
} }
chained_ids = [] chained_ids = []
for key, value in ids.items(): for key, value in ids.items():
...@@ -580,11 +590,11 @@ def test_ItemSetDict_seed_nodes_labels(batch_size, shuffle, drop_last): ...@@ -580,11 +590,11 @@ def test_ItemSetDict_seed_nodes_labels(batch_size, shuffle, drop_last):
ids = { ids = {
"user": gb.ItemSet( "user": gb.ItemSet(
(torch.arange(0, 99), torch.arange(0, 99)), (torch.arange(0, 99), torch.arange(0, 99)),
names=("seed_nodes", "labels"), names=("seeds", "labels"),
), ),
"item": gb.ItemSet( "item": gb.ItemSet(
(torch.arange(99, num_ids), torch.arange(99, num_ids)), (torch.arange(99, num_ids), torch.arange(99, num_ids)),
names=("seed_nodes", "labels"), names=("seeds", "labels"),
), ),
} }
chained_ids = [] chained_ids = []
...@@ -638,8 +648,8 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last): ...@@ -638,8 +648,8 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
node_pairs_like = torch.arange(0, num_ids * 2).reshape(-1, 2) node_pairs_like = torch.arange(0, num_ids * 2).reshape(-1, 2)
node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2) node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2)
node_pairs_dict = { node_pairs_dict = {
"user:like:item": gb.ItemSet(node_pairs_like, names="node_pairs"), "user:like:item": gb.ItemSet(node_pairs_like, names="seeds"),
"user:follow:user": gb.ItemSet(node_pairs_follow, names="node_pairs"), "user:follow:user": gb.ItemSet(node_pairs_follow, names="seeds"),
} }
item_set = gb.ItemSetDict(node_pairs_dict) item_set = gb.ItemSetDict(node_pairs_dict)
item_sampler = gb.ItemSampler( item_sampler = gb.ItemSampler(
...@@ -691,11 +701,11 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last): ...@@ -691,11 +701,11 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
node_pairs_dict = { node_pairs_dict = {
"user:like:item": gb.ItemSet( "user:like:item": gb.ItemSet(
(node_pairs_like, node_pairs_like[:, 0]), (node_pairs_like, node_pairs_like[:, 0]),
names=("node_pairs", "labels"), names=("seeds", "labels"),
), ),
"user:follow:user": gb.ItemSet( "user:follow:user": gb.ItemSet(
(node_pairs_follow, node_pairs_follow[:, 0]), (node_pairs_follow, node_pairs_follow[:, 0]),
names=("node_pairs", "labels"), names=("seeds", "labels"),
), ),
} }
item_set = gb.ItemSetDict(node_pairs_dict) item_set = gb.ItemSetDict(node_pairs_dict)
...@@ -709,7 +719,6 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last): ...@@ -709,7 +719,6 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
assert isinstance(minibatch, gb.MiniBatch) assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seeds is not None assert minibatch.seeds is not None
assert minibatch.labels is not None assert minibatch.labels is not None
assert minibatch.negative_dsts is None
is_last = (i + 1) * batch_size >= total_ids is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0: if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
...@@ -749,27 +758,64 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last): ...@@ -749,27 +758,64 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
@pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False]) @pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last): def test_ItemSetDict_node_pairs_labels_indexes(batch_size, shuffle, drop_last):
# Head, tail and negative tails. # Head, tail and negative tails.
num_ids = 103 num_ids = 103
total_ids = 2 * num_ids total_ids = 6 * num_ids
num_negs = 2 num_negs = 2
node_paris_like = torch.arange(0, num_ids * 2).reshape(-1, 2) node_pairs_like = torch.arange(0, num_ids * 2).reshape(-1, 2)
node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2) node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2)
neg_dsts_like = torch.arange( neg_dsts_like = torch.arange(num_ids * 4, num_ids * 4 + num_ids * num_negs)
num_ids * 4, num_ids * 4 + num_ids * num_negs neg_node_pairs_like = (
).reshape(-1, num_negs) torch.cat(
(node_pairs_like[:, 0].repeat_interleave(num_negs), neg_dsts_like)
)
.view(2, -1)
.T
)
all_node_pairs_like = torch.cat((node_pairs_like, neg_node_pairs_like))
labels_like = torch.empty(num_ids * 3)
labels_like[:num_ids] = 1
labels_like[num_ids:] = 0
indexes_like = torch.cat(
(
torch.arange(0, num_ids),
torch.arange(0, num_ids).repeat_interleave(num_negs),
)
)
neg_dsts_follow = torch.arange( neg_dsts_follow = torch.arange(
num_ids * 4 + num_ids * num_negs, num_ids * 4 + num_ids * num_negs * 2 num_ids * 4 + num_ids * num_negs, num_ids * 4 + num_ids * num_negs * 2
).reshape(-1, num_negs) )
neg_node_pairs_follow = (
torch.cat(
(
node_pairs_follow[:, 0].repeat_interleave(num_negs),
neg_dsts_follow,
)
)
.view(2, -1)
.T
)
all_node_pairs_follow = torch.cat(
(node_pairs_follow, neg_node_pairs_follow)
)
labels_follow = torch.empty(num_ids * 3)
labels_follow[:num_ids] = 1
labels_follow[num_ids:] = 0
indexes_follow = torch.cat(
(
torch.arange(0, num_ids),
torch.arange(0, num_ids).repeat_interleave(num_negs),
)
)
data_dict = { data_dict = {
"user:like:item": gb.ItemSet( "user:like:item": gb.ItemSet(
(node_paris_like, neg_dsts_like), (all_node_pairs_like, labels_like, indexes_like),
names=("node_pairs", "negative_dsts"), names=("seeds", "labels", "indexes"),
), ),
"user:follow:user": gb.ItemSet( "user:follow:user": gb.ItemSet(
(node_pairs_follow, neg_dsts_follow), (all_node_pairs_follow, labels_follow, indexes_follow),
names=("node_pairs", "negative_dsts"), names=("seeds", "labels", "indexes"),
), ),
} }
item_set = gb.ItemSetDict(data_dict) item_set = gb.ItemSetDict(data_dict)
...@@ -779,11 +825,13 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last): ...@@ -779,11 +825,13 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
src_ids = [] src_ids = []
dst_ids = [] dst_ids = []
negs_ids = [] negs_ids = []
final_labels = defaultdict(list)
final_indexes = defaultdict(list)
for i, minibatch in enumerate(item_sampler): for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch) assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seeds is not None assert minibatch.seeds is not None
assert minibatch.labels is not None assert minibatch.labels is not None
assert minibatch.negative_dsts is None assert minibatch.indexes is not None
is_last = (i + 1) * batch_size >= total_ids is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0: if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
...@@ -800,24 +848,23 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last): ...@@ -800,24 +848,23 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
assert isinstance(seeds, torch.Tensor) assert isinstance(seeds, torch.Tensor)
src_etype = seeds[:, 0] src_etype = seeds[:, 0]
dst_etype = seeds[:, 1] dst_etype = seeds[:, 1]
src.append(src_etype[minibatch.labels[etype].to(bool)]) src.append(src_etype)
dst.append(dst_etype[minibatch.labels[etype].to(bool)]) dst.append(dst_etype)
negs_src.append(src_etype[~minibatch.labels[etype].to(bool)]) negs_src.append(src_etype[~minibatch.labels[etype].to(bool)])
negs_dst.append(dst_etype[~minibatch.labels[etype].to(bool)]) negs_dst.append(dst_etype[~minibatch.labels[etype].to(bool)])
final_labels[etype].append(minibatch.labels[etype])
final_indexes[etype].append(minibatch.indexes[etype])
src = torch.cat(src) src = torch.cat(src)
dst = torch.cat(dst) dst = torch.cat(dst)
negs_src = torch.cat(negs_src) negs_src = torch.cat(negs_src)
negs_dst = torch.cat(negs_dst) negs_dst = torch.cat(negs_dst)
assert len(src) == expected_batch_size assert len(src) == expected_batch_size
assert len(dst) == expected_batch_size assert len(dst) == expected_batch_size
assert len(negs_src) == expected_batch_size * 2
assert len(negs_dst) == expected_batch_size * 2
src_ids.append(src) src_ids.append(src)
dst_ids.append(dst) dst_ids.append(dst)
negs_ids.append(negs_dst) negs_ids.append(negs_dst)
assert negs_src.dim() == 1 assert negs_src.dim() == 1
assert negs_dst.dim() == 1 assert negs_dst.dim() == 1
assert torch.equal(src + 1, dst)
assert torch.equal(negs_src, (negs_dst - num_ids * 4) // 2 * 2) assert torch.equal(negs_src, (negs_dst - num_ids * 4) // 2 * 2)
src_ids = torch.cat(src_ids) src_ids = torch.cat(src_ids)
dst_ids = torch.cat(dst_ids) dst_ids = torch.cat(dst_ids)
...@@ -825,12 +872,24 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last): ...@@ -825,12 +872,24 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
assert torch.all(negs_ids <= negs_ids) is not shuffle assert torch.all(negs_ids <= negs_ids) is not shuffle
for etype in data_dict.keys():
final_labels_etype = torch.cat(final_labels[etype])
final_indexes_etype = torch.cat(final_indexes[etype])
assert (
torch.all(final_labels_etype[:-1] >= final_labels_etype[1:])
is not shuffle
)
if not drop_last:
assert final_labels_etype.sum() == num_ids
assert (
torch.equal(final_indexes_etype, indexes_follow) is not shuffle
)
@pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False]) @pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSetDict_seeds(batch_size, shuffle, drop_last): def test_ItemSetDict_hyperlink(batch_size, shuffle, drop_last):
# Node pairs. # Node pairs.
num_ids = 103 num_ids = 103
total_pairs = 2 * num_ids total_pairs = 2 * num_ids
...@@ -876,7 +935,7 @@ def test_ItemSetDict_seeds(batch_size, shuffle, drop_last): ...@@ -876,7 +935,7 @@ def test_ItemSetDict_seeds(batch_size, shuffle, drop_last):
@pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False]) @pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSetDict_seeds_labels(batch_size, shuffle, drop_last): def test_ItemSetDict_hyperlink_labels(batch_size, shuffle, drop_last):
# Node pairs and labels # Node pairs and labels
num_ids = 103 num_ids = 103
total_ids = 2 * num_ids total_ids = 2 * num_ids
......
...@@ -11,6 +11,7 @@ reverse_relation = "B:rr:A" ...@@ -11,6 +11,7 @@ reverse_relation = "B:rr:A"
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64]) @pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64]) @pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_minibatch_representation_homo(indptr_dtype, indices_dtype): def test_minibatch_representation_homo(indptr_dtype, indices_dtype):
seeds = torch.tensor([10, 11])
csc_formats = [ csc_formats = [
gb.CSCFormatBase( gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3, 5, 6], dtype=indptr_dtype), indptr=torch.tensor([0, 1, 3, 5, 6], dtype=indptr_dtype),
...@@ -48,36 +49,20 @@ def test_minibatch_representation_homo(indptr_dtype, indices_dtype): ...@@ -48,36 +49,20 @@ def test_minibatch_representation_homo(indptr_dtype, indices_dtype):
original_edge_ids=original_edge_ids[i], original_edge_ids=original_edge_ids[i],
) )
) )
negative_srcs = torch.tensor([[8], [1], [6]])
negative_dsts = torch.tensor([[2], [8], [8]])
input_nodes = torch.tensor([8, 1, 6, 5, 9, 0, 2, 4]) input_nodes = torch.tensor([8, 1, 6, 5, 9, 0, 2, 4])
compacted_csc_formats = gb.CSCFormatBase( compacted_seeds = torch.tensor([0, 1])
indptr=torch.tensor([0, 2, 3]), indices=torch.tensor([3, 4, 5]) labels = torch.tensor([1.0, 2.0])
)
compacted_negative_srcs = torch.tensor([[0], [1], [2]])
compacted_negative_dsts = torch.tensor([[6], [0], [0]])
labels = torch.tensor([0.0, 1.0, 2.0])
# Test minibatch without data. # Test minibatch without data.
minibatch = gb.MiniBatch() minibatch = gb.MiniBatch()
expect_result = str( expect_result = str(
"""MiniBatch(seeds=None, """MiniBatch(seeds=None,
seed_nodes=None,
sampled_subgraphs=None, sampled_subgraphs=None,
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=None,
node_features=None, node_features=None,
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=None, labels=None,
input_nodes=None, input_nodes=None,
indexes=None, indexes=None,
edge_features=None, edge_features=None,
compacted_seeds=None, compacted_seeds=None,
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=None, blocks=None,
)""" )"""
) )
...@@ -85,21 +70,16 @@ def test_minibatch_representation_homo(indptr_dtype, indices_dtype): ...@@ -85,21 +70,16 @@ def test_minibatch_representation_homo(indptr_dtype, indices_dtype):
assert result == expect_result, print(expect_result, result) assert result == expect_result, print(expect_result, result)
# Test minibatch with all attributes. # Test minibatch with all attributes.
minibatch = gb.MiniBatch( minibatch = gb.MiniBatch(
node_pairs=csc_formats, seeds=seeds,
sampled_subgraphs=subgraphs, sampled_subgraphs=subgraphs,
labels=labels, labels=labels,
node_features=node_features, node_features=node_features,
edge_features=edge_features, edge_features=edge_features,
negative_srcs=negative_srcs, compacted_seeds=compacted_seeds,
negative_dsts=negative_dsts,
compacted_node_pairs=compacted_csc_formats,
input_nodes=input_nodes, input_nodes=input_nodes,
compacted_negative_srcs=compacted_negative_srcs,
compacted_negative_dsts=compacted_negative_dsts,
) )
expect_result = str( expect_result = str(
"""MiniBatch(seeds=None, """MiniBatch(seeds=tensor([10, 11]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6], dtype=torch.int32), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6], dtype=torch.int32),
indices=tensor([0, 1, 2, 2, 1, 2], dtype=torch.int32), indices=tensor([0, 1, 2, 2, 1, 2], dtype=torch.int32),
), ),
...@@ -114,47 +94,13 @@ def test_minibatch_representation_homo(indptr_dtype, indices_dtype): ...@@ -114,47 +94,13 @@ def test_minibatch_representation_homo(indptr_dtype, indices_dtype):
original_edge_ids=tensor([10, 15, 17]), original_edge_ids=tensor([10, 15, 17]),
original_column_node_ids=tensor([10, 11]), original_column_node_ids=tensor([10, 11]),
)], )],
positive_node_pairs=CSCFormatBase(indptr=tensor([0, 2, 3]),
indices=tensor([3, 4, 5]),
),
node_pairs_with_labels=(CSCFormatBase(indptr=tensor([0, 2, 3]),
indices=tensor([3, 4, 5]),
),
tensor([0., 1., 2.])),
node_pairs=[CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6], dtype=torch.int32),
indices=tensor([0, 1, 2, 2, 1, 2], dtype=torch.int32),
),
CSCFormatBase(indptr=tensor([0, 2, 3], dtype=torch.int32),
indices=tensor([1, 2, 0], dtype=torch.int32),
)],
node_features={'x': tensor([5, 0, 2, 1])}, node_features={'x': tensor([5, 0, 2, 1])},
negative_srcs=tensor([[8], labels=tensor([1., 2.]),
[1],
[6]]),
negative_node_pairs=(tensor([[0],
[1],
[2]]),
tensor([[6],
[0],
[0]])),
negative_dsts=tensor([[2],
[8],
[8]]),
labels=tensor([0., 1., 2.]),
input_nodes=tensor([8, 1, 6, 5, 9, 0, 2, 4]), input_nodes=tensor([8, 1, 6, 5, 9, 0, 2, 4]),
indexes=None, indexes=None,
edge_features=[{'x': tensor([9, 0, 1, 1, 7, 4])}, edge_features=[{'x': tensor([9, 0, 1, 1, 7, 4])},
{'x': tensor([0, 2, 2])}], {'x': tensor([0, 2, 2])}],
compacted_seeds=None, compacted_seeds=tensor([0, 1]),
compacted_node_pairs=CSCFormatBase(indptr=tensor([0, 2, 3]),
indices=tensor([3, 4, 5]),
),
compacted_negative_srcs=tensor([[0],
[1],
[2]]),
compacted_negative_dsts=tensor([[6],
[0],
[0]]),
blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=6), blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=6),
Block(num_src_nodes=3, num_dst_nodes=2, num_edges=3)], Block(num_src_nodes=3, num_dst_nodes=2, num_edges=3)],
)""" )"""
...@@ -166,6 +112,7 @@ def test_minibatch_representation_homo(indptr_dtype, indices_dtype): ...@@ -166,6 +112,7 @@ def test_minibatch_representation_homo(indptr_dtype, indices_dtype):
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64]) @pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64]) @pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_minibatch_representation_hetero(indptr_dtype, indices_dtype): def test_minibatch_representation_hetero(indptr_dtype, indices_dtype):
seeds = {relation: torch.tensor([10, 11])}
csc_formats = [ csc_formats = [
{ {
relation: gb.CSCFormatBase( relation: gb.CSCFormatBase(
...@@ -222,39 +169,22 @@ def test_minibatch_representation_hetero(indptr_dtype, indices_dtype): ...@@ -222,39 +169,22 @@ def test_minibatch_representation_hetero(indptr_dtype, indices_dtype):
original_edge_ids=original_edge_ids[i], original_edge_ids=original_edge_ids[i],
) )
) )
negative_srcs = {"B": torch.tensor([[8], [1], [6]])} compacted_seeds = {relation: torch.tensor([0, 1])}
negative_dsts = {"B": torch.tensor([[2], [8], [8]])}
compacted_csc_formats = {
relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2, 3]), indices=torch.tensor([3, 4, 5])
),
reverse_relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 0, 1, 2]), indices=torch.tensor([0, 1])
),
}
compacted_negative_srcs = {relation: torch.tensor([[0], [1], [2]])}
compacted_negative_dsts = {relation: torch.tensor([[6], [0], [0]])}
# Test minibatch with all attributes. # Test minibatch with all attributes.
minibatch = gb.MiniBatch( minibatch = gb.MiniBatch(
seed_nodes={"B": torch.tensor([10, 15])}, seeds=seeds,
node_pairs=csc_formats,
sampled_subgraphs=subgraphs, sampled_subgraphs=subgraphs,
node_features=node_features, node_features=node_features,
edge_features=edge_features, edge_features=edge_features,
labels={"B": torch.tensor([2, 5])}, labels={"B": torch.tensor([2, 5])},
negative_srcs=negative_srcs, compacted_seeds=compacted_seeds,
negative_dsts=negative_dsts,
compacted_node_pairs=compacted_csc_formats,
input_nodes={ input_nodes={
"A": torch.tensor([5, 7, 9, 11]), "A": torch.tensor([5, 7, 9, 11]),
"B": torch.tensor([10, 11, 12]), "B": torch.tensor([10, 11, 12]),
}, },
compacted_negative_srcs=compacted_negative_srcs,
compacted_negative_dsts=compacted_negative_dsts,
) )
expect_result = str( expect_result = str(
"""MiniBatch(seeds=None, """MiniBatch(seeds={'A:r:B': tensor([10, 11])},
seed_nodes={'B': tensor([10, 15])},
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3], dtype=torch.int32), sampled_subgraphs=[SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3], dtype=torch.int32),
indices=tensor([0, 1, 1], dtype=torch.int32), indices=tensor([0, 1, 1], dtype=torch.int32),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2], dtype=torch.int32), ), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2], dtype=torch.int32),
...@@ -271,54 +201,13 @@ def test_minibatch_representation_hetero(indptr_dtype, indices_dtype): ...@@ -271,54 +201,13 @@ def test_minibatch_representation_hetero(indptr_dtype, indices_dtype):
original_edge_ids={'A:r:B': tensor([10, 12])}, original_edge_ids={'A:r:B': tensor([10, 12])},
original_column_node_ids={'B': tensor([10, 11])}, original_column_node_ids={'B': tensor([10, 11])},
)], )],
positive_node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([3, 4, 5]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
indices=tensor([0, 1]),
)},
node_pairs_with_labels=({'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([3, 4, 5]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
indices=tensor([0, 1]),
)},
{'B': tensor([2, 5])}),
node_pairs=[{'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3], dtype=torch.int32),
indices=tensor([0, 1, 1], dtype=torch.int32),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2], dtype=torch.int32),
indices=tensor([1, 0], dtype=torch.int32),
)},
{'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([1, 0], dtype=torch.int32),
)}],
node_features={('A', 'x'): tensor([6, 4, 0, 1])}, node_features={('A', 'x'): tensor([6, 4, 0, 1])},
negative_srcs={'B': tensor([[8],
[1],
[6]])},
negative_node_pairs={'A:r:B': (tensor([[0],
[1],
[2]]), tensor([[6],
[0],
[0]]))},
negative_dsts={'B': tensor([[2],
[8],
[8]])},
labels={'B': tensor([2, 5])}, labels={'B': tensor([2, 5])},
input_nodes={'A': tensor([ 5, 7, 9, 11]), 'B': tensor([10, 11, 12])}, input_nodes={'A': tensor([ 5, 7, 9, 11]), 'B': tensor([10, 11, 12])},
indexes=None, indexes=None,
edge_features=[{('A:r:B', 'x'): tensor([4, 2, 4])}, edge_features=[{('A:r:B', 'x'): tensor([4, 2, 4])},
{('A:r:B', 'x'): tensor([0, 6])}], {('A:r:B', 'x'): tensor([0, 6])}],
compacted_seeds=None, compacted_seeds={'A:r:B': tensor([0, 1])},
compacted_node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([3, 4, 5]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
indices=tensor([0, 1]),
)},
compacted_negative_srcs={'A:r:B': tensor([[0],
[1],
[2]])},
compacted_negative_dsts={'A:r:B': tensor([[6],
[0],
[0]])},
blocks=[Block(num_src_nodes={'A': 4, 'B': 3}, blocks=[Block(num_src_nodes={'A': 4, 'B': 3},
num_dst_nodes={'A': 4, 'B': 3}, num_dst_nodes={'A': 4, 'B': 3},
num_edges={('A', 'r', 'B'): 3, ('B', 'rr', 'A'): 2}, num_edges={('A', 'r', 'B'): 3, ('B', 'rr', 'A'): 2},
...@@ -330,22 +219,12 @@ def test_minibatch_representation_hetero(indptr_dtype, indices_dtype): ...@@ -330,22 +219,12 @@ def test_minibatch_representation_hetero(indptr_dtype, indices_dtype):
)""" )"""
) )
result = str(minibatch) result = str(minibatch)
assert result == expect_result, print(expect_result, result) assert result == expect_result, print(result)
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64]) @pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64]) @pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_get_dgl_blocks_homo(indptr_dtype, indices_dtype): def test_get_dgl_blocks_homo(indptr_dtype, indices_dtype):
node_pairs = [
(
torch.tensor([0, 1, 2, 2, 2, 1]),
torch.tensor([0, 1, 1, 2, 3, 2]),
),
(
torch.tensor([0, 1, 2]),
torch.tensor([1, 0, 0]),
),
]
csc_formats = [ csc_formats = [
gb.CSCFormatBase( gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3, 5, 6], dtype=indptr_dtype), indptr=torch.tensor([0, 1, 3, 5, 6], dtype=indptr_dtype),
...@@ -368,11 +247,6 @@ def test_get_dgl_blocks_homo(indptr_dtype, indices_dtype): ...@@ -368,11 +247,6 @@ def test_get_dgl_blocks_homo(indptr_dtype, indices_dtype):
torch.tensor([19, 20, 21, 22, 25, 30]), torch.tensor([19, 20, 21, 22, 25, 30]),
torch.tensor([10, 15, 17]), torch.tensor([10, 15, 17]),
] ]
node_features = {"x": torch.tensor([7, 6, 2, 2])}
edge_features = [
{"x": torch.tensor([[8], [1], [6]])},
{"x": torch.tensor([[2], [8], [8]])},
]
subgraphs = [] subgraphs = []
for i in range(2): for i in range(2):
subgraphs.append( subgraphs.append(
...@@ -383,43 +257,19 @@ def test_get_dgl_blocks_homo(indptr_dtype, indices_dtype): ...@@ -383,43 +257,19 @@ def test_get_dgl_blocks_homo(indptr_dtype, indices_dtype):
original_edge_ids=original_edge_ids[i], original_edge_ids=original_edge_ids[i],
) )
) )
negative_srcs = torch.tensor([[8], [1], [6]])
negative_dsts = torch.tensor([[2], [8], [8]])
input_nodes = torch.tensor([8, 1, 6, 5, 9, 0, 2, 4])
compacted_node_pairs = (torch.tensor([0, 1, 2]), torch.tensor([3, 4, 5]))
compacted_negative_srcs = torch.tensor([[0], [1], [2]])
compacted_negative_dsts = torch.tensor([[6], [0], [0]])
labels = torch.tensor([0.0, 1.0, 2.0])
# Test minibatch with all attributes. # Test minibatch with all attributes.
minibatch = gb.MiniBatch( minibatch = gb.MiniBatch(
node_pairs=node_pairs,
sampled_subgraphs=subgraphs, sampled_subgraphs=subgraphs,
labels=labels,
node_features=node_features,
edge_features=edge_features,
negative_srcs=negative_srcs,
negative_dsts=negative_dsts,
compacted_node_pairs=compacted_node_pairs,
input_nodes=input_nodes,
compacted_negative_srcs=compacted_negative_srcs,
compacted_negative_dsts=compacted_negative_dsts,
) )
dgl_blocks = minibatch.blocks dgl_blocks = minibatch.blocks
expect_result = str( expect_result = str(
"""[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=6), Block(num_src_nodes=3, num_dst_nodes=2, num_edges=3)]""" """[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=6), Block(num_src_nodes=3, num_dst_nodes=2, num_edges=3)]"""
) )
result = str(dgl_blocks) result = str(dgl_blocks)
assert result == expect_result, print(result) assert result == expect_result
def test_get_dgl_blocks_hetero(): def test_get_dgl_blocks_hetero():
node_pairs = [
{
relation: (torch.tensor([0, 1, 1]), torch.tensor([0, 1, 2])),
reverse_relation: (torch.tensor([1, 0]), torch.tensor([2, 3])),
},
{relation: (torch.tensor([0, 1]), torch.tensor([1, 0]))},
]
csc_formats = [ csc_formats = [
{ {
relation: gb.CSCFormatBase( relation: gb.CSCFormatBase(
...@@ -458,13 +308,6 @@ def test_get_dgl_blocks_hetero(): ...@@ -458,13 +308,6 @@ def test_get_dgl_blocks_hetero():
}, },
{relation: torch.tensor([10, 12])}, {relation: torch.tensor([10, 12])},
] ]
node_features = {
("A", "x"): torch.tensor([6, 4, 0, 1]),
}
edge_features = [
{(relation, "x"): torch.tensor([4, 2, 4])},
{(relation, "x"): torch.tensor([0, 6])},
]
subgraphs = [] subgraphs = []
for i in range(2): for i in range(2):
subgraphs.append( subgraphs.append(
...@@ -475,31 +318,9 @@ def test_get_dgl_blocks_hetero(): ...@@ -475,31 +318,9 @@ def test_get_dgl_blocks_hetero():
original_edge_ids=original_edge_ids[i], original_edge_ids=original_edge_ids[i],
) )
) )
negative_srcs = {"B": torch.tensor([[8], [1], [6]])}
negative_dsts = {"B": torch.tensor([[2], [8], [8]])}
compacted_node_pairs = {
relation: (torch.tensor([0, 1, 2]), torch.tensor([3, 4, 5])),
reverse_relation: (torch.tensor([0, 1, 2]), torch.tensor([3, 4, 5])),
}
compacted_negative_srcs = {relation: torch.tensor([[0], [1], [2]])}
compacted_negative_dsts = {relation: torch.tensor([[6], [0], [0]])}
# Test minibatch with all attributes. # Test minibatch with all attributes.
minibatch = gb.MiniBatch( minibatch = gb.MiniBatch(
seed_nodes={"B": torch.tensor([10, 15])},
node_pairs=node_pairs,
sampled_subgraphs=subgraphs, sampled_subgraphs=subgraphs,
node_features=node_features,
edge_features=edge_features,
labels={"B": torch.tensor([2, 5])},
negative_srcs=negative_srcs,
negative_dsts=negative_dsts,
compacted_node_pairs=compacted_node_pairs,
input_nodes={
"A": torch.tensor([5, 7, 9, 11]),
"B": torch.tensor([10, 11, 12]),
},
compacted_negative_srcs=compacted_negative_srcs,
compacted_negative_dsts=compacted_negative_dsts,
) )
dgl_blocks = minibatch.blocks dgl_blocks = minibatch.blocks
expect_result = str( expect_result = str(
...@@ -512,50 +333,7 @@ def test_get_dgl_blocks_hetero(): ...@@ -512,50 +333,7 @@ def test_get_dgl_blocks_hetero():
metagraph=[('A', 'B', 'r')])]""" metagraph=[('A', 'B', 'r')])]"""
) )
result = str(dgl_blocks) result = str(dgl_blocks)
assert result == expect_result, print(result) assert result == expect_result
@pytest.mark.parametrize(
"mode", ["neg_graph", "neg_src", "neg_dst", "edge_classification"]
)
def test_minibatch_node_pairs_with_labels(mode):
# Arrange
minibatch = create_homo_minibatch()
minibatch.compacted_node_pairs = (
torch.tensor([0, 1]),
torch.tensor([1, 0]),
)
if mode == "neg_graph" or mode == "neg_src":
minibatch.compacted_negative_srcs = torch.tensor([[0, 0], [1, 1]])
if mode == "neg_graph" or mode == "neg_dst":
minibatch.compacted_negative_dsts = torch.tensor([[1, 0], [0, 1]])
if mode == "edge_classification":
minibatch.labels = torch.tensor([0, 1]).long()
# Act
node_pairs, labels = minibatch.node_pairs_with_labels
# Assert
if mode == "neg_src":
expect_node_pairs = (
torch.tensor([0, 1, 0, 0, 1, 1]),
torch.tensor([1, 0, 1, 1, 0, 0]),
)
expect_labels = torch.tensor([1, 1, 0, 0, 0, 0]).float()
elif mode != "edge_classification":
expect_node_pairs = (
torch.tensor([0, 1, 0, 0, 1, 1]),
torch.tensor([1, 0, 1, 0, 0, 1]),
)
expect_labels = torch.tensor([1, 1, 0, 0, 0, 0]).float()
else:
expect_node_pairs = (
torch.tensor([0, 1]),
torch.tensor([1, 0]),
)
expect_labels = torch.tensor([0, 1]).long()
assert torch.equal(node_pairs[0], expect_node_pairs[0])
assert torch.equal(node_pairs[1], expect_node_pairs[1])
assert torch.equal(labels, expect_labels)
def create_homo_minibatch(): def create_homo_minibatch():
...@@ -723,16 +501,10 @@ def check_dgl_blocks_homo(minibatch, blocks): ...@@ -723,16 +501,10 @@ def check_dgl_blocks_homo(minibatch, blocks):
dst_ndoes = torch.arange( dst_ndoes = torch.arange(
0, len(sampled_csc[i].indptr) - 1 0, len(sampled_csc[i].indptr) - 1
).repeat_interleave(sampled_csc[i].indptr.diff()) ).repeat_interleave(sampled_csc[i].indptr.diff())
assert torch.equal(block.edges()[0], sampled_csc[i].indices), print( assert torch.equal(block.edges()[0], sampled_csc[i].indices)
block.edges() assert torch.equal(block.edges()[1], dst_ndoes)
) assert torch.equal(block.edata[dgl.EID], original_edge_ids[i])
assert torch.equal(block.edges()[1], dst_ndoes), print(block.edges()) assert torch.equal(blocks[0].srcdata[dgl.NID], original_row_node_ids[0])
assert torch.equal(block.edata[dgl.EID], original_edge_ids[i]), print(
block.edata[dgl.EID]
)
assert torch.equal(
blocks[0].srcdata[dgl.NID], original_row_node_ids[0]
), print(blocks[0].srcdata[dgl.NID])
def test_dgl_node_classification_without_feature(): def test_dgl_node_classification_without_feature():
...@@ -740,7 +512,7 @@ def test_dgl_node_classification_without_feature(): ...@@ -740,7 +512,7 @@ def test_dgl_node_classification_without_feature():
minibatch = create_homo_minibatch() minibatch = create_homo_minibatch()
minibatch.node_features = None minibatch.node_features = None
minibatch.labels = None minibatch.labels = None
minibatch.seed_nodes = torch.tensor([10, 15]) minibatch.seeds = torch.tensor([10, 15])
# Act # Act
dgl_blocks = minibatch.blocks dgl_blocks = minibatch.blocks
...@@ -754,7 +526,7 @@ def test_dgl_node_classification_without_feature(): ...@@ -754,7 +526,7 @@ def test_dgl_node_classification_without_feature():
def test_dgl_node_classification_homo(): def test_dgl_node_classification_homo():
# Arrange # Arrange
minibatch = create_homo_minibatch() minibatch = create_homo_minibatch()
minibatch.seed_nodes = torch.tensor([10, 15]) minibatch.seeds = torch.tensor([10, 15])
minibatch.labels = torch.tensor([2, 5]) minibatch.labels = torch.tensor([2, 5])
# Act # Act
dgl_blocks = minibatch.blocks dgl_blocks = minibatch.blocks
...@@ -767,7 +539,7 @@ def test_dgl_node_classification_homo(): ...@@ -767,7 +539,7 @@ def test_dgl_node_classification_homo():
def test_dgl_node_classification_hetero(): def test_dgl_node_classification_hetero():
minibatch = create_hetero_minibatch() minibatch = create_hetero_minibatch()
minibatch.labels = {"B": torch.tensor([2, 5])} minibatch.labels = {"B": torch.tensor([2, 5])}
minibatch.seed_nodes = {"B": torch.tensor([10, 15])} minibatch.seeds = {"B": torch.tensor([10, 15])}
# Act # Act
dgl_blocks = minibatch.blocks dgl_blocks = minibatch.blocks
...@@ -776,52 +548,19 @@ def test_dgl_node_classification_hetero(): ...@@ -776,52 +548,19 @@ def test_dgl_node_classification_hetero():
check_dgl_blocks_hetero(minibatch, dgl_blocks) check_dgl_blocks_hetero(minibatch, dgl_blocks)
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"]) def test_dgl_link_predication_homo():
def test_dgl_link_predication_homo(mode):
# Arrange # Arrange
minibatch = create_homo_minibatch() minibatch = create_homo_minibatch()
minibatch.compacted_node_pairs = ( minibatch.compacted_seeds = (
torch.tensor([0, 1]), torch.tensor([[0, 1, 0, 0, 1, 1], [1, 0, 1, 1, 0, 0]]).T,
torch.tensor([1, 0]),
) )
if mode == "neg_graph" or mode == "neg_src": minibatch.labels = torch.tensor([1, 1, 0, 0, 0, 0])
minibatch.compacted_negative_srcs = torch.tensor([[0, 0], [1, 1]])
if mode == "neg_graph" or mode == "neg_dst":
minibatch.compacted_negative_dsts = torch.tensor([[1, 0], [0, 1]])
# Act # Act
dgl_blocks = minibatch.blocks dgl_blocks = minibatch.blocks
# Assert # Assert
assert len(dgl_blocks) == 2 assert len(dgl_blocks) == 2
check_dgl_blocks_homo(minibatch, dgl_blocks) check_dgl_blocks_homo(minibatch, dgl_blocks)
if mode == "neg_graph" or mode == "neg_src":
assert torch.equal(
minibatch.negative_node_pairs[0],
minibatch.compacted_negative_srcs,
)
if mode == "neg_graph" or mode == "neg_dst":
assert torch.equal(
minibatch.negative_node_pairs[1],
minibatch.compacted_negative_dsts,
)
(
node_pairs,
labels,
) = minibatch.node_pairs_with_labels
if mode == "neg_src":
expect_node_pairs = (
torch.tensor([0, 1, 0, 0, 1, 1]),
torch.tensor([1, 0, 1, 1, 0, 0]),
)
else:
expect_node_pairs = (
torch.tensor([0, 1, 0, 0, 1, 1]),
torch.tensor([1, 0, 1, 0, 0, 1]),
)
expect_labels = torch.tensor([1, 1, 0, 0, 0, 0]).float()
assert torch.equal(node_pairs[0], expect_node_pairs[0])
assert torch.equal(node_pairs[1], expect_node_pairs[1])
assert torch.equal(labels, expect_labels)
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"]) @pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"])
...@@ -829,113 +568,21 @@ def test_dgl_link_predication_hetero(mode): ...@@ -829,113 +568,21 @@ def test_dgl_link_predication_hetero(mode):
# Arrange # Arrange
minibatch = create_hetero_minibatch() minibatch = create_hetero_minibatch()
minibatch.compacted_node_pairs = { minibatch.compacted_node_pairs = {
relation: ( relation: (torch.tensor([[1, 1, 2, 0, 1, 2], [1, 0, 1, 1, 0, 0]]).T,),
torch.tensor([1, 1]),
torch.tensor([1, 0]),
),
reverse_relation: ( reverse_relation: (
torch.tensor([0, 1]), torch.tensor([[0, 1, 1, 2, 0, 2], [1, 0, 1, 1, 0, 0]]).T,
torch.tensor([1, 0]),
), ),
} }
if mode == "neg_graph" or mode == "neg_src": minibatch.labels = {
minibatch.compacted_negative_srcs = { relation: (torch.tensor([1, 1, 0, 0, 0, 0]),),
relation: torch.tensor([[2, 0], [1, 2]]), reverse_relation: (torch.tensor([1, 1, 0, 0, 0, 0]),),
reverse_relation: torch.tensor([[1, 2], [0, 2]]), }
}
if mode == "neg_graph" or mode == "neg_dst":
minibatch.compacted_negative_dsts = {
relation: torch.tensor([[1, 3], [2, 1]]),
reverse_relation: torch.tensor([[2, 1], [3, 1]]),
}
# Act # Act
dgl_blocks = minibatch.blocks dgl_blocks = minibatch.blocks
# Assert # Assert
assert len(dgl_blocks) == 2 assert len(dgl_blocks) == 2
check_dgl_blocks_hetero(minibatch, dgl_blocks) check_dgl_blocks_hetero(minibatch, dgl_blocks)
if mode == "neg_graph" or mode == "neg_src":
for etype, src in minibatch.compacted_negative_srcs.items():
assert torch.equal(
minibatch.negative_node_pairs[etype][0],
src,
)
if mode == "neg_graph" or mode == "neg_dst":
for etype, dst in minibatch.compacted_negative_dsts.items():
assert torch.equal(
minibatch.negative_node_pairs[etype][1],
minibatch.compacted_negative_dsts[etype],
)
def test_to_pyg_data_original():
test_minibatch = create_homo_minibatch()
test_minibatch.seed_nodes = torch.tensor([0, 1])
test_minibatch.labels = torch.tensor([7, 8])
expected_edge_index = torch.tensor(
[[0, 0, 1, 1, 1, 2, 2, 2, 2], [0, 1, 0, 1, 2, 0, 1, 2, 3]]
)
expected_node_features = next(iter(test_minibatch.node_features.values()))
expected_labels = torch.tensor([7, 8])
expected_batch_size = 2
expected_n_id = torch.tensor([10, 11, 12, 13])
pyg_data = test_minibatch.to_pyg_data()
pyg_data.validate()
assert torch.equal(pyg_data.edge_index, expected_edge_index)
assert torch.equal(pyg_data.x, expected_node_features)
assert torch.equal(pyg_data.y, expected_labels)
assert pyg_data.batch_size == expected_batch_size
assert torch.equal(pyg_data.n_id, expected_n_id)
subgraph = test_minibatch.sampled_subgraphs[0]
# Test with sampled_csc as None.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=None,
node_features={"feat": expected_node_features},
labels=expected_labels,
)
pyg_data = test_minibatch.to_pyg_data()
assert pyg_data.edge_index is None, "Edge index should be none."
# Test with node_features as None.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=[subgraph],
node_features=None,
labels=expected_labels,
)
pyg_data = test_minibatch.to_pyg_data()
assert pyg_data.x is None, "Node features should be None."
# Test with labels as None.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=[subgraph],
node_features={"feat": expected_node_features},
labels=None,
)
pyg_data = test_minibatch.to_pyg_data()
assert pyg_data.y is None, "Labels should be None."
# Test with multiple features.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=[subgraph],
node_features={
"feat": expected_node_features,
"extra_feat": torch.tensor([[3], [4]]),
},
labels=expected_labels,
)
try:
pyg_data = test_minibatch.to_pyg_data()
assert (
pyg_data.x is None
), "Multiple features case should raise an error."
except AssertionError as e:
assert (
str(e)
== "`to_pyg_data` only supports single feature homogeneous graph."
)
def test_to_pyg_data(): def test_to_pyg_data():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment