Unverified Commit d2497448 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add `indexes` to MiniBatch. (#6989)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent d67dae1d
......@@ -80,6 +80,19 @@ class MiniBatch:
or other graph components depending on the specific context.
"""
indexes: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Indexes associated with seed nodes / node pairs in the graph, which
indicates to which query a seed node / node pair belongs.
- If `indexes` is a tensor: It indicates the graph is homogeneous. The
value should be corresponding query to given 'seed_nodes' or
'node_pairs'.
- If `indexes` is a dictionary: It indicates the graph is
heterogeneous. The keys should be node or edge type and the value should
be corresponding query to given 'seed_nodes' or 'node_pairs'. For each
key, indexes are consecutive integers starting from zero.
"""
negative_srcs: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of negative samples for the head nodes in the link
......
......@@ -70,6 +70,7 @@ def test_minibatch_representation_homo():
negative_dsts=None,
labels=None,
input_nodes=None,
indexes=None,
edge_features=None,
compacted_node_pairs=None,
compacted_negative_srcs=None,
......@@ -138,6 +139,7 @@ def test_minibatch_representation_homo():
[8]]),
labels=tensor([0., 1., 2.]),
input_nodes=tensor([8, 1, 6, 5, 9, 0, 2, 4]),
indexes=None,
edge_features=[{'x': tensor([9, 0, 1, 1, 7, 4])},
{'x': tensor([0, 2, 2])}],
compacted_node_pairs=CSCFormatBase(indptr=tensor([0, 2, 3]),
......@@ -295,6 +297,7 @@ def test_minibatch_representation_hetero():
[8]])},
labels={'B': tensor([2, 5])},
input_nodes={'A': tensor([ 5, 7, 9, 11]), 'B': tensor([10, 11, 12])},
indexes=None,
edge_features=[{('A:r:B', 'x'): tensor([4, 2, 4])},
{('A:r:B', 'x'): tensor([0, 6])}],
compacted_node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
......
......@@ -103,6 +103,7 @@ def test_integration_link_prediction():
[3, 4]]),
labels=None,
input_nodes=tensor([5, 3, 1, 2, 0, 4]),
indexes=None,
edge_features=[{},
{}],
compacted_node_pairs=(tensor([0, 1, 1, 1]),
......@@ -160,6 +161,7 @@ def test_integration_link_prediction():
[1, 5]]),
labels=None,
input_nodes=tensor([3, 4, 0, 1, 5, 2]),
indexes=None,
edge_features=[{},
{}],
compacted_node_pairs=(tensor([0, 1, 1, 2]),
......@@ -209,6 +211,7 @@ def test_integration_link_prediction():
[0, 1]]),
labels=None,
input_nodes=tensor([5, 4, 0, 1]),
indexes=None,
edge_features=[{},
{}],
compacted_node_pairs=(tensor([0, 1]),
......@@ -310,6 +313,7 @@ def test_integration_node_classification():
negative_dsts=None,
labels=None,
input_nodes=tensor([5, 3, 1, 2, 4]),
indexes=None,
edge_features=[{},
{}],
compacted_node_pairs=(tensor([0, 1, 1, 1]),
......@@ -350,6 +354,7 @@ def test_integration_node_classification():
negative_dsts=None,
labels=None,
input_nodes=tensor([3, 4, 0]),
indexes=None,
edge_features=[{},
{}],
compacted_node_pairs=(tensor([0, 1, 1, 2]),
......@@ -390,6 +395,7 @@ def test_integration_node_classification():
negative_dsts=None,
labels=None,
input_nodes=tensor([5, 4, 0]),
indexes=None,
edge_features=[{},
{}],
compacted_node_pairs=(tensor([0, 1]),
......
......@@ -825,6 +825,7 @@ def test_ItemSetDict_seeds(batch_size, shuffle, drop_last):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seeds is not None
assert minibatch.labels is None
assert minibatch.indexes is None
is_last = (i + 1) * batch_size >= total_pairs
if not is_last or total_pairs % batch_size == 0:
expected_batch_size = batch_size
......@@ -877,6 +878,7 @@ def test_ItemSetDict_seeds_labels(batch_size, shuffle, drop_last):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seeds is not None
assert minibatch.labels is not None
assert minibatch.indexes is None
is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment