Unverified Commit 697a9977 authored by Xinyu Yao's avatar Xinyu Yao Committed by GitHub
Browse files

[GraphBolt] Remove unused attributes in minibatch. (#7337)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent c997434c
......@@ -24,27 +24,6 @@ class MiniBatch:
representation of input and output data across different stages, ensuring
consistency and ease of use throughout the loading process."""
seed_nodes: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of seed nodes used for sampling in the graph.
- If `seed_nodes` is a tensor: It indicates the graph is homogeneous.
- If `seed_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node ids.
"""
node_pairs: Union[
Tuple[torch.Tensor, torch.Tensor],
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
] = None
"""
Representation of seed node pairs utilized in link prediction tasks.
- If `node_pairs` is a tuple: It indicates a homogeneous graph where each
tuple contains two tensors representing source-destination node pairs.
- If `node_pairs` is a dictionary: The keys should be edge type, and the
value should be a tuple of tensors representing node pairs of the given
type.
"""
labels: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Labels associated with seed nodes / node pairs in the graph.
......@@ -93,26 +72,6 @@ class MiniBatch:
key, indexes are consecutive integers starting from zero.
"""
negative_srcs: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of negative samples for the head nodes in the link
prediction task.
- If `negative_srcs` is a tensor: It indicates a homogeneous graph.
- If `negative_srcs` is a dictionary: The key should be edge type, and the
value should correspond to the negative samples for head nodes of the
given type.
"""
negative_dsts: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of negative samples for the tail nodes in the link
prediction task.
- If `negative_dsts` is a tensor: It indicates a homogeneous graph.
- If `negative_dsts` is a dictionary: The key should be edge type, and the
value should correspond to the negative samples for head nodes of the
given type.
"""
sampled_subgraphs: List[SampledSubgraph] = None
"""A list of 'SampledSubgraph's, each one corresponding to one layer,
representing a subset of a larger graph structure.
......@@ -147,15 +106,6 @@ class MiniBatch:
string of format 'str:str:str'.
"""
compacted_node_pairs: Union[
Tuple[torch.Tensor, torch.Tensor],
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
] = None
"""
Representation of compacted node pairs corresponding to 'node_pairs', where
all node ids inside are compacted.
"""
compacted_seeds: Union[
torch.Tensor,
Dict[str, torch.Tensor],
......@@ -165,18 +115,6 @@ class MiniBatch:
all node ids inside are compacted.
"""
compacted_negative_srcs: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of compacted nodes corresponding to 'negative_srcs', where
all node ids inside are compacted.
"""
compacted_negative_dsts: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of compacted nodes corresponding to 'negative_dsts', where
all node ids inside are compacted.
"""
def __repr__(self) -> str:
return _minibatch_str(self)
......@@ -333,163 +271,6 @@ class MiniBatch:
block.edata[dgl.EID] = subgraph.original_edge_ids
return blocks
@property
def positive_node_pairs(self):
"""`positive_node_pairs` is a representation of positive graphs used for
evaluating or computing loss in link prediction tasks.
- If `positive_node_pairs` is a tuple: It indicates a homogeneous graph
containing two tensors representing source-destination node pairs.
- If `positive_node_pairs` is a dictionary: The keys should be edge type,
and the value should be a tuple of tensors representing node pairs of the
given type.
"""
return self.compacted_node_pairs
@property
def negative_node_pairs(self):
"""`negative_node_pairs` is a representation of negative graphs used for
evaluating or computing loss in link prediction tasks.
- If `negative_node_pairs` is a tuple: It indicates a homogeneous graph
containing two tensors representing source-destination node pairs.
- If `negative_node_pairs` is a dictionary: The keys should be edge type,
and the value should be a tuple of tensors representing node pairs of the
given type.
"""
# Build negative graph.
if (
self.compacted_negative_srcs is not None
and self.compacted_negative_dsts is not None
):
# For homogeneous graph.
if isinstance(self.compacted_negative_srcs, torch.Tensor):
negative_node_pairs = (
self.compacted_negative_srcs,
self.compacted_negative_dsts,
)
# For heterogeneous graph.
else:
negative_node_pairs = {
etype: (
neg_src,
self.compacted_negative_dsts[etype],
)
for etype, neg_src in self.compacted_negative_srcs.items()
}
elif (
self.compacted_negative_srcs is not None
and self.compacted_node_pairs is not None
):
# For homogeneous graph.
if isinstance(self.compacted_negative_srcs, torch.Tensor):
negative_ratio = self.compacted_negative_srcs.size(1)
negative_node_pairs = (
self.compacted_negative_srcs,
self.compacted_node_pairs[1]
.repeat_interleave(negative_ratio)
.view(-1, negative_ratio),
)
# For heterogeneous graph.
else:
negative_ratio = list(self.compacted_negative_srcs.values())[
0
].size(1)
negative_node_pairs = {
etype: (
neg_src,
self.compacted_node_pairs[etype][1]
.repeat_interleave(negative_ratio)
.view(-1, negative_ratio),
)
for etype, neg_src in self.compacted_negative_srcs.items()
}
elif (
self.compacted_negative_dsts is not None
and self.compacted_node_pairs is not None
):
# For homogeneous graph.
if isinstance(self.compacted_negative_dsts, torch.Tensor):
negative_ratio = self.compacted_negative_dsts.size(1)
negative_node_pairs = (
self.compacted_node_pairs[0]
.repeat_interleave(negative_ratio)
.view(-1, negative_ratio),
self.compacted_negative_dsts,
)
# For heterogeneous graph.
else:
negative_ratio = list(self.compacted_negative_dsts.values())[
0
].size(1)
negative_node_pairs = {
etype: (
self.compacted_node_pairs[etype][0]
.repeat_interleave(negative_ratio)
.view(-1, negative_ratio),
neg_dst,
)
for etype, neg_dst in self.compacted_negative_dsts.items()
}
else:
negative_node_pairs = None
return negative_node_pairs
@property
def node_pairs_with_labels(self):
"""Get a node pair tensor and a label tensor from MiniBatch. They are
used for evaluating or computing loss. For homogeneous graph, it will
return `(node_pairs, labels)` as result; for heterogeneous graph, the
`node_pairs` and `labels` will both be a dict with etype as the key.
- If it's a link prediction task, `node_pairs` will contain both
negative and positive node pairs and `labels` will consist of 0 and 1,
indicating whether the corresponding node pair is negative or positive.
- If it's an edge classification task, this function will directly
return `compacted_node_pairs` for each etype and the corresponding
`labels`.
- Otherwise it will return None.
"""
if self.labels is None:
# Link prediction.
positive_node_pairs = self.positive_node_pairs
negative_node_pairs = self.negative_node_pairs
if positive_node_pairs is None or negative_node_pairs is None:
return None
if isinstance(positive_node_pairs, Dict):
# Heterogeneous graph.
node_pairs_by_etype = {}
labels_by_etype = {}
for etype in positive_node_pairs:
pos_src, pos_dst = positive_node_pairs[etype]
neg_src, neg_dst = negative_node_pairs[etype]
neg_src, neg_dst = neg_src.view(-1), neg_dst.view(-1)
node_pairs_by_etype[etype] = (
torch.cat((pos_src, neg_src), dim=0),
torch.cat((pos_dst, neg_dst), dim=0),
)
pos_label = torch.ones_like(pos_src)
neg_label = torch.zeros_like(neg_src)
labels_by_etype[etype] = torch.cat(
[pos_label, neg_label], dim=0
)
return (node_pairs_by_etype, labels_by_etype)
else:
# Homogeneous graph.
pos_src, pos_dst = positive_node_pairs
neg_src, neg_dst = negative_node_pairs
neg_src, neg_dst = neg_src.view(-1), neg_dst.view(-1)
node_pairs = (
torch.cat((pos_src, neg_src), dim=0),
torch.cat((pos_dst, neg_dst), dim=0),
)
pos_label = torch.ones_like(pos_src)
neg_label = torch.zeros_like(neg_src)
labels = torch.cat([pos_label, neg_label], dim=0)
return (node_pairs, labels.float())
elif self.compacted_node_pairs is not None:
# Edge classification.
return (self.compacted_node_pairs, self.labels)
else:
return None
def to_pyg_data(self):
"""Construct a PyG Data from `MiniBatch`. This function only supports
node classification task on a homogeneous graph and the number of
......@@ -527,17 +308,7 @@ class MiniBatch:
), "`to_pyg_data` only supports single feature homogeneous graph."
node_features = next(iter(self.node_features.values()))
if self.seed_nodes is not None:
if isinstance(self.seed_nodes, Dict):
batch_size = len(next(iter(self.seed_nodes.values())))
else:
batch_size = len(self.seed_nodes)
elif self.node_pairs is not None:
if isinstance(self.node_pairs, Dict):
batch_size = len(next(iter(self.node_pairs.values()))[0])
else:
batch_size = len(self.node_pairs[0])
elif self.seeds is not None:
if self.seeds is not None:
if isinstance(self.seeds, Dict):
batch_size = len(next(iter(self.seeds.values())))
else:
......
......@@ -11,7 +11,7 @@ def test_integration_link_prediction():
indices = torch.tensor([5, 3, 3, 3, 3, 4, 4, 0, 5, 4])
matrix_a = dglsp.from_csc(indptr, indices)
node_pairs = torch.t(torch.stack(matrix_a.coo()))
seeds = torch.t(torch.stack(matrix_a.coo()))
node_feature_data = torch.tensor(
[
[0.9634, 0.2294],
......@@ -37,7 +37,7 @@ def test_integration_link_prediction():
]
)
item_set = gb.ItemSet(node_pairs, names="node_pairs")
item_set = gb.ItemSet(seeds, names="seeds")
graph = gb.fused_csc_sampling_graph(indptr, indices)
node_feature = gb.TorchBasedFeature(node_feature_data)
......@@ -72,7 +72,6 @@ def test_integration_link_prediction():
[3, 3],
[3, 3],
[3, 4]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 2, 2, 2, 3], dtype=torch.int32),
indices=tensor([0, 5, 4], dtype=torch.int32),
),
......@@ -87,18 +86,12 @@ def test_integration_link_prediction():
original_edge_ids=None,
original_column_node_ids=tensor([5, 1, 3, 2, 0, 4]),
)],
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=None,
node_features={'feat': tensor([[0.5160, 0.2486],
[0.6172, 0.7865],
[0.8672, 0.2276],
[0.2109, 0.1089],
[0.9634, 0.2294],
[0.5503, 0.8223]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),
input_nodes=tensor([5, 1, 3, 2, 0, 4]),
indexes=tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3]),
......@@ -116,9 +109,6 @@ def test_integration_link_prediction():
[2, 2],
[2, 2],
[2, 5]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3),
Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2)],
)"""
......@@ -136,7 +126,6 @@ def test_integration_link_prediction():
[4, 3],
[0, 1],
[0, 5]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 1, 2], dtype=torch.int32),
indices=tensor([4, 0], dtype=torch.int32),
),
......@@ -151,18 +140,12 @@ def test_integration_link_prediction():
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]),
)],
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=None,
node_features={'feat': tensor([[0.8672, 0.2276],
[0.5503, 0.8223],
[0.9634, 0.2294],
[0.6172, 0.7865],
[0.5160, 0.2486],
[0.2109, 0.1089]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),
input_nodes=tensor([3, 4, 0, 1, 5, 2]),
indexes=tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3]),
......@@ -180,9 +163,6 @@ def test_integration_link_prediction():
[1, 0],
[2, 3],
[2, 4]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2),
Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3)],
)"""
......@@ -194,7 +174,6 @@ def test_integration_link_prediction():
[5, 4],
[4, 0],
[4, 1]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2], dtype=torch.int32),
indices=tensor([1, 0], dtype=torch.int32),
),
......@@ -209,16 +188,10 @@ def test_integration_link_prediction():
original_edge_ids=None,
original_column_node_ids=tensor([5, 4, 0, 1]),
)],
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=None,
node_features={'feat': tensor([[0.5160, 0.2486],
[0.5503, 0.8223],
[0.9634, 0.2294],
[0.6172, 0.7865]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=tensor([1., 1., 0., 0., 0., 0.]),
input_nodes=tensor([5, 4, 0, 1]),
indexes=tensor([0, 1, 0, 0, 1, 1]),
......@@ -230,9 +203,6 @@ def test_integration_link_prediction():
[0, 1],
[1, 2],
[1, 3]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2),
Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2)],
)"""
......@@ -248,8 +218,7 @@ def test_integration_node_classification():
indptr = torch.tensor([0, 0, 1, 3, 6, 8, 10])
indices = torch.tensor([5, 3, 3, 3, 3, 4, 4, 0, 5, 4])
matrix_a = dglsp.from_csc(indptr, indices)
node_pairs = torch.t(torch.stack(matrix_a.coo()))
seeds = torch.tensor([5, 1, 2, 4, 3, 0])
node_feature_data = torch.tensor(
[
[0.9634, 0.2294],
......@@ -275,7 +244,7 @@ def test_integration_node_classification():
]
)
item_set = gb.ItemSet(node_pairs, names="node_pairs")
item_set = gb.ItemSet(seeds, names="seeds")
graph = gb.fused_csc_sampling_graph(indptr, indices)
node_feature = gb.TorchBasedFeature(node_feature_data)
......@@ -285,7 +254,7 @@ def test_integration_node_classification():
("edge", None, "feat"): edge_feature,
}
feature_store = gb.BasicFeatureStore(features)
datapipe = gb.ItemSampler(item_set, batch_size=4)
datapipe = gb.ItemSampler(item_set, batch_size=2)
fanouts = torch.LongTensor([1])
datapipe = datapipe.sample_neighbor(graph, [fanouts, fanouts], replace=True)
datapipe = datapipe.fetch_feature(
......@@ -296,138 +265,92 @@ def test_integration_node_classification():
)
expected = [
str(
"""MiniBatch(seeds=tensor([[5, 1],
[3, 2],
[3, 2],
[3, 3]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4], dtype=torch.int32),
indices=tensor([4, 0, 2, 2], dtype=torch.int32),
"""MiniBatch(seeds=tensor([5, 1]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([2, 0], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 1, 3, 2, 4]),
original_row_node_ids=tensor([5, 1, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 1, 3, 2]),
original_column_node_ids=tensor([5, 1]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4], dtype=torch.int32),
indices=tensor([0, 0, 2, 2], dtype=torch.int32),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([0, 0], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 1, 3, 2]),
original_row_node_ids=tensor([5, 1]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 1, 3, 2]),
original_column_node_ids=tensor([5, 1]),
)],
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=None,
node_features={'feat': tensor([[0.5160, 0.2486],
[0.6172, 0.7865],
[0.8672, 0.2276],
[0.2109, 0.1089],
[0.5503, 0.8223]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=None,
input_nodes=tensor([5, 1, 3, 2, 4]),
input_nodes=tensor([5, 1, 4]),
indexes=None,
edge_features=[{},
{}],
compacted_seeds=tensor([[0, 1],
[2, 3],
[2, 3],
[2, 2]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=5, num_dst_nodes=4, num_edges=4),
Block(num_src_nodes=4, num_dst_nodes=4, num_edges=4)],
compacted_seeds=None,
blocks=[Block(num_src_nodes=3, num_dst_nodes=2, num_edges=2),
Block(num_src_nodes=2, num_dst_nodes=2, num_edges=2)],
)"""
),
str(
"""MiniBatch(seeds=tensor([[3, 3],
[4, 3],
[4, 4],
[0, 4]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2], dtype=torch.int32),
indices=tensor([0, 2], dtype=torch.int32),
"""MiniBatch(seeds=tensor([2, 4]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 3], dtype=torch.int32),
indices=tensor([2, 1, 2], dtype=torch.int32),
),
original_row_node_ids=tensor([3, 4, 0]),
original_row_node_ids=tensor([2, 4, 3, 0]),
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0]),
original_column_node_ids=tensor([2, 4, 3, 0]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2], dtype=torch.int32),
indices=tensor([0, 2], dtype=torch.int32),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([2, 3], dtype=torch.int32),
),
original_row_node_ids=tensor([3, 4, 0]),
original_row_node_ids=tensor([2, 4, 3, 0]),
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0]),
original_column_node_ids=tensor([2, 4]),
)],
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=None,
node_features={'feat': tensor([[0.8672, 0.2276],
node_features={'feat': tensor([[0.2109, 0.1089],
[0.5503, 0.8223],
[0.8672, 0.2276],
[0.9634, 0.2294]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=None,
input_nodes=tensor([3, 4, 0]),
input_nodes=tensor([2, 4, 3, 0]),
indexes=None,
edge_features=[{},
{}],
compacted_seeds=tensor([[0, 0],
[1, 0],
[1, 1],
[2, 1]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2),
Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2)],
compacted_seeds=None,
blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=3),
Block(num_src_nodes=4, num_dst_nodes=2, num_edges=2)],
)"""
),
str(
"""MiniBatch(seeds=tensor([[5, 5],
[4, 5]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([0, 2], dtype=torch.int32),
"""MiniBatch(seeds=tensor([3, 0]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1], dtype=torch.int32),
indices=tensor([0], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 4, 0]),
original_row_node_ids=tensor([3, 0]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 4]),
original_column_node_ids=tensor([3, 0]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([1, 1], dtype=torch.int32),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1], dtype=torch.int32),
indices=tensor([0], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 4]),
original_row_node_ids=tensor([3, 0]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 4]),
original_column_node_ids=tensor([3, 0]),
)],
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=None,
node_features={'feat': tensor([[0.5160, 0.2486],
[0.5503, 0.8223],
node_features={'feat': tensor([[0.8672, 0.2276],
[0.9634, 0.2294]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=None,
input_nodes=tensor([5, 4, 0]),
input_nodes=tensor([3, 0]),
indexes=None,
edge_features=[{},
{}],
compacted_seeds=tensor([[0, 0],
[1, 0]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=3, num_dst_nodes=2, num_edges=2),
Block(num_src_nodes=2, num_dst_nodes=2, num_edges=2)],
compacted_seeds=None,
blocks=[Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1),
Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1)],
)"""
),
]
for step, data in enumerate(dataloader):
assert expected[step] == str(data), print(data)
assert expected[step] == str(data), print(step, data)
import os
import re
import unittest
from collections import defaultdict
from sys import platform
import backend as F
......@@ -47,7 +48,7 @@ def test_ItemSampler_minibatcher():
# Default minibatcher is used if not specified.
# `MiniBatch` is returned if expected names are specified.
item_set = gb.ItemSet(torch.arange(0, 10), names="seed_nodes")
item_set = gb.ItemSet(torch.arange(0, 10), names="seeds")
item_sampler = gb.ItemSampler(item_set, batch_size=4)
minibatch = next(iter(item_sampler))
assert isinstance(minibatch, gb.MiniBatch)
......@@ -78,7 +79,7 @@ def test_ItemSet_Iterable_Only(batch_size, shuffle, drop_last):
return iter(torch.arange(0, num_ids))
seed_nodes = gb.ItemSet(InvalidLength())
item_set = gb.ItemSet(seed_nodes, names="seed_nodes")
item_set = gb.ItemSet(seed_nodes, names="seeds")
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
......@@ -106,7 +107,7 @@ def test_ItemSet_Iterable_Only(batch_size, shuffle, drop_last):
def test_ItemSet_integer(batch_size, shuffle, drop_last):
# Node IDs.
num_ids = 103
item_set = gb.ItemSet(num_ids, names="seed_nodes")
item_set = gb.ItemSet(num_ids, names="seeds")
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
......@@ -135,7 +136,7 @@ def test_ItemSet_seed_nodes(batch_size, shuffle, drop_last):
# Node IDs.
num_ids = 103
seed_nodes = torch.arange(0, num_ids)
item_set = gb.ItemSet(seed_nodes, names="seed_nodes")
item_set = gb.ItemSet(seed_nodes, names="seeds")
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
......@@ -165,7 +166,7 @@ def test_ItemSet_seed_nodes_labels(batch_size, shuffle, drop_last):
num_ids = 103
seed_nodes = torch.arange(0, num_ids)
labels = torch.arange(0, num_ids)
item_set = gb.ItemSet((seed_nodes, labels), names=("seed_nodes", "labels"))
item_set = gb.ItemSet((seed_nodes, labels), names=("seeds", "labels"))
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
......@@ -249,7 +250,7 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
# Node pairs.
num_ids = 103
node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)
item_set = gb.ItemSet(node_pairs, names="node_pairs")
item_set = gb.ItemSet(node_pairs, names="seeds")
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
......@@ -289,7 +290,7 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
num_ids = 103
node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)
labels = node_pairs[:, 0]
item_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels"))
item_set = gb.ItemSet((node_pairs, labels), names=("seeds", "labels"))
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
......@@ -333,16 +334,26 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
def test_ItemSet_node_pairs_labels_indexes(batch_size, shuffle, drop_last):
# Node pairs and negative destinations.
num_ids = 103
num_negs = 2
node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)
neg_dsts = torch.arange(
2 * num_ids, 2 * num_ids + num_ids * num_negs
).reshape(-1, num_negs)
neg_srcs = node_pairs[:, 0].repeat_interleave(num_negs)
neg_dsts = torch.arange(2 * num_ids, 2 * num_ids + num_ids * num_negs)
neg_node_pairs = torch.cat((neg_srcs, neg_dsts)).reshape(2, -1).T
labels = torch.empty(num_ids * 3)
labels[:num_ids] = 1
labels[num_ids:] = 0
indexes = torch.cat(
(
torch.arange(0, num_ids),
torch.arange(0, num_ids).repeat_interleave(num_negs),
)
)
node_pairs = torch.cat((node_pairs, neg_node_pairs))
item_set = gb.ItemSet(
(node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
(node_pairs, labels, indexes), names=("seeds", "labels", "indexes")
)
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
......@@ -350,6 +361,8 @@ def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
src_ids = []
dst_ids = []
negs_ids = []
final_labels = []
final_indexes = []
for i, minibatch in enumerate(item_sampler):
assert minibatch.seeds is not None
assert isinstance(minibatch.seeds, torch.Tensor)
......@@ -358,46 +371,43 @@ def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
src, dst = minibatch.seeds.T
negs_src = src[~minibatch.labels.to(bool)]
negs_dst = dst[~minibatch.labels.to(bool)]
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
is_last = (i + 1) * batch_size >= num_ids * 3
if not is_last or num_ids * 3 % batch_size == 0:
expected_batch_size = batch_size
else:
if not drop_last:
expected_batch_size = num_ids % batch_size
expected_batch_size = num_ids * 3 % batch_size
else:
assert False
assert len(src) == expected_batch_size * 3
assert len(dst) == expected_batch_size * 3
assert len(src) == expected_batch_size
assert len(dst) == expected_batch_size
assert negs_src.dim() == 1
assert negs_dst.dim() == 1
assert len(negs_src) == expected_batch_size * 2
assert len(negs_dst) == expected_batch_size * 2
expected_indexes = torch.arange(expected_batch_size)
expected_indexes = torch.cat(
(expected_indexes, expected_indexes.repeat_interleave(2))
)
assert torch.equal(minibatch.indexes, expected_indexes)
# Verify node pairs and negative destinations.
assert torch.equal(
src[minibatch.labels.to(bool)] + 1, dst[minibatch.labels.to(bool)]
)
assert torch.equal((negs_dst - 2 * num_ids) // 2 * 2, negs_src)
# Archive batch.
src_ids.append(src)
dst_ids.append(dst)
negs_ids.append(negs_dst)
final_labels.append(minibatch.labels)
final_indexes.append(minibatch.indexes)
src_ids = torch.cat(src_ids)
dst_ids = torch.cat(dst_ids)
negs_ids = torch.cat(negs_ids)
final_labels = torch.cat(final_labels)
final_indexes = torch.cat(final_indexes)
assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
assert torch.all(negs_ids[:-1] <= negs_ids[1:]) is not shuffle
assert torch.all(final_labels[:-1] >= final_labels[1:]) is not shuffle
if not drop_last:
assert final_labels.sum() == num_ids
assert torch.equal(final_indexes, indexes) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_seeds(batch_size, shuffle, drop_last):
def test_ItemSet_hyperlink(batch_size, shuffle, drop_last):
# Node pairs.
num_ids = 103
seeds = torch.arange(0, 3 * num_ids).reshape(-1, 3)
......@@ -477,7 +487,7 @@ def test_ItemSet_seeds_labels(batch_size, shuffle, drop_last):
def test_append_with_other_datapipes():
num_ids = 100
batch_size = 4
item_set = gb.ItemSet(torch.arange(0, num_ids), names="seed_nodes")
item_set = gb.ItemSet(torch.arange(0, num_ids), names="seeds")
data_pipe = gb.ItemSampler(item_set, batch_size)
# torchdata.datapipes.iter.Enumerator
data_pipe = data_pipe.enumerate()
......@@ -500,8 +510,8 @@ def test_ItemSetDict_iterable_only(batch_size, shuffle, drop_last):
num_ids = 205
ids = {
"user": gb.ItemSet(IterableOnly(0, 99), names="seed_nodes"),
"item": gb.ItemSet(IterableOnly(99, num_ids), names="seed_nodes"),
"user": gb.ItemSet(IterableOnly(0, 99), names="seeds"),
"item": gb.ItemSet(IterableOnly(99, num_ids), names="seeds"),
}
chained_ids = []
for key, value in ids.items():
......@@ -539,8 +549,8 @@ def test_ItemSetDict_seed_nodes(batch_size, shuffle, drop_last):
# Node IDs.
num_ids = 205
ids = {
"user": gb.ItemSet(torch.arange(0, 99), names="seed_nodes"),
"item": gb.ItemSet(torch.arange(99, num_ids), names="seed_nodes"),
"user": gb.ItemSet(torch.arange(0, 99), names="seeds"),
"item": gb.ItemSet(torch.arange(99, num_ids), names="seeds"),
}
chained_ids = []
for key, value in ids.items():
......@@ -580,11 +590,11 @@ def test_ItemSetDict_seed_nodes_labels(batch_size, shuffle, drop_last):
ids = {
"user": gb.ItemSet(
(torch.arange(0, 99), torch.arange(0, 99)),
names=("seed_nodes", "labels"),
names=("seeds", "labels"),
),
"item": gb.ItemSet(
(torch.arange(99, num_ids), torch.arange(99, num_ids)),
names=("seed_nodes", "labels"),
names=("seeds", "labels"),
),
}
chained_ids = []
......@@ -638,8 +648,8 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
node_pairs_like = torch.arange(0, num_ids * 2).reshape(-1, 2)
node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2)
node_pairs_dict = {
"user:like:item": gb.ItemSet(node_pairs_like, names="node_pairs"),
"user:follow:user": gb.ItemSet(node_pairs_follow, names="node_pairs"),
"user:like:item": gb.ItemSet(node_pairs_like, names="seeds"),
"user:follow:user": gb.ItemSet(node_pairs_follow, names="seeds"),
}
item_set = gb.ItemSetDict(node_pairs_dict)
item_sampler = gb.ItemSampler(
......@@ -691,11 +701,11 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
node_pairs_dict = {
"user:like:item": gb.ItemSet(
(node_pairs_like, node_pairs_like[:, 0]),
names=("node_pairs", "labels"),
names=("seeds", "labels"),
),
"user:follow:user": gb.ItemSet(
(node_pairs_follow, node_pairs_follow[:, 0]),
names=("node_pairs", "labels"),
names=("seeds", "labels"),
),
}
item_set = gb.ItemSetDict(node_pairs_dict)
......@@ -709,7 +719,6 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seeds is not None
assert minibatch.labels is not None
assert minibatch.negative_dsts is None
is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size
......@@ -749,27 +758,64 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
def test_ItemSetDict_node_pairs_labels_indexes(batch_size, shuffle, drop_last):
# Head, tail and negative tails.
num_ids = 103
total_ids = 2 * num_ids
total_ids = 6 * num_ids
num_negs = 2
node_paris_like = torch.arange(0, num_ids * 2).reshape(-1, 2)
node_pairs_like = torch.arange(0, num_ids * 2).reshape(-1, 2)
node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2)
neg_dsts_like = torch.arange(
num_ids * 4, num_ids * 4 + num_ids * num_negs
).reshape(-1, num_negs)
neg_dsts_like = torch.arange(num_ids * 4, num_ids * 4 + num_ids * num_negs)
neg_node_pairs_like = (
torch.cat(
(node_pairs_like[:, 0].repeat_interleave(num_negs), neg_dsts_like)
)
.view(2, -1)
.T
)
all_node_pairs_like = torch.cat((node_pairs_like, neg_node_pairs_like))
labels_like = torch.empty(num_ids * 3)
labels_like[:num_ids] = 1
labels_like[num_ids:] = 0
indexes_like = torch.cat(
(
torch.arange(0, num_ids),
torch.arange(0, num_ids).repeat_interleave(num_negs),
)
)
neg_dsts_follow = torch.arange(
num_ids * 4 + num_ids * num_negs, num_ids * 4 + num_ids * num_negs * 2
).reshape(-1, num_negs)
)
neg_node_pairs_follow = (
torch.cat(
(
node_pairs_follow[:, 0].repeat_interleave(num_negs),
neg_dsts_follow,
)
)
.view(2, -1)
.T
)
all_node_pairs_follow = torch.cat(
(node_pairs_follow, neg_node_pairs_follow)
)
labels_follow = torch.empty(num_ids * 3)
labels_follow[:num_ids] = 1
labels_follow[num_ids:] = 0
indexes_follow = torch.cat(
(
torch.arange(0, num_ids),
torch.arange(0, num_ids).repeat_interleave(num_negs),
)
)
data_dict = {
"user:like:item": gb.ItemSet(
(node_paris_like, neg_dsts_like),
names=("node_pairs", "negative_dsts"),
(all_node_pairs_like, labels_like, indexes_like),
names=("seeds", "labels", "indexes"),
),
"user:follow:user": gb.ItemSet(
(node_pairs_follow, neg_dsts_follow),
names=("node_pairs", "negative_dsts"),
(all_node_pairs_follow, labels_follow, indexes_follow),
names=("seeds", "labels", "indexes"),
),
}
item_set = gb.ItemSetDict(data_dict)
......@@ -779,11 +825,13 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
src_ids = []
dst_ids = []
negs_ids = []
final_labels = defaultdict(list)
final_indexes = defaultdict(list)
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seeds is not None
assert minibatch.labels is not None
assert minibatch.negative_dsts is None
assert minibatch.indexes is not None
is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size
......@@ -800,24 +848,23 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
assert isinstance(seeds, torch.Tensor)
src_etype = seeds[:, 0]
dst_etype = seeds[:, 1]
src.append(src_etype[minibatch.labels[etype].to(bool)])
dst.append(dst_etype[minibatch.labels[etype].to(bool)])
src.append(src_etype)
dst.append(dst_etype)
negs_src.append(src_etype[~minibatch.labels[etype].to(bool)])
negs_dst.append(dst_etype[~minibatch.labels[etype].to(bool)])
final_labels[etype].append(minibatch.labels[etype])
final_indexes[etype].append(minibatch.indexes[etype])
src = torch.cat(src)
dst = torch.cat(dst)
negs_src = torch.cat(negs_src)
negs_dst = torch.cat(negs_dst)
assert len(src) == expected_batch_size
assert len(dst) == expected_batch_size
assert len(negs_src) == expected_batch_size * 2
assert len(negs_dst) == expected_batch_size * 2
src_ids.append(src)
dst_ids.append(dst)
negs_ids.append(negs_dst)
assert negs_src.dim() == 1
assert negs_dst.dim() == 1
assert torch.equal(src + 1, dst)
assert torch.equal(negs_src, (negs_dst - num_ids * 4) // 2 * 2)
src_ids = torch.cat(src_ids)
dst_ids = torch.cat(dst_ids)
......@@ -825,12 +872,24 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
assert torch.all(negs_ids <= negs_ids) is not shuffle
for etype in data_dict.keys():
final_labels_etype = torch.cat(final_labels[etype])
final_indexes_etype = torch.cat(final_indexes[etype])
assert (
torch.all(final_labels_etype[:-1] >= final_labels_etype[1:])
is not shuffle
)
if not drop_last:
assert final_labels_etype.sum() == num_ids
assert (
torch.equal(final_indexes_etype, indexes_follow) is not shuffle
)
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSetDict_seeds(batch_size, shuffle, drop_last):
def test_ItemSetDict_hyperlink(batch_size, shuffle, drop_last):
# Node pairs.
num_ids = 103
total_pairs = 2 * num_ids
......@@ -876,7 +935,7 @@ def test_ItemSetDict_seeds(batch_size, shuffle, drop_last):
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSetDict_seeds_labels(batch_size, shuffle, drop_last):
def test_ItemSetDict_hyperlink_labels(batch_size, shuffle, drop_last):
# Node pairs and labels
num_ids = 103
total_ids = 2 * num_ids
......
......@@ -11,6 +11,7 @@ reverse_relation = "B:rr:A"
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_minibatch_representation_homo(indptr_dtype, indices_dtype):
seeds = torch.tensor([10, 11])
csc_formats = [
gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3, 5, 6], dtype=indptr_dtype),
......@@ -48,36 +49,20 @@ def test_minibatch_representation_homo(indptr_dtype, indices_dtype):
original_edge_ids=original_edge_ids[i],
)
)
negative_srcs = torch.tensor([[8], [1], [6]])
negative_dsts = torch.tensor([[2], [8], [8]])
input_nodes = torch.tensor([8, 1, 6, 5, 9, 0, 2, 4])
compacted_csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 3]), indices=torch.tensor([3, 4, 5])
)
compacted_negative_srcs = torch.tensor([[0], [1], [2]])
compacted_negative_dsts = torch.tensor([[6], [0], [0]])
labels = torch.tensor([0.0, 1.0, 2.0])
compacted_seeds = torch.tensor([0, 1])
labels = torch.tensor([1.0, 2.0])
# Test minibatch without data.
minibatch = gb.MiniBatch()
expect_result = str(
"""MiniBatch(seeds=None,
seed_nodes=None,
sampled_subgraphs=None,
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=None,
node_features=None,
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=None,
input_nodes=None,
indexes=None,
edge_features=None,
compacted_seeds=None,
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=None,
)"""
)
......@@ -85,21 +70,16 @@ def test_minibatch_representation_homo(indptr_dtype, indices_dtype):
assert result == expect_result, print(expect_result, result)
# Test minibatch with all attributes.
minibatch = gb.MiniBatch(
node_pairs=csc_formats,
seeds=seeds,
sampled_subgraphs=subgraphs,
labels=labels,
node_features=node_features,
edge_features=edge_features,
negative_srcs=negative_srcs,
negative_dsts=negative_dsts,
compacted_node_pairs=compacted_csc_formats,
compacted_seeds=compacted_seeds,
input_nodes=input_nodes,
compacted_negative_srcs=compacted_negative_srcs,
compacted_negative_dsts=compacted_negative_dsts,
)
expect_result = str(
"""MiniBatch(seeds=None,
seed_nodes=None,
"""MiniBatch(seeds=tensor([10, 11]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6], dtype=torch.int32),
indices=tensor([0, 1, 2, 2, 1, 2], dtype=torch.int32),
),
......@@ -114,47 +94,13 @@ def test_minibatch_representation_homo(indptr_dtype, indices_dtype):
original_edge_ids=tensor([10, 15, 17]),
original_column_node_ids=tensor([10, 11]),
)],
positive_node_pairs=CSCFormatBase(indptr=tensor([0, 2, 3]),
indices=tensor([3, 4, 5]),
),
node_pairs_with_labels=(CSCFormatBase(indptr=tensor([0, 2, 3]),
indices=tensor([3, 4, 5]),
),
tensor([0., 1., 2.])),
node_pairs=[CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6], dtype=torch.int32),
indices=tensor([0, 1, 2, 2, 1, 2], dtype=torch.int32),
),
CSCFormatBase(indptr=tensor([0, 2, 3], dtype=torch.int32),
indices=tensor([1, 2, 0], dtype=torch.int32),
)],
node_features={'x': tensor([5, 0, 2, 1])},
negative_srcs=tensor([[8],
[1],
[6]]),
negative_node_pairs=(tensor([[0],
[1],
[2]]),
tensor([[6],
[0],
[0]])),
negative_dsts=tensor([[2],
[8],
[8]]),
labels=tensor([0., 1., 2.]),
labels=tensor([1., 2.]),
input_nodes=tensor([8, 1, 6, 5, 9, 0, 2, 4]),
indexes=None,
edge_features=[{'x': tensor([9, 0, 1, 1, 7, 4])},
{'x': tensor([0, 2, 2])}],
compacted_seeds=None,
compacted_node_pairs=CSCFormatBase(indptr=tensor([0, 2, 3]),
indices=tensor([3, 4, 5]),
),
compacted_negative_srcs=tensor([[0],
[1],
[2]]),
compacted_negative_dsts=tensor([[6],
[0],
[0]]),
compacted_seeds=tensor([0, 1]),
blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=6),
Block(num_src_nodes=3, num_dst_nodes=2, num_edges=3)],
)"""
......@@ -166,6 +112,7 @@ def test_minibatch_representation_homo(indptr_dtype, indices_dtype):
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_minibatch_representation_hetero(indptr_dtype, indices_dtype):
seeds = {relation: torch.tensor([10, 11])}
csc_formats = [
{
relation: gb.CSCFormatBase(
......@@ -222,39 +169,22 @@ def test_minibatch_representation_hetero(indptr_dtype, indices_dtype):
original_edge_ids=original_edge_ids[i],
)
)
negative_srcs = {"B": torch.tensor([[8], [1], [6]])}
negative_dsts = {"B": torch.tensor([[2], [8], [8]])}
compacted_csc_formats = {
relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2, 3]), indices=torch.tensor([3, 4, 5])
),
reverse_relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 0, 1, 2]), indices=torch.tensor([0, 1])
),
}
compacted_negative_srcs = {relation: torch.tensor([[0], [1], [2]])}
compacted_negative_dsts = {relation: torch.tensor([[6], [0], [0]])}
compacted_seeds = {relation: torch.tensor([0, 1])}
# Test minibatch with all attributes.
minibatch = gb.MiniBatch(
seed_nodes={"B": torch.tensor([10, 15])},
node_pairs=csc_formats,
seeds=seeds,
sampled_subgraphs=subgraphs,
node_features=node_features,
edge_features=edge_features,
labels={"B": torch.tensor([2, 5])},
negative_srcs=negative_srcs,
negative_dsts=negative_dsts,
compacted_node_pairs=compacted_csc_formats,
compacted_seeds=compacted_seeds,
input_nodes={
"A": torch.tensor([5, 7, 9, 11]),
"B": torch.tensor([10, 11, 12]),
},
compacted_negative_srcs=compacted_negative_srcs,
compacted_negative_dsts=compacted_negative_dsts,
)
expect_result = str(
"""MiniBatch(seeds=None,
seed_nodes={'B': tensor([10, 15])},
"""MiniBatch(seeds={'A:r:B': tensor([10, 11])},
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3], dtype=torch.int32),
indices=tensor([0, 1, 1], dtype=torch.int32),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2], dtype=torch.int32),
......@@ -271,54 +201,13 @@ def test_minibatch_representation_hetero(indptr_dtype, indices_dtype):
original_edge_ids={'A:r:B': tensor([10, 12])},
original_column_node_ids={'B': tensor([10, 11])},
)],
positive_node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([3, 4, 5]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
indices=tensor([0, 1]),
)},
node_pairs_with_labels=({'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([3, 4, 5]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
indices=tensor([0, 1]),
)},
{'B': tensor([2, 5])}),
node_pairs=[{'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3], dtype=torch.int32),
indices=tensor([0, 1, 1], dtype=torch.int32),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2], dtype=torch.int32),
indices=tensor([1, 0], dtype=torch.int32),
)},
{'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([1, 0], dtype=torch.int32),
)}],
node_features={('A', 'x'): tensor([6, 4, 0, 1])},
negative_srcs={'B': tensor([[8],
[1],
[6]])},
negative_node_pairs={'A:r:B': (tensor([[0],
[1],
[2]]), tensor([[6],
[0],
[0]]))},
negative_dsts={'B': tensor([[2],
[8],
[8]])},
labels={'B': tensor([2, 5])},
input_nodes={'A': tensor([ 5, 7, 9, 11]), 'B': tensor([10, 11, 12])},
indexes=None,
edge_features=[{('A:r:B', 'x'): tensor([4, 2, 4])},
{('A:r:B', 'x'): tensor([0, 6])}],
compacted_seeds=None,
compacted_node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([3, 4, 5]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
indices=tensor([0, 1]),
)},
compacted_negative_srcs={'A:r:B': tensor([[0],
[1],
[2]])},
compacted_negative_dsts={'A:r:B': tensor([[6],
[0],
[0]])},
compacted_seeds={'A:r:B': tensor([0, 1])},
blocks=[Block(num_src_nodes={'A': 4, 'B': 3},
num_dst_nodes={'A': 4, 'B': 3},
num_edges={('A', 'r', 'B'): 3, ('B', 'rr', 'A'): 2},
......@@ -330,22 +219,12 @@ def test_minibatch_representation_hetero(indptr_dtype, indices_dtype):
)"""
)
result = str(minibatch)
assert result == expect_result, print(expect_result, result)
assert result == expect_result, print(result)
@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
def test_get_dgl_blocks_homo(indptr_dtype, indices_dtype):
node_pairs = [
(
torch.tensor([0, 1, 2, 2, 2, 1]),
torch.tensor([0, 1, 1, 2, 3, 2]),
),
(
torch.tensor([0, 1, 2]),
torch.tensor([1, 0, 0]),
),
]
csc_formats = [
gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3, 5, 6], dtype=indptr_dtype),
......@@ -368,11 +247,6 @@ def test_get_dgl_blocks_homo(indptr_dtype, indices_dtype):
torch.tensor([19, 20, 21, 22, 25, 30]),
torch.tensor([10, 15, 17]),
]
node_features = {"x": torch.tensor([7, 6, 2, 2])}
edge_features = [
{"x": torch.tensor([[8], [1], [6]])},
{"x": torch.tensor([[2], [8], [8]])},
]
subgraphs = []
for i in range(2):
subgraphs.append(
......@@ -383,43 +257,19 @@ def test_get_dgl_blocks_homo(indptr_dtype, indices_dtype):
original_edge_ids=original_edge_ids[i],
)
)
negative_srcs = torch.tensor([[8], [1], [6]])
negative_dsts = torch.tensor([[2], [8], [8]])
input_nodes = torch.tensor([8, 1, 6, 5, 9, 0, 2, 4])
compacted_node_pairs = (torch.tensor([0, 1, 2]), torch.tensor([3, 4, 5]))
compacted_negative_srcs = torch.tensor([[0], [1], [2]])
compacted_negative_dsts = torch.tensor([[6], [0], [0]])
labels = torch.tensor([0.0, 1.0, 2.0])
# Test minibatch with all attributes.
minibatch = gb.MiniBatch(
node_pairs=node_pairs,
sampled_subgraphs=subgraphs,
labels=labels,
node_features=node_features,
edge_features=edge_features,
negative_srcs=negative_srcs,
negative_dsts=negative_dsts,
compacted_node_pairs=compacted_node_pairs,
input_nodes=input_nodes,
compacted_negative_srcs=compacted_negative_srcs,
compacted_negative_dsts=compacted_negative_dsts,
)
dgl_blocks = minibatch.blocks
expect_result = str(
"""[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=6), Block(num_src_nodes=3, num_dst_nodes=2, num_edges=3)]"""
)
result = str(dgl_blocks)
assert result == expect_result, print(result)
assert result == expect_result
def test_get_dgl_blocks_hetero():
node_pairs = [
{
relation: (torch.tensor([0, 1, 1]), torch.tensor([0, 1, 2])),
reverse_relation: (torch.tensor([1, 0]), torch.tensor([2, 3])),
},
{relation: (torch.tensor([0, 1]), torch.tensor([1, 0]))},
]
csc_formats = [
{
relation: gb.CSCFormatBase(
......@@ -458,13 +308,6 @@ def test_get_dgl_blocks_hetero():
},
{relation: torch.tensor([10, 12])},
]
node_features = {
("A", "x"): torch.tensor([6, 4, 0, 1]),
}
edge_features = [
{(relation, "x"): torch.tensor([4, 2, 4])},
{(relation, "x"): torch.tensor([0, 6])},
]
subgraphs = []
for i in range(2):
subgraphs.append(
......@@ -475,31 +318,9 @@ def test_get_dgl_blocks_hetero():
original_edge_ids=original_edge_ids[i],
)
)
negative_srcs = {"B": torch.tensor([[8], [1], [6]])}
negative_dsts = {"B": torch.tensor([[2], [8], [8]])}
compacted_node_pairs = {
relation: (torch.tensor([0, 1, 2]), torch.tensor([3, 4, 5])),
reverse_relation: (torch.tensor([0, 1, 2]), torch.tensor([3, 4, 5])),
}
compacted_negative_srcs = {relation: torch.tensor([[0], [1], [2]])}
compacted_negative_dsts = {relation: torch.tensor([[6], [0], [0]])}
# Test minibatch with all attributes.
minibatch = gb.MiniBatch(
seed_nodes={"B": torch.tensor([10, 15])},
node_pairs=node_pairs,
sampled_subgraphs=subgraphs,
node_features=node_features,
edge_features=edge_features,
labels={"B": torch.tensor([2, 5])},
negative_srcs=negative_srcs,
negative_dsts=negative_dsts,
compacted_node_pairs=compacted_node_pairs,
input_nodes={
"A": torch.tensor([5, 7, 9, 11]),
"B": torch.tensor([10, 11, 12]),
},
compacted_negative_srcs=compacted_negative_srcs,
compacted_negative_dsts=compacted_negative_dsts,
)
dgl_blocks = minibatch.blocks
expect_result = str(
......@@ -512,50 +333,7 @@ def test_get_dgl_blocks_hetero():
metagraph=[('A', 'B', 'r')])]"""
)
result = str(dgl_blocks)
assert result == expect_result, print(result)
@pytest.mark.parametrize(
"mode", ["neg_graph", "neg_src", "neg_dst", "edge_classification"]
)
def test_minibatch_node_pairs_with_labels(mode):
# Arrange
minibatch = create_homo_minibatch()
minibatch.compacted_node_pairs = (
torch.tensor([0, 1]),
torch.tensor([1, 0]),
)
if mode == "neg_graph" or mode == "neg_src":
minibatch.compacted_negative_srcs = torch.tensor([[0, 0], [1, 1]])
if mode == "neg_graph" or mode == "neg_dst":
minibatch.compacted_negative_dsts = torch.tensor([[1, 0], [0, 1]])
if mode == "edge_classification":
minibatch.labels = torch.tensor([0, 1]).long()
# Act
node_pairs, labels = minibatch.node_pairs_with_labels
# Assert
if mode == "neg_src":
expect_node_pairs = (
torch.tensor([0, 1, 0, 0, 1, 1]),
torch.tensor([1, 0, 1, 1, 0, 0]),
)
expect_labels = torch.tensor([1, 1, 0, 0, 0, 0]).float()
elif mode != "edge_classification":
expect_node_pairs = (
torch.tensor([0, 1, 0, 0, 1, 1]),
torch.tensor([1, 0, 1, 0, 0, 1]),
)
expect_labels = torch.tensor([1, 1, 0, 0, 0, 0]).float()
else:
expect_node_pairs = (
torch.tensor([0, 1]),
torch.tensor([1, 0]),
)
expect_labels = torch.tensor([0, 1]).long()
assert torch.equal(node_pairs[0], expect_node_pairs[0])
assert torch.equal(node_pairs[1], expect_node_pairs[1])
assert torch.equal(labels, expect_labels)
assert result == expect_result
def create_homo_minibatch():
......@@ -723,16 +501,10 @@ def check_dgl_blocks_homo(minibatch, blocks):
dst_ndoes = torch.arange(
0, len(sampled_csc[i].indptr) - 1
).repeat_interleave(sampled_csc[i].indptr.diff())
assert torch.equal(block.edges()[0], sampled_csc[i].indices), print(
block.edges()
)
assert torch.equal(block.edges()[1], dst_ndoes), print(block.edges())
assert torch.equal(block.edata[dgl.EID], original_edge_ids[i]), print(
block.edata[dgl.EID]
)
assert torch.equal(
blocks[0].srcdata[dgl.NID], original_row_node_ids[0]
), print(blocks[0].srcdata[dgl.NID])
assert torch.equal(block.edges()[0], sampled_csc[i].indices)
assert torch.equal(block.edges()[1], dst_ndoes)
assert torch.equal(block.edata[dgl.EID], original_edge_ids[i])
assert torch.equal(blocks[0].srcdata[dgl.NID], original_row_node_ids[0])
def test_dgl_node_classification_without_feature():
......@@ -740,7 +512,7 @@ def test_dgl_node_classification_without_feature():
minibatch = create_homo_minibatch()
minibatch.node_features = None
minibatch.labels = None
minibatch.seed_nodes = torch.tensor([10, 15])
minibatch.seeds = torch.tensor([10, 15])
# Act
dgl_blocks = minibatch.blocks
......@@ -754,7 +526,7 @@ def test_dgl_node_classification_without_feature():
def test_dgl_node_classification_homo():
# Arrange
minibatch = create_homo_minibatch()
minibatch.seed_nodes = torch.tensor([10, 15])
minibatch.seeds = torch.tensor([10, 15])
minibatch.labels = torch.tensor([2, 5])
# Act
dgl_blocks = minibatch.blocks
......@@ -767,7 +539,7 @@ def test_dgl_node_classification_homo():
def test_dgl_node_classification_hetero():
minibatch = create_hetero_minibatch()
minibatch.labels = {"B": torch.tensor([2, 5])}
minibatch.seed_nodes = {"B": torch.tensor([10, 15])}
minibatch.seeds = {"B": torch.tensor([10, 15])}
# Act
dgl_blocks = minibatch.blocks
......@@ -776,52 +548,19 @@ def test_dgl_node_classification_hetero():
check_dgl_blocks_hetero(minibatch, dgl_blocks)
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"])
def test_dgl_link_predication_homo(mode):
def test_dgl_link_predication_homo():
# Arrange
minibatch = create_homo_minibatch()
minibatch.compacted_node_pairs = (
torch.tensor([0, 1]),
torch.tensor([1, 0]),
minibatch.compacted_seeds = (
torch.tensor([[0, 1, 0, 0, 1, 1], [1, 0, 1, 1, 0, 0]]).T,
)
if mode == "neg_graph" or mode == "neg_src":
minibatch.compacted_negative_srcs = torch.tensor([[0, 0], [1, 1]])
if mode == "neg_graph" or mode == "neg_dst":
minibatch.compacted_negative_dsts = torch.tensor([[1, 0], [0, 1]])
minibatch.labels = torch.tensor([1, 1, 0, 0, 0, 0])
# Act
dgl_blocks = minibatch.blocks
# Assert
assert len(dgl_blocks) == 2
check_dgl_blocks_homo(minibatch, dgl_blocks)
if mode == "neg_graph" or mode == "neg_src":
assert torch.equal(
minibatch.negative_node_pairs[0],
minibatch.compacted_negative_srcs,
)
if mode == "neg_graph" or mode == "neg_dst":
assert torch.equal(
minibatch.negative_node_pairs[1],
minibatch.compacted_negative_dsts,
)
(
node_pairs,
labels,
) = minibatch.node_pairs_with_labels
if mode == "neg_src":
expect_node_pairs = (
torch.tensor([0, 1, 0, 0, 1, 1]),
torch.tensor([1, 0, 1, 1, 0, 0]),
)
else:
expect_node_pairs = (
torch.tensor([0, 1, 0, 0, 1, 1]),
torch.tensor([1, 0, 1, 0, 0, 1]),
)
expect_labels = torch.tensor([1, 1, 0, 0, 0, 0]).float()
assert torch.equal(node_pairs[0], expect_node_pairs[0])
assert torch.equal(node_pairs[1], expect_node_pairs[1])
assert torch.equal(labels, expect_labels)
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"])
......@@ -829,113 +568,21 @@ def test_dgl_link_predication_hetero(mode):
# Arrange
minibatch = create_hetero_minibatch()
minibatch.compacted_node_pairs = {
relation: (
torch.tensor([1, 1]),
torch.tensor([1, 0]),
),
relation: (torch.tensor([[1, 1, 2, 0, 1, 2], [1, 0, 1, 1, 0, 0]]).T,),
reverse_relation: (
torch.tensor([0, 1]),
torch.tensor([1, 0]),
torch.tensor([[0, 1, 1, 2, 0, 2], [1, 0, 1, 1, 0, 0]]).T,
),
}
if mode == "neg_graph" or mode == "neg_src":
minibatch.compacted_negative_srcs = {
relation: torch.tensor([[2, 0], [1, 2]]),
reverse_relation: torch.tensor([[1, 2], [0, 2]]),
}
if mode == "neg_graph" or mode == "neg_dst":
minibatch.compacted_negative_dsts = {
relation: torch.tensor([[1, 3], [2, 1]]),
reverse_relation: torch.tensor([[2, 1], [3, 1]]),
}
minibatch.labels = {
relation: (torch.tensor([1, 1, 0, 0, 0, 0]),),
reverse_relation: (torch.tensor([1, 1, 0, 0, 0, 0]),),
}
# Act
dgl_blocks = minibatch.blocks
# Assert
assert len(dgl_blocks) == 2
check_dgl_blocks_hetero(minibatch, dgl_blocks)
if mode == "neg_graph" or mode == "neg_src":
for etype, src in minibatch.compacted_negative_srcs.items():
assert torch.equal(
minibatch.negative_node_pairs[etype][0],
src,
)
if mode == "neg_graph" or mode == "neg_dst":
for etype, dst in minibatch.compacted_negative_dsts.items():
assert torch.equal(
minibatch.negative_node_pairs[etype][1],
minibatch.compacted_negative_dsts[etype],
)
def test_to_pyg_data_original():
test_minibatch = create_homo_minibatch()
test_minibatch.seed_nodes = torch.tensor([0, 1])
test_minibatch.labels = torch.tensor([7, 8])
expected_edge_index = torch.tensor(
[[0, 0, 1, 1, 1, 2, 2, 2, 2], [0, 1, 0, 1, 2, 0, 1, 2, 3]]
)
expected_node_features = next(iter(test_minibatch.node_features.values()))
expected_labels = torch.tensor([7, 8])
expected_batch_size = 2
expected_n_id = torch.tensor([10, 11, 12, 13])
pyg_data = test_minibatch.to_pyg_data()
pyg_data.validate()
assert torch.equal(pyg_data.edge_index, expected_edge_index)
assert torch.equal(pyg_data.x, expected_node_features)
assert torch.equal(pyg_data.y, expected_labels)
assert pyg_data.batch_size == expected_batch_size
assert torch.equal(pyg_data.n_id, expected_n_id)
subgraph = test_minibatch.sampled_subgraphs[0]
# Test with sampled_csc as None.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=None,
node_features={"feat": expected_node_features},
labels=expected_labels,
)
pyg_data = test_minibatch.to_pyg_data()
assert pyg_data.edge_index is None, "Edge index should be none."
# Test with node_features as None.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=[subgraph],
node_features=None,
labels=expected_labels,
)
pyg_data = test_minibatch.to_pyg_data()
assert pyg_data.x is None, "Node features should be None."
# Test with labels as None.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=[subgraph],
node_features={"feat": expected_node_features},
labels=None,
)
pyg_data = test_minibatch.to_pyg_data()
assert pyg_data.y is None, "Labels should be None."
# Test with multiple features.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=[subgraph],
node_features={
"feat": expected_node_features,
"extra_feat": torch.tensor([[3], [4]]),
},
labels=expected_labels,
)
try:
pyg_data = test_minibatch.to_pyg_data()
assert (
pyg_data.x is None
), "Multiple features case should raise an error."
except AssertionError as e:
assert (
str(e)
== "`to_pyg_data` only supports single feature homogeneous graph."
)
def test_to_pyg_data():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment