"src/vscode:/vscode.git/clone" did not exist on "673eb60f1c4d971e1a577bed767053e50578b461"
Unverified Commit 173257b3 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Add `seeds` to MiniBatch. (#6968)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent c864c910
......@@ -54,6 +54,32 @@ class MiniBatch:
value should be corresponding labels to given 'seed_nodes' or 'node_pairs'.
"""
seeds: Union[
torch.Tensor,
Dict[str, torch.Tensor],
] = None
"""
Representation of seed items utilized in node classification tasks, link
prediction tasks and hyperlinks tasks.
- If `seeds` is a tensor: it indicates that the seeds originate from a
homogeneous graph. It can be either a 1-dimensional or 2-dimensional
tensor:
- 1-dimensional tensor: Each element directly represents a seed node
within the graph.
- 2-dimensional tensor: Each row designates a seed item, which can
encompass various entities such as edges, hyperlinks, or other graph
components depending on the specific context.
- If `seeds` is a dictionary: it indicates that the seeds originate from a
heterogeneous graph. The keys should be edge or node type, and the value
should be a tensor, which can be either a 1-dimensional or 2-dimensional
tensor:
- 1-dimensional tensor: Each element directly represents a seed node
of the given type within the graph.
- 2-dimensional tensor: Each row designates a seed item of the given
type, which can encompass various entities such as edges, hyperlinks,
or other graph components depending on the specific context.
"""
negative_srcs: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of negative samples for the head nodes in the link
......
......@@ -58,7 +58,8 @@ def test_minibatch_representation_homo():
# Test minibatch without data.
minibatch = gb.MiniBatch()
expect_result = str(
"""MiniBatch(seed_nodes=None,
"""MiniBatch(seeds=None,
seed_nodes=None,
sampled_subgraphs=None,
positive_node_pairs=None,
node_pairs_with_labels=None,
......@@ -77,7 +78,7 @@ def test_minibatch_representation_homo():
)"""
)
result = str(minibatch)
assert result == expect_result, print(len(expect_result), len(result))
assert result == expect_result, print(expect_result, result)
# Test minibatch with all attributes.
minibatch = gb.MiniBatch(
node_pairs=csc_formats,
......@@ -93,7 +94,8 @@ def test_minibatch_representation_homo():
compacted_negative_dsts=compacted_negative_dsts,
)
expect_result = str(
"""MiniBatch(seed_nodes=None,
"""MiniBatch(seeds=None,
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6]),
indices=tensor([0, 1, 2, 2, 1, 2]),
),
......@@ -242,7 +244,8 @@ def test_minibatch_representation_hetero():
compacted_negative_dsts=compacted_negative_dsts,
)
expect_result = str(
"""MiniBatch(seed_nodes={'B': tensor([10, 15])},
"""MiniBatch(seeds=None,
seed_nodes={'B': tensor([10, 15])},
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([0, 1, 1]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
......
......@@ -60,7 +60,8 @@ def test_integration_link_prediction():
)
expected = [
str(
"""MiniBatch(seed_nodes=None,
"""MiniBatch(seeds=None,
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 1, 1, 1, 1, 2]),
indices=tensor([0, 4]),
),
......@@ -116,7 +117,8 @@ def test_integration_link_prediction():
)"""
),
str(
"""MiniBatch(seed_nodes=None,
"""MiniBatch(seeds=None,
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 0, 0, 1, 2, 3]),
indices=tensor([4, 1, 0]),
),
......@@ -172,7 +174,8 @@ def test_integration_link_prediction():
)"""
),
str(
"""MiniBatch(seed_nodes=None,
"""MiniBatch(seeds=None,
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 0, 1, 1, 2]),
indices=tensor([1, 0]),
),
......@@ -276,7 +279,8 @@ def test_integration_node_classification():
)
expected = [
str(
"""MiniBatch(seed_nodes=None,
"""MiniBatch(seeds=None,
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 3, 4]),
indices=tensor([4, 1, 0, 1]),
),
......@@ -317,7 +321,8 @@ def test_integration_node_classification():
)"""
),
str(
"""MiniBatch(seed_nodes=None,
"""MiniBatch(seeds=None,
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2, 2]),
indices=tensor([0, 2]),
),
......@@ -356,7 +361,8 @@ def test_integration_node_classification():
)"""
),
str(
"""MiniBatch(seed_nodes=None,
"""MiniBatch(seeds=None,
seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(sampled_csc=CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([0, 2]),
),
......
......@@ -376,6 +376,86 @@ def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
assert torch.all(negs_ids[:-1, 1] <= negs_ids[1:, 1]) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_seeds(batch_size, shuffle, drop_last):
# Node pairs.
num_ids = 103
seeds = torch.arange(0, 3 * num_ids).reshape(-1, 3)
item_set = gb.ItemSet(seeds, names="seeds")
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
seeds_ids = []
for i, minibatch in enumerate(item_sampler):
assert minibatch.seeds is not None
assert isinstance(minibatch.seeds, torch.Tensor)
assert minibatch.labels is None
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size
else:
if not drop_last:
expected_batch_size = num_ids % batch_size
else:
assert False
assert minibatch.seeds.shape == (expected_batch_size, 3)
# Verify seeds match.
assert torch.equal(minibatch.seeds[:, 0] + 1, minibatch.seeds[:, 1])
assert torch.equal(minibatch.seeds[:, 1] + 1, minibatch.seeds[:, 2])
# Archive batch.
seeds_ids.append(minibatch.seeds)
seeds_ids = torch.cat(seeds_ids)
assert torch.all(seeds_ids[:-1, 0] <= seeds_ids[1:, 0]) is not shuffle
assert torch.all(seeds_ids[:-1, 1] <= seeds_ids[1:, 1]) is not shuffle
assert torch.all(seeds_ids[:-1, 2] <= seeds_ids[1:, 2]) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSet_seeds_labels(batch_size, shuffle, drop_last):
# Node pairs and labels
num_ids = 103
seeds = torch.arange(0, 3 * num_ids).reshape(-1, 3)
labels = seeds[:, 0]
item_set = gb.ItemSet((seeds, labels), names=("seeds", "labels"))
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
seeds_ids = []
labels = []
for i, minibatch in enumerate(item_sampler):
assert minibatch.seeds is not None
assert isinstance(minibatch.seeds, torch.Tensor)
assert minibatch.labels is not None
label = minibatch.labels
assert len(minibatch.seeds) == len(label)
is_last = (i + 1) * batch_size >= num_ids
if not is_last or num_ids % batch_size == 0:
expected_batch_size = batch_size
else:
if not drop_last:
expected_batch_size = num_ids % batch_size
else:
assert False
assert minibatch.seeds.shape == (expected_batch_size, 3)
assert len(label) == expected_batch_size
# Verify seeds and labels match.
assert torch.equal(minibatch.seeds[:, 0] + 1, minibatch.seeds[:, 1])
assert torch.equal(minibatch.seeds[:, 1] + 1, minibatch.seeds[:, 2])
# Archive batch.
seeds_ids.append(minibatch.seeds)
labels.append(label)
seeds_ids = torch.cat(seeds_ids)
labels = torch.cat(labels)
assert torch.all(seeds_ids[:-1, 0] <= seeds_ids[1:, 0]) is not shuffle
assert torch.all(seeds_ids[:-1, 1] <= seeds_ids[1:, 1]) is not shuffle
assert torch.all(seeds_ids[:-1, 2] <= seeds_ids[1:, 2]) is not shuffle
assert torch.all(labels[:-1] <= labels[1:]) is not shuffle
def test_append_with_other_datapipes():
num_ids = 100
batch_size = 4
......@@ -723,6 +803,112 @@ def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last):
assert torch.all(negs_ids[:-1] <= negs_ids[1:]) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSetDict_seeds(batch_size, shuffle, drop_last):
# Node pairs.
num_ids = 103
total_pairs = 2 * num_ids
seeds_like = torch.arange(0, num_ids * 3).reshape(-1, 3)
seeds_follow = torch.arange(num_ids * 3, num_ids * 6).reshape(-1, 3)
seeds_dict = {
"user:like:item": gb.ItemSet(seeds_like, names="seeds"),
"user:follow:user": gb.ItemSet(seeds_follow, names="seeds"),
}
item_set = gb.ItemSetDict(seeds_dict)
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
seeds_ids = []
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seeds is not None
assert minibatch.labels is None
is_last = (i + 1) * batch_size >= total_pairs
if not is_last or total_pairs % batch_size == 0:
expected_batch_size = batch_size
else:
if not drop_last:
expected_batch_size = total_pairs % batch_size
else:
assert False
seeds_lst = []
for _, (seeds) in minibatch.seeds.items():
assert isinstance(seeds, torch.Tensor)
seeds_lst.append(seeds)
seeds_lst = torch.cat(seeds_lst)
assert seeds_lst.shape == (expected_batch_size, 3)
seeds_ids.append(seeds_lst)
assert torch.equal(seeds_lst[:, 0] + 1, seeds_lst[:, 1])
assert torch.equal(seeds_lst[:, 1] + 1, seeds_lst[:, 2])
seeds_ids = torch.cat(seeds_ids)
assert torch.all(seeds_ids[:-1, 0] <= seeds_ids[1:, 0]) is not shuffle
assert torch.all(seeds_ids[:-1, 1] <= seeds_ids[1:, 1]) is not shuffle
assert torch.all(seeds_ids[:-1, 2] <= seeds_ids[1:, 2]) is not shuffle
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("shuffle", [True, False])
@pytest.mark.parametrize("drop_last", [True, False])
def test_ItemSetDict_seeds_labels(batch_size, shuffle, drop_last):
# Node pairs and labels
num_ids = 103
total_ids = 2 * num_ids
seeds_like = torch.arange(0, num_ids * 3).reshape(-1, 3)
seeds_follow = torch.arange(num_ids * 3, num_ids * 6).reshape(-1, 3)
seeds_dict = {
"user:like:item": gb.ItemSet(
(seeds_like, seeds_like[:, 0]),
names=("seeds", "labels"),
),
"user:follow:user": gb.ItemSet(
(seeds_follow, seeds_follow[:, 0]),
names=("seeds", "labels"),
),
}
item_set = gb.ItemSetDict(seeds_dict)
item_sampler = gb.ItemSampler(
item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last
)
seeds_ids = []
labels = []
for i, minibatch in enumerate(item_sampler):
assert isinstance(minibatch, gb.MiniBatch)
assert minibatch.seeds is not None
assert minibatch.labels is not None
is_last = (i + 1) * batch_size >= total_ids
if not is_last or total_ids % batch_size == 0:
expected_batch_size = batch_size
else:
if not drop_last:
expected_batch_size = total_ids % batch_size
else:
assert False
seeds_lst = []
label = []
for _, seeds in minibatch.seeds.items():
assert isinstance(seeds, torch.Tensor)
seeds_lst.append(seeds)
for _, v_label in minibatch.labels.items():
label.append(v_label)
seeds_lst = torch.cat(seeds_lst)
label = torch.cat(label)
assert seeds_lst.shape == (expected_batch_size, 3)
assert len(label) == expected_batch_size
seeds_ids.append(seeds_lst)
labels.append(label)
assert torch.equal(seeds_lst[:, 0] + 1, seeds_lst[:, 1])
assert torch.equal(seeds_lst[:, 1] + 1, seeds_lst[:, 2])
assert torch.equal(seeds_lst[:, 0], label)
seeds_ids = torch.cat(seeds_ids)
labels = torch.cat(labels)
assert torch.all(seeds_ids[:-1, 0] <= seeds_ids[1:, 0]) is not shuffle
assert torch.all(seeds_ids[:-1, 1] <= seeds_ids[1:, 1]) is not shuffle
assert torch.all(seeds_ids[:-1, 2] <= seeds_ids[1:, 2]) is not shuffle
assert torch.all(labels[:-1] <= labels[1:]) is not shuffle
def distributed_item_sampler_subprocess(
proc_id,
nprocs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment