Unverified Commit 697a9977 authored by Xinyu Yao's avatar Xinyu Yao Committed by GitHub
Browse files

[GraphBolt] Remove unused attributes in minibatch. (#7337)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent c997434c
......@@ -24,27 +24,6 @@ class MiniBatch:
representation of input and output data across different stages, ensuring
consistency and ease of use throughout the loading process."""
seed_nodes: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of seed nodes used for sampling in the graph.
- If `seed_nodes` is a tensor: It indicates the graph is homogeneous.
- If `seed_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node ids.
"""
node_pairs: Union[
Tuple[torch.Tensor, torch.Tensor],
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
] = None
"""
Representation of seed node pairs utilized in link prediction tasks.
- If `node_pairs` is a tuple: It indicates a homogeneous graph where each
tuple contains two tensors representing source-destination node pairs.
- If `node_pairs` is a dictionary: The keys should be edge type, and the
value should be a tuple of tensors representing node pairs of the given
type.
"""
labels: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Labels associated with seed nodes / node pairs in the graph.
......@@ -93,26 +72,6 @@ class MiniBatch:
key, indexes are consecutive integers starting from zero.
"""
negative_srcs: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of negative samples for the head nodes in the link
prediction task.
- If `negative_srcs` is a tensor: It indicates a homogeneous graph.
- If `negative_srcs` is a dictionary: The key should be edge type, and the
value should correspond to the negative samples for head nodes of the
given type.
"""
negative_dsts: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of negative samples for the tail nodes in the link
prediction task.
- If `negative_dsts` is a tensor: It indicates a homogeneous graph.
- If `negative_dsts` is a dictionary: The key should be edge type, and the
value should correspond to the negative samples for head nodes of the
given type.
"""
sampled_subgraphs: List[SampledSubgraph] = None
"""A list of 'SampledSubgraph's, each one corresponding to one layer,
representing a subset of a larger graph structure.
......@@ -147,15 +106,6 @@ class MiniBatch:
string of format 'str:str:str'.
"""
compacted_node_pairs: Union[
Tuple[torch.Tensor, torch.Tensor],
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
] = None
"""
Representation of compacted node pairs corresponding to 'node_pairs', where
all node ids inside are compacted.
"""
compacted_seeds: Union[
torch.Tensor,
Dict[str, torch.Tensor],
......@@ -165,18 +115,6 @@ class MiniBatch:
all node ids inside are compacted.
"""
compacted_negative_srcs: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of compacted nodes corresponding to 'negative_srcs', where
all node ids inside are compacted.
"""
compacted_negative_dsts: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of compacted nodes corresponding to 'negative_dsts', where
all node ids inside are compacted.
"""
def __repr__(self) -> str:
return _minibatch_str(self)
......@@ -333,163 +271,6 @@ class MiniBatch:
block.edata[dgl.EID] = subgraph.original_edge_ids
return blocks
@property
def positive_node_pairs(self):
"""`positive_node_pairs` is a representation of positive graphs used for
evaluating or computing loss in link prediction tasks.
- If `positive_node_pairs` is a tuple: It indicates a homogeneous graph
containing two tensors representing source-destination node pairs.
- If `positive_node_pairs` is a dictionary: The keys should be edge type,
and the value should be a tuple of tensors representing node pairs of the
given type.
"""
return self.compacted_node_pairs
@property
def negative_node_pairs(self):
"""`negative_node_pairs` is a representation of negative graphs used for
evaluating or computing loss in link prediction tasks.
- If `negative_node_pairs` is a tuple: It indicates a homogeneous graph
containing two tensors representing source-destination node pairs.
- If `negative_node_pairs` is a dictionary: The keys should be edge type,
and the value should be a tuple of tensors representing node pairs of the
given type.
"""
# Build negative graph.
if (
self.compacted_negative_srcs is not None
and self.compacted_negative_dsts is not None
):
# For homogeneous graph.
if isinstance(self.compacted_negative_srcs, torch.Tensor):
negative_node_pairs = (
self.compacted_negative_srcs,
self.compacted_negative_dsts,
)
# For heterogeneous graph.
else:
negative_node_pairs = {
etype: (
neg_src,
self.compacted_negative_dsts[etype],
)
for etype, neg_src in self.compacted_negative_srcs.items()
}
elif (
self.compacted_negative_srcs is not None
and self.compacted_node_pairs is not None
):
# For homogeneous graph.
if isinstance(self.compacted_negative_srcs, torch.Tensor):
negative_ratio = self.compacted_negative_srcs.size(1)
negative_node_pairs = (
self.compacted_negative_srcs,
self.compacted_node_pairs[1]
.repeat_interleave(negative_ratio)
.view(-1, negative_ratio),
)
# For heterogeneous graph.
else:
negative_ratio = list(self.compacted_negative_srcs.values())[
0
].size(1)
negative_node_pairs = {
etype: (
neg_src,
self.compacted_node_pairs[etype][1]
.repeat_interleave(negative_ratio)
.view(-1, negative_ratio),
)
for etype, neg_src in self.compacted_negative_srcs.items()
}
elif (
self.compacted_negative_dsts is not None
and self.compacted_node_pairs is not None
):
# For homogeneous graph.
if isinstance(self.compacted_negative_dsts, torch.Tensor):
negative_ratio = self.compacted_negative_dsts.size(1)
negative_node_pairs = (
self.compacted_node_pairs[0]
.repeat_interleave(negative_ratio)
.view(-1, negative_ratio),
self.compacted_negative_dsts,
)
# For heterogeneous graph.
else:
negative_ratio = list(self.compacted_negative_dsts.values())[
0
].size(1)
negative_node_pairs = {
etype: (
self.compacted_node_pairs[etype][0]
.repeat_interleave(negative_ratio)
.view(-1, negative_ratio),
neg_dst,
)
for etype, neg_dst in self.compacted_negative_dsts.items()
}
else:
negative_node_pairs = None
return negative_node_pairs
@property
def node_pairs_with_labels(self):
"""Get a node pair tensor and a label tensor from MiniBatch. They are
used for evaluating or computing loss. For homogeneous graph, it will
return `(node_pairs, labels)` as result; for heterogeneous graph, the
`node_pairs` and `labels` will both be a dict with etype as the key.
- If it's a link prediction task, `node_pairs` will contain both
negative and positive node pairs and `labels` will consist of 0 and 1,
indicating whether the corresponding node pair is negative or positive.
- If it's an edge classification task, this function will directly
return `compacted_node_pairs` for each etype and the corresponding
`labels`.
- Otherwise it will return None.
"""
if self.labels is None:
# Link prediction.
positive_node_pairs = self.positive_node_pairs
negative_node_pairs = self.negative_node_pairs
if positive_node_pairs is None or negative_node_pairs is None:
return None
if isinstance(positive_node_pairs, Dict):
# Heterogeneous graph.
node_pairs_by_etype = {}
labels_by_etype = {}
for etype in positive_node_pairs:
pos_src, pos_dst = positive_node_pairs[etype]
neg_src, neg_dst = negative_node_pairs[etype]
neg_src, neg_dst = neg_src.view(-1), neg_dst.view(-1)
node_pairs_by_etype[etype] = (
torch.cat((pos_src, neg_src), dim=0),
torch.cat((pos_dst, neg_dst), dim=0),
)
pos_label = torch.ones_like(pos_src)
neg_label = torch.zeros_like(neg_src)
labels_by_etype[etype] = torch.cat(
[pos_label, neg_label], dim=0
)
return (node_pairs_by_etype, labels_by_etype)
else:
# Homogeneous graph.
pos_src, pos_dst = positive_node_pairs
neg_src, neg_dst = negative_node_pairs
neg_src, neg_dst = neg_src.view(-1), neg_dst.view(-1)
node_pairs = (
torch.cat((pos_src, neg_src), dim=0),
torch.cat((pos_dst, neg_dst), dim=0),
)
pos_label = torch.ones_like(pos_src)
neg_label = torch.zeros_like(neg_src)
labels = torch.cat([pos_label, neg_label], dim=0)
return (node_pairs, labels.float())
elif self.compacted_node_pairs is not None:
# Edge classification.
return (self.compacted_node_pairs, self.labels)
else:
return None
def to_pyg_data(self):
"""Construct a PyG Data from `MiniBatch`. This function only supports
node classification task on a homogeneous graph and the number of
......@@ -527,17 +308,7 @@ class MiniBatch:
), "`to_pyg_data` only supports single feature homogeneous graph."
node_features = next(iter(self.node_features.values()))
if self.seed_nodes is not None:
if isinstance(self.seed_nodes, Dict):
batch_size = len(next(iter(self.seed_nodes.values())))
else:
batch_size = len(self.seed_nodes)
elif self.node_pairs is not None:
if isinstance(self.node_pairs, Dict):
batch_size = len(next(iter(self.node_pairs.values()))[0])
else:
batch_size = len(self.node_pairs[0])
elif self.seeds is not None:
if self.seeds is not None:
if isinstance(self.seeds, Dict):
batch_size = len(next(iter(self.seeds.values())))
else:
......
......@@ -11,7 +11,7 @@ def test_integration_link_prediction():
indices = torch.tensor([5, 3, 3, 3, 3, 4, 4, 0, 5, 4])
matrix_a = dglsp.from_csc(indptr, indices)
node_pairs = torch.t(torch.stack(matrix_a.coo()))
seeds = torch.t(torch.stack(matrix_a.coo()))
node_feature_data = torch.tensor(
[
[0.9634, 0.2294],
......@@ -37,7 +37,7 @@ def test_integration_link_prediction():
]
)
item_set = gb.ItemSet(node_pairs, names="node_pairs")
item_set = gb.ItemSet(seeds, names="seeds")
graph = gb.fused_csc_sampling_graph(indptr, indices)
node_feature = gb.TorchBasedFeature(node_feature_data)
......@@ -72,7 +72,6 @@ def test_integration_link_prediction():
[3, 3],
[3, 3],
[3, 4]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 2, 2, 2, 3], dtype=torch.int32),
indices=tensor([0, 5, 4], dtype=torch.int32),
),
......@@ -87,18 +86,12 @@ def test_integration_link_prediction():
original_edge_ids=None,
original_column_node_ids=tensor([5, 1, 3, 2, 0, 4]),
)],
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=None,
node_features={'feat': tensor([[0.5160, 0.2486],
[0.6172, 0.7865],
[0.8672, 0.2276],
[0.2109, 0.1089],
[0.9634, 0.2294],
[0.5503, 0.8223]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),
input_nodes=tensor([5, 1, 3, 2, 0, 4]),
indexes=tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3]),
......@@ -116,9 +109,6 @@ def test_integration_link_prediction():
[2, 2],
[2, 2],
[2, 5]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3),
Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2)],
)"""
......@@ -136,7 +126,6 @@ def test_integration_link_prediction():
[4, 3],
[0, 1],
[0, 5]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 1, 2], dtype=torch.int32),
indices=tensor([4, 0], dtype=torch.int32),
),
......@@ -151,18 +140,12 @@ def test_integration_link_prediction():
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0, 1, 5, 2]),
)],
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=None,
node_features={'feat': tensor([[0.8672, 0.2276],
[0.5503, 0.8223],
[0.9634, 0.2294],
[0.6172, 0.7865],
[0.5160, 0.2486],
[0.2109, 0.1089]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=tensor([1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.]),
input_nodes=tensor([3, 4, 0, 1, 5, 2]),
indexes=tensor([0, 1, 2, 3, 0, 0, 1, 1, 2, 2, 3, 3]),
......@@ -180,9 +163,6 @@ def test_integration_link_prediction():
[1, 0],
[2, 3],
[2, 4]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2),
Block(num_src_nodes=6, num_dst_nodes=6, num_edges=3)],
)"""
......@@ -194,7 +174,6 @@ def test_integration_link_prediction():
[5, 4],
[4, 0],
[4, 1]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2], dtype=torch.int32),
indices=tensor([1, 0], dtype=torch.int32),
),
......@@ -209,16 +188,10 @@ def test_integration_link_prediction():
original_edge_ids=None,
original_column_node_ids=tensor([5, 4, 0, 1]),
)],
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=None,
node_features={'feat': tensor([[0.5160, 0.2486],
[0.5503, 0.8223],
[0.9634, 0.2294],
[0.6172, 0.7865]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=tensor([1., 1., 0., 0., 0., 0.]),
input_nodes=tensor([5, 4, 0, 1]),
indexes=tensor([0, 1, 0, 0, 1, 1]),
......@@ -230,9 +203,6 @@ def test_integration_link_prediction():
[0, 1],
[1, 2],
[1, 3]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2),
Block(num_src_nodes=4, num_dst_nodes=4, num_edges=2)],
)"""
......@@ -248,8 +218,7 @@ def test_integration_node_classification():
indptr = torch.tensor([0, 0, 1, 3, 6, 8, 10])
indices = torch.tensor([5, 3, 3, 3, 3, 4, 4, 0, 5, 4])
matrix_a = dglsp.from_csc(indptr, indices)
node_pairs = torch.t(torch.stack(matrix_a.coo()))
seeds = torch.tensor([5, 1, 2, 4, 3, 0])
node_feature_data = torch.tensor(
[
[0.9634, 0.2294],
......@@ -275,7 +244,7 @@ def test_integration_node_classification():
]
)
item_set = gb.ItemSet(node_pairs, names="node_pairs")
item_set = gb.ItemSet(seeds, names="seeds")
graph = gb.fused_csc_sampling_graph(indptr, indices)
node_feature = gb.TorchBasedFeature(node_feature_data)
......@@ -285,7 +254,7 @@ def test_integration_node_classification():
("edge", None, "feat"): edge_feature,
}
feature_store = gb.BasicFeatureStore(features)
datapipe = gb.ItemSampler(item_set, batch_size=4)
datapipe = gb.ItemSampler(item_set, batch_size=2)
fanouts = torch.LongTensor([1])
datapipe = datapipe.sample_neighbor(graph, [fanouts, fanouts], replace=True)
datapipe = datapipe.fetch_feature(
......@@ -296,138 +265,92 @@ def test_integration_node_classification():
)
expected = [
str(
"""MiniBatch(seeds=tensor([[5, 1],
[3, 2],
[3, 2],
[3, 3]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4], dtype=torch.int32),
indices=tensor([4, 0, 2, 2], dtype=torch.int32),
"""MiniBatch(seeds=tensor([5, 1]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([2, 0], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 1, 3, 2, 4]),
original_row_node_ids=tensor([5, 1, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 1, 3, 2]),
original_column_node_ids=tensor([5, 1]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4], dtype=torch.int32),
indices=tensor([0, 0, 2, 2], dtype=torch.int32),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([0, 0], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 1, 3, 2]),
original_row_node_ids=tensor([5, 1]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 1, 3, 2]),
original_column_node_ids=tensor([5, 1]),
)],
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=None,
node_features={'feat': tensor([[0.5160, 0.2486],
[0.6172, 0.7865],
[0.8672, 0.2276],
[0.2109, 0.1089],
[0.5503, 0.8223]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=None,
input_nodes=tensor([5, 1, 3, 2, 4]),
input_nodes=tensor([5, 1, 4]),
indexes=None,
edge_features=[{},
{}],
compacted_seeds=tensor([[0, 1],
[2, 3],
[2, 3],
[2, 2]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=5, num_dst_nodes=4, num_edges=4),
Block(num_src_nodes=4, num_dst_nodes=4, num_edges=4)],
compacted_seeds=None,
blocks=[Block(num_src_nodes=3, num_dst_nodes=2, num_edges=2),
Block(num_src_nodes=2, num_dst_nodes=2, num_edges=2)],
)"""
),
str(
"""MiniBatch(seeds=tensor([[3, 3],
[4, 3],
[4, 4],
[0, 4]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2], dtype=torch.int32),
indices=tensor([0, 2], dtype=torch.int32),
"""MiniBatch(seeds=tensor([2, 4]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 3], dtype=torch.int32),
indices=tensor([2, 1, 2], dtype=torch.int32),
),
original_row_node_ids=tensor([3, 4, 0]),
original_row_node_ids=tensor([2, 4, 3, 0]),
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0]),
original_column_node_ids=tensor([2, 4, 3, 0]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2], dtype=torch.int32),
indices=tensor([0, 2], dtype=torch.int32),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([2, 3], dtype=torch.int32),
),
original_row_node_ids=tensor([3, 4, 0]),
original_row_node_ids=tensor([2, 4, 3, 0]),
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0]),
original_column_node_ids=tensor([2, 4]),
)],
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=None,
node_features={'feat': tensor([[0.8672, 0.2276],
node_features={'feat': tensor([[0.2109, 0.1089],
[0.5503, 0.8223],
[0.8672, 0.2276],
[0.9634, 0.2294]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=None,
input_nodes=tensor([3, 4, 0]),
input_nodes=tensor([2, 4, 3, 0]),
indexes=None,
edge_features=[{},
{}],
compacted_seeds=tensor([[0, 0],
[1, 0],
[1, 1],
[2, 1]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2),
Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2)],
compacted_seeds=None,
blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=3),
Block(num_src_nodes=4, num_dst_nodes=2, num_edges=2)],
)"""
),
str(
"""MiniBatch(seeds=tensor([[5, 5],
[4, 5]]),
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([0, 2], dtype=torch.int32),
"""MiniBatch(seeds=tensor([3, 0]),
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1], dtype=torch.int32),
indices=tensor([0], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 4, 0]),
original_row_node_ids=tensor([3, 0]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 4]),
original_column_node_ids=tensor([3, 0]),
),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2], dtype=torch.int32),
indices=tensor([1, 1], dtype=torch.int32),
SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1], dtype=torch.int32),
indices=tensor([0], dtype=torch.int32),
),
original_row_node_ids=tensor([5, 4]),
original_row_node_ids=tensor([3, 0]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 4]),
original_column_node_ids=tensor([3, 0]),
)],
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=None,
node_features={'feat': tensor([[0.5160, 0.2486],
[0.5503, 0.8223],
node_features={'feat': tensor([[0.8672, 0.2276],
[0.9634, 0.2294]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=None,
input_nodes=tensor([5, 4, 0]),
input_nodes=tensor([3, 0]),
indexes=None,
edge_features=[{},
{}],
compacted_seeds=tensor([[0, 0],
[1, 0]]),
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=3, num_dst_nodes=2, num_edges=2),
Block(num_src_nodes=2, num_dst_nodes=2, num_edges=2)],
compacted_seeds=None,
blocks=[Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1),
Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1)],
)"""
),
]
for step, data in enumerate(dataloader):
assert expected[step] == str(data), print(data)
assert expected[step] == str(data), print(step, data)
import os
import re
import unittest
from collections import defaultdict
from sys import platform
import backend as F
......@@ -47,7 +48,7 @@ def test_ItemSampler_minibatcher():
# Default minibatcher is used if not specified.
# `MiniBatch` is returned if expected names are specified.
item_set = gb.ItemSet(torch.arange(0, 10), names="seed_nodes")
item_set = gb.ItemSet(torch.arange(0, 10), names="seeds")
item_sampler = gb.ItemSampler(item_set, batch_size=4)
minibatch = next(iter(item_sampler))
assert isinstance(minibatch, gb.MiniBatch)
......@@ -78,7 +79,7 @@ def test_ItemSet_Iterable_Only(batch_size, shuffle, drop_last):
return iter(torch.arange(0, num_ids))
seed_nodes = gb.ItemSet(InvalidLength())
item_set = gb.ItemSet(seed_nodes, names="seed_nodes")
item_set = gb.ItemSet(seed_nodes, names="seeds")
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
......@@ -106,7 +107,7 @@ def test_ItemSet_Iterable_Only(batch_size, shuffle, drop_last):
def test_ItemSet_integer(batch_size, shuffle, drop_last):
# Node IDs.
num_ids = 103
item_set = gb.ItemSet(num_ids, names="seed_nodes")
item_set = gb.ItemSet(num_ids, names="seeds")
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
......@@ -135,7 +136,7 @@ def test_ItemSet_seed_nodes(batch_size, shuffle, drop_last):
# Node IDs.
num_ids = 103
seed_nodes = torch.arange(0, num_ids)
item_set = gb.ItemSet(seed_nodes, names="seed_nodes")
item_set = gb.ItemSet(seed_nodes, names="seeds")
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
......@@ -165,7 +166,7 @@ def test_ItemSet_seed_nodes_labels(batch_size, shuffle, drop_last):
num_ids = 103
seed_nodes = torch.arange(0, num_ids)
labels = torch.arange(0, num_ids)
item_set = gb.ItemSet((seed_nodes, labels), names=("seed_nodes", "labels"))
item_set = gb.ItemSet((seed_nodes, labels), names=("seeds", "labels"))
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
......@@ -249,7 +250,7 @@ def test_ItemSet_node_pairs(batch_size, shuffle, drop_last):
# Node pairs.
num_ids = 103
node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)
item_set = gb.ItemSet(node_pairs, names="node_pairs")
item_set = gb.ItemSet(node_pairs, names="seeds")
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
......@@ -289,7 +290,7 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
num_ids = 103
node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)
labels = node_pairs[:, 0]
item_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels"))
item_set = gb.ItemSet((node_pairs, labels), names=("seeds", "labels"))
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
......@@ -333,16 +334,26 @@ def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last):
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
def test_ItemSet_node_pairs_labels_indexes(batch_size, shuffle, drop_last):
# Node pairs and negative destinations.
num_ids = 103
num_negs = 2
node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2)
neg_dsts = torch.arange(
2 * num_ids, 2 * num_ids + num_ids * num_negs
).reshape(-1, num_negs)
neg_srcs = node_pairs[:, 0].repeat_interleave(num_negs)
neg_dsts = torch.arange(2 * num_ids, 2 * num_ids + num_ids * num_negs)
neg_node_pairs = torch.cat((neg_srcs, neg_dsts)).reshape(2, -1).T
labels = torch.empty(num_ids * 3)
labels[:num_ids] = 1
labels[num_ids:] = 0
indexes = torch.cat(
(
torch.arange(0, num_ids),
torch.arange(0, num_ids).repeat_interleave(num_negs),
)
)
node_pairs = torch.cat((node_pairs, neg_node_pairs))
item_set = gb.ItemSet(
(node_pairs, neg_dsts), names=("node_pairs", "negative_dsts")
(node_pairs, labels, indexes), names=("seeds", "labels", "indexes")
)
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
......@@ -350,6 +361,8 @@ def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
src_ids = []
dst_ids = []
negs_ids = []
final_labels = []
final_indexes = []
for i, minibatch in enumerate(item_sampler):
assert minibatch.seeds is not None
assert isinstance(minibatch.seeds, torch.Tensor)
......@@ -358,46 +371,43 @@ def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
src, dst = minibatch.seeds.T
negs_src = src[~minibatch.labels.to(bool)]
negs_dst = dst[~minibatch.labels.to(bool)]
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
is_last = (i + 1) * batch_size >= num_ids * 3
if not is_last or num_ids * 3 % batch_size == 0:
expected_batch_size = batch_size
else:
if not drop_last:
expected_batch_size = num_ids % batch_size
expected_batch_size = num_ids * 3 % batch_size
else:
assert False
assert len(src) == expected_batch_size * 3
assert len(dst) == expected_batch_size * 3
assert len(src) == expected_batch_size
assert len(dst) == expected_batch_size
assert negs_src.dim() == 1
assert negs_dst.dim() == 1
assert len(negs_src) == expected_batch_size * 2
assert len(negs_dst) == expected_batch_size * 2
expected_indexes = torch.arange(expected_batch_size)
expected_indexes = torch.cat(
(expected_indexes, expected_indexes.repeat_interleave(2))
)
assert torch.equal(minibatch.indexes, expected_indexes)
# Verify node pairs and negative destinations.
assert torch.equal(
src[minibatch.labels.to(bool)] + 1, dst[minibatch.labels.to(bool)]
)
assert torch.equal((negs_dst - 2 * num_ids) // 2 * 2, negs_src)
# Archive batch.
src_ids.append(src)
dst_ids.append(dst)
negs_ids.append(negs_dst)
final_labels.append(minibatch.labels)
final_indexes.append(minibatch.indexes)
src_ids = torch.cat(src_ids)
dst_ids = torch.cat(dst_ids)
negs_ids = torch.cat(negs_ids)
final_labels = torch.cat(final_labels)
final_indexes = torch.cat(final_indexes)
assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
assert torch.all(negs_ids[:-1] <= negs_ids[1:]) is not shuffle
assert torch.all(final_labels[:-1] >= final_labels[1:]) is not shuffle
if not drop_last:
assert final_labels.sum() == num_ids
assert torch.equal(final_indexes, indexes) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_seeds(batch_size, shuffle, drop_last):
def test_ItemSet_hyperlink(batch_size, shuffle, drop_last):
# Node pairs.
num_ids = 103
seeds = torch.arange(0, 3 * num_ids).reshape(-1, 3)
......@@ -477,7 +487,7 @@ def test_ItemSet_seeds_labels(batch_size, shuffle, drop_last):
def test_append_with_other_datapipes():
num_ids = 100
batch_size = 4
item_set = gb.ItemSet(torch.arange(0, num_ids), names="seed_nodes")
item_set = gb.ItemSet(torch.arange(0, num_ids), names="seeds")
data_pipe = gb.ItemSampler(item_set, batch_size)
# torchdata.datapipes.iter.Enumerator
data_pipe = data_pipe.enumerate()
......@@ -500,8 +510,8 @@ def test_ItemSetDict_iterable_only(batch_size, shuffle, drop_last):
num_ids = 205
ids = {
"user": gb.ItemSet(IterableOnly(0, 99), names="seed_nodes"),
"item": gb.ItemSet(IterableOnly(99, num_ids), names="seed_nodes"),
"user": gb.ItemSet(IterableOnly(0, 99), names="seeds"),
"item": gb.ItemSet(IterableOnly(99, num_ids), names="seeds"),
}
chained_ids = []
for key, value in ids.items():
......@@ -539,8 +549,8 @@ def test_ItemSetDict_seed_nodes(batch_size, shuffle, drop_last):
# Node IDs.
num_ids = 205
ids = {
"user": gb.ItemSet(torch.arange(0, 99), names="seed_nodes"),
"item": gb.ItemSet(torch.arange(99, num_ids), names="seed_nodes"),
"user": gb.ItemSet(torch.arange(0, 99), names="seeds"),
"item": gb.ItemSet(torch.arange(99, num_ids), names="seeds"),
}
chained_ids = []
for key, value in ids.items():
......@@ -580,11 +590,11 @@ def test_ItemSetDict_seed_nodes_labels(batch_size, shuffle, drop_last):
ids = {
"user": gb.ItemSet(
(torch.arange(0, 99), torch.arange(0, 99)),
names=("seed_nodes", "labels"),
names=("seeds", "labels"),
),
"item": gb.ItemSet(
(torch.arange(99, num_ids), torch.arange(99, num_ids)),
names=("seed_nodes", "labels"),
names=("seeds", "labels"),
),
}
chained_ids = []
......@@ -638,8 +648,8 @@ def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last):
node_pairs_like = torch.arange(0, num_ids * 2).reshape(-1, 2)
node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2)
node_pairs_dict = {
"user:like:item": gb.ItemSet(node_pairs_like, names="node_pairs"),
"user:follow:user": gb.ItemSet(node_pairs_follow, names="node_pairs"),
"user:like:item": gb.ItemSet(node_pairs_like, names="seeds"),
"user:follow:user": gb.ItemSet(node_pairs_follow, names="seeds"),
}
item_set = gb.ItemSetDict(node_pairs_dict)
item_sampler = gb.ItemSampler(
......@@ -691,11 +701,11 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
node_pairs_dict = {
"user:like:item": gb.ItemSet(
(node_pairs_like, node_pairs_like[:, 0]),
names=("node_pairs", "labels"),
names=("seeds", "labels"),
),
"user:follow:user": gb.ItemSet(
(node_pairs_follow, node_pairs_follow[:, 0]),
names=("node_pairs", "labels"),
names=("seeds", "labels"),
),
}
item_set = gb.ItemSetDict(node_pairs_dict)
......@@ -709,7 +719,6 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seeds is not None
assert minibatch.labels is not None
assert minibatch.negative_dsts is None
is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size
......@@ -749,27 +758,64 @@ def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last):
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
def test_ItemSetDict_node_pairs_labels_indexes(batch_size, shuffle, drop_last):
# Head, tail and negative tails.
num_ids = 103
total_ids = 2 * num_ids
total_ids = 6 * num_ids
num_negs = 2
node_paris_like = torch.arange(0, num_ids * 2).reshape(-1, 2)
node_pairs_like = torch.arange(0, num_ids * 2).reshape(-1, 2)
node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2)
neg_dsts_like = torch.arange(
num_ids * 4, num_ids * 4 + num_ids * num_negs
).reshape(-1, num_negs)
neg_dsts_like = torch.arange(num_ids * 4, num_ids * 4 + num_ids * num_negs)
neg_node_pairs_like = (
torch.cat(
(node_pairs_like[:, 0].repeat_interleave(num_negs), neg_dsts_like)
)
.view(2, -1)
.T
)
all_node_pairs_like = torch.cat((node_pairs_like, neg_node_pairs_like))
labels_like = torch.empty(num_ids * 3)
labels_like[:num_ids] = 1
labels_like[num_ids:] = 0
indexes_like = torch.cat(
(
torch.arange(0, num_ids),
torch.arange(0, num_ids).repeat_interleave(num_negs),
)
)
neg_dsts_follow = torch.arange(
num_ids * 4 + num_ids * num_negs, num_ids * 4 + num_ids * num_negs * 2
).reshape(-1, num_negs)
)
neg_node_pairs_follow = (
torch.cat(
(
node_pairs_follow[:, 0].repeat_interleave(num_negs),
neg_dsts_follow,
)
)
.view(2, -1)
.T
)
all_node_pairs_follow = torch.cat(
(node_pairs_follow, neg_node_pairs_follow)
)
labels_follow = torch.empty(num_ids * 3)
labels_follow[:num_ids] = 1
labels_follow[num_ids:] = 0
indexes_follow = torch.cat(
(
torch.arange(0, num_ids),
torch.arange(0, num_ids).repeat_interleave(num_negs),
)
)
data_dict = {
"user:like:item": gb.ItemSet(
(node_paris_like, neg_dsts_like),
names=("node_pairs", "negative_dsts"),
(all_node_pairs_like, labels_like, indexes_like),
names=("seeds", "labels", "indexes"),
),
"user:follow:user": gb.ItemSet(
(node_pairs_follow, neg_dsts_follow),
names=("node_pairs", "negative_dsts"),
(all_node_pairs_follow, labels_follow, indexes_follow),
names=("seeds", "labels", "indexes"),
),
}
item_set = gb.ItemSetDict(data_dict)
......@@ -779,11 +825,13 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
src_ids = []
dst_ids = []
negs_ids = []
final_labels = defaultdict(list)
final_indexes = defaultdict(list)
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seeds is not None
assert minibatch.labels is not None
assert minibatch.negative_dsts is None
assert minibatch.indexes is not None
is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size
......@@ -800,24 +848,23 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
assert isinstance(seeds, torch.Tensor)
src_etype = seeds[:, 0]
dst_etype = seeds[:, 1]
src.append(src_etype[minibatch.labels[etype].to(bool)])
dst.append(dst_etype[minibatch.labels[etype].to(bool)])
src.append(src_etype)
dst.append(dst_etype)
negs_src.append(src_etype[~minibatch.labels[etype].to(bool)])
negs_dst.append(dst_etype[~minibatch.labels[etype].to(bool)])
final_labels[etype].append(minibatch.labels[etype])
final_indexes[etype].append(minibatch.indexes[etype])
src = torch.cat(src)
dst = torch.cat(dst)
negs_src = torch.cat(negs_src)
negs_dst = torch.cat(negs_dst)
assert len(src) == expected_batch_size
assert len(dst) == expected_batch_size
assert len(negs_src) == expected_batch_size * 2
assert len(negs_dst) == expected_batch_size * 2
src_ids.append(src)
dst_ids.append(dst)
negs_ids.append(negs_dst)
assert negs_src.dim() == 1
assert negs_dst.dim() == 1
assert torch.equal(src + 1, dst)
assert torch.equal(negs_src, (negs_dst - num_ids * 4) // 2 * 2)
src_ids = torch.cat(src_ids)
dst_ids = torch.cat(dst_ids)
......@@ -825,12 +872,24 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle
assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle
assert torch.all(negs_ids <= negs_ids) is not shuffle
for etype in data_dict.keys():
final_labels_etype = torch.cat(final_labels[etype])
final_indexes_etype = torch.cat(final_indexes[etype])
assert (
torch.all(final_labels_etype[:-1] >= final_labels_etype[1:])
is not shuffle
)
if not drop_last:
assert final_labels_etype.sum() == num_ids
assert (
torch.equal(final_indexes_etype, indexes_follow) is not shuffle
)
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSetDict_seeds(batch_size, shuffle, drop_last):
def test_ItemSetDict_hyperlink(batch_size, shuffle, drop_last):
# Node pairs.
num_ids = 103
total_pairs = 2 * num_ids
......@@ -876,7 +935,7 @@ def test_ItemSetDict_seeds(batch_size, shuffle, drop_last):
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSetDict_seeds_labels(batch_size, shuffle, drop_last):
def test_ItemSetDict_hyperlink_labels(batch_size, shuffle, drop_last):
# Node pairs and labels
num_ids = 103
total_ids = 2 * num_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment