Unverified Commit d2497448 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add `indexes` to MiniBatch. (#6989)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent d67dae1d
...@@ -80,6 +80,19 @@ class MiniBatch: ...@@ -80,6 +80,19 @@ class MiniBatch:
or other graph components depending on the specific context. or other graph components depending on the specific context.
""" """
indexes: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Indexes associated with seed nodes / node pairs in the graph, which
indicates to which query a seed node / node pair belongs.
- If `indexes` is a tensor: It indicates the graph is homogeneous. The
value should be corresponding query to given 'seed_nodes' or
'node_pairs'.
- If `indexes` is a dictionary: It indicates the graph is
heterogeneous. The keys should be node or edge type and the value should
be corresponding query to given 'seed_nodes' or 'node_pairs'. For each
key, indexes are consecutive integers starting from zero.
"""
negative_srcs: Union[torch.Tensor, Dict[str, torch.Tensor]] = None negative_srcs: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
""" """
Representation of negative samples for the head nodes in the link Representation of negative samples for the head nodes in the link
......
...@@ -70,6 +70,7 @@ def test_minibatch_representation_homo(): ...@@ -70,6 +70,7 @@ def test_minibatch_representation_homo():
negative_dsts=None, negative_dsts=None,
labels=None, labels=None,
input_nodes=None, input_nodes=None,
indexes=None,
edge_features=None, edge_features=None,
compacted_node_pairs=None, compacted_node_pairs=None,
compacted_negative_srcs=None, compacted_negative_srcs=None,
...@@ -138,6 +139,7 @@ def test_minibatch_representation_homo(): ...@@ -138,6 +139,7 @@ def test_minibatch_representation_homo():
[8]]), [8]]),
labels=tensor([0., 1., 2.]), labels=tensor([0., 1., 2.]),
input_nodes=tensor([8, 1, 6, 5, 9, 0, 2, 4]), input_nodes=tensor([8, 1, 6, 5, 9, 0, 2, 4]),
indexes=None,
edge_features=[{'x': tensor([9, 0, 1, 1, 7, 4])}, edge_features=[{'x': tensor([9, 0, 1, 1, 7, 4])},
{'x': tensor([0, 2, 2])}], {'x': tensor([0, 2, 2])}],
compacted_node_pairs=CSCFormatBase(indptr=tensor([0, 2, 3]), compacted_node_pairs=CSCFormatBase(indptr=tensor([0, 2, 3]),
...@@ -295,6 +297,7 @@ def test_minibatch_representation_hetero(): ...@@ -295,6 +297,7 @@ def test_minibatch_representation_hetero():
[8]])}, [8]])},
labels={'B': tensor([2, 5])}, labels={'B': tensor([2, 5])},
input_nodes={'A': tensor([ 5, 7, 9, 11]), 'B': tensor([10, 11, 12])}, input_nodes={'A': tensor([ 5, 7, 9, 11]), 'B': tensor([10, 11, 12])},
indexes=None,
edge_features=[{('A:r:B', 'x'): tensor([4, 2, 4])}, edge_features=[{('A:r:B', 'x'): tensor([4, 2, 4])},
{('A:r:B', 'x'): tensor([0, 6])}], {('A:r:B', 'x'): tensor([0, 6])}],
compacted_node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]), compacted_node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
......
...@@ -103,6 +103,7 @@ def test_integration_link_prediction(): ...@@ -103,6 +103,7 @@ def test_integration_link_prediction():
[3, 4]]), [3, 4]]),
labels=None, labels=None,
input_nodes=tensor([5, 3, 1, 2, 0, 4]), input_nodes=tensor([5, 3, 1, 2, 0, 4]),
indexes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
compacted_node_pairs=(tensor([0, 1, 1, 1]), compacted_node_pairs=(tensor([0, 1, 1, 1]),
...@@ -160,6 +161,7 @@ def test_integration_link_prediction(): ...@@ -160,6 +161,7 @@ def test_integration_link_prediction():
[1, 5]]), [1, 5]]),
labels=None, labels=None,
input_nodes=tensor([3, 4, 0, 1, 5, 2]), input_nodes=tensor([3, 4, 0, 1, 5, 2]),
indexes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
compacted_node_pairs=(tensor([0, 1, 1, 2]), compacted_node_pairs=(tensor([0, 1, 1, 2]),
...@@ -209,6 +211,7 @@ def test_integration_link_prediction(): ...@@ -209,6 +211,7 @@ def test_integration_link_prediction():
[0, 1]]), [0, 1]]),
labels=None, labels=None,
input_nodes=tensor([5, 4, 0, 1]), input_nodes=tensor([5, 4, 0, 1]),
indexes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
compacted_node_pairs=(tensor([0, 1]), compacted_node_pairs=(tensor([0, 1]),
...@@ -310,6 +313,7 @@ def test_integration_node_classification(): ...@@ -310,6 +313,7 @@ def test_integration_node_classification():
negative_dsts=None, negative_dsts=None,
labels=None, labels=None,
input_nodes=tensor([5, 3, 1, 2, 4]), input_nodes=tensor([5, 3, 1, 2, 4]),
indexes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
compacted_node_pairs=(tensor([0, 1, 1, 1]), compacted_node_pairs=(tensor([0, 1, 1, 1]),
...@@ -350,6 +354,7 @@ def test_integration_node_classification(): ...@@ -350,6 +354,7 @@ def test_integration_node_classification():
negative_dsts=None, negative_dsts=None,
labels=None, labels=None,
input_nodes=tensor([3, 4, 0]), input_nodes=tensor([3, 4, 0]),
indexes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
compacted_node_pairs=(tensor([0, 1, 1, 2]), compacted_node_pairs=(tensor([0, 1, 1, 2]),
...@@ -390,6 +395,7 @@ def test_integration_node_classification(): ...@@ -390,6 +395,7 @@ def test_integration_node_classification():
negative_dsts=None, negative_dsts=None,
labels=None, labels=None,
input_nodes=tensor([5, 4, 0]), input_nodes=tensor([5, 4, 0]),
indexes=None,
edge_features=[{}, edge_features=[{},
{}], {}],
compacted_node_pairs=(tensor([0, 1]), compacted_node_pairs=(tensor([0, 1]),
......
...@@ -825,6 +825,7 @@ def test_ItemSetDict_seeds(batch_size, shuffle, drop_last): ...@@ -825,6 +825,7 @@ def test_ItemSetDict_seeds(batch_size, shuffle, drop_last):
assert isinstance(minibatch, gb.MiniBatch) assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seeds is not None assert minibatch.seeds is not None
assert minibatch.labels is None assert minibatch.labels is None
assert minibatch.indexes is None
is_last = (i + 1) * batch_size >= total_pairs is_last = (i + 1) * batch_size >= total_pairs
if not is_last or total_pairs % batch_size == 0: if not is_last or total_pairs % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
...@@ -877,6 +878,7 @@ def test_ItemSetDict_seeds_labels(batch_size, shuffle, drop_last): ...@@ -877,6 +878,7 @@ def test_ItemSetDict_seeds_labels(batch_size, shuffle, drop_last):
assert isinstance(minibatch, gb.MiniBatch) assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seeds is not None assert minibatch.seeds is not None
assert minibatch.labels is not None assert minibatch.labels is not None
assert minibatch.indexes is None
is_last = (i + 1) * batch_size >= total_ids is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0: if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size expected_batch_size = batch_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment