import os import re import unittest from sys import platform import backend as F import dgl import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp from dgl import graphbolt as gb def test_ItemSampler_minibatcher(): # Default minibatcher is used if not specified. # Warning message is raised if names are not specified. item_set = gb.ItemSet(torch.arange(0, 10)) item_sampler = gb.ItemSampler(item_set, batch_size=4) with pytest.warns( UserWarning, match=re.escape( "Failed to map item list to `MiniBatch` as the names of items are " "not provided. Please provide a customized `MiniBatcher`. The " "item list is returned as is." ), ): minibatch = next(iter(item_sampler)) assert not isinstance(minibatch, gb.MiniBatch) # Default minibatcher is used if not specified. # Warning message is raised if unrecognized names are specified. item_set = gb.ItemSet(torch.arange(0, 10), names="unknown_name") item_sampler = gb.ItemSampler(item_set, batch_size=4) with pytest.warns( UserWarning, match=re.escape( "Unknown item name 'unknown_name' is detected and added into " "`MiniBatch`. You probably need to provide a customized " "`MiniBatcher`." ), ): minibatch = next(iter(item_sampler)) assert isinstance(minibatch, gb.MiniBatch) assert minibatch.unknown_name is not None # Default minibatcher is used if not specified. # `MiniBatch` is returned if expected names are specified. item_set = gb.ItemSet(torch.arange(0, 10), names="seed_nodes") item_sampler = gb.ItemSampler(item_set, batch_size=4) minibatch = next(iter(item_sampler)) assert isinstance(minibatch, gb.MiniBatch) assert minibatch.seeds is not None assert len(minibatch.seeds) == 4 # Customized minibatcher is used if specified. def minibatcher(batch, names): return gb.MiniBatch(seeds=batch) item_sampler = gb.ItemSampler( item_set, batch_size=4, minibatcher=minibatcher ) minibatch = next(iter(item_sampler)) assert isinstance(minibatch, gb.MiniBatch) assert minibatch.seeds is not None assert len(minibatch.seeds) == 4 @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) def test_ItemSet_Iterable_Only(batch_size, shuffle, drop_last): num_ids = 103 class InvalidLength: def __iter__(self): return iter(torch.arange(0, num_ids)) seed_nodes = gb.ItemSet(InvalidLength()) item_set = gb.ItemSet(seed_nodes, names="seed_nodes") item_sampler = gb.ItemSampler( item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last ) minibatch_ids = [] for i, minibatch in enumerate(item_sampler): assert isinstance(minibatch, gb.MiniBatch) assert minibatch.seeds is not None assert minibatch.labels is None is_last = (i + 1) * batch_size >= num_ids if not is_last or num_ids % batch_size == 0: assert len(minibatch.seeds) == batch_size else: if not drop_last: assert len(minibatch.seeds) == num_ids % batch_size else: assert False minibatch_ids.append(minibatch.seeds) minibatch_ids = torch.cat(minibatch_ids) assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) def test_ItemSet_integer(batch_size, shuffle, drop_last): # Node IDs. num_ids = 103 item_set = gb.ItemSet(num_ids, names="seed_nodes") item_sampler = gb.ItemSampler( item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last ) minibatch_ids = [] for i, minibatch in enumerate(item_sampler): assert isinstance(minibatch, gb.MiniBatch) assert minibatch.seeds is not None assert minibatch.labels is None is_last = (i + 1) * batch_size >= num_ids if not is_last or num_ids % batch_size == 0: assert len(minibatch.seeds) == batch_size else: if not drop_last: assert len(minibatch.seeds) == num_ids % batch_size else: assert False minibatch_ids.append(minibatch.seeds) minibatch_ids = torch.cat(minibatch_ids) assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) def test_ItemSet_seed_nodes(batch_size, shuffle, drop_last): # Node IDs. num_ids = 103 seed_nodes = torch.arange(0, num_ids) item_set = gb.ItemSet(seed_nodes, names="seed_nodes") item_sampler = gb.ItemSampler( item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last ) minibatch_ids = [] for i, minibatch in enumerate(item_sampler): assert isinstance(minibatch, gb.MiniBatch) assert minibatch.seeds is not None assert minibatch.labels is None is_last = (i + 1) * batch_size >= num_ids if not is_last or num_ids % batch_size == 0: assert len(minibatch.seeds) == batch_size else: if not drop_last: assert len(minibatch.seeds) == num_ids % batch_size else: assert False minibatch_ids.append(minibatch.seeds) minibatch_ids = torch.cat(minibatch_ids) assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) def test_ItemSet_seed_nodes_labels(batch_size, shuffle, drop_last): # Node IDs. num_ids = 103 seed_nodes = torch.arange(0, num_ids) labels = torch.arange(0, num_ids) item_set = gb.ItemSet((seed_nodes, labels), names=("seed_nodes", "labels")) item_sampler = gb.ItemSampler( item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last ) minibatch_ids = [] minibatch_labels = [] for i, minibatch in enumerate(item_sampler): assert isinstance(minibatch, gb.MiniBatch) assert minibatch.seeds is not None assert minibatch.labels is not None assert len(minibatch.seeds) == len(minibatch.labels) is_last = (i + 1) * batch_size >= num_ids if not is_last or num_ids % batch_size == 0: assert len(minibatch.seeds) == batch_size else: if not drop_last: assert len(minibatch.seeds) == num_ids % batch_size else: assert False minibatch_ids.append(minibatch.seeds) minibatch_labels.append(minibatch.labels) minibatch_ids = torch.cat(minibatch_ids) minibatch_labels = torch.cat(minibatch_labels) assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle assert ( torch.all(minibatch_labels[:-1] <= minibatch_labels[1:]) is not shuffle ) @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) def test_ItemSet_graphs(batch_size, shuffle, drop_last): # Graphs. num_graphs = 103 num_nodes = 10 num_edges = 20 graphs = [ dgl.rand_graph(num_nodes * (i + 1), num_edges * (i + 1)) for i in range(num_graphs) ] item_set = gb.ItemSet(graphs, names="graphs") # DGLGraph is not supported in gb.MiniBatch yet. Let's use a customized # minibatcher to return the original graphs. customized_minibatcher = lambda batch, names: batch item_sampler = gb.ItemSampler( item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, minibatcher=customized_minibatcher, ) minibatch_num_nodes = [] minibatch_num_edges = [] for i, minibatch in enumerate(item_sampler): is_last = (i + 1) * batch_size >= num_graphs if not is_last or num_graphs % batch_size == 0: assert minibatch.batch_size == batch_size else: if not drop_last: assert minibatch.batch_size == num_graphs % batch_size else: assert False minibatch_num_nodes.append(minibatch.batch_num_nodes()) minibatch_num_edges.append(minibatch.batch_num_edges()) minibatch_num_nodes = torch.cat(minibatch_num_nodes) minibatch_num_edges = torch.cat(minibatch_num_edges) assert ( torch.all(minibatch_num_nodes[:-1] <= minibatch_num_nodes[1:]) is not shuffle ) assert ( torch.all(minibatch_num_edges[:-1] <= minibatch_num_edges[1:]) is not shuffle ) @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) def test_ItemSet_node_pairs(batch_size, shuffle, drop_last): # Node pairs. num_ids = 103 node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2) item_set = gb.ItemSet(node_pairs, names="node_pairs") item_sampler = gb.ItemSampler( item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last ) src_ids = [] dst_ids = [] for i, minibatch in enumerate(item_sampler): assert minibatch.seeds is not None assert isinstance(minibatch.seeds, torch.Tensor) assert minibatch.labels is None src, dst = minibatch.seeds.T is_last = (i + 1) * batch_size >= num_ids if not is_last or num_ids % batch_size == 0: expected_batch_size = batch_size else: if not drop_last: expected_batch_size = num_ids % batch_size else: assert False assert len(src) == expected_batch_size assert len(dst) == expected_batch_size # Verify src and dst IDs match. assert torch.equal(src + 1, dst) # Archive batch. src_ids.append(src) dst_ids.append(dst) src_ids = torch.cat(src_ids) dst_ids = torch.cat(dst_ids) assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) def test_ItemSet_node_pairs_labels(batch_size, shuffle, drop_last): # Node pairs and labels num_ids = 103 node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2) labels = node_pairs[:, 0] item_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels")) item_sampler = gb.ItemSampler( item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last ) src_ids = [] dst_ids = [] labels = [] for i, minibatch in enumerate(item_sampler): assert minibatch.seeds is not None assert isinstance(minibatch.seeds, torch.Tensor) assert minibatch.labels is not None src, dst = minibatch.seeds.T label = minibatch.labels assert len(src) == len(dst) assert len(src) == len(label) is_last = (i + 1) * batch_size >= num_ids if not is_last or num_ids % batch_size == 0: expected_batch_size = batch_size else: if not drop_last: expected_batch_size = num_ids % batch_size else: assert False assert len(src) == expected_batch_size assert len(dst) == expected_batch_size assert len(label) == expected_batch_size # Verify src/dst IDs and labels match. assert torch.equal(src + 1, dst) assert torch.equal(src, label) # Archive batch. src_ids.append(src) dst_ids.append(dst) labels.append(label) src_ids = torch.cat(src_ids) dst_ids = torch.cat(dst_ids) labels = torch.cat(labels) assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle assert torch.all(labels[:-1] <= labels[1:]) is not shuffle @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) def test_ItemSet_node_pairs_negative_dsts(batch_size, shuffle, drop_last): # Node pairs and negative destinations. num_ids = 103 num_negs = 2 node_pairs = torch.arange(0, 2 * num_ids).reshape(-1, 2) neg_dsts = torch.arange( 2 * num_ids, 2 * num_ids + num_ids * num_negs ).reshape(-1, num_negs) item_set = gb.ItemSet( (node_pairs, neg_dsts), names=("node_pairs", "negative_dsts") ) item_sampler = gb.ItemSampler( item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last ) src_ids = [] dst_ids = [] negs_ids = [] for i, minibatch in enumerate(item_sampler): assert minibatch.seeds is not None assert isinstance(minibatch.seeds, torch.Tensor) assert minibatch.labels is not None assert minibatch.indexes is not None src, dst = minibatch.seeds.T negs_src = src[~minibatch.labels.to(bool)] negs_dst = dst[~minibatch.labels.to(bool)] is_last = (i + 1) * batch_size >= num_ids if not is_last or num_ids % batch_size == 0: expected_batch_size = batch_size else: if not drop_last: expected_batch_size = num_ids % batch_size else: assert False assert len(src) == expected_batch_size * 3 assert len(dst) == expected_batch_size * 3 assert negs_src.dim() == 1 assert negs_dst.dim() == 1 assert len(negs_src) == expected_batch_size * 2 assert len(negs_dst) == expected_batch_size * 2 expected_indexes = torch.arange(expected_batch_size) expected_indexes = torch.cat( (expected_indexes, expected_indexes.repeat_interleave(2)) ) assert torch.equal(minibatch.indexes, expected_indexes) # Verify node pairs and negative destinations. assert torch.equal( src[minibatch.labels.to(bool)] + 1, dst[minibatch.labels.to(bool)] ) assert torch.equal((negs_dst - 2 * num_ids) // 2 * 2, negs_src) # Archive batch. src_ids.append(src) dst_ids.append(dst) negs_ids.append(negs_dst) src_ids = torch.cat(src_ids) dst_ids = torch.cat(dst_ids) negs_ids = torch.cat(negs_ids) assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle assert torch.all(negs_ids[:-1] <= negs_ids[1:]) is not shuffle @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) def test_ItemSet_seeds(batch_size, shuffle, drop_last): # Node pairs. num_ids = 103 seeds = torch.arange(0, 3 * num_ids).reshape(-1, 3) item_set = gb.ItemSet(seeds, names="seeds") item_sampler = gb.ItemSampler( item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last ) seeds_ids = [] for i, minibatch in enumerate(item_sampler): assert minibatch.seeds is not None assert isinstance(minibatch.seeds, torch.Tensor) assert minibatch.labels is None is_last = (i + 1) * batch_size >= num_ids if not is_last or num_ids % batch_size == 0: expected_batch_size = batch_size else: if not drop_last: expected_batch_size = num_ids % batch_size else: assert False assert minibatch.seeds.shape == (expected_batch_size, 3) # Verify seeds match. assert torch.equal(minibatch.seeds[:, 0] + 1, minibatch.seeds[:, 1]) assert torch.equal(minibatch.seeds[:, 1] + 1, minibatch.seeds[:, 2]) # Archive batch. seeds_ids.append(minibatch.seeds) seeds_ids = torch.cat(seeds_ids) assert torch.all(seeds_ids[:-1, 0] <= seeds_ids[1:, 0]) is not shuffle assert torch.all(seeds_ids[:-1, 1] <= seeds_ids[1:, 1]) is not shuffle assert torch.all(seeds_ids[:-1, 2] <= seeds_ids[1:, 2]) is not shuffle @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) def test_ItemSet_seeds_labels(batch_size, shuffle, drop_last): # Node pairs and labels num_ids = 103 seeds = torch.arange(0, 3 * num_ids).reshape(-1, 3) labels = seeds[:, 0] item_set = gb.ItemSet((seeds, labels), names=("seeds", "labels")) item_sampler = gb.ItemSampler( item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last ) seeds_ids = [] labels = [] for i, minibatch in enumerate(item_sampler): assert minibatch.seeds is not None assert isinstance(minibatch.seeds, torch.Tensor) assert minibatch.labels is not None label = minibatch.labels assert len(minibatch.seeds) == len(label) is_last = (i + 1) * batch_size >= num_ids if not is_last or num_ids % batch_size == 0: expected_batch_size = batch_size else: if not drop_last: expected_batch_size = num_ids % batch_size else: assert False assert minibatch.seeds.shape == (expected_batch_size, 3) assert len(label) == expected_batch_size # Verify seeds and labels match. assert torch.equal(minibatch.seeds[:, 0] + 1, minibatch.seeds[:, 1]) assert torch.equal(minibatch.seeds[:, 1] + 1, minibatch.seeds[:, 2]) # Archive batch. seeds_ids.append(minibatch.seeds) labels.append(label) seeds_ids = torch.cat(seeds_ids) labels = torch.cat(labels) assert torch.all(seeds_ids[:-1, 0] <= seeds_ids[1:, 0]) is not shuffle assert torch.all(seeds_ids[:-1, 1] <= seeds_ids[1:, 1]) is not shuffle assert torch.all(seeds_ids[:-1, 2] <= seeds_ids[1:, 2]) is not shuffle assert torch.all(labels[:-1] <= labels[1:]) is not shuffle def test_append_with_other_datapipes(): num_ids = 100 batch_size = 4 item_set = gb.ItemSet(torch.arange(0, num_ids), names="seed_nodes") data_pipe = gb.ItemSampler(item_set, batch_size) # torchdata.datapipes.iter.Enumerator data_pipe = data_pipe.enumerate() for i, (idx, data) in enumerate(data_pipe): assert i == idx assert len(data.seeds) == batch_size @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) def test_ItemSetDict_iterable_only(batch_size, shuffle, drop_last): class IterableOnly: def __init__(self, start, stop): self._start = start self._stop = stop def __iter__(self): return iter(torch.arange(self._start, self._stop)) num_ids = 205 ids = { "user": gb.ItemSet(IterableOnly(0, 99), names="seed_nodes"), "item": gb.ItemSet(IterableOnly(99, num_ids), names="seed_nodes"), } chained_ids = [] for key, value in ids.items(): chained_ids += [(key, v) for v in value] item_set = gb.ItemSetDict(ids) item_sampler = gb.ItemSampler( item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last ) minibatch_ids = [] for i, minibatch in enumerate(item_sampler): is_last = (i + 1) * batch_size >= num_ids if not is_last or num_ids % batch_size == 0: expected_batch_size = batch_size else: if not drop_last: expected_batch_size = num_ids % batch_size else: assert False assert isinstance(minibatch, gb.MiniBatch) assert minibatch.seeds is not None ids = [] for _, v in minibatch.seeds.items(): ids.append(v) ids = torch.cat(ids) assert len(ids) == expected_batch_size minibatch_ids.append(ids) minibatch_ids = torch.cat(minibatch_ids) assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) def test_ItemSetDict_seed_nodes(batch_size, shuffle, drop_last): # Node IDs. num_ids = 205 ids = { "user": gb.ItemSet(torch.arange(0, 99), names="seed_nodes"), "item": gb.ItemSet(torch.arange(99, num_ids), names="seed_nodes"), } chained_ids = [] for key, value in ids.items(): chained_ids += [(key, v) for v in value] item_set = gb.ItemSetDict(ids) item_sampler = gb.ItemSampler( item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last ) minibatch_ids = [] for i, minibatch in enumerate(item_sampler): is_last = (i + 1) * batch_size >= num_ids if not is_last or num_ids % batch_size == 0: expected_batch_size = batch_size else: if not drop_last: expected_batch_size = num_ids % batch_size else: assert False assert isinstance(minibatch, gb.MiniBatch) assert minibatch.seeds is not None ids = [] for _, v in minibatch.seeds.items(): ids.append(v) ids = torch.cat(ids) assert len(ids) == expected_batch_size minibatch_ids.append(ids) minibatch_ids = torch.cat(minibatch_ids) assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) def test_ItemSetDict_seed_nodes_labels(batch_size, shuffle, drop_last): # Node IDs. num_ids = 205 ids = { "user": gb.ItemSet( (torch.arange(0, 99), torch.arange(0, 99)), names=("seed_nodes", "labels"), ), "item": gb.ItemSet( (torch.arange(99, num_ids), torch.arange(99, num_ids)), names=("seed_nodes", "labels"), ), } chained_ids = [] for key, value in ids.items(): chained_ids += [(key, v) for v in value] item_set = gb.ItemSetDict(ids) item_sampler = gb.ItemSampler( item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last ) minibatch_ids = [] minibatch_labels = [] for i, minibatch in enumerate(item_sampler): assert isinstance(minibatch, gb.MiniBatch) assert minibatch.seeds is not None assert minibatch.labels is not None is_last = (i + 1) * batch_size >= num_ids if not is_last or num_ids % batch_size == 0: expected_batch_size = batch_size else: if not drop_last: expected_batch_size = num_ids % batch_size else: assert False ids = [] for _, v in minibatch.seeds.items(): ids.append(v) ids = torch.cat(ids) assert len(ids) == expected_batch_size minibatch_ids.append(ids) labels = [] for _, v in minibatch.labels.items(): labels.append(v) labels = torch.cat(labels) assert len(labels) == expected_batch_size minibatch_labels.append(labels) minibatch_ids = torch.cat(minibatch_ids) minibatch_labels = torch.cat(minibatch_labels) assert torch.all(minibatch_ids[:-1] <= minibatch_ids[1:]) is not shuffle assert ( torch.all(minibatch_labels[:-1] <= minibatch_labels[1:]) is not shuffle ) @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) def test_ItemSetDict_node_pairs(batch_size, shuffle, drop_last): # Node pairs. num_ids = 103 total_pairs = 2 * num_ids node_pairs_like = torch.arange(0, num_ids * 2).reshape(-1, 2) node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2) node_pairs_dict = { "user:like:item": gb.ItemSet(node_pairs_like, names="node_pairs"), "user:follow:user": gb.ItemSet(node_pairs_follow, names="node_pairs"), } item_set = gb.ItemSetDict(node_pairs_dict) item_sampler = gb.ItemSampler( item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last ) src_ids = [] dst_ids = [] for i, minibatch in enumerate(item_sampler): assert isinstance(minibatch, gb.MiniBatch) assert minibatch.seeds is not None assert minibatch.labels is None is_last = (i + 1) * batch_size >= total_pairs if not is_last or total_pairs % batch_size == 0: expected_batch_size = batch_size else: if not drop_last: expected_batch_size = total_pairs % batch_size else: assert False src = [] dst = [] for _, (seeds) in minibatch.seeds.items(): assert isinstance(seeds, torch.Tensor) src.append(seeds[:, 0]) dst.append(seeds[:, 1]) src = torch.cat(src) dst = torch.cat(dst) assert len(src) == expected_batch_size assert len(dst) == expected_batch_size src_ids.append(src) dst_ids.append(dst) assert torch.equal(src + 1, dst) src_ids = torch.cat(src_ids) dst_ids = torch.cat(dst_ids) assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) def test_ItemSetDict_node_pairs_labels(batch_size, shuffle, drop_last): # Node pairs and labels num_ids = 103 total_ids = 2 * num_ids node_pairs_like = torch.arange(0, num_ids * 2).reshape(-1, 2) node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2) labels = torch.arange(0, num_ids) node_pairs_dict = { "user:like:item": gb.ItemSet( (node_pairs_like, node_pairs_like[:, 0]), names=("node_pairs", "labels"), ), "user:follow:user": gb.ItemSet( (node_pairs_follow, node_pairs_follow[:, 0]), names=("node_pairs", "labels"), ), } item_set = gb.ItemSetDict(node_pairs_dict) item_sampler = gb.ItemSampler( item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last ) src_ids = [] dst_ids = [] labels = [] for i, minibatch in enumerate(item_sampler): assert isinstance(minibatch, gb.MiniBatch) assert minibatch.seeds is not None assert minibatch.labels is not None assert minibatch.negative_dsts is None is_last = (i + 1) * batch_size >= total_ids if not is_last or total_ids % batch_size == 0: expected_batch_size = batch_size else: if not drop_last: expected_batch_size = total_ids % batch_size else: assert False src = [] dst = [] label = [] for _, seeds in minibatch.seeds.items(): assert isinstance(seeds, torch.Tensor) src.append(seeds[:, 0]) dst.append(seeds[:, 1]) for _, v_label in minibatch.labels.items(): label.append(v_label) src = torch.cat(src) dst = torch.cat(dst) label = torch.cat(label) assert len(src) == expected_batch_size assert len(dst) == expected_batch_size assert len(label) == expected_batch_size src_ids.append(src) dst_ids.append(dst) labels.append(label) assert torch.equal(src + 1, dst) assert torch.equal(src, label) src_ids = torch.cat(src_ids) dst_ids = torch.cat(dst_ids) labels = torch.cat(labels) assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle assert torch.all(labels[:-1] <= labels[1:]) is not shuffle @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) def test_ItemSetDict_node_pairs_negative_dsts(batch_size, shuffle, drop_last): # Head, tail and negative tails. num_ids = 103 total_ids = 2 * num_ids num_negs = 2 node_paris_like = torch.arange(0, num_ids * 2).reshape(-1, 2) node_pairs_follow = torch.arange(num_ids * 2, num_ids * 4).reshape(-1, 2) neg_dsts_like = torch.arange( num_ids * 4, num_ids * 4 + num_ids * num_negs ).reshape(-1, num_negs) neg_dsts_follow = torch.arange( num_ids * 4 + num_ids * num_negs, num_ids * 4 + num_ids * num_negs * 2 ).reshape(-1, num_negs) data_dict = { "user:like:item": gb.ItemSet( (node_paris_like, neg_dsts_like), names=("node_pairs", "negative_dsts"), ), "user:follow:user": gb.ItemSet( (node_pairs_follow, neg_dsts_follow), names=("node_pairs", "negative_dsts"), ), } item_set = gb.ItemSetDict(data_dict) item_sampler = gb.ItemSampler( item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last ) src_ids = [] dst_ids = [] negs_ids = [] for i, minibatch in enumerate(item_sampler): assert isinstance(minibatch, gb.MiniBatch) assert minibatch.seeds is not None assert minibatch.labels is not None assert minibatch.negative_dsts is None is_last = (i + 1) * batch_size >= total_ids if not is_last or total_ids % batch_size == 0: expected_batch_size = batch_size else: if not drop_last: expected_batch_size = total_ids % batch_size else: assert False src = [] dst = [] negs_src = [] negs_dst = [] for etype, seeds in minibatch.seeds.items(): assert isinstance(seeds, torch.Tensor) src_etype = seeds[:, 0] dst_etype = seeds[:, 1] src.append(src_etype[minibatch.labels[etype].to(bool)]) dst.append(dst_etype[minibatch.labels[etype].to(bool)]) negs_src.append(src_etype[~minibatch.labels[etype].to(bool)]) negs_dst.append(dst_etype[~minibatch.labels[etype].to(bool)]) src = torch.cat(src) dst = torch.cat(dst) negs_src = torch.cat(negs_src) negs_dst = torch.cat(negs_dst) assert len(src) == expected_batch_size assert len(dst) == expected_batch_size assert len(negs_src) == expected_batch_size * 2 assert len(negs_dst) == expected_batch_size * 2 src_ids.append(src) dst_ids.append(dst) negs_ids.append(negs_dst) assert negs_src.dim() == 1 assert negs_dst.dim() == 1 assert torch.equal(src + 1, dst) assert torch.equal(negs_src, (negs_dst - num_ids * 4) // 2 * 2) src_ids = torch.cat(src_ids) dst_ids = torch.cat(dst_ids) negs_ids = torch.cat(negs_ids) assert torch.all(src_ids[:-1] <= src_ids[1:]) is not shuffle assert torch.all(dst_ids[:-1] <= dst_ids[1:]) is not shuffle assert torch.all(negs_ids <= negs_ids) is not shuffle @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) def test_ItemSetDict_seeds(batch_size, shuffle, drop_last): # Node pairs. num_ids = 103 total_pairs = 2 * num_ids seeds_like = torch.arange(0, num_ids * 3).reshape(-1, 3) seeds_follow = torch.arange(num_ids * 3, num_ids * 6).reshape(-1, 3) seeds_dict = { "user:like:item": gb.ItemSet(seeds_like, names="seeds"), "user:follow:user": gb.ItemSet(seeds_follow, names="seeds"), } item_set = gb.ItemSetDict(seeds_dict) item_sampler = gb.ItemSampler( item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last ) seeds_ids = [] for i, minibatch in enumerate(item_sampler): assert isinstance(minibatch, gb.MiniBatch) assert minibatch.seeds is not None assert minibatch.labels is None assert minibatch.indexes is None is_last = (i + 1) * batch_size >= total_pairs if not is_last or total_pairs % batch_size == 0: expected_batch_size = batch_size else: if not drop_last: expected_batch_size = total_pairs % batch_size else: assert False seeds_lst = [] for _, (seeds) in minibatch.seeds.items(): assert isinstance(seeds, torch.Tensor) seeds_lst.append(seeds) seeds_lst = torch.cat(seeds_lst) assert seeds_lst.shape == (expected_batch_size, 3) seeds_ids.append(seeds_lst) assert torch.equal(seeds_lst[:, 0] + 1, seeds_lst[:, 1]) assert torch.equal(seeds_lst[:, 1] + 1, seeds_lst[:, 2]) seeds_ids = torch.cat(seeds_ids) assert torch.all(seeds_ids[:-1, 0] <= seeds_ids[1:, 0]) is not shuffle assert torch.all(seeds_ids[:-1, 1] <= seeds_ids[1:, 1]) is not shuffle assert torch.all(seeds_ids[:-1, 2] <= seeds_ids[1:, 2]) is not shuffle @pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize("shuffle", [True, False]) @pytest.mark.parametrize("drop_last", [True, False]) def test_ItemSetDict_seeds_labels(batch_size, shuffle, drop_last): # Node pairs and labels num_ids = 103 total_ids = 2 * num_ids seeds_like = torch.arange(0, num_ids * 3).reshape(-1, 3) seeds_follow = torch.arange(num_ids * 3, num_ids * 6).reshape(-1, 3) seeds_dict = { "user:like:item": gb.ItemSet( (seeds_like, seeds_like[:, 0]), names=("seeds", "labels"), ), "user:follow:user": gb.ItemSet( (seeds_follow, seeds_follow[:, 0]), names=("seeds", "labels"), ), } item_set = gb.ItemSetDict(seeds_dict) item_sampler = gb.ItemSampler( item_set, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last ) seeds_ids = [] labels = [] for i, minibatch in enumerate(item_sampler): assert isinstance(minibatch, gb.MiniBatch) assert minibatch.seeds is not None assert minibatch.labels is not None assert minibatch.indexes is None is_last = (i + 1) * batch_size >= total_ids if not is_last or total_ids % batch_size == 0: expected_batch_size = batch_size else: if not drop_last: expected_batch_size = total_ids % batch_size else: assert False seeds_lst = [] label = [] for _, seeds in minibatch.seeds.items(): assert isinstance(seeds, torch.Tensor) seeds_lst.append(seeds) for _, v_label in minibatch.labels.items(): label.append(v_label) seeds_lst = torch.cat(seeds_lst) label = torch.cat(label) assert seeds_lst.shape == (expected_batch_size, 3) assert len(label) == expected_batch_size seeds_ids.append(seeds_lst) labels.append(label) assert torch.equal(seeds_lst[:, 0] + 1, seeds_lst[:, 1]) assert torch.equal(seeds_lst[:, 1] + 1, seeds_lst[:, 2]) assert torch.equal(seeds_lst[:, 0], label) seeds_ids = torch.cat(seeds_ids) labels = torch.cat(labels) assert torch.all(seeds_ids[:-1, 0] <= seeds_ids[1:, 0]) is not shuffle assert torch.all(seeds_ids[:-1, 1] <= seeds_ids[1:, 1]) is not shuffle assert torch.all(seeds_ids[:-1, 2] <= seeds_ids[1:, 2]) is not shuffle assert torch.all(labels[:-1] <= labels[1:]) is not shuffle def distributed_item_sampler_subprocess( proc_id, nprocs, item_set, num_ids, num_workers, batch_size, drop_last, drop_uneven_inputs, ): # On Windows, the init method can only be file. init_method = ( f"file:///{os.path.join(os.getcwd(), 'dis_tempfile')}" if platform == "win32" else "tcp://127.0.0.1:12345" ) dist.init_process_group( backend="gloo", # Use Gloo backend for CPU multiprocessing init_method=init_method, world_size=nprocs, rank=proc_id, ) # Create a DistributedItemSampler. item_sampler = gb.DistributedItemSampler( item_set, batch_size=batch_size, shuffle=True, drop_last=drop_last, drop_uneven_inputs=drop_uneven_inputs, ) feature_fetcher = gb.FeatureFetcher( item_sampler, gb.BasicFeatureStore({}), [], ) data_loader = gb.DataLoader(feature_fetcher, num_workers=num_workers) # Count the numbers of items and batches. num_items = 0 sampled_count = torch.zeros(num_ids, dtype=torch.int32) for i in data_loader: # Count how many times each item is sampled. sampled_count[i.seeds] += 1 if drop_last: assert i.seeds.size(0) == batch_size num_items += i.seeds.size(0) num_batches = len(list(item_sampler)) if drop_uneven_inputs: num_batches_tensor = torch.tensor(num_batches) dist.broadcast(num_batches_tensor, 0) # Test if the number of batches are the same for all processes. assert num_batches_tensor == num_batches # Add up results from all processes. dist.reduce(sampled_count, 0) try: # Make sure no item is sampled more than once. assert sampled_count.max() <= 1 finally: dist.destroy_process_group() @pytest.mark.parametrize( "params", [ ((24, 4, 0, 4, False, False), [(8, 8), (8, 8), (4, 4), (4, 4)]), ((30, 4, 0, 4, False, False), [(8, 8), (8, 8), (8, 8), (6, 6)]), ((30, 4, 0, 4, True, False), [(8, 8), (8, 8), (8, 8), (6, 4)]), ((30, 4, 0, 4, False, True), [(8, 8), (8, 8), (8, 8), (6, 6)]), ((30, 4, 0, 4, True, True), [(8, 4), (8, 4), (8, 4), (6, 4)]), ( (53, 4, 2, 4, False, False), [(8, 8), (8, 8), (8, 8), (5, 5), (8, 8), (4, 4), (8, 8), (4, 4)], ), ( (53, 4, 2, 4, True, False), [(8, 8), (8, 8), (9, 8), (4, 4), (8, 8), (4, 4), (8, 8), (4, 4)], ), ( (53, 4, 2, 4, False, True), [(10, 8), (6, 4), (9, 8), (4, 4), (8, 8), (4, 4), (8, 8), (4, 4)], ), ( (53, 4, 2, 4, True, True), [(10, 8), (6, 4), (9, 8), (4, 4), (8, 8), (4, 4), (8, 8), (4, 4)], ), ( (63, 4, 2, 4, False, False), [(8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (7, 7)], ), ( (63, 4, 2, 4, True, False), [(8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (10, 8), (5, 4)], ), ( (63, 4, 2, 4, False, True), [(8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (7, 7)], ), ( (63, 4, 2, 4, True, True), [ (10, 8), (6, 4), (10, 8), (6, 4), (10, 8), (6, 4), (10, 8), (5, 4), ], ), ( (65, 4, 2, 4, False, False), [(9, 9), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8)], ), ( (65, 4, 2, 4, True, True), [(9, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8), (8, 8)], ), ], ) def test_RangeCalculation(params): ( ( total, num_replicas, num_workers, batch_size, drop_last, drop_uneven_inputs, ), key, ) = params answer = [] sum = 0 for rank in range(num_replicas): for worker_id in range(max(num_workers, 1)): result = gb.internal.calculate_range( True, total, num_replicas, rank, num_workers, worker_id, batch_size, drop_last, drop_uneven_inputs, ) assert sum == result[0] sum += result[1] answer.append((result[1], result[2])) assert key == answer @unittest.skipIf(F._default_context_str != "cpu", reason="GPU not required.") @pytest.mark.parametrize("num_ids", [24, 30, 32, 34, 36]) @pytest.mark.parametrize("num_workers", [0, 2]) @pytest.mark.parametrize("drop_last", [False, True]) @pytest.mark.parametrize("drop_uneven_inputs", [False, True]) def test_DistributedItemSampler( num_ids, num_workers, drop_last, drop_uneven_inputs ): nprocs = 4 batch_size = 4 item_set = gb.ItemSet(torch.arange(0, num_ids), names="seed_nodes") # On Windows, if the process group initialization file already exists, # the program may hang. So we need to delete it if it exists. if platform == "win32": try: os.remove(os.path.join(os.getcwd(), "dis_tempfile")) except FileNotFoundError: pass mp.spawn( distributed_item_sampler_subprocess, args=( nprocs, item_set, num_ids, num_workers, batch_size, drop_last, drop_uneven_inputs, ), nprocs=nprocs, join=True, )