Unverified Commit 9a8aa8fa authored by peizhou001's avatar peizhou001 Committed by GitHub
Browse files

[Grapbolt]Negative node pairs should be 2D (#6951)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-21-218.ap-northeast-1.compute.internal>
parent 90e57e74
...@@ -299,15 +299,15 @@ class MiniBatch: ...@@ -299,15 +299,15 @@ class MiniBatch:
# For homogeneous graph. # For homogeneous graph.
if isinstance(self.compacted_negative_srcs, torch.Tensor): if isinstance(self.compacted_negative_srcs, torch.Tensor):
negative_node_pairs = ( negative_node_pairs = (
self.compacted_negative_srcs.view(-1), self.compacted_negative_srcs,
self.compacted_negative_dsts.view(-1), self.compacted_negative_dsts,
) )
# For heterogeneous graph. # For heterogeneous graph.
else: else:
negative_node_pairs = { negative_node_pairs = {
etype: ( etype: (
neg_src.view(-1), neg_src,
self.compacted_negative_dsts[etype].view(-1), self.compacted_negative_dsts[etype],
) )
for etype, neg_src in self.compacted_negative_srcs.items() for etype, neg_src in self.compacted_negative_srcs.items()
} }
...@@ -319,10 +319,10 @@ class MiniBatch: ...@@ -319,10 +319,10 @@ class MiniBatch:
if isinstance(self.compacted_negative_srcs, torch.Tensor): if isinstance(self.compacted_negative_srcs, torch.Tensor):
negative_ratio = self.compacted_negative_srcs.size(1) negative_ratio = self.compacted_negative_srcs.size(1)
negative_node_pairs = ( negative_node_pairs = (
self.compacted_negative_srcs.view(-1), self.compacted_negative_srcs,
self.compacted_node_pairs[1].repeat_interleave( self.compacted_node_pairs[1]
negative_ratio .repeat_interleave(negative_ratio)
), .view(-1, negative_ratio),
) )
# For heterogeneous graph. # For heterogeneous graph.
else: else:
...@@ -331,10 +331,10 @@ class MiniBatch: ...@@ -331,10 +331,10 @@ class MiniBatch:
].size(1) ].size(1)
negative_node_pairs = { negative_node_pairs = {
etype: ( etype: (
neg_src.view(-1), neg_src,
self.compacted_node_pairs[etype][1].repeat_interleave( self.compacted_node_pairs[etype][1]
negative_ratio .repeat_interleave(negative_ratio)
), .view(-1, negative_ratio),
) )
for etype, neg_src in self.compacted_negative_srcs.items() for etype, neg_src in self.compacted_negative_srcs.items()
} }
...@@ -346,10 +346,10 @@ class MiniBatch: ...@@ -346,10 +346,10 @@ class MiniBatch:
if isinstance(self.compacted_negative_dsts, torch.Tensor): if isinstance(self.compacted_negative_dsts, torch.Tensor):
negative_ratio = self.compacted_negative_dsts.size(1) negative_ratio = self.compacted_negative_dsts.size(1)
negative_node_pairs = ( negative_node_pairs = (
self.compacted_node_pairs[0].repeat_interleave( self.compacted_node_pairs[0]
negative_ratio .repeat_interleave(negative_ratio)
), .view(-1, negative_ratio),
self.compacted_negative_dsts.view(-1), self.compacted_negative_dsts,
) )
# For heterogeneous graph. # For heterogeneous graph.
else: else:
...@@ -358,10 +358,10 @@ class MiniBatch: ...@@ -358,10 +358,10 @@ class MiniBatch:
].size(1) ].size(1)
negative_node_pairs = { negative_node_pairs = {
etype: ( etype: (
self.compacted_node_pairs[etype][0].repeat_interleave( self.compacted_node_pairs[etype][0]
negative_ratio .repeat_interleave(negative_ratio)
), .view(-1, negative_ratio),
neg_dst.view(-1), neg_dst,
) )
for etype, neg_dst in self.compacted_negative_dsts.items() for etype, neg_dst in self.compacted_negative_dsts.items()
} }
...@@ -396,6 +396,7 @@ class MiniBatch: ...@@ -396,6 +396,7 @@ class MiniBatch:
for etype in positive_node_pairs: for etype in positive_node_pairs:
pos_src, pos_dst = positive_node_pairs[etype] pos_src, pos_dst = positive_node_pairs[etype]
neg_src, neg_dst = negative_node_pairs[etype] neg_src, neg_dst = negative_node_pairs[etype]
neg_src, neg_dst = neg_src.view(-1), neg_dst.view(-1)
node_pairs_by_etype[etype] = ( node_pairs_by_etype[etype] = (
torch.cat((pos_src, neg_src), dim=0), torch.cat((pos_src, neg_src), dim=0),
torch.cat((pos_dst, neg_dst), dim=0), torch.cat((pos_dst, neg_dst), dim=0),
...@@ -410,6 +411,7 @@ class MiniBatch: ...@@ -410,6 +411,7 @@ class MiniBatch:
# Homogeneous graph. # Homogeneous graph.
pos_src, pos_dst = positive_node_pairs pos_src, pos_dst = positive_node_pairs
neg_src, neg_dst = negative_node_pairs neg_src, neg_dst = negative_node_pairs
neg_src, neg_dst = neg_src.view(-1), neg_dst.view(-1)
node_pairs = ( node_pairs = (
torch.cat((pos_src, neg_src), dim=0), torch.cat((pos_src, neg_src), dim=0),
torch.cat((pos_dst, neg_dst), dim=0), torch.cat((pos_dst, neg_dst), dim=0),
......
...@@ -130,10 +130,16 @@ class SubgraphSampler(MiniBatchTransformer): ...@@ -130,10 +130,16 @@ class SubgraphSampler(MiniBatchTransformer):
for etype, _ in neg_src.items(): for etype, _ in neg_src.items():
src_type, _, _ = etype_str_to_tuple(etype) src_type, _, _ = etype_str_to_tuple(etype)
compacted_negative_srcs[etype] = compacted[src_type].pop(0) compacted_negative_srcs[etype] = compacted[src_type].pop(0)
compacted_negative_srcs[etype] = compacted_negative_srcs[
etype
].view(neg_src[etype].shape)
if has_neg_dst: if has_neg_dst:
for etype, _ in neg_dst.items(): for etype, _ in neg_dst.items():
_, _, dst_type = etype_str_to_tuple(etype) _, _, dst_type = etype_str_to_tuple(etype)
compacted_negative_dsts[etype] = compacted[dst_type].pop(0) compacted_negative_dsts[etype] = compacted[dst_type].pop(0)
compacted_negative_dsts[etype] = compacted_negative_dsts[
etype
].view(neg_dst[etype].shape)
else: else:
# Collect nodes from all types of input. # Collect nodes from all types of input.
nodes = list(node_pairs) nodes = list(node_pairs)
......
...@@ -125,8 +125,12 @@ def test_minibatch_representation_homo(): ...@@ -125,8 +125,12 @@ def test_minibatch_representation_homo():
negative_srcs=tensor([[8], negative_srcs=tensor([[8],
[1], [1],
[6]]), [6]]),
negative_node_pairs=(tensor([0, 1, 2]), negative_node_pairs=(tensor([[0],
tensor([6, 0, 0])), [1],
[2]]),
tensor([[6],
[0],
[0]])),
negative_dsts=tensor([[2], negative_dsts=tensor([[2],
[8], [8],
[8]]), [8]]),
...@@ -278,7 +282,11 @@ def test_minibatch_representation_hetero(): ...@@ -278,7 +282,11 @@ def test_minibatch_representation_hetero():
negative_srcs={'B': tensor([[8], negative_srcs={'B': tensor([[8],
[1], [1],
[6]])}, [6]])},
negative_node_pairs={'A:r:B': (tensor([0, 1, 2]), tensor([6, 0, 0]))}, negative_node_pairs={'A:r:B': (tensor([[0],
[1],
[2]]), tensor([[6],
[0],
[0]]))},
negative_dsts={'B': tensor([[2], negative_dsts={'B': tensor([[2],
[8], [8],
[8]])}, [8]])},
...@@ -773,12 +781,12 @@ def test_dgl_link_predication_homo(mode): ...@@ -773,12 +781,12 @@ def test_dgl_link_predication_homo(mode):
if mode == "neg_graph" or mode == "neg_src": if mode == "neg_graph" or mode == "neg_src":
assert torch.equal( assert torch.equal(
minibatch.negative_node_pairs[0], minibatch.negative_node_pairs[0],
minibatch.compacted_negative_srcs.view(-1), minibatch.compacted_negative_srcs,
) )
if mode == "neg_graph" or mode == "neg_dst": if mode == "neg_graph" or mode == "neg_dst":
assert torch.equal( assert torch.equal(
minibatch.negative_node_pairs[1], minibatch.negative_node_pairs[1],
minibatch.compacted_negative_dsts.view(-1), minibatch.compacted_negative_dsts,
) )
( (
node_pairs, node_pairs,
...@@ -834,11 +842,11 @@ def test_dgl_link_predication_hetero(mode): ...@@ -834,11 +842,11 @@ def test_dgl_link_predication_hetero(mode):
for etype, src in minibatch.compacted_negative_srcs.items(): for etype, src in minibatch.compacted_negative_srcs.items():
assert torch.equal( assert torch.equal(
minibatch.negative_node_pairs[etype][0], minibatch.negative_node_pairs[etype][0],
src.view(-1), src,
) )
if mode == "neg_graph" or mode == "neg_dst": if mode == "neg_graph" or mode == "neg_dst":
for etype, dst in minibatch.compacted_negative_dsts.items(): for etype, dst in minibatch.compacted_negative_dsts.items():
assert torch.equal( assert torch.equal(
minibatch.negative_node_pairs[etype][1], minibatch.negative_node_pairs[etype][1],
minibatch.compacted_negative_dsts[etype].view(-1), minibatch.compacted_negative_dsts[etype],
) )
...@@ -88,8 +88,14 @@ def test_integration_link_prediction(): ...@@ -88,8 +88,14 @@ def test_integration_link_prediction():
[0.9634, 0.2294], [0.9634, 0.2294],
[0.5503, 0.8223]])}, [0.5503, 0.8223]])},
negative_srcs=None, negative_srcs=None,
negative_node_pairs=(tensor([0, 0, 1, 1, 1, 1, 1, 1]), negative_node_pairs=(tensor([[0, 0],
tensor([4, 4, 1, 4, 0, 1, 1, 5])), [1, 1],
[1, 1],
[1, 1]]),
tensor([[4, 4],
[1, 4],
[0, 1],
[1, 5]])),
negative_dsts=tensor([[0, 0], negative_dsts=tensor([[0, 0],
[3, 0], [3, 0],
[5, 3], [5, 3],
...@@ -138,8 +144,14 @@ def test_integration_link_prediction(): ...@@ -138,8 +144,14 @@ def test_integration_link_prediction():
[0.5160, 0.2486], [0.5160, 0.2486],
[0.2109, 0.1089]])}, [0.2109, 0.1089]])},
negative_srcs=None, negative_srcs=None,
negative_node_pairs=(tensor([0, 0, 1, 1, 1, 1, 2, 2]), negative_node_pairs=(tensor([[0, 0],
tensor([3, 4, 5, 4, 1, 0, 3, 4])), [1, 1],
[1, 1],
[2, 2]]),
tensor([[3, 4],
[5, 4],
[1, 0],
[3, 4]])),
negative_dsts=tensor([[1, 5], negative_dsts=tensor([[1, 5],
[2, 5], [2, 5],
[4, 3], [4, 3],
...@@ -186,8 +198,10 @@ def test_integration_link_prediction(): ...@@ -186,8 +198,10 @@ def test_integration_link_prediction():
[0.9634, 0.2294], [0.9634, 0.2294],
[0.6172, 0.7865]])}, [0.6172, 0.7865]])},
negative_srcs=None, negative_srcs=None,
negative_node_pairs=(tensor([0, 0, 1, 1]), negative_node_pairs=(tensor([[0, 0],
tensor([2, 1, 2, 3])), [1, 1]]),
tensor([[2, 1],
[2, 3]])),
negative_dsts=tensor([[0, 4], negative_dsts=tensor([[0, 4],
[0, 1]]), [0, 1]]),
labels=None, labels=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment