Unverified Commit cda8b381 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add `compacted_seeds` to minibatch. (#7033)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent b085224f
...@@ -156,6 +156,15 @@ class MiniBatch: ...@@ -156,6 +156,15 @@ class MiniBatch:
all node ids inside are compacted. all node ids inside are compacted.
""" """
compacted_seeds: Union[
torch.Tensor,
Dict[str, torch.Tensor],
] = None
"""
Representation of compacted seeds corresponding to 'seeds', where
all node ids inside are compacted.
"""
compacted_negative_srcs: Union[torch.Tensor, Dict[str, torch.Tensor]] = None compacted_negative_srcs: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
""" """
Representation of compacted nodes corresponding to 'negative_srcs', where Representation of compacted nodes corresponding to 'negative_srcs', where
......
...@@ -72,6 +72,7 @@ def test_minibatch_representation_homo(): ...@@ -72,6 +72,7 @@ def test_minibatch_representation_homo():
input_nodes=None, input_nodes=None,
indexes=None, indexes=None,
edge_features=None, edge_features=None,
compacted_seeds=None,
compacted_node_pairs=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_srcs=None,
compacted_negative_dsts=None, compacted_negative_dsts=None,
...@@ -142,6 +143,7 @@ def test_minibatch_representation_homo(): ...@@ -142,6 +143,7 @@ def test_minibatch_representation_homo():
indexes=None, indexes=None,
edge_features=[{'x': tensor([9, 0, 1, 1, 7, 4])}, edge_features=[{'x': tensor([9, 0, 1, 1, 7, 4])},
{'x': tensor([0, 2, 2])}], {'x': tensor([0, 2, 2])}],
compacted_seeds=None,
compacted_node_pairs=CSCFormatBase(indptr=tensor([0, 2, 3]), compacted_node_pairs=CSCFormatBase(indptr=tensor([0, 2, 3]),
indices=tensor([3, 4, 5]), indices=tensor([3, 4, 5]),
), ),
...@@ -300,6 +302,7 @@ def test_minibatch_representation_hetero(): ...@@ -300,6 +302,7 @@ def test_minibatch_representation_hetero():
indexes=None, indexes=None,
edge_features=[{('A:r:B', 'x'): tensor([4, 2, 4])}, edge_features=[{('A:r:B', 'x'): tensor([4, 2, 4])},
{('A:r:B', 'x'): tensor([0, 6])}], {('A:r:B', 'x'): tensor([0, 6])}],
compacted_seeds=None,
compacted_node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]), compacted_node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([3, 4, 5]), indices=tensor([3, 4, 5]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]), ), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
......
...@@ -106,6 +106,7 @@ def test_integration_link_prediction(): ...@@ -106,6 +106,7 @@ def test_integration_link_prediction():
indexes=None, indexes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
compacted_seeds=None,
compacted_node_pairs=(tensor([0, 1, 1, 1]), compacted_node_pairs=(tensor([0, 1, 1, 1]),
tensor([2, 3, 3, 1])), tensor([2, 3, 3, 1])),
compacted_negative_srcs=None, compacted_negative_srcs=None,
...@@ -164,6 +165,7 @@ def test_integration_link_prediction(): ...@@ -164,6 +165,7 @@ def test_integration_link_prediction():
indexes=None, indexes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
compacted_seeds=None,
compacted_node_pairs=(tensor([0, 1, 1, 2]), compacted_node_pairs=(tensor([0, 1, 1, 2]),
tensor([0, 0, 1, 1])), tensor([0, 0, 1, 1])),
compacted_negative_srcs=None, compacted_negative_srcs=None,
...@@ -214,6 +216,7 @@ def test_integration_link_prediction(): ...@@ -214,6 +216,7 @@ def test_integration_link_prediction():
indexes=None, indexes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
compacted_seeds=None,
compacted_node_pairs=(tensor([0, 1]), compacted_node_pairs=(tensor([0, 1]),
tensor([0, 0])), tensor([0, 0])),
compacted_negative_srcs=None, compacted_negative_srcs=None,
...@@ -316,6 +319,7 @@ def test_integration_node_classification(): ...@@ -316,6 +319,7 @@ def test_integration_node_classification():
indexes=None, indexes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
compacted_seeds=None,
compacted_node_pairs=(tensor([0, 1, 1, 1]), compacted_node_pairs=(tensor([0, 1, 1, 1]),
tensor([2, 3, 3, 1])), tensor([2, 3, 3, 1])),
compacted_negative_srcs=None, compacted_negative_srcs=None,
...@@ -357,6 +361,7 @@ def test_integration_node_classification(): ...@@ -357,6 +361,7 @@ def test_integration_node_classification():
indexes=None, indexes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
compacted_seeds=None,
compacted_node_pairs=(tensor([0, 1, 1, 2]), compacted_node_pairs=(tensor([0, 1, 1, 2]),
tensor([0, 0, 1, 1])), tensor([0, 0, 1, 1])),
compacted_negative_srcs=None, compacted_negative_srcs=None,
...@@ -398,6 +403,7 @@ def test_integration_node_classification(): ...@@ -398,6 +403,7 @@ def test_integration_node_classification():
indexes=None, indexes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
compacted_seeds=None,
compacted_node_pairs=(tensor([0, 1]), compacted_node_pairs=(tensor([0, 1]),
tensor([0, 0])), tensor([0, 0])),
compacted_negative_srcs=None, compacted_negative_srcs=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment