Unverified Commit 67cd09d7 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt] Support heterogeneous graph in `node_pairs_with_labels` (#6787)

parent 541f2ba4
...@@ -381,20 +381,42 @@ class MiniBatch: ...@@ -381,20 +381,42 @@ class MiniBatch:
@property @property
def node_pairs_with_labels(self): def node_pairs_with_labels(self):
"""Get a node pair tensor and a label tensor from MiniBatch. They are """Get a node pair tensor and a label tensor from MiniBatch. They are
used for evaluating or computing loss. It will return used for evaluating or computing loss. For homogeneous graph, it will
`(node_pairs, labels)` as result. return `(node_pairs, labels)` as result; for heterogeneous graph, the
`node_pairs` and `labels` will both be a dict with etype as the key.
- If it's a link prediction task, `node_pairs` will contain both - If it's a link prediction task, `node_pairs` will contain both
negative and positive node pairs and `labels` will consist of 0 and 1, negative and positive node pairs and `labels` will consist of 0 and 1,
indicating whether the corresponding node pair is negative or positive. indicating whether the corresponding node pair is negative or positive.
- If it's an edge classification task, this function will directly - If it's an edge classification task, this function will directly
return `compacted_node_pairs` and corresponding `labels`. return `compacted_node_pairs` for each etype and the corresponding
`labels`.
- Otherwise it will return None. - Otherwise it will return None.
""" """
if self.labels is None: if self.labels is None:
# Link prediction.
positive_node_pairs = self.positive_node_pairs positive_node_pairs = self.positive_node_pairs
negative_node_pairs = self.negative_node_pairs negative_node_pairs = self.negative_node_pairs
if positive_node_pairs is None or negative_node_pairs is None: if positive_node_pairs is None or negative_node_pairs is None:
return None return None
if isinstance(positive_node_pairs, Dict):
# Heterogeneous graph.
node_pairs_by_etype = {}
labels_by_etype = {}
for etype in positive_node_pairs:
pos_src, pos_dst = positive_node_pairs[etype]
neg_src, neg_dst = negative_node_pairs[etype]
node_pairs_by_etype[etype] = (
torch.cat((pos_src, neg_src), dim=0),
torch.cat((pos_dst, neg_dst), dim=0),
)
pos_label = torch.ones_like(pos_src)
neg_label = torch.zeros_like(neg_src)
labels_by_etype[etype] = torch.cat(
[pos_label, neg_label], dim=0
)
return (node_pairs_by_etype, labels_by_etype)
else:
# Homogeneous graph.
pos_src, pos_dst = positive_node_pairs pos_src, pos_dst = positive_node_pairs
neg_src, neg_dst = negative_node_pairs neg_src, neg_dst = negative_node_pairs
node_pairs = ( node_pairs = (
...@@ -405,8 +427,11 @@ class MiniBatch: ...@@ -405,8 +427,11 @@ class MiniBatch:
neg_label = torch.zeros_like(neg_src) neg_label = torch.zeros_like(neg_src)
labels = torch.cat([pos_label, neg_label], dim=0) labels = torch.cat([pos_label, neg_label], dim=0)
return (node_pairs, labels.float()) return (node_pairs, labels.float())
else: elif self.compacted_node_pairs is not None:
# Edge classification.
return (self.compacted_node_pairs, self.labels) return (self.compacted_node_pairs, self.labels)
else:
return None
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
"""Copy `MiniBatch` to the specified device using reflection.""" """Copy `MiniBatch` to the specified device using reflection."""
......
...@@ -786,6 +786,48 @@ def test_dgl_link_predication_hetero(mode): ...@@ -786,6 +786,48 @@ def test_dgl_link_predication_hetero(mode):
minibatch.negative_node_pairs[etype][1], minibatch.negative_node_pairs[etype][1],
minibatch.compacted_negative_dsts[etype].view(-1), minibatch.compacted_negative_dsts[etype].view(-1),
) )
node_pairs, labels = minibatch.node_pairs_with_labels
if mode == "neg_src":
expect_node_pairs = {
"A:r:B": (
torch.tensor([1, 1, 2, 0, 1, 2]),
torch.tensor([1, 0, 1, 1, 0, 0]),
),
"B:rr:A": (
torch.tensor([0, 1, 1, 2, 0, 2]),
torch.tensor([1, 0, 1, 1, 0, 0]),
),
}
elif mode == "neg_dst":
expect_node_pairs = {
"A:r:B": (
torch.tensor([1, 1, 1, 1, 1, 1]),
torch.tensor([1, 0, 1, 3, 2, 1]),
),
"B:rr:A": (
torch.tensor([0, 1, 0, 0, 1, 1]),
torch.tensor([1, 0, 2, 1, 3, 1]),
),
}
else:
expect_node_pairs = {
"A:r:B": (
torch.tensor([1, 1, 2, 0, 1, 2]),
torch.tensor([1, 0, 1, 3, 2, 1]),
),
"B:rr:A": (
torch.tensor([0, 1, 1, 2, 0, 2]),
torch.tensor([1, 0, 2, 1, 3, 1]),
),
}
expect_labels = {
"A:r:B": torch.tensor([1, 1, 0, 0, 0, 0]),
"B:rr:A": torch.tensor([1, 1, 0, 0, 0, 0]),
}
for etype in node_pairs:
assert torch.equal(node_pairs[etype][0], expect_node_pairs[etype][0])
assert torch.equal(node_pairs[etype][1], expect_node_pairs[etype][1])
assert torch.equal(labels[etype], expect_labels[etype])
def create_homo_minibatch_csc_format(): def create_homo_minibatch_csc_format():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment