Unverified Commit cad7caeb authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[Graphbolt] Rewrite `to_dgl` to multiple `get` functions (#6735)

parent 70fdb69f
...@@ -98,11 +98,8 @@ def evaluate(model, dataset, device): ...@@ -98,11 +98,8 @@ def evaluate(model, dataset, device):
logits = [] logits = []
labels = [] labels = []
for step, data in enumerate(dataloader): for step, data in enumerate(dataloader):
# Convert data to DGL format for computing. # Get node pairs with labels for loss calculation.
data = data.to_dgl() compacted_pairs, label = data.node_pairs_with_labels
# Unpack MiniBatch.
compacted_pairs, label = to_binary_link_dgl_computing_pack(data)
# The features of sampled nodes. # The features of sampled nodes.
x = data.node_features["feat"] x = data.node_features["feat"]
...@@ -140,11 +137,8 @@ def train(model, dataset, device): ...@@ -140,11 +137,8 @@ def train(model, dataset, device):
# mini-batches. # mini-batches.
######################################################################## ########################################################################
for step, data in enumerate(dataloader): for step, data in enumerate(dataloader):
# Convert data to DGL format for computing. # Get node pairs with labels for loss calculation.
data = data.to_dgl() compacted_pairs, labels = data.node_pairs_with_labels
# Unpack MiniBatch.
compacted_pairs, labels = to_binary_link_dgl_computing_pack(data)
# The features of sampled nodes. # The features of sampled nodes.
x = data.node_features["feat"] x = data.node_features["feat"]
......
...@@ -57,7 +57,6 @@ def evaluate(model, dataset, itemset, device): ...@@ -57,7 +57,6 @@ def evaluate(model, dataset, itemset, device):
dataloader = create_dataloader(dataset, itemset, device) dataloader = create_dataloader(dataset, itemset, device)
for step, data in enumerate(dataloader): for step, data in enumerate(dataloader):
data = data.to_dgl()
x = data.node_features["feat"] x = data.node_features["feat"]
y.append(data.labels) y.append(data.labels)
y_hats.append(model(data.blocks, x)) y_hats.append(model(data.blocks, x))
...@@ -84,9 +83,6 @@ def train(model, dataset, device): ...@@ -84,9 +83,6 @@ def train(model, dataset, device):
# mini-batches. # mini-batches.
######################################################################## ########################################################################
for step, data in enumerate(dataloader): for step, data in enumerate(dataloader):
# Convert data to DGL format for computing.
data = data.to_dgl()
# The features of sampled nodes. # The features of sampled nodes.
x = data.node_features["feat"] x = data.node_features["feat"]
......
...@@ -363,9 +363,10 @@ class MiniBatch: ...@@ -363,9 +363,10 @@ class MiniBatch:
"""Set edge features.""" """Set edge features."""
self.edge_features = edge_features self.edge_features = edge_features
def _to_dgl_blocks(self): @property
"""Transforming a `MiniBatch` into DGL blocks necessitates constructing def blocks(self):
a graphical structure and ID mappings. """Extracts DGL blocks from `MiniBatch` to construct a graphical
structure and ID mappings.
""" """
if not self.sampled_subgraphs: if not self.sampled_subgraphs:
return None return None
...@@ -459,98 +460,135 @@ class MiniBatch: ...@@ -459,98 +460,135 @@ class MiniBatch:
block.edata[dgl.EID] = subgraph.original_edge_ids block.edata[dgl.EID] = subgraph.original_edge_ids
return blocks return blocks
def to_dgl(self): @property
"""Converting a `MiniBatch` into a DGL MiniBatch that contains def positive_node_pairs(self):
everything necessary for computation." """`positive_node_pairs` is a representation of positive graphs used for
evaluating or computing loss in link prediction tasks.
- If `positive_node_pairs` is a tuple: It indicates a homogeneous graph
containing two tensors representing source-destination node pairs.
- If `positive_node_pairs` is a dictionary: The keys should be edge type,
and the value should be a tuple of tensors representing node pairs of the
given type.
""" """
minibatch = DGLMiniBatch( return self.compacted_node_pairs
blocks=self._to_dgl_blocks(),
node_features=self.node_features, @property
edge_features=self.edge_features, def negative_node_pairs(self):
labels=self.labels, """`negative_node_pairs` is a representation of negative graphs used for
) evaluating or computing loss in link prediction tasks.
# Need input nodes to fetch feature. - If `negative_node_pairs` is a tuple: It indicates a homogeneous graph
if self.node_features is None: containing two tensors representing source-destination node pairs.
minibatch.input_nodes = self.input_nodes - If `negative_node_pairs` is a dictionary: The keys should be edge type,
# Need output nodes to fetch label. and the value should be a tuple of tensors representing node pairs of the
if self.labels is None: given type.
minibatch.output_nodes = self.seed_nodes """
assert ( # Build negative graph.
minibatch.blocks is not None if (
), "Sampled subgraphs for computation are missing." self.compacted_negative_srcs is not None
and self.compacted_negative_dsts is not None
# For link prediction tasks. ):
if self.compacted_node_pairs is not None: # For homogeneous graph.
minibatch.positive_node_pairs = self.compacted_node_pairs if isinstance(self.compacted_negative_srcs, torch.Tensor):
# Build negative graph. negative_node_pairs = (
if ( self.compacted_negative_srcs.view(-1),
self.compacted_negative_srcs is not None self.compacted_negative_dsts.view(-1),
and self.compacted_negative_dsts is not None )
): # For heterogeneous graph.
# For homogeneous graph. else:
if isinstance(self.compacted_negative_srcs, torch.Tensor): negative_node_pairs = {
minibatch.negative_node_pairs = ( etype: (
self.compacted_negative_srcs.view(-1), neg_src.view(-1),
self.compacted_negative_dsts.view(-1), self.compacted_negative_dsts[etype].view(-1),
) )
# For heterogeneous graph. for etype, neg_src in self.compacted_negative_srcs.items()
else: }
minibatch.negative_node_pairs = { elif (
etype: ( self.compacted_negative_srcs is not None
neg_src.view(-1), and self.compacted_node_pairs is not None
self.compacted_negative_dsts[etype].view(-1), ):
) # For homogeneous graph.
for etype, neg_src in self.compacted_negative_srcs.items() if isinstance(self.compacted_negative_srcs, torch.Tensor):
} negative_ratio = self.compacted_negative_srcs.size(1)
elif self.compacted_negative_srcs is not None: negative_node_pairs = (
# For homogeneous graph. self.compacted_negative_srcs.view(-1),
if isinstance(self.compacted_negative_srcs, torch.Tensor): self.compacted_node_pairs[1].repeat_interleave(
negative_ratio = self.compacted_negative_srcs.size(1) negative_ratio
minibatch.negative_node_pairs = ( ),
self.compacted_negative_srcs.view(-1), )
self.compacted_node_pairs[1].repeat_interleave( # For heterogeneous graph.
else:
negative_ratio = list(self.compacted_negative_srcs.values())[
0
].size(1)
negative_node_pairs = {
etype: (
neg_src.view(-1),
self.compacted_node_pairs[etype][1].repeat_interleave(
negative_ratio negative_ratio
), ),
) )
# For heterogeneous graph. for etype, neg_src in self.compacted_negative_srcs.items()
else: }
negative_ratio = list( elif (
self.compacted_negative_srcs.values() self.compacted_negative_dsts is not None
)[0].size(1) and self.compacted_node_pairs is not None
minibatch.negative_node_pairs = { ):
etype: ( # For homogeneous graph.
neg_src.view(-1), if isinstance(self.compacted_negative_dsts, torch.Tensor):
self.compacted_node_pairs[etype][ negative_ratio = self.compacted_negative_dsts.size(1)
1 negative_node_pairs = (
].repeat_interleave(negative_ratio), self.compacted_node_pairs[0].repeat_interleave(
) negative_ratio
for etype, neg_src in self.compacted_negative_srcs.items() ),
} self.compacted_negative_dsts.view(-1),
elif self.compacted_negative_dsts is not None: )
# For homogeneous graph. # For heterogeneous graph.
if isinstance(self.compacted_negative_dsts, torch.Tensor): else:
negative_ratio = self.compacted_negative_dsts.size(1) negative_ratio = list(self.compacted_negative_dsts.values())[
minibatch.negative_node_pairs = ( 0
self.compacted_node_pairs[0].repeat_interleave( ].size(1)
negative_node_pairs = {
etype: (
self.compacted_node_pairs[etype][0].repeat_interleave(
negative_ratio negative_ratio
), ),
self.compacted_negative_dsts.view(-1), neg_dst.view(-1),
) )
# For heterogeneous graph. for etype, neg_dst in self.compacted_negative_dsts.items()
else: }
negative_ratio = list( else:
self.compacted_negative_dsts.values() negative_node_pairs = None
)[0].size(1) return negative_node_pairs
minibatch.negative_node_pairs = {
etype: ( @property
self.compacted_node_pairs[etype][ def node_pairs_with_labels(self):
0 """Get a node pair tensor and a label tensor from MiniBatch. They are
].repeat_interleave(negative_ratio), used for evaluating or computing loss. It will return
neg_dst.view(-1), `(node_pairs, labels)` as result.
) - If it's a link prediction task, `node_pairs` will contain both
for etype, neg_dst in self.compacted_negative_dsts.items() negative and positive node pairs and `labels` will consist of 0 and 1,
} indicating whether the corresponding node pair is negative or positive.
return minibatch - If it's an edge classification task, this function will directly
return `compacted_node_pairs` and corresponding `labels`.
- Otherwise it will return None.
"""
if self.labels is None:
positive_node_pairs = self.positive_node_pairs
negative_node_pairs = self.negative_node_pairs
if positive_node_pairs is None or negative_node_pairs is None:
return None
pos_src, pos_dst = positive_node_pairs
neg_src, neg_dst = negative_node_pairs
node_pairs = (
torch.cat((pos_src, neg_src), dim=0),
torch.cat((pos_dst, neg_dst), dim=0),
)
pos_label = torch.ones_like(pos_src)
neg_label = torch.zeros_like(neg_src)
labels = torch.cat([pos_label, neg_label], dim=0)
return (node_pairs, labels.float())
else:
return (self.compacted_node_pairs, self.labels)
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
"""Copy `MiniBatch` to the specified device using reflection.""" """Copy `MiniBatch` to the specified device using reflection."""
...@@ -561,13 +599,16 @@ class MiniBatch: ...@@ -561,13 +599,16 @@ class MiniBatch:
for attr in dir(self): for attr in dir(self):
# Only copy member variables. # Only copy member variables.
if not callable(getattr(self, attr)) and not attr.startswith("__"): if not callable(getattr(self, attr)) and not attr.startswith("__"):
setattr( try:
self, setattr(
attr, self,
recursive_apply( attr,
getattr(self, attr), lambda x: _to(x, device) recursive_apply(
), getattr(self, attr), lambda x: _to(x, device)
) ),
)
except AttributeError:
continue
return self return self
......
...@@ -8,7 +8,6 @@ from .minibatch import DGLMiniBatch, MiniBatch ...@@ -8,7 +8,6 @@ from .minibatch import DGLMiniBatch, MiniBatch
__all__ = [ __all__ = [
"MiniBatchTransformer", "MiniBatchTransformer",
"DGLMiniBatchConverter",
] ]
...@@ -41,22 +40,3 @@ class MiniBatchTransformer(Mapper): ...@@ -41,22 +40,3 @@ class MiniBatchTransformer(Mapper):
minibatch, (MiniBatch, DGLMiniBatch) minibatch, (MiniBatch, DGLMiniBatch)
), "The transformer output should be an instance of MiniBatch" ), "The transformer output should be an instance of MiniBatch"
return minibatch return minibatch
@functional_datapipe("to_dgl")
class DGLMiniBatchConverter(Mapper):
"""Convert a graphbolt mini-batch to a dgl mini-batch.
Functional name: :obj:`to_dgl`.
Parameters
----------
datapipe : DataPipe
The datapipe.
"""
def __init__(
self,
datapipe,
):
super().__init__(datapipe, MiniBatch.to_dgl)
...@@ -163,9 +163,12 @@ def test_minibatch_representation_homo(): ...@@ -163,9 +163,12 @@ def test_minibatch_representation_homo():
expect_result = str( expect_result = str(
"""MiniBatch(seed_nodes=None, """MiniBatch(seed_nodes=None,
sampled_subgraphs=None, sampled_subgraphs=None,
positive_node_pairs=None,
node_pairs_with_labels=None,
node_pairs=None, node_pairs=None,
node_features=None, node_features=None,
negative_srcs=None, negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None, negative_dsts=None,
labels=None, labels=None,
input_nodes=None, input_nodes=None,
...@@ -173,6 +176,7 @@ def test_minibatch_representation_homo(): ...@@ -173,6 +176,7 @@ def test_minibatch_representation_homo():
compacted_node_pairs=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None, compacted_negative_dsts=None,
blocks=None,
)""" )"""
) )
result = str(minibatch) result = str(minibatch)
...@@ -207,6 +211,13 @@ def test_minibatch_representation_homo(): ...@@ -207,6 +211,13 @@ def test_minibatch_representation_homo():
indices=tensor([1, 2, 0]), indices=tensor([1, 2, 0]),
), ),
)], )],
positive_node_pairs=CSCFormatBase(indptr=tensor([0, 2, 3]),
indices=tensor([3, 4, 5]),
),
node_pairs_with_labels=(CSCFormatBase(indptr=tensor([0, 2, 3]),
indices=tensor([3, 4, 5]),
),
tensor([0., 1., 2.])),
node_pairs=[CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6]), node_pairs=[CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6]),
indices=tensor([0, 1, 2, 2, 1, 2]), indices=tensor([0, 1, 2, 2, 1, 2]),
), ),
...@@ -217,6 +228,8 @@ def test_minibatch_representation_homo(): ...@@ -217,6 +228,8 @@ def test_minibatch_representation_homo():
negative_srcs=tensor([[8], negative_srcs=tensor([[8],
[1], [1],
[6]]), [6]]),
negative_node_pairs=(tensor([0, 1, 2]),
tensor([6, 0, 0])),
negative_dsts=tensor([[2], negative_dsts=tensor([[2],
[8], [8],
[8]]), [8]]),
...@@ -233,6 +246,8 @@ def test_minibatch_representation_homo(): ...@@ -233,6 +246,8 @@ def test_minibatch_representation_homo():
compacted_negative_dsts=tensor([[6], compacted_negative_dsts=tensor([[6],
[0], [0],
[0]]), [0]]),
blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=6),
Block(num_src_nodes=3, num_dst_nodes=2, num_edges=3)],
)""" )"""
) )
result = str(minibatch) result = str(minibatch)
...@@ -307,7 +322,7 @@ def test_minibatch_representation_hetero(): ...@@ -307,7 +322,7 @@ def test_minibatch_representation_hetero():
} }
compacted_negative_srcs = {relation: torch.tensor([[0], [1], [2]])} compacted_negative_srcs = {relation: torch.tensor([[0], [1], [2]])}
compacted_negative_dsts = {relation: torch.tensor([[6], [0], [0]])} compacted_negative_dsts = {relation: torch.tensor([[6], [0], [0]])}
# Test dglminibatch with all attributes. # Test minibatch with all attributes.
minibatch = gb.MiniBatch( minibatch = gb.MiniBatch(
seed_nodes={"B": torch.tensor([10, 15])}, seed_nodes={"B": torch.tensor([10, 15])},
node_pairs=csc_formats, node_pairs=csc_formats,
...@@ -343,6 +358,17 @@ def test_minibatch_representation_hetero(): ...@@ -343,6 +358,17 @@ def test_minibatch_representation_hetero():
indices=tensor([1, 0]), indices=tensor([1, 0]),
)}, )},
)], )],
positive_node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([3, 4, 5]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
indices=tensor([0, 1]),
)},
node_pairs_with_labels=({'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([3, 4, 5]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
indices=tensor([0, 1]),
)},
{'B': tensor([2, 5])}),
node_pairs=[{'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]), node_pairs=[{'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([0, 1, 1]), indices=tensor([0, 1, 1]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]), ), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
...@@ -355,6 +381,7 @@ def test_minibatch_representation_hetero(): ...@@ -355,6 +381,7 @@ def test_minibatch_representation_hetero():
negative_srcs={'B': tensor([[8], negative_srcs={'B': tensor([[8],
[1], [1],
[6]])}, [6]])},
negative_node_pairs={'A:r:B': (tensor([0, 1, 2]), tensor([6, 0, 0]))},
negative_dsts={'B': tensor([[2], negative_dsts={'B': tensor([[2],
[8], [8],
[8]])}, [8]])},
...@@ -373,13 +400,21 @@ def test_minibatch_representation_hetero(): ...@@ -373,13 +400,21 @@ def test_minibatch_representation_hetero():
compacted_negative_dsts={'A:r:B': tensor([[6], compacted_negative_dsts={'A:r:B': tensor([[6],
[0], [0],
[0]])}, [0]])},
blocks=[Block(num_src_nodes={'A': 4, 'B': 3},
num_dst_nodes={'A': 4, 'B': 3},
num_edges={('A', 'r', 'B'): 3, ('B', 'rr', 'A'): 2},
metagraph=[('A', 'B', 'r'), ('B', 'A', 'rr')]),
Block(num_src_nodes={'A': 2, 'B': 2},
num_dst_nodes={'B': 2},
num_edges={('A', 'r', 'B'): 2},
metagraph=[('A', 'B', 'r')])],
)""" )"""
) )
result = str(minibatch) result = str(minibatch)
assert result == expect_result, print(result) assert result == expect_result, print(result)
def test_dgl_minibatch_representation_homo(): def test_get_dgl_blocks_homo():
node_pairs = [ node_pairs = [
( (
torch.tensor([0, 1, 2, 2, 2, 1]), torch.tensor([0, 1, 2, 2, 2, 1]),
...@@ -424,7 +459,7 @@ def test_dgl_minibatch_representation_homo(): ...@@ -424,7 +459,7 @@ def test_dgl_minibatch_representation_homo():
compacted_negative_srcs = torch.tensor([[0], [1], [2]]) compacted_negative_srcs = torch.tensor([[0], [1], [2]])
compacted_negative_dsts = torch.tensor([[6], [0], [0]]) compacted_negative_dsts = torch.tensor([[6], [0], [0]])
labels = torch.tensor([0.0, 1.0, 2.0]) labels = torch.tensor([0.0, 1.0, 2.0])
# Test dglminibatch with all attributes. # Test minibatch with all attributes.
minibatch = gb.MiniBatch( minibatch = gb.MiniBatch(
node_pairs=node_pairs, node_pairs=node_pairs,
sampled_subgraphs=subgraphs, sampled_subgraphs=subgraphs,
...@@ -438,31 +473,15 @@ def test_dgl_minibatch_representation_homo(): ...@@ -438,31 +473,15 @@ def test_dgl_minibatch_representation_homo():
compacted_negative_srcs=compacted_negative_srcs, compacted_negative_srcs=compacted_negative_srcs,
compacted_negative_dsts=compacted_negative_dsts, compacted_negative_dsts=compacted_negative_dsts,
) )
dgl_minibatch = minibatch.to_dgl() dgl_blocks = minibatch.blocks
expect_result = str( expect_result = str(
"""DGLMiniBatch(positive_node_pairs=(tensor([0, 1, 2]), """[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=6), Block(num_src_nodes=3, num_dst_nodes=2, num_edges=3)]"""
tensor([3, 4, 5])),
output_nodes=None,
node_features={'x': tensor([7, 6, 2, 2])},
negative_node_pairs=(tensor([0, 1, 2]),
tensor([6, 0, 0])),
labels=tensor([0., 1., 2.]),
input_nodes=None,
edge_features=[{'x': tensor([[8],
[1],
[6]])},
{'x': tensor([[2],
[8],
[8]])}],
blocks=[Block(num_src_nodes=4, num_dst_nodes=4, num_edges=6),
Block(num_src_nodes=3, num_dst_nodes=2, num_edges=3)],
)"""
) )
result = str(dgl_minibatch) result = str(dgl_blocks)
assert result == expect_result, print(result) assert result == expect_result, print(result)
def test_dgl_minibatch_representation_hetero(): def test_get_dgl_blocks_hetero():
node_pairs = [ node_pairs = [
{ {
relation: (torch.tensor([0, 1, 1]), torch.tensor([0, 1, 2])), relation: (torch.tensor([0, 1, 1]), torch.tensor([0, 1, 2])),
...@@ -516,7 +535,7 @@ def test_dgl_minibatch_representation_hetero(): ...@@ -516,7 +535,7 @@ def test_dgl_minibatch_representation_hetero():
} }
compacted_negative_srcs = {relation: torch.tensor([[0], [1], [2]])} compacted_negative_srcs = {relation: torch.tensor([[0], [1], [2]])}
compacted_negative_dsts = {relation: torch.tensor([[6], [0], [0]])} compacted_negative_dsts = {relation: torch.tensor([[6], [0], [0]])}
# Test dglminibatch with all attributes. # Test minibatch with all attributes.
minibatch = gb.MiniBatch( minibatch = gb.MiniBatch(
seed_nodes={"B": torch.tensor([10, 15])}, seed_nodes={"B": torch.tensor([10, 15])},
node_pairs=node_pairs, node_pairs=node_pairs,
...@@ -534,30 +553,63 @@ def test_dgl_minibatch_representation_hetero(): ...@@ -534,30 +553,63 @@ def test_dgl_minibatch_representation_hetero():
compacted_negative_srcs=compacted_negative_srcs, compacted_negative_srcs=compacted_negative_srcs,
compacted_negative_dsts=compacted_negative_dsts, compacted_negative_dsts=compacted_negative_dsts,
) )
dgl_minibatch = minibatch.to_dgl() dgl_blocks = minibatch.blocks
expect_result = str( expect_result = str(
"""DGLMiniBatch(positive_node_pairs={'A:r:B': (tensor([0, 1, 2]), tensor([3, 4, 5])), 'B:rr:A': (tensor([0, 1, 2]), tensor([3, 4, 5]))}, """[Block(num_src_nodes={'A': 4, 'B': 3},
output_nodes=None, num_dst_nodes={'A': 4, 'B': 3},
node_features={('A', 'x'): tensor([6, 4, 0, 1])}, num_edges={('A', 'r', 'B'): 3, ('B', 'rr', 'A'): 2},
negative_node_pairs={'A:r:B': (tensor([0, 1, 2]), tensor([6, 0, 0]))}, metagraph=[('A', 'B', 'r'), ('B', 'A', 'rr')]), Block(num_src_nodes={'A': 2, 'B': 2},
labels={'B': tensor([2, 5])}, num_dst_nodes={'B': 2},
input_nodes=None, num_edges={('A', 'r', 'B'): 2},
edge_features=[{('A:r:B', 'x'): tensor([4, 2, 4])}, metagraph=[('A', 'B', 'r')])]"""
{('A:r:B', 'x'): tensor([0, 6])}],
blocks=[Block(num_src_nodes={'A': 4, 'B': 3},
num_dst_nodes={'A': 4, 'B': 3},
num_edges={('A', 'r', 'B'): 3, ('B', 'rr', 'A'): 2},
metagraph=[('A', 'B', 'r'), ('B', 'A', 'rr')]),
Block(num_src_nodes={'A': 2, 'B': 2},
num_dst_nodes={'B': 2},
num_edges={('A', 'r', 'B'): 2},
metagraph=[('A', 'B', 'r')])],
)"""
) )
result = str(dgl_minibatch) result = str(dgl_blocks)
assert result == expect_result, print(result) assert result == expect_result, print(result)
@pytest.mark.parametrize(
"mode", ["neg_graph", "neg_src", "neg_dst", "edge_classification"]
)
def test_minibatch_node_pairs_with_labels(mode):
# Arrange
minibatch = create_homo_minibatch()
minibatch.compacted_node_pairs = (
torch.tensor([0, 1]),
torch.tensor([1, 0]),
)
if mode == "neg_graph" or mode == "neg_src":
minibatch.compacted_negative_srcs = torch.tensor([[0, 0], [1, 1]])
if mode == "neg_graph" or mode == "neg_dst":
minibatch.compacted_negative_dsts = torch.tensor([[1, 0], [0, 1]])
if mode == "edge_classification":
minibatch.labels = torch.tensor([0, 1]).long()
# Act
node_pairs, labels = minibatch.node_pairs_with_labels
# Assert
if mode == "neg_src":
expect_node_pairs = (
torch.tensor([0, 1, 0, 0, 1, 1]),
torch.tensor([1, 0, 1, 1, 0, 0]),
)
expect_labels = torch.tensor([1, 1, 0, 0, 0, 0]).float()
elif mode != "edge_classification":
expect_node_pairs = (
torch.tensor([0, 1, 0, 0, 1, 1]),
torch.tensor([1, 0, 1, 0, 0, 1]),
)
expect_labels = torch.tensor([1, 1, 0, 0, 0, 0]).float()
else:
expect_node_pairs = (
torch.tensor([0, 1]),
torch.tensor([1, 0]),
)
expect_labels = torch.tensor([0, 1]).long()
assert torch.equal(node_pairs[0], expect_node_pairs[0])
assert torch.equal(node_pairs[1], expect_node_pairs[1])
assert torch.equal(labels, expect_labels)
def check_dgl_blocks_hetero(minibatch, blocks): def check_dgl_blocks_hetero(minibatch, blocks):
etype = gb.etype_str_to_tuple(relation) etype = gb.etype_str_to_tuple(relation)
node_pairs = [ node_pairs = [
...@@ -607,61 +659,48 @@ def check_dgl_blocks_homo(minibatch, blocks): ...@@ -607,61 +659,48 @@ def check_dgl_blocks_homo(minibatch, blocks):
assert torch.equal(blocks[0].srcdata[dgl.NID], original_row_node_ids[0]) assert torch.equal(blocks[0].srcdata[dgl.NID], original_row_node_ids[0])
def test_to_dgl_node_classification_without_feature(): def test_get_dgl_blocks_node_classification_without_feature():
# Arrange # Arrange
minibatch = create_homo_minibatch() minibatch = create_homo_minibatch()
minibatch.node_features = None minibatch.node_features = None
minibatch.labels = None minibatch.labels = None
minibatch.seed_nodes = torch.tensor([10, 15]) minibatch.seed_nodes = torch.tensor([10, 15])
# Act # Act
dgl_minibatch = minibatch.to_dgl() dgl_blocks = minibatch.blocks
# Assert # Assert
assert len(dgl_minibatch.blocks) == 2 assert len(dgl_blocks) == 2
assert dgl_minibatch.node_features is None assert minibatch.node_features is None
assert minibatch.edge_features is dgl_minibatch.edge_features assert minibatch.labels is None
assert dgl_minibatch.labels is None check_dgl_blocks_homo(minibatch, dgl_blocks)
assert minibatch.input_nodes is dgl_minibatch.input_nodes
assert minibatch.seed_nodes is dgl_minibatch.output_nodes
check_dgl_blocks_homo(minibatch, dgl_minibatch.blocks)
def test_to_dgl_node_classification_homo(): def test_get_dgl_blocks_node_classification_homo():
# Arrange # Arrange
minibatch = create_homo_minibatch() minibatch = create_homo_minibatch()
minibatch.seed_nodes = torch.tensor([10, 15]) minibatch.seed_nodes = torch.tensor([10, 15])
minibatch.labels = torch.tensor([2, 5]) minibatch.labels = torch.tensor([2, 5])
# Act # Act
dgl_minibatch = minibatch.to_dgl() dgl_blocks = minibatch.blocks
# Assert # Assert
assert len(dgl_minibatch.blocks) == 2 assert len(dgl_blocks) == 2
assert minibatch.node_features is dgl_minibatch.node_features check_dgl_blocks_homo(minibatch, dgl_blocks)
assert minibatch.edge_features is dgl_minibatch.edge_features
assert minibatch.labels is dgl_minibatch.labels
assert dgl_minibatch.input_nodes is None
assert dgl_minibatch.output_nodes is None
check_dgl_blocks_homo(minibatch, dgl_minibatch.blocks)
def test_to_dgl_node_classification_hetero(): def test_to_dgl_node_classification_hetero():
minibatch = create_hetero_minibatch() minibatch = create_hetero_minibatch()
minibatch.labels = {"B": torch.tensor([2, 5])} minibatch.labels = {"B": torch.tensor([2, 5])}
minibatch.seed_nodes = {"B": torch.tensor([10, 15])} minibatch.seed_nodes = {"B": torch.tensor([10, 15])}
dgl_minibatch = minibatch.to_dgl() dgl_blocks = minibatch.blocks
# Assert # Assert
assert len(dgl_minibatch.blocks) == 2 assert len(dgl_blocks) == 2
assert minibatch.node_features is dgl_minibatch.node_features check_dgl_blocks_hetero(minibatch, dgl_blocks)
assert minibatch.edge_features is dgl_minibatch.edge_features
assert minibatch.labels is dgl_minibatch.labels
assert dgl_minibatch.input_nodes is None
assert dgl_minibatch.output_nodes is None
check_dgl_blocks_hetero(minibatch, dgl_minibatch.blocks)
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"]) @pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"])
def test_to_dgl_link_predication_homo(mode): def test_dgl_link_predication_homo(mode):
# Arrange # Arrange
minibatch = create_homo_minibatch() minibatch = create_homo_minibatch()
minibatch.compacted_node_pairs = ( minibatch.compacted_node_pairs = (
...@@ -673,28 +712,40 @@ def test_to_dgl_link_predication_homo(mode): ...@@ -673,28 +712,40 @@ def test_to_dgl_link_predication_homo(mode):
if mode == "neg_graph" or mode == "neg_dst": if mode == "neg_graph" or mode == "neg_dst":
minibatch.compacted_negative_dsts = torch.tensor([[1, 0], [0, 1]]) minibatch.compacted_negative_dsts = torch.tensor([[1, 0], [0, 1]])
# Act # Act
dgl_minibatch = minibatch.to_dgl() dgl_blocks = minibatch.blocks
# Assert # Assert
assert len(dgl_minibatch.blocks) == 2 assert len(dgl_blocks) == 2
assert minibatch.node_features is dgl_minibatch.node_features check_dgl_blocks_homo(minibatch, dgl_blocks)
assert minibatch.edge_features is dgl_minibatch.edge_features
assert minibatch.compacted_node_pairs is dgl_minibatch.positive_node_pairs
check_dgl_blocks_homo(minibatch, dgl_minibatch.blocks)
if mode == "neg_graph" or mode == "neg_src": if mode == "neg_graph" or mode == "neg_src":
assert torch.equal( assert torch.equal(
dgl_minibatch.negative_node_pairs[0], minibatch.negative_node_pairs[0],
minibatch.compacted_negative_srcs.view(-1), minibatch.compacted_negative_srcs.view(-1),
) )
if mode == "neg_graph" or mode == "neg_dst": if mode == "neg_graph" or mode == "neg_dst":
assert torch.equal( assert torch.equal(
dgl_minibatch.negative_node_pairs[1], minibatch.negative_node_pairs[1],
minibatch.compacted_negative_dsts.view(-1), minibatch.compacted_negative_dsts.view(-1),
) )
node_pairs, labels = minibatch.node_pairs_with_labels
if mode == "neg_src":
expect_node_pairs = (
torch.tensor([0, 1, 0, 0, 1, 1]),
torch.tensor([1, 0, 1, 1, 0, 0]),
)
else:
expect_node_pairs = (
torch.tensor([0, 1, 0, 0, 1, 1]),
torch.tensor([1, 0, 1, 0, 0, 1]),
)
expect_labels = torch.tensor([1, 1, 0, 0, 0, 0]).float()
assert torch.equal(node_pairs[0], expect_node_pairs[0])
assert torch.equal(node_pairs[1], expect_node_pairs[1])
assert torch.equal(labels, expect_labels)
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"]) @pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"])
def test_to_dgl_link_predication_hetero(mode): def test_dgl_link_predication_hetero(mode):
# Arrange # Arrange
minibatch = create_hetero_minibatch() minibatch = create_hetero_minibatch()
minibatch.compacted_node_pairs = { minibatch.compacted_node_pairs = {
...@@ -718,24 +769,21 @@ def test_to_dgl_link_predication_hetero(mode): ...@@ -718,24 +769,21 @@ def test_to_dgl_link_predication_hetero(mode):
reverse_relation: torch.tensor([[2, 1], [3, 1]]), reverse_relation: torch.tensor([[2, 1], [3, 1]]),
} }
# Act # Act
dgl_minibatch = minibatch.to_dgl() dgl_blocks = minibatch.blocks
# Assert # Assert
assert len(dgl_minibatch.blocks) == 2 assert len(dgl_blocks) == 2
assert minibatch.node_features is dgl_minibatch.node_features check_dgl_blocks_hetero(minibatch, dgl_blocks)
assert minibatch.edge_features is dgl_minibatch.edge_features
assert minibatch.compacted_node_pairs is dgl_minibatch.positive_node_pairs
check_dgl_blocks_hetero(minibatch, dgl_minibatch.blocks)
if mode == "neg_graph" or mode == "neg_src": if mode == "neg_graph" or mode == "neg_src":
for etype, src in minibatch.compacted_negative_srcs.items(): for etype, src in minibatch.compacted_negative_srcs.items():
assert torch.equal( assert torch.equal(
dgl_minibatch.negative_node_pairs[etype][0], minibatch.negative_node_pairs[etype][0],
src.view(-1), src.view(-1),
) )
if mode == "neg_graph" or mode == "neg_dst": if mode == "neg_graph" or mode == "neg_dst":
for etype, dst in minibatch.compacted_negative_dsts.items(): for etype, dst in minibatch.compacted_negative_dsts.items():
assert torch.equal( assert torch.equal(
dgl_minibatch.negative_node_pairs[etype][1], minibatch.negative_node_pairs[etype][1],
minibatch.compacted_negative_dsts[etype].view(-1), minibatch.compacted_negative_dsts[etype].view(-1),
) )
...@@ -925,61 +973,49 @@ def check_dgl_blocks_homo_csc_format(minibatch, blocks): ...@@ -925,61 +973,49 @@ def check_dgl_blocks_homo_csc_format(minibatch, blocks):
), print(blocks[0].srcdata[dgl.NID]) ), print(blocks[0].srcdata[dgl.NID])
def test_to_dgl_node_classification_without_feature_csc_format(): def test_dgl_node_classification_without_feature_csc_format():
# Arrange # Arrange
minibatch = create_homo_minibatch_csc_format() minibatch = create_homo_minibatch_csc_format()
minibatch.node_features = None minibatch.node_features = None
minibatch.labels = None minibatch.labels = None
minibatch.seed_nodes = torch.tensor([10, 15]) minibatch.seed_nodes = torch.tensor([10, 15])
# Act # Act
dgl_minibatch = minibatch.to_dgl() dgl_blocks = minibatch.blocks
# Assert # Assert
assert len(dgl_minibatch.blocks) == 2 assert len(dgl_blocks) == 2
assert dgl_minibatch.node_features is None assert minibatch.node_features is None
assert minibatch.edge_features is dgl_minibatch.edge_features assert minibatch.labels is None
assert dgl_minibatch.labels is None check_dgl_blocks_homo_csc_format(minibatch, dgl_blocks)
assert minibatch.input_nodes is dgl_minibatch.input_nodes
assert minibatch.seed_nodes is dgl_minibatch.output_nodes
check_dgl_blocks_homo_csc_format(minibatch, dgl_minibatch.blocks)
def test_to_dgl_node_classification_homo_csc_format(): def test_dgl_node_classification_homo_csc_format():
# Arrange # Arrange
minibatch = create_homo_minibatch_csc_format() minibatch = create_homo_minibatch_csc_format()
minibatch.seed_nodes = torch.tensor([10, 15]) minibatch.seed_nodes = torch.tensor([10, 15])
minibatch.labels = torch.tensor([2, 5]) minibatch.labels = torch.tensor([2, 5])
# Act # Act
dgl_minibatch = minibatch.to_dgl() dgl_blocks = minibatch.blocks
# Assert # Assert
assert len(dgl_minibatch.blocks) == 2 assert len(dgl_blocks) == 2
assert minibatch.node_features is dgl_minibatch.node_features check_dgl_blocks_homo_csc_format(minibatch, dgl_blocks)
assert minibatch.edge_features is dgl_minibatch.edge_features
assert minibatch.labels is dgl_minibatch.labels
assert dgl_minibatch.input_nodes is None
assert dgl_minibatch.output_nodes is None
check_dgl_blocks_homo_csc_format(minibatch, dgl_minibatch.blocks)
def test_to_dgl_node_classification_hetero_csc_format(): def test_dgl_node_classification_hetero_csc_format():
minibatch = create_hetero_minibatch_csc_format() minibatch = create_hetero_minibatch_csc_format()
minibatch.labels = {"B": torch.tensor([2, 5])} minibatch.labels = {"B": torch.tensor([2, 5])}
minibatch.seed_nodes = {"B": torch.tensor([10, 15])} minibatch.seed_nodes = {"B": torch.tensor([10, 15])}
dgl_minibatch = minibatch.to_dgl() # Act
dgl_blocks = minibatch.blocks
# Assert # Assert
assert len(dgl_minibatch.blocks) == 2 assert len(dgl_blocks) == 2
assert minibatch.node_features is dgl_minibatch.node_features check_dgl_blocks_hetero_csc_format(minibatch, dgl_blocks)
assert minibatch.edge_features is dgl_minibatch.edge_features
assert minibatch.labels is dgl_minibatch.labels
assert dgl_minibatch.input_nodes is None
assert dgl_minibatch.output_nodes is None
check_dgl_blocks_hetero_csc_format(minibatch, dgl_minibatch.blocks)
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"]) @pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"])
def test_to_dgl_link_predication_homo_csc_format(mode): def test_dgl_link_predication_homo_csc_format(mode):
# Arrange # Arrange
minibatch = create_homo_minibatch_csc_format() minibatch = create_homo_minibatch_csc_format()
minibatch.compacted_node_pairs = ( minibatch.compacted_node_pairs = (
...@@ -991,28 +1027,43 @@ def test_to_dgl_link_predication_homo_csc_format(mode): ...@@ -991,28 +1027,43 @@ def test_to_dgl_link_predication_homo_csc_format(mode):
if mode == "neg_graph" or mode == "neg_dst": if mode == "neg_graph" or mode == "neg_dst":
minibatch.compacted_negative_dsts = torch.tensor([[1, 0], [0, 1]]) minibatch.compacted_negative_dsts = torch.tensor([[1, 0], [0, 1]])
# Act # Act
dgl_minibatch = minibatch.to_dgl() dgl_blocks = minibatch.blocks
# Assert # Assert
assert len(dgl_minibatch.blocks) == 2 assert len(dgl_blocks) == 2
assert minibatch.node_features is dgl_minibatch.node_features check_dgl_blocks_homo_csc_format(minibatch, dgl_blocks)
assert minibatch.edge_features is dgl_minibatch.edge_features
assert minibatch.compacted_node_pairs is dgl_minibatch.positive_node_pairs
check_dgl_blocks_homo_csc_format(minibatch, dgl_minibatch.blocks)
if mode == "neg_graph" or mode == "neg_src": if mode == "neg_graph" or mode == "neg_src":
assert torch.equal( assert torch.equal(
dgl_minibatch.negative_node_pairs[0], minibatch.negative_node_pairs[0],
minibatch.compacted_negative_srcs.view(-1), minibatch.compacted_negative_srcs.view(-1),
) )
if mode == "neg_graph" or mode == "neg_dst": if mode == "neg_graph" or mode == "neg_dst":
assert torch.equal( assert torch.equal(
dgl_minibatch.negative_node_pairs[1], minibatch.negative_node_pairs[1],
minibatch.compacted_negative_dsts.view(-1), minibatch.compacted_negative_dsts.view(-1),
) )
(
node_pairs,
labels,
) = minibatch.node_pairs_with_labels
if mode == "neg_src":
expect_node_pairs = (
torch.tensor([0, 1, 0, 0, 1, 1]),
torch.tensor([1, 0, 1, 1, 0, 0]),
)
else:
expect_node_pairs = (
torch.tensor([0, 1, 0, 0, 1, 1]),
torch.tensor([1, 0, 1, 0, 0, 1]),
)
expect_labels = torch.tensor([1, 1, 0, 0, 0, 0]).float()
assert torch.equal(node_pairs[0], expect_node_pairs[0])
assert torch.equal(node_pairs[1], expect_node_pairs[1])
assert torch.equal(labels, expect_labels)
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"]) @pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"])
def test_to_dgl_link_predication_hetero_csc_format(mode): def test_dgl_link_predication_hetero_csc_format(mode):
# Arrange # Arrange
minibatch = create_hetero_minibatch_csc_format() minibatch = create_hetero_minibatch_csc_format()
minibatch.compacted_node_pairs = { minibatch.compacted_node_pairs = {
...@@ -1036,23 +1087,20 @@ def test_to_dgl_link_predication_hetero_csc_format(mode): ...@@ -1036,23 +1087,20 @@ def test_to_dgl_link_predication_hetero_csc_format(mode):
reverse_relation: torch.tensor([[2, 1], [3, 1]]), reverse_relation: torch.tensor([[2, 1], [3, 1]]),
} }
# Act # Act
dgl_minibatch = minibatch.to_dgl() dgl_blocks = minibatch.blocks
# Assert # Assert
assert len(dgl_minibatch.blocks) == 2 assert len(dgl_blocks) == 2
assert minibatch.node_features is dgl_minibatch.node_features check_dgl_blocks_hetero_csc_format(minibatch, dgl_blocks)
assert minibatch.edge_features is dgl_minibatch.edge_features
assert minibatch.compacted_node_pairs is dgl_minibatch.positive_node_pairs
check_dgl_blocks_hetero_csc_format(minibatch, dgl_minibatch.blocks)
if mode == "neg_graph" or mode == "neg_src": if mode == "neg_graph" or mode == "neg_src":
for etype, src in minibatch.compacted_negative_srcs.items(): for etype, src in minibatch.compacted_negative_srcs.items():
assert torch.equal( assert torch.equal(
dgl_minibatch.negative_node_pairs[etype][0], minibatch.negative_node_pairs[etype][0],
src.view(-1), src.view(-1),
) )
if mode == "neg_graph" or mode == "neg_dst": if mode == "neg_graph" or mode == "neg_dst":
for etype, dst in minibatch.compacted_negative_dsts.items(): for etype, dst in minibatch.compacted_negative_dsts.items():
assert torch.equal( assert torch.equal(
dgl_minibatch.negative_node_pairs[etype][1], minibatch.negative_node_pairs[etype][1],
minibatch.compacted_negative_dsts[etype].view(-1), minibatch.compacted_negative_dsts[etype].view(-1),
) )
...@@ -2086,7 +2086,6 @@ def test_OnDiskDataset_homogeneous(include_original_edge_id): ...@@ -2086,7 +2086,6 @@ def test_OnDiskDataset_homogeneous(include_original_edge_id):
datapipe = datapipe.fetch_feature( datapipe = datapipe.fetch_feature(
dataset.feature, node_feature_keys=["feat"] dataset.feature, node_feature_keys=["feat"]
) )
datapipe = datapipe.to_dgl()
dataloader = gb.DataLoader(datapipe) dataloader = gb.DataLoader(datapipe)
for _ in dataloader: for _ in dataloader:
pass pass
...@@ -2158,7 +2157,6 @@ def test_OnDiskDataset_heterogeneous(include_original_edge_id): ...@@ -2158,7 +2157,6 @@ def test_OnDiskDataset_heterogeneous(include_original_edge_id):
datapipe = datapipe.fetch_feature( datapipe = datapipe.fetch_feature(
dataset.feature, node_feature_keys={"user": ["feat"]} dataset.feature, node_feature_keys={"user": ["feat"]}
) )
datapipe = datapipe.to_dgl()
dataloader = gb.DataLoader(datapipe) dataloader = gb.DataLoader(datapipe)
for _ in dataloader: for _ in dataloader:
pass pass
......
...@@ -67,9 +67,6 @@ def test_CopyToWithMiniBatches(): ...@@ -67,9 +67,6 @@ def test_CopyToWithMiniBatches():
# Invoke CopyTo via functional form. # Invoke CopyTo via functional form.
test_data_device(datapipe.copy_to("cuda")) test_data_device(datapipe.copy_to("cuda"))
# Test for DGLMiniBatch.
datapipe = gb.DGLMiniBatchConverter(datapipe)
# Invoke CopyTo via class constructor. # Invoke CopyTo via class constructor.
test_data_device(gb.CopyTo(datapipe, "cuda")) test_data_device(gb.CopyTo(datapipe, "cuda"))
......
...@@ -14,10 +14,7 @@ class MiniBatchType(Enum): ...@@ -14,10 +14,7 @@ class MiniBatchType(Enum):
DGLMiniBatch = 2 DGLMiniBatch = 2
@pytest.mark.parametrize( def test_FeatureFetcher_invoke():
"minibatch_type", [MiniBatchType.MiniBatch, MiniBatchType.DGLMiniBatch]
)
def test_FeatureFetcher_invoke(minibatch_type):
# Prepare graph and required datapipes. # Prepare graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True) graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
a = torch.tensor( a = torch.tensor(
...@@ -40,8 +37,6 @@ def test_FeatureFetcher_invoke(minibatch_type): ...@@ -40,8 +37,6 @@ def test_FeatureFetcher_invoke(minibatch_type):
# Invoke FeatureFetcher via class constructor. # Invoke FeatureFetcher via class constructor.
datapipe = gb.NeighborSampler(item_sampler, graph, fanouts) datapipe = gb.NeighborSampler(item_sampler, graph, fanouts)
if minibatch_type == MiniBatchType.DGLMiniBatch:
datapipe = datapipe.to_dgl()
datapipe = gb.FeatureFetcher(datapipe, feature_store, ["a"], ["b"]) datapipe = gb.FeatureFetcher(datapipe, feature_store, ["a"], ["b"])
assert len(list(datapipe)) == 5 assert len(list(datapipe)) == 5
...@@ -53,10 +48,7 @@ def test_FeatureFetcher_invoke(minibatch_type): ...@@ -53,10 +48,7 @@ def test_FeatureFetcher_invoke(minibatch_type):
assert len(list(datapipe)) == 5 assert len(list(datapipe)) == 5
@pytest.mark.parametrize( def test_FeatureFetcher_homo():
"minibatch_type", [MiniBatchType.MiniBatch, MiniBatchType.DGLMiniBatch]
)
def test_FeatureFetcher_homo(minibatch_type):
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True) graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
a = torch.tensor( a = torch.tensor(
[[random.randint(0, 10)] for _ in range(graph.total_num_nodes)] [[random.randint(0, 10)] for _ in range(graph.total_num_nodes)]
...@@ -76,17 +68,12 @@ def test_FeatureFetcher_homo(minibatch_type): ...@@ -76,17 +68,12 @@ def test_FeatureFetcher_homo(minibatch_type):
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts) sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts)
if minibatch_type == MiniBatchType.DGLMiniBatch:
sampler_dp = sampler_dp.to_dgl()
fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, ["a"], ["b"]) fetcher_dp = gb.FeatureFetcher(sampler_dp, feature_store, ["a"], ["b"])
assert len(list(fetcher_dp)) == 5 assert len(list(fetcher_dp)) == 5
@pytest.mark.parametrize( def test_FeatureFetcher_with_edges_homo():
"minibatch_type", [MiniBatchType.MiniBatch, MiniBatchType.DGLMiniBatch]
)
def test_FeatureFetcher_with_edges_homo(minibatch_type):
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True) graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True)
a = torch.tensor( a = torch.tensor(
[[random.randint(0, 10)] for _ in range(graph.total_num_nodes)] [[random.randint(0, 10)] for _ in range(graph.total_num_nodes)]
...@@ -121,8 +108,6 @@ def test_FeatureFetcher_with_edges_homo(minibatch_type): ...@@ -121,8 +108,6 @@ def test_FeatureFetcher_with_edges_homo(minibatch_type):
itemset = gb.ItemSet(torch.arange(10)) itemset = gb.ItemSet(torch.arange(10))
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids) converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids)
if minibatch_type == MiniBatchType.DGLMiniBatch:
converter_dp = converter_dp.to_dgl()
fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, ["a"], ["b"]) fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, ["a"], ["b"])
assert len(list(fetcher_dp)) == 5 assert len(list(fetcher_dp)) == 5
...@@ -155,10 +140,7 @@ def get_hetero_graph(): ...@@ -155,10 +140,7 @@ def get_hetero_graph():
) )
@pytest.mark.parametrize( def test_FeatureFetcher_hetero():
"minibatch_type", [MiniBatchType.MiniBatch, MiniBatchType.DGLMiniBatch]
)
def test_FeatureFetcher_hetero(minibatch_type):
graph = get_hetero_graph() graph = get_hetero_graph()
a = torch.tensor([[random.randint(0, 10)] for _ in range(2)]) a = torch.tensor([[random.randint(0, 10)] for _ in range(2)])
b = torch.tensor([[random.randint(0, 10)] for _ in range(3)]) b = torch.tensor([[random.randint(0, 10)] for _ in range(3)])
...@@ -179,8 +161,6 @@ def test_FeatureFetcher_hetero(minibatch_type): ...@@ -179,8 +161,6 @@ def test_FeatureFetcher_hetero(minibatch_type):
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts) sampler_dp = gb.NeighborSampler(item_sampler, graph, fanouts)
if minibatch_type == MiniBatchType.DGLMiniBatch:
sampler_dp = sampler_dp.to_dgl()
fetcher_dp = gb.FeatureFetcher( fetcher_dp = gb.FeatureFetcher(
sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]} sampler_dp, feature_store, {"n1": ["a"], "n2": ["a"]}
) )
...@@ -188,10 +168,7 @@ def test_FeatureFetcher_hetero(minibatch_type): ...@@ -188,10 +168,7 @@ def test_FeatureFetcher_hetero(minibatch_type):
assert len(list(fetcher_dp)) == 3 assert len(list(fetcher_dp)) == 3
@pytest.mark.parametrize( def test_FeatureFetcher_with_edges_hetero():
"minibatch_type", [MiniBatchType.MiniBatch, MiniBatchType.DGLMiniBatch]
)
def test_FeatureFetcher_with_edges_hetero(minibatch_type):
a = torch.tensor([[random.randint(0, 10)] for _ in range(20)]) a = torch.tensor([[random.randint(0, 10)] for _ in range(20)])
b = torch.tensor([[random.randint(0, 10)] for _ in range(50)]) b = torch.tensor([[random.randint(0, 10)] for _ in range(50)])
...@@ -243,8 +220,6 @@ def test_FeatureFetcher_with_edges_hetero(minibatch_type): ...@@ -243,8 +220,6 @@ def test_FeatureFetcher_with_edges_hetero(minibatch_type):
) )
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids) converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids)
if minibatch_type == MiniBatchType.DGLMiniBatch:
converter_dp = converter_dp.to_dgl()
fetcher_dp = gb.FeatureFetcher( fetcher_dp = gb.FeatureFetcher(
converter_dp, feature_store, {"n1": ["a"]}, {"n1:e1:n2": ["a"]} converter_dp, feature_store, {"n1": ["a"]}, {"n1:e1:n2": ["a"]}
) )
......
...@@ -55,65 +55,152 @@ def test_integration_link_prediction(): ...@@ -55,65 +55,152 @@ def test_integration_link_prediction():
datapipe = datapipe.fetch_feature( datapipe = datapipe.fetch_feature(
feature_store, node_feature_keys=["feat"], edge_feature_keys=["feat"] feature_store, node_feature_keys=["feat"], edge_feature_keys=["feat"]
) )
datapipe = datapipe.to_dgl()
dataloader = gb.DataLoader( dataloader = gb.DataLoader(
datapipe, datapipe,
) )
expected = [ expected = [
str( str(
"""DGLMiniBatch(positive_node_pairs=(tensor([0, 1, 1, 1]), """MiniBatch(seed_nodes=None,
tensor([2, 3, 3, 1])), sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
output_nodes=None, original_edge_ids=None,
node_features={'feat': tensor([[0.5160, 0.2486], original_column_node_ids=tensor([5, 3, 1, 2, 0, 4]),
[0.8672, 0.2276], node_pairs=(tensor([5, 4]), tensor([0, 5])),
[0.6172, 0.7865], ),
[0.2109, 0.1089], FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 0, 4]),
[0.9634, 0.2294], original_edge_ids=None,
[0.5503, 0.8223]])}, original_column_node_ids=tensor([5, 3, 1, 2, 0]),
negative_node_pairs=(tensor([0, 1, 1, 1]), node_pairs=(tensor([5]), tensor([0])),
tensor([4, 4, 1, 4])), )],
labels=None, positive_node_pairs=(tensor([0, 1, 1, 1]),
input_nodes=None, tensor([2, 3, 3, 1])),
edge_features=[{}, node_pairs_with_labels=((tensor([0, 1, 1, 1, 0, 1, 1, 1]), tensor([2, 3, 3, 1, 4, 4, 1, 4])),
{}], tensor([1., 1., 1., 1., 0., 0., 0., 0.])),
blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2), node_pairs=(tensor([5, 3, 3, 3]),
Block(num_src_nodes=6, num_dst_nodes=5, num_edges=1)], tensor([1, 2, 2, 3])),
)""" node_features={'feat': tensor([[0.5160, 0.2486],
[0.8672, 0.2276],
[0.6172, 0.7865],
[0.2109, 0.1089],
[0.9634, 0.2294],
[0.5503, 0.8223]])},
negative_srcs=tensor([[5],
[3],
[3],
[3]]),
negative_node_pairs=(tensor([0, 1, 1, 1]),
tensor([4, 4, 1, 4])),
negative_dsts=tensor([[0],
[0],
[3],
[0]]),
labels=None,
input_nodes=tensor([5, 3, 1, 2, 0, 4]),
edge_features=[{},
{}],
compacted_node_pairs=(tensor([0, 1, 1, 1]),
tensor([2, 3, 3, 1])),
compacted_negative_srcs=tensor([[0],
[1],
[1],
[1]]),
compacted_negative_dsts=tensor([[4],
[4],
[1],
[4]]),
blocks=[Block(num_src_nodes=6, num_dst_nodes=6, num_edges=2),
Block(num_src_nodes=6, num_dst_nodes=5, num_edges=1)],
)"""
), ),
str( str(
"""DGLMiniBatch(positive_node_pairs=(tensor([0, 1, 1, 2]), """MiniBatch(seed_nodes=None,
tensor([0, 0, 1, 1])), sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0, 5, 1]),
output_nodes=None, original_edge_ids=None,
node_features={'feat': tensor([[0.8672, 0.2276], original_column_node_ids=tensor([3, 4, 0, 5, 1]),
[0.5503, 0.8223], node_pairs=(tensor([1, 3]), tensor([3, 4])),
[0.9634, 0.2294], ),
[0.5160, 0.2486], FusedSampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0, 5, 1]),
[0.6172, 0.7865]])}, original_edge_ids=None,
negative_node_pairs=(tensor([0, 1, 1, 2]), original_column_node_ids=tensor([3, 4, 0, 5, 1]),
tensor([1, 1, 3, 4])), node_pairs=(tensor([1, 3]), tensor([3, 4])),
labels=None, )],
input_nodes=None, positive_node_pairs=(tensor([0, 1, 1, 2]),
edge_features=[{}, tensor([0, 0, 1, 1])),
{}], node_pairs_with_labels=((tensor([0, 1, 1, 2, 0, 1, 1, 2]), tensor([0, 0, 1, 1, 1, 1, 3, 4])),
blocks=[Block(num_src_nodes=5, num_dst_nodes=5, num_edges=2), tensor([1., 1., 1., 1., 0., 0., 0., 0.])),
Block(num_src_nodes=5, num_dst_nodes=5, num_edges=2)], node_pairs=(tensor([3, 4, 4, 0]),
)""" tensor([3, 3, 4, 4])),
node_features={'feat': tensor([[0.8672, 0.2276],
[0.5503, 0.8223],
[0.9634, 0.2294],
[0.5160, 0.2486],
[0.6172, 0.7865]])},
negative_srcs=tensor([[3],
[4],
[4],
[0]]),
negative_node_pairs=(tensor([0, 1, 1, 2]),
tensor([1, 1, 3, 4])),
negative_dsts=tensor([[4],
[4],
[5],
[1]]),
labels=None,
input_nodes=tensor([3, 4, 0, 5, 1]),
edge_features=[{},
{}],
compacted_node_pairs=(tensor([0, 1, 1, 2]),
tensor([0, 0, 1, 1])),
compacted_negative_srcs=tensor([[0],
[1],
[1],
[2]]),
compacted_negative_dsts=tensor([[1],
[1],
[3],
[4]]),
blocks=[Block(num_src_nodes=5, num_dst_nodes=5, num_edges=2),
Block(num_src_nodes=5, num_dst_nodes=5, num_edges=2)],
)"""
), ),
str( str(
"""DGLMiniBatch(positive_node_pairs=(tensor([0, 1]), """MiniBatch(seed_nodes=None,
tensor([0, 0])), sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 4]),
output_nodes=None, original_edge_ids=None,
node_features={'feat': tensor([[0.5160, 0.2486], original_column_node_ids=tensor([5, 4]),
[0.5503, 0.8223]])}, node_pairs=(tensor([1]), tensor([1])),
negative_node_pairs=(tensor([0, 1]), ),
tensor([0, 0])), FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 4]),
labels=None, original_edge_ids=None,
input_nodes=None, original_column_node_ids=tensor([5, 4]),
edge_features=[{}, node_pairs=(tensor([1]), tensor([1])),
{}], )],
blocks=[Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1), positive_node_pairs=(tensor([0, 1]),
Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1)], tensor([0, 0])),
)""" node_pairs_with_labels=((tensor([0, 1, 0, 1]), tensor([0, 0, 0, 0])),
tensor([1., 1., 0., 0.])),
node_pairs=(tensor([5, 4]),
tensor([5, 5])),
node_features={'feat': tensor([[0.5160, 0.2486],
[0.5503, 0.8223]])},
negative_srcs=tensor([[5],
[4]]),
negative_node_pairs=(tensor([0, 1]),
tensor([0, 0])),
negative_dsts=tensor([[5],
[5]]),
labels=None,
input_nodes=tensor([5, 4]),
edge_features=[{},
{}],
compacted_node_pairs=(tensor([0, 1]),
tensor([0, 0])),
compacted_negative_srcs=tensor([[0],
[1]]),
compacted_negative_dsts=tensor([[0],
[0]]),
blocks=[Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1),
Block(num_src_nodes=2, num_dst_nodes=2, num_edges=1)],
)"""
), ),
] ]
for step, data in enumerate(dataloader): for step, data in enumerate(dataloader):
...@@ -169,60 +256,116 @@ def test_integration_node_classification(): ...@@ -169,60 +256,116 @@ def test_integration_node_classification():
datapipe = datapipe.fetch_feature( datapipe = datapipe.fetch_feature(
feature_store, node_feature_keys=["feat"], edge_feature_keys=["feat"] feature_store, node_feature_keys=["feat"], edge_feature_keys=["feat"]
) )
datapipe = datapipe.to_dgl()
dataloader = gb.DataLoader( dataloader = gb.DataLoader(
datapipe, datapipe,
) )
expected = [ expected = [
str( str(
"""DGLMiniBatch(positive_node_pairs=(tensor([0, 1, 1, 1]), """MiniBatch(seed_nodes=None,
tensor([2, 3, 3, 1])), sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2, 4]),
output_nodes=None, original_edge_ids=None,
node_features={'feat': tensor([[0.5160, 0.2486], original_column_node_ids=tensor([5, 3, 1, 2]),
[0.8672, 0.2276], node_pairs=(tensor([4, 1, 0, 1]), tensor([0, 1, 2, 3])),
[0.6172, 0.7865], ),
[0.2109, 0.1089], FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 3, 1, 2]),
[0.5503, 0.8223]])}, original_edge_ids=None,
negative_node_pairs=None, original_column_node_ids=tensor([5, 3, 1, 2]),
labels=None, node_pairs=(tensor([0, 1, 0, 1]), tensor([0, 1, 2, 3])),
input_nodes=None, )],
edge_features=[{}, positive_node_pairs=(tensor([0, 1, 1, 1]),
{}], tensor([2, 3, 3, 1])),
blocks=[Block(num_src_nodes=5, num_dst_nodes=4, num_edges=4), node_pairs_with_labels=None,
Block(num_src_nodes=4, num_dst_nodes=4, num_edges=4)], node_pairs=(tensor([5, 3, 3, 3]),
)""" tensor([1, 2, 2, 3])),
node_features={'feat': tensor([[0.5160, 0.2486],
[0.8672, 0.2276],
[0.6172, 0.7865],
[0.2109, 0.1089],
[0.5503, 0.8223]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=None,
input_nodes=tensor([5, 3, 1, 2, 4]),
edge_features=[{},
{}],
compacted_node_pairs=(tensor([0, 1, 1, 1]),
tensor([2, 3, 3, 1])),
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=5, num_dst_nodes=4, num_edges=4),
Block(num_src_nodes=4, num_dst_nodes=4, num_edges=4)],
)"""
), ),
str( str(
"""DGLMiniBatch(positive_node_pairs=(tensor([0, 1, 1, 2]), """MiniBatch(seed_nodes=None,
tensor([0, 0, 1, 1])), sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0]),
output_nodes=None, original_edge_ids=None,
node_features={'feat': tensor([[0.8672, 0.2276], original_column_node_ids=tensor([3, 4, 0]),
[0.5503, 0.8223], node_pairs=(tensor([0, 2]), tensor([0, 1])),
[0.9634, 0.2294]])}, ),
negative_node_pairs=None, FusedSampledSubgraphImpl(original_row_node_ids=tensor([3, 4, 0]),
labels=None, original_edge_ids=None,
input_nodes=None, original_column_node_ids=tensor([3, 4, 0]),
edge_features=[{}, node_pairs=(tensor([0, 2]), tensor([0, 1])),
{}], )],
blocks=[Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2), positive_node_pairs=(tensor([0, 1, 1, 2]),
Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2)], tensor([0, 0, 1, 1])),
)""" node_pairs_with_labels=None,
node_pairs=(tensor([3, 4, 4, 0]),
tensor([3, 3, 4, 4])),
node_features={'feat': tensor([[0.8672, 0.2276],
[0.5503, 0.8223],
[0.9634, 0.2294]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=None,
input_nodes=tensor([3, 4, 0]),
edge_features=[{},
{}],
compacted_node_pairs=(tensor([0, 1, 1, 2]),
tensor([0, 0, 1, 1])),
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2),
Block(num_src_nodes=3, num_dst_nodes=3, num_edges=2)],
)"""
), ),
str( str(
"""DGLMiniBatch(positive_node_pairs=(tensor([0, 1]), """MiniBatch(seed_nodes=None,
tensor([0, 0])), sampled_subgraphs=[FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 4, 0]),
output_nodes=None, original_edge_ids=None,
node_features={'feat': tensor([[0.5160, 0.2486], original_column_node_ids=tensor([5, 4]),
[0.5503, 0.8223], node_pairs=(tensor([0, 2]), tensor([0, 1])),
[0.9634, 0.2294]])}, ),
negative_node_pairs=None, FusedSampledSubgraphImpl(original_row_node_ids=tensor([5, 4]),
labels=None, original_edge_ids=None,
input_nodes=None, original_column_node_ids=tensor([5, 4]),
edge_features=[{}, node_pairs=(tensor([1, 1]), tensor([0, 1])),
{}], )],
blocks=[Block(num_src_nodes=3, num_dst_nodes=2, num_edges=2), positive_node_pairs=(tensor([0, 1]),
Block(num_src_nodes=2, num_dst_nodes=2, num_edges=2)], tensor([0, 0])),
)""" node_pairs_with_labels=None,
node_pairs=(tensor([5, 4]),
tensor([5, 5])),
node_features={'feat': tensor([[0.5160, 0.2486],
[0.5503, 0.8223],
[0.9634, 0.2294]])},
negative_srcs=None,
negative_node_pairs=None,
negative_dsts=None,
labels=None,
input_nodes=tensor([5, 4, 0]),
edge_features=[{},
{}],
compacted_node_pairs=(tensor([0, 1]),
tensor([0, 0])),
compacted_negative_srcs=None,
compacted_negative_dsts=None,
blocks=[Block(num_src_nodes=3, num_dst_nodes=2, num_edges=2),
Block(num_src_nodes=2, num_dst_nodes=2, num_edges=2)],
)"""
), ),
] ]
for step, data in enumerate(dataloader): for step, data in enumerate(dataloader):
......
import dgl.graphbolt as gb
import torch
from . import gb_test_utils
def test_dgl_minibatch_converter():
N = 32
B = 4
itemset = gb.ItemSet(torch.arange(N), names="seed_nodes")
graph = gb_test_utils.rand_csc_graph(200, 0.15, bidirection_edge=True)
features = {}
keys = [("node", None, "a"), ("node", None, "b")]
features[keys[0]] = gb.TorchBasedFeature(torch.randn(200, 4))
features[keys[1]] = gb.TorchBasedFeature(torch.randn(200, 4))
feature_store = gb.BasicFeatureStore(features)
item_sampler = gb.ItemSampler(itemset, batch_size=B)
subgraph_sampler = gb.NeighborSampler(
item_sampler,
graph,
fanouts=[torch.LongTensor([2]) for _ in range(2)],
)
feature_fetcher = gb.FeatureFetcher(
subgraph_sampler,
feature_store,
["a"],
)
dgl_converter = gb.DGLMiniBatchConverter(feature_fetcher)
dataloader = gb.DataLoader(dgl_converter)
assert len(list(dataloader)) == N // B
minibatch = next(iter(dataloader))
assert isinstance(minibatch, gb.DGLMiniBatch)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment