Unverified Commit cad7caeb authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[Graphbolt] Rewrite `to_dgl` to multiple `get` functions (#6735)

parent 70fdb69f
......@@ -98,11 +98,8 @@ def evaluate(model, dataset, device):
logits = []
labels = []
for step, data in enumerate(dataloader):
# Convert data to DGL format for computing.
data = data.to_dgl()
# Unpack MiniBatch.
compacted_pairs, label = to_binary_link_dgl_computing_pack(data)
# Get node pairs with labels for loss calculation.
compacted_pairs, label = data.node_pairs_with_labels
# The features of sampled nodes.
x = data.node_features["feat"]
......@@ -140,11 +137,8 @@ def train(model, dataset, device):
# mini-batches.
########################################################################
for step, data in enumerate(dataloader):
# Convert data to DGL format for computing.
data = data.to_dgl()
# Unpack MiniBatch.
compacted_pairs, labels = to_binary_link_dgl_computing_pack(data)
# Get node pairs with labels for loss calculation.
compacted_pairs, labels = data.node_pairs_with_labels
# The features of sampled nodes.
x = data.node_features["feat"]
......
......@@ -57,7 +57,6 @@ def evaluate(model, dataset, itemset, device):
dataloader = create_dataloader(dataset, itemset, device)
for step, data in enumerate(dataloader):
data = data.to_dgl()
x = data.node_features["feat"]
y.append(data.labels)
y_hats.append(model(data.blocks, x))
......@@ -84,9 +83,6 @@ def train(model, dataset, device):
# mini-batches.
########################################################################
for step, data in enumerate(dataloader):
# Convert data to DGL format for computing.
data = data.to_dgl()
# The features of sampled nodes.
x = data.node_features["feat"]
......
......@@ -363,9 +363,10 @@ class MiniBatch:
"""Set edge features."""
self.edge_features = edge_features
def _to_dgl_blocks(self):
"""Transforming a `MiniBatch` into DGL blocks necessitates constructing
a graphical structure and ID mappings.
@property
def blocks(self):
"""Extracts DGL blocks from `MiniBatch` to construct a graphical
structure and ID mappings.
"""
if not self.sampled_subgraphs:
return None
......@@ -459,29 +460,28 @@ class MiniBatch:
block.edata[dgl.EID] = subgraph.original_edge_ids
return blocks
def to_dgl(self):
"""Converting a `MiniBatch` into a DGL MiniBatch that contains
everything necessary for computation."
@property
def positive_node_pairs(self):
"""`positive_node_pairs` is a representation of positive graphs used for
evaluating or computing loss in link prediction tasks.
- If `positive_node_pairs` is a tuple: It indicates a homogeneous graph
containing two tensors representing source-destination node pairs.
- If `positive_node_pairs` is a dictionary: The keys should be edge type,
and the value should be a tuple of tensors representing node pairs of the
given type.
"""
minibatch = DGLMiniBatch(
blocks=self._to_dgl_blocks(),
node_features=self.node_features,
edge_features=self.edge_features,
labels=self.labels,
)
# Need input nodes to fetch feature.
if self.node_features is None:
minibatch.input_nodes = self.input_nodes
# Need output nodes to fetch label.
if self.labels is None:
minibatch.output_nodes = self.seed_nodes
assert (
minibatch.blocks is not None
), "Sampled subgraphs for computation are missing."
return self.compacted_node_pairs
# For link prediction tasks.
if self.compacted_node_pairs is not None:
minibatch.positive_node_pairs = self.compacted_node_pairs
@property
def negative_node_pairs(self):
"""`negative_node_pairs` is a representation of negative graphs used for
evaluating or computing loss in link prediction tasks.
- If `negative_node_pairs` is a tuple: It indicates a homogeneous graph
containing two tensors representing source-destination node pairs.
- If `negative_node_pairs` is a dictionary: The keys should be edge type,
and the value should be a tuple of tensors representing node pairs of the
given type.
"""
# Build negative graph.
if (
self.compacted_negative_srcs is not None
......@@ -489,24 +489,27 @@ class MiniBatch:
):
# For homogeneous graph.
if isinstance(self.compacted_negative_srcs, torch.Tensor):
minibatch.negative_node_pairs = (
negative_node_pairs = (
self.compacted_negative_srcs.view(-1),
self.compacted_negative_dsts.view(-1),
)
# For heterogeneous graph.
else:
minibatch.negative_node_pairs = {
negative_node_pairs = {
etype: (
neg_src.view(-1),
self.compacted_negative_dsts[etype].view(-1),
)
for etype, neg_src in self.compacted_negative_srcs.items()
}
elif self.compacted_negative_srcs is not None:
elif (
self.compacted_negative_srcs is not None
and self.compacted_node_pairs is not None
):
# For homogeneous graph.
if isinstance(self.compacted_negative_srcs, torch.Tensor):
negative_ratio = self.compacted_negative_srcs.size(1)
minibatch.negative_node_pairs = (
negative_node_pairs = (
self.compacted_negative_srcs.view(-1),
self.compacted_node_pairs[1].repeat_interleave(
negative_ratio
......@@ -514,23 +517,26 @@ class MiniBatch:
)
# For heterogeneous graph.
else:
negative_ratio = list(
self.compacted_negative_srcs.values()
)[0].size(1)
minibatch.negative_node_pairs = {
negative_ratio = list(self.compacted_negative_srcs.values())[
0
].size(1)
negative_node_pairs = {
etype: (
neg_src.view(-1),
self.compacted_node_pairs[etype][
1
].repeat_interleave(negative_ratio),
self.compacted_node_pairs[etype][1].repeat_interleave(
negative_ratio
),
)
for etype, neg_src in self.compacted_negative_srcs.items()
}
elif self.compacted_negative_dsts is not None:
elif (
self.compacted_negative_dsts is not None
and self.compacted_node_pairs is not None
):
# For homogeneous graph.
if isinstance(self.compacted_negative_dsts, torch.Tensor):
negative_ratio = self.compacted_negative_dsts.size(1)
minibatch.negative_node_pairs = (
negative_node_pairs = (
self.compacted_node_pairs[0].repeat_interleave(
negative_ratio
),
......@@ -538,19 +544,51 @@ class MiniBatch:
)
# For heterogeneous graph.
else:
negative_ratio = list(
self.compacted_negative_dsts.values()
)[0].size(1)
minibatch.negative_node_pairs = {
etype: (
self.compacted_node_pairs[etype][
negative_ratio = list(self.compacted_negative_dsts.values())[
0
].repeat_interleave(negative_ratio),
].size(1)
negative_node_pairs = {
etype: (
self.compacted_node_pairs[etype][0].repeat_interleave(
negative_ratio
),
neg_dst.view(-1),
)
for etype, neg_dst in self.compacted_negative_dsts.items()
}
return minibatch
else:
negative_node_pairs = None
return negative_node_pairs
@property
def node_pairs_with_labels(self):
"""Get a node pair tensor and a label tensor from MiniBatch. They are
used for evaluating or computing loss. It will return
`(node_pairs, labels)` as result.
- If it's a link prediction task, `node_pairs` will contain both
negative and positive node pairs and `labels` will consist of 0 and 1,
indicating whether the corresponding node pair is negative or positive.
- If it's an edge classification task, this function will directly
return `compacted_node_pairs` and corresponding `labels`.
- Otherwise it will return None.
"""
if self.labels is None:
positive_node_pairs = self.positive_node_pairs
negative_node_pairs = self.negative_node_pairs
if positive_node_pairs is None or negative_node_pairs is None:
return None
pos_src, pos_dst = positive_node_pairs
neg_src, neg_dst = negative_node_pairs
node_pairs = (
torch.cat((pos_src, neg_src), dim=0),
torch.cat((pos_dst, neg_dst), dim=0),
)
pos_label = torch.ones_like(pos_src)
neg_label = torch.zeros_like(neg_src)
labels = torch.cat([pos_label, neg_label], dim=0)
return (node_pairs, labels.float())
else:
return (self.compacted_node_pairs, self.labels)
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
"""Copy `MiniBatch` to the specified device using reflection."""
......@@ -561,6 +599,7 @@ class MiniBatch:
for attr in dir(self):
# Only copy member variables.
if not callable(getattr(self, attr)) and not attr.startswith("__"):
try:
setattr(
self,
attr,
......@@ -568,6 +607,8 @@ class MiniBatch:
getattr(self, attr), lambda x: _to(x, device)
),
)
except AttributeError:
continue
return self
......
......@@ -8,7 +8,6 @@ from .minibatch import DGLMiniBatch, MiniBatch
__all__ = [
"MiniBatchTransformer",
"DGLMiniBatchConverter",
]
......@@ -41,22 +40,3 @@ class MiniBatchTransformer(Mapper):
minibatch, (MiniBatch, DGLMiniBatch)
), "The transformer output should be an instance of MiniBatch"
return minibatch
@functional_datapipe("to_dgl")
class DGLMiniBatchConverter(Mapper):
"""Convert a graphbolt mini-batch to a dgl mini-batch.
Functional name: :obj:`to_dgl`.
Parameters
----------
datapipe : DataPipe
The datapipe.
"""
def __init__(
self,
datapipe,
):
super().__init__(datapipe, MiniBatch.to_dgl)
......@@ -2086,7 +2086,6 @@ def test_OnDiskDataset_homogeneous(include_original_edge_id):
datapipe = datapipe.fetch_feature(
dataset.feature, node_feature_keys=["feat"]
)
datapipe = datapipe.to_dgl()
dataloader = gb.DataLoader(datapipe)
for _ in dataloader:
pass
......@@ -2158,7 +2157,6 @@ def test_OnDiskDataset_heterogeneous(include_original_edge_id):
datapipe = datapipe.fetch_feature(
dataset.feature, node_feature_keys={"user": ["feat"]}
)
datapipe = datapipe.to_dgl()
dataloader = gb.DataLoader(datapipe)
for _ in dataloader:
pass
......
......@@ -67,9 +67,6 @@ def test_CopyToWithMiniBatches():
# Invoke CopyTo via functional form.
test_data_device(datapipe.copy_to("cuda"))
# Test for DGLMiniBatch.
datapipe = gb.DGLMiniBatchConverter(datapipe)
# Invoke CopyTo via class constructor.
test_data_device(gb.CopyTo(datapipe, "cuda"))
......
......@@ -14,10 +14,7 @@ class MiniBatchType(Enum):
DGLMiniBatch = 2
@pytest.mark.parametrize(
"minibatch_type", [MiniBatchType.MiniBatch, MiniBatchType.DGLMiniBatch]
)
def test_FeatureFetcher_invoke(minibatch_type):
def test_FeatureFetcher_invoke():
# Prepare graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
a = torch.tensor(
......@@ -40,8 +37,6 @@ def test_FeatureFetcher_invoke(minibatch_type):
# Invoke FeatureFetcher via class constructor.
datapipe = gb.NeighborSampler(item_sampler, graph, fanouts)
if minibatch_type == MiniBatchType.DGLMiniBatch:
datapipe = datapipe.to_dgl()
datapipe = gb.FeatureFetcher(datapipe, feature_store, ["a"], ["b"])
assert len(list(datapipe)) == 5
......@@ -53,10 +48,7 @@ def test_FeatureFetcher_invoke(minibatch_type):
assert len(list(datapipe)) == 5
@pytest.mark.parametrize(
"minibatch_type", [MiniBatchType.MiniBatch, MiniBatchType.DGLMiniBatch]
)
def test_FeatureFetcher_homo(minibatch_type):
def test_FeatureFetcher_homo():
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
a = torch.tensor(
[[random.randint(0, 10)] for _ in range(graph.total_num_nodes)]
......@@ -76,17 +68,12 @@ def test_FeatureFetcher_homo(minibatch_type):
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts)
if minibatch_type == MiniBatchType.DGLMiniBatch:
sampler_dp = sampler_dp.to_dgl()
fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, ["a"], ["b"])
assert len(list(fetcher_dp)) == 5
@pytest.mark.parametrize(
"minibatch_type", [MiniBatchType.MiniBatch, MiniBatchType.DGLMiniBatch]
)
def test_FeatureFetcher_with_edges_homo(minibatch_type):
def test_FeatureFetcher_with_edges_homo():
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
a = torch.tensor(
[[random.randint(0, 10)] for _ in range(graph.total_num_nodes)]
......@@ -121,8 +108,6 @@ def test_FeatureFetcher_with_edges_homo(minibatch_type):
itemset = gb.ItemSet(torch.arange(10))
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids)
if minibatch_type == MiniBatchType.DGLMiniBatch:
converter_dp = converter_dp.to_dgl()
fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, ["a"], ["b"])
assert len(list(fetcher_dp)) == 5
......@@ -155,10 +140,7 @@ def get_hetero_graph():
)
@pytest.mark.parametrize(
"minibatch_type", [MiniBatchType.MiniBatch, MiniBatchType.DGLMiniBatch]
)
def test_FeatureFetcher_hetero(minibatch_type):
def test_FeatureFetcher_hetero():
graph = get_hetero_graph()
a = torch.tensor([[random.randint(0, 10)] for _ in range(2)])
b = torch.tensor([[random.randint(0, 10)] for _ in range(3)])
......@@ -179,8 +161,6 @@ def test_FeatureFetcher_hetero(minibatch_type):
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts)
if minibatch_type == MiniBatchType.DGLMiniBatch:
sampler_dp = sampler_dp.to_dgl()
fetcher_dp = gb.FeatureFetcher(
sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]}
)
......@@ -188,10 +168,7 @@ def test_FeatureFetcher_hetero(minibatch_type):
assert len(list(fetcher_dp)) == 3
@pytest.mark.parametrize(
"minibatch_type", [MiniBatchType.MiniBatch, MiniBatchType.DGLMiniBatch]
)
def test_FeatureFetcher_with_edges_hetero(minibatch_type):
def test_FeatureFetcher_with_edges_hetero():
a = torch.tensor([[random.randint(0, 10)] for _ in range(20)])
b = torch.tensor([[random.randint(0, 10)] for _ in range(50)])
......@@ -243,8 +220,6 @@ def test_FeatureFetcher_with_edges_hetero(minibatch_type):
)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids)
if minibatch_type == MiniBatchType.DGLMiniBatch:
converter_dp = converter_dp.to_dgl()
fetcher_dp = gb.FeatureFetcher(
converter_dp, feature_store, {"n1": ["a"]}, {"n1:e1:n2": ["a"]}
)
......
......@@ -55,62 +55,149 @@ def test_integration_link_prediction():
datapipe = datapipe.fetch_feature(
feature_store, node_feature_keys=["feat"], edge_feature_keys=["feat"]
)
datapipe = datapipe.to_dgl()
dataloader = gb.DataLoader(
datapipe,
)
expected = [
str(
"""DGLMiniBatch(positive_node_pairs=(tensor([0, 1, 1, 1]),
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]),
node_pairs=(tensor([5, 4]), tensor([0, 5])),
),
FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2, 0]),
node_pairs=(tensor([5]), tensor([0])),
)],
positive_node_pairs=(tensor([0, 1, 1, 1]),
tensor([2, 3, 3, 1])),
output_nodes=None,
node_pairs_with_labels=((tensor([0, 1, 1, 1, 0, 1, 1, 1]), tensor([2, 3, 3, 1, 4, 4, 1, 4])),
tensor([1., 1., 1., 1., 0., 0., 0., 0.])),
node_pairs=(tensor([5, 3, 3, 3]),
tensor([1, 2, 2, 3])),
node_features={'feat': tensor([[0.5160, 0.2486],
[0.8672, 0.2276],
[0.6172, 0.7865],
[0.2109, 0.1089],
[0.9634, 0.2294],
[0.5503, 0.8223]])},
negative_srcs=tensor([[5],
[3],
[3],
[3]]),
negative_node_pairs=(tensor([0, 1, 1, 1]),
tensor([4, 4, 1, 4])),
negative_dsts=tensor([[0],
[0],
[3],
[0]]),
labels=None,
input_nodes=None,
input_nodes=tensor([5, 3, 1, 2, 0, 4]),
edge_features=[{},
{}],
compacted_node_pairs=(tensor([0, 1, 1, 1]),
tensor([2, 3, 3, 1])),
compacted_negative_srcs=tensor([[0],
[1],
[1],
[1]]),
compacted_negative_dsts=tensor([[4],
[4],
[1],
[4]]),
blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2),
Block(num_src_nodes=6, num_dst_nodes=5, num_edges=1)],
)"""
),
str(
"""DGLMiniBatch(positive_node_pairs=(tensor([0, 1, 1, 2]),
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0, 5, 1]),
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0, 5, 1]),
node_pairs=(tensor([1, 3]), tensor([3, 4])),
),
FusedSampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0, 5, 1]),
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0, 5, 1]),
node_pairs=(tensor([1, 3]), tensor([3, 4])),
)],
positive_node_pairs=(tensor([0, 1, 1, 2]),
tensor([0, 0, 1, 1])),
output_nodes=None,
node_pairs_with_labels=((tensor([0, 1, 1, 2, 0, 1, 1, 2]), tensor([0, 0, 1, 1, 1, 1, 3, 4])),
tensor([1., 1., 1., 1., 0., 0., 0., 0.])),
node_pairs=(tensor([3, 4, 4, 0]),
tensor([3, 3, 4, 4])),
node_features={'feat': tensor([[0.8672, 0.2276],
[0.5503, 0.8223],
[0.9634, 0.2294],
[0.5160, 0.2486],
[0.6172, 0.7865]])},
negative_srcs=tensor([[3],
[4],
[4],
[0]]),
negative_node_pairs=(tensor([0, 1, 1, 2]),
tensor([1, 1, 3, 4])),
negative_dsts=tensor([[4],
[4],
[5],
[1]]),
labels=None,
input_nodes=None,
input_nodes=tensor([3, 4, 0, 5, 1]),
edge_features=[{},
{}],
compacted_node_pairs=(tensor([0, 1, 1, 2]),
tensor([0, 0, 1, 1])),
compacted_negative_srcs=tensor([[0],
[1],
[1],
[2]]),
compacted_negative_dsts=tensor([[1],
[1],
[3],
[4]]),
blocks=[Block(num_src_nodes=5, num_dst_nodes=5, num_edges=2),
Block(num_src_nodes=5, num_dst_nodes=5, num_edges=2)],
)"""
),
str(
"""DGLMiniBatch(positive_node_pairs=(tensor([0, 1]),
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 4]),
node_pairs=(tensor([1]), tensor([1])),
),
FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 4]),
node_pairs=(tensor([1]), tensor([1])),
)],
positive_node_pairs=(tensor([0, 1]),
tensor([0, 0])),
output_nodes=None,
node_pairs_with_labels=((tensor([0, 1, 0, 1]), tensor([0, 0, 0, 0])),
tensor([1., 1., 0., 0.])),
node_pairs=(tensor([5, 4]),
tensor([5, 5])),
node_features={'feat': tensor([[0.5160, 0.2486],
[0.5503, 0.8223]])},
negative_srcs=tensor([[5],
[4]]),
negative_node_pairs=(tensor([0, 1]),
tensor([0, 0])),
negative_dsts=tensor([[5],
[5]]),
labels=None,
input_nodes=None,
input_nodes=tensor([5, 4]),
edge_features=[{},
{}],
compacted_node_pairs=(tensor([0, 1]),
tensor([0, 0])),
compacted_negative_srcs=tensor([[0],
[1]]),
compacted_negative_dsts=tensor([[0],
[0]]),
blocks=[Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1),
Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1)],
)"""
......@@ -169,57 +256,113 @@ def test_integration_node_classification():
datapipe = datapipe.fetch_feature(
feature_store, node_feature_keys=["feat"], edge_feature_keys=["feat"]
)
datapipe = datapipe.to_dgl()
dataloader = gb.DataLoader(
datapipe,
)
expected = [
str(
"""DGLMiniBatch(positive_node_pairs=(tensor([0, 1, 1, 1]),
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2]),
node_pairs=(tensor([4, 1, 0, 1]), tensor([0, 1, 2, 3])),
),
FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 3, 1, 2]),
node_pairs=(tensor([0, 1, 0, 1]), tensor([0, 1, 2, 3])),
)],
positive_node_pairs=(tensor([0, 1, 1, 1]),
tensor([2, 3, 3, 1])),
output_nodes=None,
node_pairs_with_labels=None,
node_pairs=(tensor([5, 3, 3, 3]),
tensor([1, 2, 2, 3])),
node_features={'feat': tensor([[0.5160, 0.2486],
[0.8672, 0.2276],
[0.6172, 0.7865],
[0.2109, 0.1089],
[0.5503, 0.8223]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=None,
input_nodes=None,
input_nodes=tensor([5, 3, 1, 2, 4]),
edge_features=[{},
{}],
compacted_node_pairs=(tensor([0, 1, 1, 1]),
tensor([2, 3, 3, 1])),
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=5, num_dst_nodes=4, num_edges=4),
Block(num_src_nodes=4, num_dst_nodes=4, num_edges=4)],
)"""
),
str(
"""DGLMiniBatch(positive_node_pairs=(tensor([0, 1, 1, 2]),
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0]),
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0]),
node_pairs=(tensor([0, 2]), tensor([0, 1])),
),
FusedSampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0]),
original_edge_ids=None,
original_column_node_ids=tensor([3, 4, 0]),
node_pairs=(tensor([0, 2]), tensor([0, 1])),
)],
positive_node_pairs=(tensor([0, 1, 1, 2]),
tensor([0, 0, 1, 1])),
output_nodes=None,
node_pairs_with_labels=None,
node_pairs=(tensor([3, 4, 4, 0]),
tensor([3, 3, 4, 4])),
node_features={'feat': tensor([[0.8672, 0.2276],
[0.5503, 0.8223],
[0.9634, 0.2294]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=None,
input_nodes=None,
input_nodes=tensor([3, 4, 0]),
edge_features=[{},
{}],
compacted_node_pairs=(tensor([0, 1, 1, 2]),
tensor([0, 0, 1, 1])),
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2),
Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2)],
)"""
),
str(
"""DGLMiniBatch(positive_node_pairs=(tensor([0, 1]),
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 4, 0]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 4]),
node_pairs=(tensor([0, 2]), tensor([0, 1])),
),
FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 4]),
original_edge_ids=None,
original_column_node_ids=tensor([5, 4]),
node_pairs=(tensor([1, 1]), tensor([0, 1])),
)],
positive_node_pairs=(tensor([0, 1]),
tensor([0, 0])),
output_nodes=None,
node_pairs_with_labels=None,
node_pairs=(tensor([5, 4]),
tensor([5, 5])),
node_features={'feat': tensor([[0.5160, 0.2486],
[0.5503, 0.8223],
[0.9634, 0.2294]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=None,
input_nodes=None,
input_nodes=tensor([5, 4, 0]),
edge_features=[{},
{}],
compacted_node_pairs=(tensor([0, 1]),
tensor([0, 0])),
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=3, num_dst_nodes=2, num_edges=2),
Block(num_src_nodes=2, num_dst_nodes=2, num_edges=2)],
)"""
......
import dgl.graphbolt as gb
import torch
from . import gb_test_utils
def test_dgl_minibatch_converter():
N = 32
B = 4
itemset = gb.ItemSet(torch.arange(N), names="seed_nodes")
graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True)
features = {}
keys = [("node", None, "a"), ("node", None, "b")]
features[keys[0]] = gb.TorchBasedFeature(torch.randn(200, 4))
features[keys[1]] = gb.TorchBasedFeature(torch.randn(200, 4))
feature_store = gb.BasicFeatureStore(features)
item_sampler = gb.ItemSampler(itemset, batch_size=B)
subgraph_sampler = gb.NeighborSampler(
item_sampler,
graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)],
)
feature_fetcher = gb.FeatureFetcher(
subgraph_sampler,
feature_store,
["a"],
)
dgl_converter = gb.DGLMiniBatchConverter(feature_fetcher)
dataloader = gb.DataLoader(dgl_converter)
assert len(list(dataloader)) == N // B
minibatch = next(iter(dataloader))
assert isinstance(minibatch, gb.DGLMiniBatch)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment