Unverified Commit 9a8aa8fa authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Grapbolt]Negative node pairs should be 2D (#6951)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-21-218.ap-northeast-1.compute.internal>
parent 90e57e74
......@@ -299,15 +299,15 @@ class MiniBatch:
# For homogeneous graph.
if isinstance(self.compacted_negative_srcs, torch.Tensor):
negative_node_pairs = (
self.compacted_negative_srcs.view(-1),
self.compacted_negative_dsts.view(-1),
self.compacted_negative_srcs,
self.compacted_negative_dsts,
)
# For heterogeneous graph.
else:
negative_node_pairs = {
etype: (
neg_src.view(-1),
self.compacted_negative_dsts[etype].view(-1),
neg_src,
self.compacted_negative_dsts[etype],
)
for etype, neg_src in self.compacted_negative_srcs.items()
}
......@@ -319,10 +319,10 @@ class MiniBatch:
if isinstance(self.compacted_negative_srcs, torch.Tensor):
negative_ratio = self.compacted_negative_srcs.size(1)
negative_node_pairs = (
self.compacted_negative_srcs.view(-1),
self.compacted_node_pairs[1].repeat_interleave(
negative_ratio
),
self.compacted_negative_srcs,
self.compacted_node_pairs[1]
.repeat_interleave(negative_ratio)
.view(-1, negative_ratio),
)
# For heterogeneous graph.
else:
......@@ -331,10 +331,10 @@ class MiniBatch:
].size(1)
negative_node_pairs = {
etype: (
neg_src.view(-1),
self.compacted_node_pairs[etype][1].repeat_interleave(
negative_ratio
),
neg_src,
self.compacted_node_pairs[etype][1]
.repeat_interleave(negative_ratio)
.view(-1, negative_ratio),
)
for etype, neg_src in self.compacted_negative_srcs.items()
}
......@@ -346,10 +346,10 @@ class MiniBatch:
if isinstance(self.compacted_negative_dsts, torch.Tensor):
negative_ratio = self.compacted_negative_dsts.size(1)
negative_node_pairs = (
self.compacted_node_pairs[0].repeat_interleave(
negative_ratio
),
self.compacted_negative_dsts.view(-1),
self.compacted_node_pairs[0]
.repeat_interleave(negative_ratio)
.view(-1, negative_ratio),
self.compacted_negative_dsts,
)
# For heterogeneous graph.
else:
......@@ -358,10 +358,10 @@ class MiniBatch:
].size(1)
negative_node_pairs = {
etype: (
self.compacted_node_pairs[etype][0].repeat_interleave(
negative_ratio
),
neg_dst.view(-1),
self.compacted_node_pairs[etype][0]
.repeat_interleave(negative_ratio)
.view(-1, negative_ratio),
neg_dst,
)
for etype, neg_dst in self.compacted_negative_dsts.items()
}
......@@ -396,6 +396,7 @@ class MiniBatch:
for etype in positive_node_pairs:
pos_src, pos_dst = positive_node_pairs[etype]
neg_src, neg_dst = negative_node_pairs[etype]
neg_src, neg_dst = neg_src.view(-1), neg_dst.view(-1)
node_pairs_by_etype[etype] = (
torch.cat((pos_src, neg_src), dim=0),
torch.cat((pos_dst, neg_dst), dim=0),
......@@ -410,6 +411,7 @@ class MiniBatch:
# Homogeneous graph.
pos_src, pos_dst = positive_node_pairs
neg_src, neg_dst = negative_node_pairs
neg_src, neg_dst = neg_src.view(-1), neg_dst.view(-1)
node_pairs = (
torch.cat((pos_src, neg_src), dim=0),
torch.cat((pos_dst, neg_dst), dim=0),
......
......@@ -130,10 +130,16 @@ class SubgraphSampler(MiniBatchTransformer):
for etype, _ in neg_src.items():
src_type, _, _ = etype_str_to_tuple(etype)
compacted_negative_srcs[etype] = compacted[src_type].pop(0)
compacted_negative_srcs[etype] = compacted_negative_srcs[
etype
].view(neg_src[etype].shape)
if has_neg_dst:
for etype, _ in neg_dst.items():
_, _, dst_type = etype_str_to_tuple(etype)
compacted_negative_dsts[etype] = compacted[dst_type].pop(0)
compacted_negative_dsts[etype] = compacted_negative_dsts[
etype
].view(neg_dst[etype].shape)
else:
# Collect nodes from all types of input.
nodes = list(node_pairs)
......
......@@ -125,8 +125,12 @@ def test_minibatch_representation_homo():
negative_srcs=tensor([[8],
[1],
[6]]),
negative_node_pairs=(tensor([0, 1, 2]),
tensor([6, 0, 0])),
negative_node_pairs=(tensor([[0],
[1],
[2]]),
tensor([[6],
[0],
[0]])),
negative_dsts=tensor([[2],
[8],
[8]]),
......@@ -278,7 +282,11 @@ def test_minibatch_representation_hetero():
negative_srcs={'B': tensor([[8],
[1],
[6]])},
negative_node_pairs={'A:r:B': (tensor([0, 1, 2]), tensor([6, 0, 0]))},
negative_node_pairs={'A:r:B': (tensor([[0],
[1],
[2]]), tensor([[6],
[0],
[0]]))},
negative_dsts={'B': tensor([[2],
[8],
[8]])},
......@@ -773,12 +781,12 @@ def test_dgl_link_predication_homo(mode):
if mode == "neg_graph" or mode == "neg_src":
assert torch.equal(
minibatch.negative_node_pairs[0],
minibatch.compacted_negative_srcs.view(-1),
minibatch.compacted_negative_srcs,
)
if mode == "neg_graph" or mode == "neg_dst":
assert torch.equal(
minibatch.negative_node_pairs[1],
minibatch.compacted_negative_dsts.view(-1),
minibatch.compacted_negative_dsts,
)
(
node_pairs,
......@@ -834,11 +842,11 @@ def test_dgl_link_predication_hetero(mode):
for etype, src in minibatch.compacted_negative_srcs.items():
assert torch.equal(
minibatch.negative_node_pairs[etype][0],
src.view(-1),
src,
)
if mode == "neg_graph" or mode == "neg_dst":
for etype, dst in minibatch.compacted_negative_dsts.items():
assert torch.equal(
minibatch.negative_node_pairs[etype][1],
minibatch.compacted_negative_dsts[etype].view(-1),
minibatch.compacted_negative_dsts[etype],
)
......@@ -88,8 +88,14 @@ def test_integration_link_prediction():
[0.9634, 0.2294],
[0.5503, 0.8223]])},
negative_srcs=None,
negative_node_pairs=(tensor([0, 0, 1, 1, 1, 1, 1, 1]),
tensor([4, 4, 1, 4, 0, 1, 1, 5])),
negative_node_pairs=(tensor([[0, 0],
[1, 1],
[1, 1],
[1, 1]]),
tensor([[4, 4],
[1, 4],
[0, 1],
[1, 5]])),
negative_dsts=tensor([[0, 0],
[3, 0],
[5, 3],
......@@ -138,8 +144,14 @@ def test_integration_link_prediction():
[0.5160, 0.2486],
[0.2109, 0.1089]])},
negative_srcs=None,
negative_node_pairs=(tensor([0, 0, 1, 1, 1, 1, 2, 2]),
tensor([3, 4, 5, 4, 1, 0, 3, 4])),
negative_node_pairs=(tensor([[0, 0],
[1, 1],
[1, 1],
[2, 2]]),
tensor([[3, 4],
[5, 4],
[1, 0],
[3, 4]])),
negative_dsts=tensor([[1, 5],
[2, 5],
[4, 3],
......@@ -186,8 +198,10 @@ def test_integration_link_prediction():
[0.9634, 0.2294],
[0.6172, 0.7865]])},
negative_srcs=None,
negative_node_pairs=(tensor([0, 0, 1, 1]),
tensor([2, 1, 2, 3])),
negative_node_pairs=(tensor([[0, 0],
[1, 1]]),
tensor([[2, 1],
[2, 3]])),
negative_dsts=tensor([[0, 4],
[0, 1]]),
labels=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment