"projects/vscode:/vscode.git/clone" did not exist on "4ecc9ea89d55b51c6ad66996ff0edd013ded0815"
Unverified Commit 86f739b3 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Change data_block to mini_batch (#6256)

parent f281959a
......@@ -6,7 +6,7 @@ import torch
from .._ffi import libinfo
from .base import *
from .data_block import *
from .minibatch import *
from .data_format import *
from .dataloader import *
from .dataset import *
......
......@@ -53,9 +53,9 @@ class NeighborSampler(SubgraphSampler):
-------
>>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper
>>> def to_link_block(data):
... block = gb.LinkPredictionBlock(node_pair=data)
... return block
>>> def minibatch_link_collator(data):
... minibatch = gb.MiniBatch(node_pair=data)
... return minibatch
...
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
......@@ -67,9 +67,10 @@ class NeighborSampler(SubgraphSampler):
>>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1,
...)
>>> data_block_converter = Mapper(item_sampler, to_link_block)
>>> minibatch_converter = Mapper(item_sampler,
...minibatch_link_collator)
>>> neg_sampler = gb.UniformNegativeSampler(
...data_block_converter, 2, data_format, graph)
...minibatch_converter, 2, data_format, graph)
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
...torch.LongTensor([15])]
>>> subgraph_sampler = gb.NeighborSampler(
......@@ -164,9 +165,9 @@ class LayerNeighborSampler(NeighborSampler):
-------
>>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper
>>> def to_link_block(data):
... block = gb.LinkPredictionBlock(node_pair=data)
... return block
>>> def minibatch_link_collator(data):
... minibatch = gb.MiniBatch(node_pair=data)
... return minibatch
...
>>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
......@@ -178,9 +179,10 @@ class LayerNeighborSampler(NeighborSampler):
>>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1,
...)
>>> data_block_converter = Mapper(item_sampler, to_link_block)
>>> minibatch_converter = Mapper(item_sampler,
...minibatch_link_collator)
>>> neg_sampler = gb.UniformNegativeSampler(
...data_block_converter, 2, data_format, graph)
...minibatch_converter, 2, data_format, graph)
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
...torch.LongTensor([15])]
>>> subgraph_sampler = gb.LayerNeighborSampler(
......
......@@ -7,11 +7,11 @@ import torch
from .sampled_subgraph import SampledSubgraph
__all__ = ["DataBlock", "NodeClassificationBlock", "LinkPredictionBlock"]
__all__ = ["MiniBatch"]
@dataclass
class DataBlock:
class MiniBatch:
r"""A composite data class for data structure in the graphbolt. It is
designed to facilitate the exchange of data among different components
involved in processing data. The purpose of this class is to unify the
......@@ -52,12 +52,6 @@ class DataBlock:
value should be corresponding heterogeneous node id.
"""
@dataclass
class NodeClassificationBlock(DataBlock):
r"""A subclass of 'UnifiedDataStruct', specialized for handling node level
tasks."""
seed_node: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of seed nodes used for sampling in the graph.
......@@ -69,17 +63,12 @@ class NodeClassificationBlock(DataBlock):
label: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Labels associated with seed nodes in the graph.
- If `label` is a tensor: It indicates the graph is homogeneous.
- If `label` is a dictionary: The keys should be node type and the
value should be corresponding node labels to given 'seed_node'.
- If `label` is a tensor: It indicates the graph is homogeneous. The value
should be corresponding labels to given 'seed_node' or 'node_pair'.
- If `label` is a dictionary: The keys should be node or edge type and the
value should be corresponding labels to given 'seed_node' or 'node_pair'.
"""
@dataclass
class LinkPredictionBlock(DataBlock):
r"""A subclass of 'UnifiedDataStruct', specialized for handling edge level
tasks."""
node_pair: Union[
Tuple[torch.Tensor, torch.Tensor],
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
......@@ -93,15 +82,6 @@ class LinkPredictionBlock(DataBlock):
type.
"""
label: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Labels associated with the link prediction task.
- If `label` is a tensor: It indicates a homogeneous graph. The value are
edge labels corresponding to given 'node_pair'.
- If `label` is a dictionary: The keys should be edge type, and the value
should correspond to given 'node_pair'.
"""
negative_head: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of negative samples for the head nodes in the link
......
......@@ -30,49 +30,50 @@ class NegativeSampler(Mapper):
negative_ratio : int
The proportion of negative samples to positive samples.
output_format : LinkPredictionEdgeFormat
Determines the edge format of the output data.
Determines the edge format of the output minibatch.
"""
super().__init__(datapipe, self._sample)
assert negative_ratio > 0, "Negative_ratio should be positive Integer."
self.negative_ratio = negative_ratio
self.output_format = output_format
def _sample(self, data):
def _sample(self, minibatch):
"""
Generate a mix of positive and negative samples.
Parameters
----------
data : LinkPredictionBlock
An instance of 'LinkPredictionBlock' class requires the 'node_pair'
field. This function is responsible for generating negative edges
minibatch : MiniBatch
An instance of 'MiniBatch' class requires the 'node_pair' field.
This function is responsible for generating negative edges
corresponding to the positive edges defined by the 'node_pair'. In
cases where negative edges already exist, this function will
overwrite them.
Returns
-------
LinkPredictionBlock
An instance of 'LinkPredictionBlock' encompasses both positive and
negative samples.
MiniBatch
An instance of 'MiniBatch' encompasses both positive and negative
samples.
"""
node_pairs = data.node_pair
node_pairs = minibatch.node_pair
assert node_pairs is not None
if isinstance(node_pairs, Mapping):
if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT:
data.label = {}
minibatch.label = {}
else:
data.negative_head, data.negative_tail = {}, {}
minibatch.negative_head, minibatch.negative_tail = {}, {}
for etype, pos_pairs in node_pairs.items():
self._collate(
data, self._sample_with_etype(pos_pairs, etype), etype
minibatch, self._sample_with_etype(pos_pairs, etype), etype
)
if self.output_format == LinkPredictionEdgeFormat.HEAD_CONDITIONED:
data.negative_tail = None
minibatch.negative_tail = None
if self.output_format == LinkPredictionEdgeFormat.TAIL_CONDITIONED:
data.negative_head = None
minibatch.negative_head = None
else:
self._collate(data, self._sample_with_etype(node_pairs))
return data
self._collate(minibatch, self._sample_with_etype(node_pairs))
return minibatch
def _sample_with_etype(self, node_pairs, etype=None):
"""Generate negative pairs for a given etype form positive pairs
......@@ -94,13 +95,13 @@ class NegativeSampler(Mapper):
"""
raise NotImplementedError
def _collate(self, data, neg_pairs, etype=None):
"""Collates positive and negative samples into data.
def _collate(self, minibatch, neg_pairs, etype=None):
"""Collates positive and negative samples into minibatch.
Parameters
----------
data : LinkPredictionBlock
The input data, which contains positive node pairs, will be filled
minibatch : MiniBatch
The input minibatch, which contains positive node pairs, will be filled
with negative information in this function.
neg_pairs : Tuple[Tensor, Tensor]
A tuple of tensors represents source-destination node pairs of
......@@ -110,7 +111,9 @@ class NegativeSampler(Mapper):
Canonical edge type.
"""
pos_src, pos_dst = (
data.node_pair[etype] if etype is not None else data.node_pair
minibatch.node_pair[etype]
if etype is not None
else minibatch.node_pair
)
neg_src, neg_dst = neg_pairs
if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT:
......@@ -120,11 +123,11 @@ class NegativeSampler(Mapper):
dst = torch.cat([pos_dst, neg_dst])
label = torch.cat([pos_label, neg_label])
if etype is not None:
data.node_pair[etype] = (src, dst)
data.label[etype] = label
minibatch.node_pair[etype] = (src, dst)
minibatch.label[etype] = label
else:
data.node_pair = (src, dst)
data.label = label
minibatch.node_pair = (src, dst)
minibatch.label = label
else:
if self.output_format == LinkPredictionEdgeFormat.CONDITIONED:
neg_src = neg_src.view(-1, self.negative_ratio)
......@@ -144,8 +147,8 @@ class NegativeSampler(Mapper):
f"Unsupported output format {self.output_format}."
)
if etype is not None:
data.negative_head[etype] = neg_src
data.negative_tail[etype] = neg_dst
minibatch.negative_head[etype] = neg_src
minibatch.negative_tail[etype] = neg_dst
else:
data.negative_head = neg_src
data.negative_tail = neg_dst
minibatch.negative_head = neg_src
minibatch.negative_tail = neg_dst
......@@ -6,7 +6,6 @@ from typing import Dict
from torchdata.datapipes.iter import Mapper
from .base import etype_str_to_tuple
from .data_block import LinkPredictionBlock, NodeClassificationBlock
from .utils import unique_and_compact
......@@ -28,24 +27,30 @@ class SubgraphSampler(Mapper):
"""
super().__init__(datapipe, self._sample)
def _sample(self, data):
if isinstance(data, LinkPredictionBlock):
def _sample(self, minibatch):
if minibatch.node_pair is not None:
(
seeds,
data.compacted_node_pair,
data.compacted_negative_head,
data.compacted_negative_tail,
) = self._link_prediction_preprocess(data)
elif isinstance(data, NodeClassificationBlock):
seeds = data.seed_node
minibatch.compacted_node_pair,
minibatch.compacted_negative_head,
minibatch.compacted_negative_tail,
) = self._node_pair_preprocess(minibatch)
elif minibatch.seed_node is not None:
seeds = minibatch.seed_node
else:
raise TypeError(f"Unsupported type of data {data}.")
data.input_nodes, data.sampled_subgraphs = self._sample_subgraphs(seeds)
return data
raise ValueError(
f"Invalid minibatch {minibatch}: Either 'node_pair' or \
'seed_node' should have a value."
)
(
minibatch.input_nodes,
minibatch.sampled_subgraphs,
) = self._sample_subgraphs(seeds)
return minibatch
def _link_prediction_preprocess(self, data):
node_pair = data.node_pair
neg_src, neg_dst = data.negative_head, data.negative_tail
def _node_pair_preprocess(self, minibatch):
node_pair = minibatch.node_pair
neg_src, neg_dst = minibatch.negative_head, minibatch.negative_tail
has_neg_src = neg_src is not None
has_neg_dst = neg_dst is not None
is_heterogeneous = isinstance(node_pair, Dict)
......
......@@ -8,9 +8,14 @@ import scipy.sparse as sp
import torch
def to_node_block(data):
block = gb.NodeClassificationBlock(seed_node=data)
return block
def minibatch_node_collator(data):
minibatch = gb.MiniBatch(seed_node=data)
return minibatch
def minibatch_link_collator(data):
minibatch = gb.MiniBatch(node_pair=data)
return minibatch
def rand_csc_graph(N, density):
......
......@@ -5,10 +5,6 @@ import torch
from torchdata.datapipes.iter import Mapper
def to_data_block(data):
return gb.LinkPredictionBlock(node_pair=data)
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
def test_NegativeSampler_Independent_Format(negative_ratio):
# Construct CSCSamplingGraph.
......@@ -22,10 +18,12 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(item_sampler, to_data_block)
minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_link_collator
)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
data_block_converter,
minibatch_converter,
negative_ratio,
gb.LinkPredictionEdgeFormat.INDEPENDENT,
graph,
......@@ -55,10 +53,12 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(item_sampler, to_data_block)
minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_link_collator
)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
data_block_converter,
minibatch_converter,
negative_ratio,
gb.LinkPredictionEdgeFormat.CONDITIONED,
graph,
......@@ -91,10 +91,12 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(item_sampler, to_data_block)
minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_link_collator
)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
data_block_converter,
minibatch_converter,
negative_ratio,
gb.LinkPredictionEdgeFormat.HEAD_CONDITIONED,
graph,
......@@ -125,10 +127,12 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(item_sampler, to_data_block)
minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_link_collator
)
# Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler(
data_block_converter,
minibatch_converter,
negative_ratio,
gb.LinkPredictionEdgeFormat.TAIL_CONDITIONED,
graph,
......@@ -166,11 +170,6 @@ def get_hetero_graph():
)
def to_link_block(data):
block = gb.LinkPredictionBlock(node_pair=data)
return block
@pytest.mark.parametrize(
"format",
[
......@@ -200,8 +199,10 @@ def test_NegativeSampler_Hetero_Data(format):
)
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
data_block_converter = Mapper(item_sampler_dp, to_link_block)
minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_link_collator
)
negative_dp = gb.UniformNegativeSampler(
data_block_converter, 1, format, graph
minibatch_converter, 1, format, graph
)
assert len(list(negative_dp)) == 5
......@@ -19,8 +19,10 @@ def test_FeatureFetcher_homo():
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(item_sampler_dp, gb_test_utils.to_node_block)
sampler_dp = gb.NeighborSampler(data_block_converter, graph, fanouts)
minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_node_collator
)
sampler_dp = gb.NeighborSampler(minibatch_converter, graph, fanouts)
fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, ["a"], ["b"])
assert len(list(fetcher_dp)) == 5
......@@ -40,9 +42,7 @@ def test_FeatureFetcher_with_edges_homo():
reverse_edge_ids=torch.randint(0, graph.num_edges, (10,)),
)
)
data = gb.NodeClassificationBlock(
input_nodes=seeds, sampled_subgraphs=subgraphs
)
data = gb.MiniBatch(input_nodes=seeds, sampled_subgraphs=subgraphs)
return data
features = {}
......@@ -106,8 +106,10 @@ def test_FeatureFetcher_hetero():
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(item_sampler_dp, gb_test_utils.to_node_block)
sampler_dp = gb.NeighborSampler(data_block_converter, graph, fanouts)
minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_node_collator
)
sampler_dp = gb.NeighborSampler(minibatch_converter, graph, fanouts)
fetcher_dp = gb.FeatureFetcher(
sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]}
)
......@@ -132,9 +134,7 @@ def test_FeatureFetcher_with_edges_hetero():
reverse_edge_ids=reverse_edge_ids,
)
)
data = gb.NodeClassificationBlock(
input_nodes=seeds, sampled_subgraphs=subgraphs
)
data = gb.MiniBatch(input_nodes=seeds, sampled_subgraphs=subgraphs)
return data
features = {}
......
import os
import unittest
from functools import partial
import backend as F
......@@ -23,9 +22,11 @@ def test_DataLoader():
feature_store = dgl.graphbolt.BasicFeatureStore(features)
item_sampler = dgl.graphbolt.ItemSampler(itemset, batch_size=B)
block_converter = Mapper(item_sampler, gb_test_utils.to_node_block)
minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_node_collator
)
subgraph_sampler = dgl.graphbolt.NeighborSampler(
block_converter,
minibatch_converter,
graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)],
)
......
......@@ -7,11 +7,6 @@ import torch
from torchdata.datapipes.iter import Mapper
def to_node_block(data):
block = dgl.graphbolt.NodeClassificationBlock(seed_node=data)
return block
def test_DataLoader():
N = 32
B = 4
......@@ -25,9 +20,11 @@ def test_DataLoader():
feature_store = dgl.graphbolt.BasicFeatureStore(features)
item_sampler = dgl.graphbolt.ItemSampler(itemset, batch_size=B)
block_converter = Mapper(item_sampler, to_node_block)
minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_node_collator
)
subgraph_sampler = dgl.graphbolt.NeighborSampler(
block_converter,
minibatch_converter,
graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)],
)
......
......@@ -6,11 +6,6 @@ import torchdata.datapipes as dp
from torchdata.datapipes.iter import Mapper
def to_node_block(data):
block = gb.NodeClassificationBlock(seed_node=data)
return block
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15)
......@@ -18,14 +13,16 @@ def test_SubgraphSampler_Node(labor):
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(item_sampler_dp, to_node_block)
minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_node_collator
)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
sampler_dp = Sampler(data_block_converter, graph, fanouts)
sampler_dp = Sampler(minibatch_converter, graph, fanouts)
assert len(list(sampler_dp)) == 5
def to_link_block(data):
block = gb.LinkPredictionBlock(node_pair=data)
def to_link_batch(data):
block = gb.MiniBatch(node_pair=data)
return block
......@@ -41,9 +38,11 @@ def test_SubgraphSampler_Link(labor):
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(item_sampler_dp, to_link_block)
minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_link_collator
)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(data_block_converter, graph, fanouts)
neighbor_dp = Sampler(minibatch_converter, graph, fanouts)
assert len(list(neighbor_dp)) == 5
......@@ -68,9 +67,11 @@ def test_SubgraphSampler_Link_With_Negative(format, labor):
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(item_sampler_dp, to_link_block)
minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_link_collator
)
negative_dp = gb.UniformNegativeSampler(
data_block_converter, 1, format, graph
minibatch_converter, 1, format, graph
)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(negative_dp, graph, fanouts)
......@@ -122,9 +123,11 @@ def test_SubgraphSampler_Link_Hetero(labor):
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(item_sampler_dp, to_link_block)
minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_link_collator
)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(data_block_converter, graph, fanouts)
neighbor_dp = Sampler(minibatch_converter, graph, fanouts)
assert len(list(neighbor_dp)) == 5
......@@ -160,9 +163,11 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor):
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(item_sampler_dp, to_link_block)
minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_link_collator
)
negative_dp = gb.UniformNegativeSampler(
data_block_converter, 1, format, graph
minibatch_converter, 1, format, graph
)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(negative_dp, graph, fanouts)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment