Unverified Commit cda8b381 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add `compacted_seeds` to minibatch. (#7033)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent b085224f
......@@ -156,6 +156,15 @@ class MiniBatch:
all node ids inside are compacted.
"""
compacted_seeds: Union[
torch.Tensor,
Dict[str, torch.Tensor],
] = None
"""
Representation of compacted seeds corresponding to 'seeds', where
all node ids inside are compacted.
"""
compacted_negative_srcs: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of compacted nodes corresponding to 'negative_srcs', where
......
......@@ -72,6 +72,7 @@ def test_minibatch_representation_homo():
input_nodes=None,
indexes=None,
edge_features=None,
compacted_seeds=None,
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
......@@ -142,6 +143,7 @@ def test_minibatch_representation_homo():
indexes=None,
edge_features=[{'x': tensor([9, 0, 1, 1, 7, 4])},
{'x': tensor([0, 2, 2])}],
compacted_seeds=None,
compacted_node_pairs=CSCFormatBase(indptr=tensor([0, 2, 3]),
indices=tensor([3, 4, 5]),
),
......@@ -300,6 +302,7 @@ def test_minibatch_representation_hetero():
indexes=None,
edge_features=[{('A:r:B', 'x'): tensor([4, 2, 4])},
{('A:r:B', 'x'): tensor([0, 6])}],
compacted_seeds=None,
compacted_node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([3, 4, 5]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
......
......@@ -106,6 +106,7 @@ def test_integration_link_prediction():
indexes=None,
edge_features=[{},
{}],
compacted_seeds=None,
compacted_node_pairs=(tensor([0, 1, 1, 1]),
tensor([2, 3, 3, 1])),
compacted_negative_srcs=None,
......@@ -164,6 +165,7 @@ def test_integration_link_prediction():
indexes=None,
edge_features=[{},
{}],
compacted_seeds=None,
compacted_node_pairs=(tensor([0, 1, 1, 2]),
tensor([0, 0, 1, 1])),
compacted_negative_srcs=None,
......@@ -214,6 +216,7 @@ def test_integration_link_prediction():
indexes=None,
edge_features=[{},
{}],
compacted_seeds=None,
compacted_node_pairs=(tensor([0, 1]),
tensor([0, 0])),
compacted_negative_srcs=None,
......@@ -316,6 +319,7 @@ def test_integration_node_classification():
indexes=None,
edge_features=[{},
{}],
compacted_seeds=None,
compacted_node_pairs=(tensor([0, 1, 1, 1]),
tensor([2, 3, 3, 1])),
compacted_negative_srcs=None,
......@@ -357,6 +361,7 @@ def test_integration_node_classification():
indexes=None,
edge_features=[{},
{}],
compacted_seeds=None,
compacted_node_pairs=(tensor([0, 1, 1, 2]),
tensor([0, 0, 1, 1])),
compacted_negative_srcs=None,
......@@ -398,6 +403,7 @@ def test_integration_node_classification():
indexes=None,
edge_features=[{},
{}],
compacted_seeds=None,
compacted_node_pairs=(tensor([0, 1]),
tensor([0, 0])),
compacted_negative_srcs=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment