Unverified Commit 86f739b3 authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Graphbolt] Change data_block to mini_batch (#6256)

parent f281959a
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from .._ffi import libinfo from .._ffi import libinfo
from .base import * from .base import *
from .data_block import * from .minibatch import *
from .data_format import * from .data_format import *
from .dataloader import * from .dataloader import *
from .dataset import * from .dataset import *
......
...@@ -53,9 +53,9 @@ class NeighborSampler(SubgraphSampler): ...@@ -53,9 +53,9 @@ class NeighborSampler(SubgraphSampler):
------- -------
>>> import dgl.graphbolt as gb >>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper >>> from torchdata.datapipes.iter import Mapper
>>> def to_link_block(data): >>> def minibatch_link_collator(data):
... block = gb.LinkPredictionBlock(node_pair=data) ... minibatch = gb.MiniBatch(node_pair=data)
... return block ... return minibatch
... ...
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8]) >>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
...@@ -67,9 +67,10 @@ class NeighborSampler(SubgraphSampler): ...@@ -67,9 +67,10 @@ class NeighborSampler(SubgraphSampler):
>>> item_sampler = gb.ItemSampler( >>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1, ...item_set, batch_size=1,
...) ...)
>>> data_block_converter = Mapper(item_sampler, to_link_block) >>> minibatch_converter = Mapper(item_sampler,
...minibatch_link_collator)
>>> neg_sampler = gb.UniformNegativeSampler( >>> neg_sampler = gb.UniformNegativeSampler(
...data_block_converter, 2, data_format, graph) ...minibatch_converter, 2, data_format, graph)
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]), >>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
...torch.LongTensor([15])] ...torch.LongTensor([15])]
>>> subgraph_sampler = gb.NeighborSampler( >>> subgraph_sampler = gb.NeighborSampler(
...@@ -164,9 +165,9 @@ class LayerNeighborSampler(NeighborSampler): ...@@ -164,9 +165,9 @@ class LayerNeighborSampler(NeighborSampler):
------- -------
>>> import dgl.graphbolt as gb >>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper >>> from torchdata.datapipes.iter import Mapper
>>> def to_link_block(data): >>> def minibatch_link_collator(data):
... block = gb.LinkPredictionBlock(node_pair=data) ... minibatch = gb.MiniBatch(node_pair=data)
... return block ... return minibatch
... ...
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
>>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8]) >>> indptr = torch.LongTensor([0, 2, 4, 5, 6, 7 ,8])
...@@ -178,9 +179,10 @@ class LayerNeighborSampler(NeighborSampler): ...@@ -178,9 +179,10 @@ class LayerNeighborSampler(NeighborSampler):
>>> item_sampler = gb.ItemSampler( >>> item_sampler = gb.ItemSampler(
...item_set, batch_size=1, ...item_set, batch_size=1,
...) ...)
>>> data_block_converter = Mapper(item_sampler, to_link_block) >>> minibatch_converter = Mapper(item_sampler,
...minibatch_link_collator)
>>> neg_sampler = gb.UniformNegativeSampler( >>> neg_sampler = gb.UniformNegativeSampler(
...data_block_converter, 2, data_format, graph) ...minibatch_converter, 2, data_format, graph)
>>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]), >>> fanouts = [torch.LongTensor([5]), torch.LongTensor([10]),
...torch.LongTensor([15])] ...torch.LongTensor([15])]
>>> subgraph_sampler = gb.LayerNeighborSampler( >>> subgraph_sampler = gb.LayerNeighborSampler(
......
...@@ -7,11 +7,11 @@ import torch ...@@ -7,11 +7,11 @@ import torch
from .sampled_subgraph import SampledSubgraph from .sampled_subgraph import SampledSubgraph
__all__ = ["DataBlock", "NodeClassificationBlock", "LinkPredictionBlock"] __all__ = ["MiniBatch"]
@dataclass @dataclass
class DataBlock: class MiniBatch:
r"""A composite data class for data structure in the graphbolt. It is r"""A composite data class for data structure in the graphbolt. It is
designed to facilitate the exchange of data among different components designed to facilitate the exchange of data among different components
involved in processing data. The purpose of this class is to unify the involved in processing data. The purpose of this class is to unify the
...@@ -52,12 +52,6 @@ class DataBlock: ...@@ -52,12 +52,6 @@ class DataBlock:
value should be corresponding heterogeneous node id. value should be corresponding heterogeneous node id.
""" """
@dataclass
class NodeClassificationBlock(DataBlock):
r"""A subclass of 'UnifiedDataStruct', specialized for handling node level
tasks."""
seed_node: Union[torch.Tensor, Dict[str, torch.Tensor]] = None seed_node: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
""" """
Representation of seed nodes used for sampling in the graph. Representation of seed nodes used for sampling in the graph.
...@@ -69,17 +63,12 @@ class NodeClassificationBlock(DataBlock): ...@@ -69,17 +63,12 @@ class NodeClassificationBlock(DataBlock):
label: Union[torch.Tensor, Dict[str, torch.Tensor]] = None label: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
""" """
Labels associated with seed nodes in the graph. Labels associated with seed nodes in the graph.
- If `label` is a tensor: It indicates the graph is homogeneous. - If `label` is a tensor: It indicates the graph is homogeneous. The value
- If `label` is a dictionary: The keys should be node type and the should be corresponding labels to given 'seed_node' or 'node_pair'.
value should be corresponding node labels to given 'seed_node'. - If `label` is a dictionary: The keys should be node or edge type and the
value should be corresponding labels to given 'seed_node' or 'node_pair'.
""" """
@dataclass
class LinkPredictionBlock(DataBlock):
r"""A subclass of 'UnifiedDataStruct', specialized for handling edge level
tasks."""
node_pair: Union[ node_pair: Union[
Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor],
Dict[str, Tuple[torch.Tensor, torch.Tensor]], Dict[str, Tuple[torch.Tensor, torch.Tensor]],
...@@ -93,15 +82,6 @@ class LinkPredictionBlock(DataBlock): ...@@ -93,15 +82,6 @@ class LinkPredictionBlock(DataBlock):
type. type.
""" """
label: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Labels associated with the link prediction task.
- If `label` is a tensor: It indicates a homogeneous graph. The value are
edge labels corresponding to given 'node_pair'.
- If `label` is a dictionary: The keys should be edge type, and the value
should correspond to given 'node_pair'.
"""
negative_head: Union[torch.Tensor, Dict[str, torch.Tensor]] = None negative_head: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
""" """
Representation of negative samples for the head nodes in the link Representation of negative samples for the head nodes in the link
......
...@@ -30,49 +30,50 @@ class NegativeSampler(Mapper): ...@@ -30,49 +30,50 @@ class NegativeSampler(Mapper):
negative_ratio : int negative_ratio : int
The proportion of negative samples to positive samples. The proportion of negative samples to positive samples.
output_format : LinkPredictionEdgeFormat output_format : LinkPredictionEdgeFormat
Determines the edge format of the output data. Determines the edge format of the output minibatch.
""" """
super().__init__(datapipe, self._sample) super().__init__(datapipe, self._sample)
assert negative_ratio > 0, "Negative_ratio should be positive Integer." assert negative_ratio > 0, "Negative_ratio should be positive Integer."
self.negative_ratio = negative_ratio self.negative_ratio = negative_ratio
self.output_format = output_format self.output_format = output_format
def _sample(self, data): def _sample(self, minibatch):
""" """
Generate a mix of positive and negative samples. Generate a mix of positive and negative samples.
Parameters Parameters
---------- ----------
data : LinkPredictionBlock minibatch : MiniBatch
An instance of 'LinkPredictionBlock' class requires the 'node_pair' An instance of 'MiniBatch' class requires the 'node_pair' field.
field. This function is responsible for generating negative edges This function is responsible for generating negative edges
corresponding to the positive edges defined by the 'node_pair'. In corresponding to the positive edges defined by the 'node_pair'. In
cases where negative edges already exist, this function will cases where negative edges already exist, this function will
overwrite them. overwrite them.
Returns Returns
------- -------
LinkPredictionBlock MiniBatch
An instance of 'LinkPredictionBlock' encompasses both positive and An instance of 'MiniBatch' encompasses both positive and negative
negative samples. samples.
""" """
node_pairs = data.node_pair node_pairs = minibatch.node_pair
assert node_pairs is not None
if isinstance(node_pairs, Mapping): if isinstance(node_pairs, Mapping):
if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT: if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT:
data.label = {} minibatch.label = {}
else: else:
data.negative_head, data.negative_tail = {}, {} minibatch.negative_head, minibatch.negative_tail = {}, {}
for etype, pos_pairs in node_pairs.items(): for etype, pos_pairs in node_pairs.items():
self._collate( self._collate(
data, self._sample_with_etype(pos_pairs, etype), etype minibatch, self._sample_with_etype(pos_pairs, etype), etype
) )
if self.output_format == LinkPredictionEdgeFormat.HEAD_CONDITIONED: if self.output_format == LinkPredictionEdgeFormat.HEAD_CONDITIONED:
data.negative_tail = None minibatch.negative_tail = None
if self.output_format == LinkPredictionEdgeFormat.TAIL_CONDITIONED: if self.output_format == LinkPredictionEdgeFormat.TAIL_CONDITIONED:
data.negative_head = None minibatch.negative_head = None
else: else:
self._collate(data, self._sample_with_etype(node_pairs)) self._collate(minibatch, self._sample_with_etype(node_pairs))
return data return minibatch
def _sample_with_etype(self, node_pairs, etype=None): def _sample_with_etype(self, node_pairs, etype=None):
"""Generate negative pairs for a given etype form positive pairs """Generate negative pairs for a given etype form positive pairs
...@@ -94,13 +95,13 @@ class NegativeSampler(Mapper): ...@@ -94,13 +95,13 @@ class NegativeSampler(Mapper):
""" """
raise NotImplementedError raise NotImplementedError
def _collate(self, data, neg_pairs, etype=None): def _collate(self, minibatch, neg_pairs, etype=None):
"""Collates positive and negative samples into data. """Collates positive and negative samples into minibatch.
Parameters Parameters
---------- ----------
data : LinkPredictionBlock minibatch : MiniBatch
The input data, which contains positive node pairs, will be filled The input minibatch, which contains positive node pairs, will be filled
with negative information in this function. with negative information in this function.
neg_pairs : Tuple[Tensor, Tensor] neg_pairs : Tuple[Tensor, Tensor]
A tuple of tensors represents source-destination node pairs of A tuple of tensors represents source-destination node pairs of
...@@ -110,7 +111,9 @@ class NegativeSampler(Mapper): ...@@ -110,7 +111,9 @@ class NegativeSampler(Mapper):
Canonical edge type. Canonical edge type.
""" """
pos_src, pos_dst = ( pos_src, pos_dst = (
data.node_pair[etype] if etype is not None else data.node_pair minibatch.node_pair[etype]
if etype is not None
else minibatch.node_pair
) )
neg_src, neg_dst = neg_pairs neg_src, neg_dst = neg_pairs
if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT: if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT:
...@@ -120,11 +123,11 @@ class NegativeSampler(Mapper): ...@@ -120,11 +123,11 @@ class NegativeSampler(Mapper):
dst = torch.cat([pos_dst, neg_dst]) dst = torch.cat([pos_dst, neg_dst])
label = torch.cat([pos_label, neg_label]) label = torch.cat([pos_label, neg_label])
if etype is not None: if etype is not None:
data.node_pair[etype] = (src, dst) minibatch.node_pair[etype] = (src, dst)
data.label[etype] = label minibatch.label[etype] = label
else: else:
data.node_pair = (src, dst) minibatch.node_pair = (src, dst)
data.label = label minibatch.label = label
else: else:
if self.output_format == LinkPredictionEdgeFormat.CONDITIONED: if self.output_format == LinkPredictionEdgeFormat.CONDITIONED:
neg_src = neg_src.view(-1, self.negative_ratio) neg_src = neg_src.view(-1, self.negative_ratio)
...@@ -144,8 +147,8 @@ class NegativeSampler(Mapper): ...@@ -144,8 +147,8 @@ class NegativeSampler(Mapper):
f"Unsupported output format {self.output_format}." f"Unsupported output format {self.output_format}."
) )
if etype is not None: if etype is not None:
data.negative_head[etype] = neg_src minibatch.negative_head[etype] = neg_src
data.negative_tail[etype] = neg_dst minibatch.negative_tail[etype] = neg_dst
else: else:
data.negative_head = neg_src minibatch.negative_head = neg_src
data.negative_tail = neg_dst minibatch.negative_tail = neg_dst
...@@ -6,7 +6,6 @@ from typing import Dict ...@@ -6,7 +6,6 @@ from typing import Dict
from torchdata.datapipes.iter import Mapper from torchdata.datapipes.iter import Mapper
from .base import etype_str_to_tuple from .base import etype_str_to_tuple
from .data_block import LinkPredictionBlock, NodeClassificationBlock
from .utils import unique_and_compact from .utils import unique_and_compact
...@@ -28,24 +27,30 @@ class SubgraphSampler(Mapper): ...@@ -28,24 +27,30 @@ class SubgraphSampler(Mapper):
""" """
super().__init__(datapipe, self._sample) super().__init__(datapipe, self._sample)
def _sample(self, data): def _sample(self, minibatch):
if isinstance(data, LinkPredictionBlock): if minibatch.node_pair is not None:
( (
seeds, seeds,
data.compacted_node_pair, minibatch.compacted_node_pair,
data.compacted_negative_head, minibatch.compacted_negative_head,
data.compacted_negative_tail, minibatch.compacted_negative_tail,
) = self._link_prediction_preprocess(data) ) = self._node_pair_preprocess(minibatch)
elif isinstance(data, NodeClassificationBlock): elif minibatch.seed_node is not None:
seeds = data.seed_node seeds = minibatch.seed_node
else: else:
raise TypeError(f"Unsupported type of data {data}.") raise ValueError(
data.input_nodes, data.sampled_subgraphs = self._sample_subgraphs(seeds) f"Invalid minibatch {minibatch}: Either 'node_pair' or \
return data 'seed_node' should have a value."
)
(
minibatch.input_nodes,
minibatch.sampled_subgraphs,
) = self._sample_subgraphs(seeds)
return minibatch
def _link_prediction_preprocess(self, data): def _node_pair_preprocess(self, minibatch):
node_pair = data.node_pair node_pair = minibatch.node_pair
neg_src, neg_dst = data.negative_head, data.negative_tail neg_src, neg_dst = minibatch.negative_head, minibatch.negative_tail
has_neg_src = neg_src is not None has_neg_src = neg_src is not None
has_neg_dst = neg_dst is not None has_neg_dst = neg_dst is not None
is_heterogeneous = isinstance(node_pair, Dict) is_heterogeneous = isinstance(node_pair, Dict)
......
...@@ -8,9 +8,14 @@ import scipy.sparse as sp ...@@ -8,9 +8,14 @@ import scipy.sparse as sp
import torch import torch
def to_node_block(data): def minibatch_node_collator(data):
block = gb.NodeClassificationBlock(seed_node=data) minibatch = gb.MiniBatch(seed_node=data)
return block return minibatch
def minibatch_link_collator(data):
minibatch = gb.MiniBatch(node_pair=data)
return minibatch
def rand_csc_graph(N, density): def rand_csc_graph(N, density):
......
...@@ -5,10 +5,6 @@ import torch ...@@ -5,10 +5,6 @@ import torch
from torchdata.datapipes.iter import Mapper from torchdata.datapipes.iter import Mapper
def to_data_block(data):
return gb.LinkPredictionBlock(node_pair=data)
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20]) @pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
def test_NegativeSampler_Independent_Format(negative_ratio): def test_NegativeSampler_Independent_Format(negative_ratio):
# Construct CSCSamplingGraph. # Construct CSCSamplingGraph.
...@@ -22,10 +18,12 @@ def test_NegativeSampler_Independent_Format(negative_ratio): ...@@ -22,10 +18,12 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
) )
batch_size = 10 batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(item_sampler, to_data_block) minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_link_collator
)
# Construct NegativeSampler. # Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
data_block_converter, minibatch_converter,
negative_ratio, negative_ratio,
gb.LinkPredictionEdgeFormat.INDEPENDENT, gb.LinkPredictionEdgeFormat.INDEPENDENT,
graph, graph,
...@@ -55,10 +53,12 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio): ...@@ -55,10 +53,12 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
) )
batch_size = 10 batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(item_sampler, to_data_block) minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_link_collator
)
# Construct NegativeSampler. # Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
data_block_converter, minibatch_converter,
negative_ratio, negative_ratio,
gb.LinkPredictionEdgeFormat.CONDITIONED, gb.LinkPredictionEdgeFormat.CONDITIONED,
graph, graph,
...@@ -91,10 +91,12 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio): ...@@ -91,10 +91,12 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
) )
batch_size = 10 batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(item_sampler, to_data_block) minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_link_collator
)
# Construct NegativeSampler. # Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
data_block_converter, minibatch_converter,
negative_ratio, negative_ratio,
gb.LinkPredictionEdgeFormat.HEAD_CONDITIONED, gb.LinkPredictionEdgeFormat.HEAD_CONDITIONED,
graph, graph,
...@@ -125,10 +127,12 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio): ...@@ -125,10 +127,12 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
) )
batch_size = 10 batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size) item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
data_block_converter = Mapper(item_sampler, to_data_block) minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_link_collator
)
# Construct NegativeSampler. # Construct NegativeSampler.
negative_sampler = gb.UniformNegativeSampler( negative_sampler = gb.UniformNegativeSampler(
data_block_converter, minibatch_converter,
negative_ratio, negative_ratio,
gb.LinkPredictionEdgeFormat.TAIL_CONDITIONED, gb.LinkPredictionEdgeFormat.TAIL_CONDITIONED,
graph, graph,
...@@ -166,11 +170,6 @@ def get_hetero_graph(): ...@@ -166,11 +170,6 @@ def get_hetero_graph():
) )
def to_link_block(data):
block = gb.LinkPredictionBlock(node_pair=data)
return block
@pytest.mark.parametrize( @pytest.mark.parametrize(
"format", "format",
[ [
...@@ -200,8 +199,10 @@ def test_NegativeSampler_Hetero_Data(format): ...@@ -200,8 +199,10 @@ def test_NegativeSampler_Hetero_Data(format):
) )
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
data_block_converter = Mapper(item_sampler_dp, to_link_block) minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_link_collator
)
negative_dp = gb.UniformNegativeSampler( negative_dp = gb.UniformNegativeSampler(
data_block_converter, 1, format, graph minibatch_converter, 1, format, graph
) )
assert len(list(negative_dp)) == 5 assert len(list(negative_dp)) == 5
...@@ -19,8 +19,10 @@ def test_FeatureFetcher_homo(): ...@@ -19,8 +19,10 @@ def test_FeatureFetcher_homo():
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(item_sampler_dp, gb_test_utils.to_node_block) minibatch_converter = Mapper(
sampler_dp = gb.NeighborSampler(data_block_converter, graph, fanouts) item_sampler_dp, gb_test_utils.minibatch_node_collator
)
sampler_dp = gb.NeighborSampler(minibatch_converter, graph, fanouts)
fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, ["a"], ["b"]) fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, ["a"], ["b"])
assert len(list(fetcher_dp)) == 5 assert len(list(fetcher_dp)) == 5
...@@ -40,9 +42,7 @@ def test_FeatureFetcher_with_edges_homo(): ...@@ -40,9 +42,7 @@ def test_FeatureFetcher_with_edges_homo():
reverse_edge_ids=torch.randint(0, graph.num_edges, (10,)), reverse_edge_ids=torch.randint(0, graph.num_edges, (10,)),
) )
) )
data = gb.NodeClassificationBlock( data = gb.MiniBatch(input_nodes=seeds, sampled_subgraphs=subgraphs)
input_nodes=seeds, sampled_subgraphs=subgraphs
)
return data return data
features = {} features = {}
...@@ -106,8 +106,10 @@ def test_FeatureFetcher_hetero(): ...@@ -106,8 +106,10 @@ def test_FeatureFetcher_hetero():
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(item_sampler_dp, gb_test_utils.to_node_block) minibatch_converter = Mapper(
sampler_dp = gb.NeighborSampler(data_block_converter, graph, fanouts) item_sampler_dp, gb_test_utils.minibatch_node_collator
)
sampler_dp = gb.NeighborSampler(minibatch_converter, graph, fanouts)
fetcher_dp = gb.FeatureFetcher( fetcher_dp = gb.FeatureFetcher(
sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]} sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]}
) )
...@@ -132,9 +134,7 @@ def test_FeatureFetcher_with_edges_hetero(): ...@@ -132,9 +134,7 @@ def test_FeatureFetcher_with_edges_hetero():
reverse_edge_ids=reverse_edge_ids, reverse_edge_ids=reverse_edge_ids,
) )
) )
data = gb.NodeClassificationBlock( data = gb.MiniBatch(input_nodes=seeds, sampled_subgraphs=subgraphs)
input_nodes=seeds, sampled_subgraphs=subgraphs
)
return data return data
features = {} features = {}
......
import os import os
import unittest import unittest
from functools import partial
import backend as F import backend as F
...@@ -23,9 +22,11 @@ def test_DataLoader(): ...@@ -23,9 +22,11 @@ def test_DataLoader():
feature_store = dgl.graphbolt.BasicFeatureStore(features) feature_store = dgl.graphbolt.BasicFeatureStore(features)
item_sampler = dgl.graphbolt.ItemSampler(itemset, batch_size=B) item_sampler = dgl.graphbolt.ItemSampler(itemset, batch_size=B)
block_converter = Mapper(item_sampler, gb_test_utils.to_node_block) minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_node_collator
)
subgraph_sampler = dgl.graphbolt.NeighborSampler( subgraph_sampler = dgl.graphbolt.NeighborSampler(
block_converter, minibatch_converter,
graph, graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)], fanouts=[torch.LongTensor([2]) for _ in range(2)],
) )
......
...@@ -7,11 +7,6 @@ import torch ...@@ -7,11 +7,6 @@ import torch
from torchdata.datapipes.iter import Mapper from torchdata.datapipes.iter import Mapper
def to_node_block(data):
block = dgl.graphbolt.NodeClassificationBlock(seed_node=data)
return block
def test_DataLoader(): def test_DataLoader():
N = 32 N = 32
B = 4 B = 4
...@@ -25,9 +20,11 @@ def test_DataLoader(): ...@@ -25,9 +20,11 @@ def test_DataLoader():
feature_store = dgl.graphbolt.BasicFeatureStore(features) feature_store = dgl.graphbolt.BasicFeatureStore(features)
item_sampler = dgl.graphbolt.ItemSampler(itemset, batch_size=B) item_sampler = dgl.graphbolt.ItemSampler(itemset, batch_size=B)
block_converter = Mapper(item_sampler, to_node_block) minibatch_converter = Mapper(
item_sampler, gb_test_utils.minibatch_node_collator
)
subgraph_sampler = dgl.graphbolt.NeighborSampler( subgraph_sampler = dgl.graphbolt.NeighborSampler(
block_converter, minibatch_converter,
graph, graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)], fanouts=[torch.LongTensor([2]) for _ in range(2)],
) )
......
...@@ -6,11 +6,6 @@ import torchdata.datapipes as dp ...@@ -6,11 +6,6 @@ import torchdata.datapipes as dp
from torchdata.datapipes.iter import Mapper from torchdata.datapipes.iter import Mapper
def to_node_block(data):
block = gb.NodeClassificationBlock(seed_node=data)
return block
@pytest.mark.parametrize("labor", [False, True]) @pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_Node(labor): def test_SubgraphSampler_Node(labor):
graph = gb_test_utils.rand_csc_graph(20, 0.15) graph = gb_test_utils.rand_csc_graph(20, 0.15)
...@@ -18,14 +13,16 @@ def test_SubgraphSampler_Node(labor): ...@@ -18,14 +13,16 @@ def test_SubgraphSampler_Node(labor):
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(item_sampler_dp, to_node_block) minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_node_collator
)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
sampler_dp = Sampler(data_block_converter, graph, fanouts) sampler_dp = Sampler(minibatch_converter, graph, fanouts)
assert len(list(sampler_dp)) == 5 assert len(list(sampler_dp)) == 5
def to_link_block(data): def to_link_batch(data):
block = gb.LinkPredictionBlock(node_pair=data) block = gb.MiniBatch(node_pair=data)
return block return block
...@@ -41,9 +38,11 @@ def test_SubgraphSampler_Link(labor): ...@@ -41,9 +38,11 @@ def test_SubgraphSampler_Link(labor):
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(item_sampler_dp, to_link_block) minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_link_collator
)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(data_block_converter, graph, fanouts) neighbor_dp = Sampler(minibatch_converter, graph, fanouts)
assert len(list(neighbor_dp)) == 5 assert len(list(neighbor_dp)) == 5
...@@ -68,9 +67,11 @@ def test_SubgraphSampler_Link_With_Negative(format, labor): ...@@ -68,9 +67,11 @@ def test_SubgraphSampler_Link_With_Negative(format, labor):
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(item_sampler_dp, to_link_block) minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_link_collator
)
negative_dp = gb.UniformNegativeSampler( negative_dp = gb.UniformNegativeSampler(
data_block_converter, 1, format, graph minibatch_converter, 1, format, graph
) )
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(negative_dp, graph, fanouts) neighbor_dp = Sampler(negative_dp, graph, fanouts)
...@@ -122,9 +123,11 @@ def test_SubgraphSampler_Link_Hetero(labor): ...@@ -122,9 +123,11 @@ def test_SubgraphSampler_Link_Hetero(labor):
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(item_sampler_dp, to_link_block) minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_link_collator
)
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(data_block_converter, graph, fanouts) neighbor_dp = Sampler(minibatch_converter, graph, fanouts)
assert len(list(neighbor_dp)) == 5 assert len(list(neighbor_dp)) == 5
...@@ -160,9 +163,11 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor): ...@@ -160,9 +163,11 @@ def test_SubgraphSampler_Link_Hetero_With_Negative(format, labor):
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
data_block_converter = Mapper(item_sampler_dp, to_link_block) minibatch_converter = Mapper(
item_sampler_dp, gb_test_utils.minibatch_link_collator
)
negative_dp = gb.UniformNegativeSampler( negative_dp = gb.UniformNegativeSampler(
data_block_converter, 1, format, graph minibatch_converter, 1, format, graph
) )
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
neighbor_dp = Sampler(negative_dp, graph, fanouts) neighbor_dp = Sampler(negative_dp, graph, fanouts)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment