Unverified Commit 6133cecd authored by Xinyu Yao's avatar Xinyu Yao Committed by GitHub
Browse files

[GraphBolt] Update tests related to `seeds` refactor. (#7352)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent ba9c1521
...@@ -81,7 +81,7 @@ def test_InSubgraphSampler_homo(): ...@@ -81,7 +81,7 @@ def test_InSubgraphSampler_homo():
graph = gb.fused_csc_sampling_graph(indptr, indices).to(F.ctx()) graph = gb.fused_csc_sampling_graph(indptr, indices).to(F.ctx())
seed_nodes = torch.LongTensor([0, 5, 3]) seed_nodes = torch.LongTensor([0, 5, 3])
item_set = gb.ItemSet(seed_nodes, names="seed_nodes") item_set = gb.ItemSet(seed_nodes, names="seeds")
batch_size = 1 batch_size = 1
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to( item_sampler = gb.ItemSampler(item_set, batch_size=batch_size).copy_to(
F.ctx() F.ctx()
...@@ -162,8 +162,8 @@ def test_InSubgraphSampler_hetero(): ...@@ -162,8 +162,8 @@ def test_InSubgraphSampler_hetero():
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
{ {
"N0": gb.ItemSet(torch.LongTensor([1, 0, 2]), names="seed_nodes"), "N0": gb.ItemSet(torch.LongTensor([1, 0, 2]), names="seeds"),
"N1": gb.ItemSet(torch.LongTensor([0, 2, 1]), names="seed_nodes"), "N1": gb.ItemSet(torch.LongTensor([0, 2, 1]), names="seeds"),
} }
) )
batch_size = 2 batch_size = 2
......
...@@ -47,7 +47,7 @@ def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted): ...@@ -47,7 +47,7 @@ def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted):
items = torch.arange(3) items = torch.arange(3)
else: else:
items = torch.tensor([2, 0, 1]) items = torch.tensor([2, 0, 1])
names = "seed_nodes" names = "seeds"
itemset = gb.ItemSet(items, names=names) itemset = gb.ItemSet(items, names=names)
graph = get_hetero_graph().to(F.ctx()) graph = get_hetero_graph().to(F.ctx())
if hetero: if hetero:
...@@ -94,9 +94,7 @@ def test_labor_dependent_minibatching(layer_dependency, overlap_graph_fetch): ...@@ -94,9 +94,7 @@ def test_labor_dependent_minibatching(layer_dependency, overlap_graph_fetch):
).to(F.ctx()) ).to(F.ctx())
torch.random.set_rng_state(torch.manual_seed(123).get_state()) torch.random.set_rng_state(torch.manual_seed(123).get_state())
batch_dependency = 100 batch_dependency = 100
itemset = gb.ItemSet( itemset = gb.ItemSet(torch.zeros(batch_dependency + 1).int(), names="seeds")
torch.zeros(batch_dependency + 1).int(), names="seed_nodes"
)
datapipe = gb.ItemSampler(itemset, batch_size=1).copy_to(F.ctx()) datapipe = gb.ItemSampler(itemset, batch_size=1).copy_to(F.ctx())
fanouts = [5, 5] fanouts = [5, 5]
datapipe = datapipe.sample_layer_neighbor( datapipe = datapipe.sample_layer_neighbor(
......
...@@ -96,7 +96,7 @@ def test_OnDiskDataset_multiple_tasks(): ...@@ -96,7 +96,7 @@ def test_OnDiskDataset_multiple_tasks():
train_set: train_set:
- type: null - type: null
data: data:
- name: seed_nodes - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {train_ids_path} path: {train_ids_path}
...@@ -112,7 +112,7 @@ def test_OnDiskDataset_multiple_tasks(): ...@@ -112,7 +112,7 @@ def test_OnDiskDataset_multiple_tasks():
train_set: train_set:
- type: null - type: null
data: data:
- name: seed_nodes - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {train_ids_path} path: {train_ids_path}
...@@ -140,7 +140,7 @@ def test_OnDiskDataset_multiple_tasks(): ...@@ -140,7 +140,7 @@ def test_OnDiskDataset_multiple_tasks():
for i, (id, label, _) in enumerate(train_set): for i, (id, label, _) in enumerate(train_set):
assert id == train_ids[i] assert id == train_ids[i]
assert label == train_labels[i] assert label == train_labels[i]
assert train_set.names == ("seed_nodes", "labels", None) assert train_set.names == ("seeds", "labels", None)
train_set = None train_set = None
dataset = None dataset = None
...@@ -162,7 +162,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_names(): ...@@ -162,7 +162,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_names():
train_set: train_set:
- type: null - type: null
data: data:
- name: seed_nodes - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {train_ids_path} path: {train_ids_path}
...@@ -183,7 +183,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_names(): ...@@ -183,7 +183,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_names():
for i, (id, label, _) in enumerate(train_set): for i, (id, label, _) in enumerate(train_set):
assert id == train_ids[i] assert id == train_ids[i]
assert label == train_labels[i] assert label == train_labels[i]
assert train_set.names == ("seed_nodes", "labels", None) assert train_set.names == ("seeds", "labels", None)
train_set = None train_set = None
...@@ -204,7 +204,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_names(): ...@@ -204,7 +204,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_names():
train_set: train_set:
- type: "author:writes:paper" - type: "author:writes:paper"
data: data:
- name: seed_nodes - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {train_ids_path} path: {train_ids_path}
...@@ -228,7 +228,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_names(): ...@@ -228,7 +228,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_names():
id, label, _ = item["author:writes:paper"] id, label, _ = item["author:writes:paper"]
assert id == train_ids[i] assert id == train_ids[i]
assert label == train_labels[i] assert label == train_labels[i]
assert train_set.names == ("seed_nodes", "labels", None) assert train_set.names == ("seeds", "labels", None)
train_set = None train_set = None
...@@ -267,7 +267,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -267,7 +267,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
train_set: train_set:
- type: null - type: null
data: data:
- name: seed_nodes - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {train_ids_path} path: {train_ids_path}
...@@ -277,7 +277,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -277,7 +277,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
path: {train_labels_path} path: {train_labels_path}
validation_set: validation_set:
- data: - data:
- name: seed_nodes - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {validation_ids_path} path: {validation_ids_path}
...@@ -288,7 +288,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -288,7 +288,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
test_set: test_set:
- type: null - type: null
data: data:
- name: seed_nodes - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {test_ids_path} path: {test_ids_path}
...@@ -311,7 +311,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -311,7 +311,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
for i, (id, label) in enumerate(train_set): for i, (id, label) in enumerate(train_set):
assert id == train_ids[i] assert id == train_ids[i]
assert label == train_labels[i] assert label == train_labels[i]
assert train_set.names == ("seed_nodes", "labels") assert train_set.names == ("seeds", "labels")
train_set = None train_set = None
# Verify validation set. # Verify validation set.
...@@ -321,7 +321,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -321,7 +321,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
for i, (id, label) in enumerate(validation_set): for i, (id, label) in enumerate(validation_set):
assert id == validation_ids[i] assert id == validation_ids[i]
assert label == validation_labels[i] assert label == validation_labels[i]
assert validation_set.names == ("seed_nodes", "labels") assert validation_set.names == ("seeds", "labels")
validation_set = None validation_set = None
# Verify test set. # Verify test set.
...@@ -331,7 +331,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -331,7 +331,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
for i, (id, label) in enumerate(test_set): for i, (id, label) in enumerate(test_set):
assert id == test_ids[i] assert id == test_ids[i]
assert label == test_labels[i] assert label == test_labels[i]
assert test_set.names == ("seed_nodes", "labels") assert test_set.names == ("seeds", "labels")
test_set = None test_set = None
dataset = None dataset = None
...@@ -355,25 +355,23 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -355,25 +355,23 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_labels(): def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_labels():
"""Test TVTSet which returns ItemSet with node pairs and labels.""" """Test TVTSet which returns ItemSet with node pairs and labels."""
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
train_node_pairs = np.arange(2000).reshape(1000, 2) train_seeds = np.arange(2000).reshape(1000, 2)
train_node_pairs_path = os.path.join(test_dir, "train_node_pairs.npy") train_seeds_path = os.path.join(test_dir, "train_seeds.npy")
np.save(train_node_pairs_path, train_node_pairs) np.save(train_seeds_path, train_seeds)
train_labels = np.random.randint(0, 10, size=1000) train_labels = np.random.randint(0, 10, size=1000)
train_labels_path = os.path.join(test_dir, "train_labels.npy") train_labels_path = os.path.join(test_dir, "train_labels.npy")
np.save(train_labels_path, train_labels) np.save(train_labels_path, train_labels)
validation_node_pairs = np.arange(2000, 4000).reshape(1000, 2) validation_seeds = np.arange(2000, 4000).reshape(1000, 2)
validation_node_pairs_path = os.path.join( validation_seeds_path = os.path.join(test_dir, "validation_seeds.npy")
test_dir, "validation_node_pairs.npy" np.save(validation_seeds_path, validation_seeds)
)
np.save(validation_node_pairs_path, validation_node_pairs)
validation_labels = np.random.randint(0, 10, size=1000) validation_labels = np.random.randint(0, 10, size=1000)
validation_labels_path = os.path.join(test_dir, "validation_labels.npy") validation_labels_path = os.path.join(test_dir, "validation_labels.npy")
np.save(validation_labels_path, validation_labels) np.save(validation_labels_path, validation_labels)
test_node_pairs = np.arange(4000, 6000).reshape(1000, 2) test_seeds = np.arange(4000, 6000).reshape(1000, 2)
test_node_pairs_path = os.path.join(test_dir, "test_node_pairs.npy") test_seeds_path = os.path.join(test_dir, "test_seeds.npy")
np.save(test_node_pairs_path, test_node_pairs) np.save(test_seeds_path, test_seeds)
test_labels = np.random.randint(0, 10, size=1000) test_labels = np.random.randint(0, 10, size=1000)
test_labels_path = os.path.join(test_dir, "test_labels.npy") test_labels_path = os.path.join(test_dir, "test_labels.npy")
np.save(test_labels_path, test_labels) np.save(test_labels_path, test_labels)
...@@ -384,20 +382,20 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_labels(): ...@@ -384,20 +382,20 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_labels():
train_set: train_set:
- type: null - type: null
data: data:
- name: node_pairs - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {train_node_pairs_path} path: {train_seeds_path}
- name: labels - name: labels
format: numpy format: numpy
in_memory: true in_memory: true
path: {train_labels_path} path: {train_labels_path}
validation_set: validation_set:
- data: - data:
- name: node_pairs - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {validation_node_pairs_path} path: {validation_seeds_path}
- name: labels - name: labels
format: numpy format: numpy
in_memory: true in_memory: true
...@@ -405,10 +403,10 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_labels(): ...@@ -405,10 +403,10 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_labels():
test_set: test_set:
- type: null - type: null
data: data:
- name: node_pairs - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {test_node_pairs_path} path: {test_seeds_path}
- name: labels - name: labels
format: numpy format: numpy
in_memory: true in_memory: true
...@@ -421,10 +419,10 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_labels(): ...@@ -421,10 +419,10 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_labels():
assert len(train_set) == 1000 assert len(train_set) == 1000
assert isinstance(train_set, gb.ItemSet) assert isinstance(train_set, gb.ItemSet)
for i, (node_pair, label) in enumerate(train_set): for i, (node_pair, label) in enumerate(train_set):
assert node_pair[0] == train_node_pairs[i][0] assert node_pair[0] == train_seeds[i][0]
assert node_pair[1] == train_node_pairs[i][1] assert node_pair[1] == train_seeds[i][1]
assert label == train_labels[i] assert label == train_labels[i]
assert train_set.names == ("node_pairs", "labels") assert train_set.names == ("seeds", "labels")
train_set = None train_set = None
# Verify validation set. # Verify validation set.
...@@ -432,10 +430,10 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_labels(): ...@@ -432,10 +430,10 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_labels():
assert len(validation_set) == 1000 assert len(validation_set) == 1000
assert isinstance(validation_set, gb.ItemSet) assert isinstance(validation_set, gb.ItemSet)
for i, (node_pair, label) in enumerate(validation_set): for i, (node_pair, label) in enumerate(validation_set):
assert node_pair[0] == validation_node_pairs[i][0] assert node_pair[0] == validation_seeds[i][0]
assert node_pair[1] == validation_node_pairs[i][1] assert node_pair[1] == validation_seeds[i][1]
assert label == validation_labels[i] assert label == validation_labels[i]
assert validation_set.names == ("node_pairs", "labels") assert validation_set.names == ("seeds", "labels")
validation_set = None validation_set = None
# Verify test set. # Verify test set.
...@@ -443,43 +441,69 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_labels(): ...@@ -443,43 +441,69 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_labels():
assert len(test_set) == 1000 assert len(test_set) == 1000
assert isinstance(test_set, gb.ItemSet) assert isinstance(test_set, gb.ItemSet)
for i, (node_pair, label) in enumerate(test_set): for i, (node_pair, label) in enumerate(test_set):
assert node_pair[0] == test_node_pairs[i][0] assert node_pair[0] == test_seeds[i][0]
assert node_pair[1] == test_node_pairs[i][1] assert node_pair[1] == test_seeds[i][1]
assert label == test_labels[i] assert label == test_labels[i]
assert test_set.names == ("node_pairs", "labels") assert test_set.names == ("seeds", "labels")
test_set = None test_set = None
dataset = None dataset = None
def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_negs(): def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_labels_indexes():
"""Test TVTSet which returns ItemSet with node pairs and negative ones.""" """Test TVTSet which returns ItemSet with node pairs and negative ones."""
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
train_node_pairs = np.arange(2000).reshape(1000, 2) train_seeds = np.arange(2000).reshape(1000, 2)
train_node_pairs_path = os.path.join(test_dir, "train_node_pairs.npy") train_neg_dst = np.random.choice(1000 * 10, size=1000 * 10)
np.save(train_node_pairs_path, train_node_pairs) train_neg_src = train_seeds[:, 0].repeat(10)
train_neg_dst = np.random.choice(1000 * 10, size=1000 * 10).reshape( train_neg_seeds = (
1000, 10 np.concatenate((train_neg_dst, train_neg_src)).reshape(2, -1).T
) )
train_neg_dst_path = os.path.join(test_dir, "train_neg_dst.npy") train_seeds = np.concatenate((train_seeds, train_neg_seeds))
np.save(train_neg_dst_path, train_neg_dst) train_seeds_path = os.path.join(test_dir, "train_seeds.npy")
np.save(train_seeds_path, train_seeds)
validation_node_pairs = np.arange(2000, 4000).reshape(1000, 2)
validation_node_pairs_path = os.path.join( train_labels = torch.empty(1000 * 11)
test_dir, "validation_node_pairs.npy" train_labels[:1000] = 1
) train_labels[1000:] = 0
np.save(validation_node_pairs_path, validation_node_pairs) train_labels_path = os.path.join(test_dir, "train_labels.pt")
validation_neg_dst = train_neg_dst + 1 torch.save(train_labels, train_labels_path)
validation_neg_dst_path = os.path.join(
test_dir, "validation_neg_dst.npy" train_indexes = torch.arange(0, 1000)
) train_indexes = np.concatenate(
np.save(validation_neg_dst_path, validation_neg_dst) (train_indexes, train_indexes.repeat_interleave(10))
)
test_node_pairs = np.arange(4000, 6000).reshape(1000, 2) train_indexes_path = os.path.join(test_dir, "train_indexes.pt")
test_node_pairs_path = os.path.join(test_dir, "test_node_pairs.npy") torch.save(train_indexes, train_indexes_path)
np.save(test_node_pairs_path, test_node_pairs)
test_neg_dst = train_neg_dst + 2 validation_seeds = np.arange(2000, 4000).reshape(1000, 2)
test_neg_dst_path = os.path.join(test_dir, "test_neg_dst.npy") validation_neg_seeds = train_neg_seeds + 1
np.save(test_neg_dst_path, test_neg_dst) validation_seeds = np.concatenate(
(validation_seeds, validation_neg_seeds)
)
validation_seeds_path = os.path.join(test_dir, "validation_seeds.npy")
np.save(validation_seeds_path, validation_seeds)
validation_labels = train_labels
validation_labels_path = os.path.join(test_dir, "validation_labels.pt")
torch.save(validation_labels, validation_labels_path)
validation_indexes = train_indexes
validation_indexes_path = os.path.join(
test_dir, "validation_indexes.pt"
)
torch.save(validation_indexes, validation_indexes_path)
test_seeds = np.arange(4000, 6000).reshape(1000, 2)
test_neg_seeds = train_neg_seeds + 2
test_seeds = np.concatenate((test_seeds, test_neg_seeds))
test_seeds_path = os.path.join(test_dir, "test_seeds.npy")
np.save(test_seeds_path, test_seeds)
test_labels = train_labels
test_labels_path = os.path.join(test_dir, "test_labels.pt")
torch.save(test_labels, test_labels_path)
test_indexes = train_indexes
test_indexes_path = os.path.join(test_dir, "test_indexes.pt")
torch.save(test_indexes, test_indexes_path)
yaml_content = f""" yaml_content = f"""
tasks: tasks:
...@@ -487,69 +511,83 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_negs(): ...@@ -487,69 +511,83 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pairs_negs():
train_set: train_set:
- type: null - type: null
data: data:
- name: node_pairs - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {train_node_pairs_path} path: {train_seeds_path}
- name: negative_dsts - name: labels
format: numpy format: torch
in_memory: true
path: {train_labels_path}
- name: indexes
format: torch
in_memory: true in_memory: true
path: {train_neg_dst_path} path: {train_indexes_path}
validation_set: validation_set:
- data: - data:
- name: node_pairs - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {validation_node_pairs_path} path: {validation_seeds_path}
- name: negative_dsts - name: labels
format: numpy format: torch
in_memory: true in_memory: true
path: {validation_neg_dst_path} path: {validation_labels_path}
- name: indexes
format: torch
in_memory: true
path: {validation_indexes_path}
test_set: test_set:
- type: null - type: null
data: data:
- name: node_pairs - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {test_node_pairs_path} path: {test_seeds_path}
- name: negative_dsts - name: labels
format: numpy format: torch
in_memory: true in_memory: true
path: {test_neg_dst_path} path: {test_labels_path}
- name: indexes
format: torch
in_memory: true
path: {test_indexes_path}
""" """
dataset = write_yaml_and_load_dataset(yaml_content, test_dir) dataset = write_yaml_and_load_dataset(yaml_content, test_dir)
# Verify train set. # Verify train set.
train_set = dataset.tasks[0].train_set train_set = dataset.tasks[0].train_set
assert len(train_set) == 1000 assert len(train_set) == 1000 * 11
assert isinstance(train_set, gb.ItemSet) assert isinstance(train_set, gb.ItemSet)
for i, (node_pair, negs) in enumerate(train_set): for i, (node_pair, label, index) in enumerate(train_set):
assert node_pair[0] == train_node_pairs[i][0] assert node_pair[0] == train_seeds[i][0]
assert node_pair[1] == train_node_pairs[i][1] assert node_pair[1] == train_seeds[i][1]
assert torch.equal(negs, torch.from_numpy(train_neg_dst[i])) assert label == train_labels[i]
assert train_set.names == ("node_pairs", "negative_dsts") assert index == train_indexes[i]
assert train_set.names == ("seeds", "labels", "indexes")
train_set = None train_set = None
# Verify validation set. # Verify validation set.
validation_set = dataset.tasks[0].validation_set validation_set = dataset.tasks[0].validation_set
assert len(validation_set) == 1000 assert len(validation_set) == 1000 * 11
assert isinstance(validation_set, gb.ItemSet) assert isinstance(validation_set, gb.ItemSet)
for i, (node_pair, negs) in enumerate(validation_set): for i, (node_pair, label, index) in enumerate(validation_set):
assert node_pair[0] == validation_node_pairs[i][0] assert node_pair[0] == validation_seeds[i][0]
assert node_pair[1] == validation_node_pairs[i][1] assert node_pair[1] == validation_seeds[i][1]
assert torch.equal(negs, torch.from_numpy(validation_neg_dst[i])) assert label == validation_labels[i]
assert validation_set.names == ("node_pairs", "negative_dsts") assert index == validation_indexes[i]
assert validation_set.names == ("seeds", "labels", "indexes")
validation_set = None validation_set = None
# Verify test set. # Verify test set.
test_set = dataset.tasks[0].test_set test_set = dataset.tasks[0].test_set
assert len(test_set) == 1000 assert len(test_set) == 1000 * 11
assert isinstance(test_set, gb.ItemSet) assert isinstance(test_set, gb.ItemSet)
for i, (node_pair, negs) in enumerate(test_set): for i, (node_pair, label, index) in enumerate(test_set):
assert node_pair[0] == test_node_pairs[i][0] assert node_pair[0] == test_seeds[i][0]
assert node_pair[1] == test_node_pairs[i][1] assert label == test_labels[i]
assert torch.equal(negs, torch.from_numpy(test_neg_dst[i])) assert index == test_indexes[i]
assert test_set.names == ("node_pairs", "negative_dsts") assert test_set.names == ("seeds", "labels", "indexes")
test_set = None test_set = None
dataset = None dataset = None
...@@ -581,36 +619,36 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -581,36 +619,36 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
train_set: train_set:
- type: paper - type: paper
data: data:
- name: seed_nodes - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {train_path} path: {train_path}
- type: author - type: author
data: data:
- name: seed_nodes - name: seeds
format: numpy format: numpy
path: {train_path} path: {train_path}
validation_set: validation_set:
- type: paper - type: paper
data: data:
- name: seed_nodes - name: seeds
format: numpy format: numpy
path: {validation_path} path: {validation_path}
- type: author - type: author
data: data:
- name: seed_nodes - name: seeds
format: numpy format: numpy
path: {validation_path} path: {validation_path}
test_set: test_set:
- type: paper - type: paper
data: data:
- name: seed_nodes - name: seeds
format: numpy format: numpy
in_memory: false in_memory: false
path: {test_path} path: {test_path}
- type: author - type: author
data: data:
- name: seed_nodes - name: seeds
format: numpy format: numpy
path: {test_path} path: {test_path}
""" """
...@@ -628,7 +666,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -628,7 +666,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
id, label = item[key] id, label = item[key]
assert id == train_ids[i % 1000] assert id == train_ids[i % 1000]
assert label == train_labels[i % 1000] assert label == train_labels[i % 1000]
assert train_set.names == ("seed_nodes",) assert train_set.names == ("seeds",)
train_set = None train_set = None
# Verify validation set. # Verify validation set.
...@@ -643,7 +681,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -643,7 +681,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
id, label = item[key] id, label = item[key]
assert id == validation_ids[i % 1000] assert id == validation_ids[i % 1000]
assert label == validation_labels[i % 1000] assert label == validation_labels[i % 1000]
assert validation_set.names == ("seed_nodes",) assert validation_set.names == ("seeds",)
validation_set = None validation_set = None
# Verify test set. # Verify test set.
...@@ -658,7 +696,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -658,7 +696,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
id, label = item[key] id, label = item[key]
assert id == test_ids[i % 1000] assert id == test_ids[i % 1000]
assert label == test_labels[i % 1000] assert label == test_labels[i % 1000]
assert test_set.names == ("seed_nodes",) assert test_set.names == ("seeds",)
test_set = None test_set = None
dataset = None dataset = None
...@@ -666,25 +704,23 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -666,25 +704,23 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
def test_OnDiskDataset_TVTSet_ItemSetDict_node_pairs_labels(): def test_OnDiskDataset_TVTSet_ItemSetDict_node_pairs_labels():
"""Test TVTSet which returns ItemSetDict with node pairs and labels.""" """Test TVTSet which returns ItemSetDict with node pairs and labels."""
with tempfile.TemporaryDirectory() as test_dir: with tempfile.TemporaryDirectory() as test_dir:
train_node_pairs = np.arange(2000).reshape(1000, 2) train_seeds = np.arange(2000).reshape(1000, 2)
train_node_pairs_path = os.path.join(test_dir, "train_node_pairs.npy") train_seeds_path = os.path.join(test_dir, "train_seeds.npy")
np.save(train_node_pairs_path, train_node_pairs) np.save(train_seeds_path, train_seeds)
train_labels = np.random.randint(0, 10, size=1000) train_labels = np.random.randint(0, 10, size=1000)
train_labels_path = os.path.join(test_dir, "train_labels.npy") train_labels_path = os.path.join(test_dir, "train_labels.npy")
np.save(train_labels_path, train_labels) np.save(train_labels_path, train_labels)
validation_node_pairs = np.arange(2000, 4000).reshape(1000, 2) validation_seeds = np.arange(2000, 4000).reshape(1000, 2)
validation_node_pairs_path = os.path.join( validation_seeds_path = os.path.join(test_dir, "validation_seeds.npy")
test_dir, "validation_node_pairs.npy" np.save(validation_seeds_path, validation_seeds)
)
np.save(validation_node_pairs_path, validation_node_pairs)
validation_labels = np.random.randint(0, 10, size=1000) validation_labels = np.random.randint(0, 10, size=1000)
validation_labels_path = os.path.join(test_dir, "validation_labels.npy") validation_labels_path = os.path.join(test_dir, "validation_labels.npy")
np.save(validation_labels_path, validation_labels) np.save(validation_labels_path, validation_labels)
test_node_pairs = np.arange(4000, 6000).reshape(1000, 2) test_seeds = np.arange(4000, 6000).reshape(1000, 2)
test_node_pairs_path = os.path.join(test_dir, "test_node_pairs.npy") test_seeds_path = os.path.join(test_dir, "test_seeds.npy")
np.save(test_node_pairs_path, test_node_pairs) np.save(test_seeds_path, test_seeds)
test_labels = np.random.randint(0, 10, size=1000) test_labels = np.random.randint(0, 10, size=1000)
test_labels_path = os.path.join(test_dir, "test_labels.npy") test_labels_path = os.path.join(test_dir, "test_labels.npy")
np.save(test_labels_path, test_labels) np.save(test_labels_path, test_labels)
...@@ -695,56 +731,56 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pairs_labels(): ...@@ -695,56 +731,56 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pairs_labels():
train_set: train_set:
- type: paper:cites:paper - type: paper:cites:paper
data: data:
- name: node_pairs - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {train_node_pairs_path} path: {train_seeds_path}
- name: labels - name: labels
format: numpy format: numpy
in_memory: true in_memory: true
path: {train_labels_path} path: {train_labels_path}
- type: author:writes:paper - type: author:writes:paper
data: data:
- name: node_pairs - name: seeds
format: numpy format: numpy
path: {train_node_pairs_path} path: {train_seeds_path}
- name: labels - name: labels
format: numpy format: numpy
path: {train_labels_path} path: {train_labels_path}
validation_set: validation_set:
- type: paper:cites:paper - type: paper:cites:paper
data: data:
- name: node_pairs - name: seeds
format: numpy format: numpy
path: {validation_node_pairs_path} path: {validation_seeds_path}
- name: labels - name: labels
format: numpy format: numpy
path: {validation_labels_path} path: {validation_labels_path}
- type: author:writes:paper - type: author:writes:paper
data: data:
- name: node_pairs - name: seeds
format: numpy format: numpy
path: {validation_node_pairs_path} path: {validation_seeds_path}
- name: labels - name: labels
format: numpy format: numpy
path: {validation_labels_path} path: {validation_labels_path}
test_set: test_set:
- type: paper:cites:paper - type: paper:cites:paper
data: data:
- name: node_pairs - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {test_node_pairs_path} path: {test_seeds_path}
- name: labels - name: labels
format: numpy format: numpy
in_memory: true in_memory: true
path: {test_labels_path} path: {test_labels_path}
- type: author:writes:paper - type: author:writes:paper
data: data:
- name: node_pairs - name: seeds
format: numpy format: numpy
in_memory: true in_memory: true
path: {test_node_pairs_path} path: {test_seeds_path}
- name: labels - name: labels
format: numpy format: numpy
in_memory: true in_memory: true
...@@ -762,10 +798,10 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pairs_labels(): ...@@ -762,10 +798,10 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pairs_labels():
key = list(item.keys())[0] key = list(item.keys())[0]
assert key in ["paper:cites:paper", "author:writes:paper"] assert key in ["paper:cites:paper", "author:writes:paper"]
node_pair, label = item[key] node_pair, label = item[key]
assert node_pair[0] == train_node_pairs[i % 1000][0] assert node_pair[0] == train_seeds[i % 1000][0]
assert node_pair[1] == train_node_pairs[i % 1000][1] assert node_pair[1] == train_seeds[i % 1000][1]
assert label == train_labels[i % 1000] assert label == train_labels[i % 1000]
assert train_set.names == ("node_pairs", "labels") assert train_set.names == ("seeds", "labels")
train_set = None train_set = None
# Verify validation set. # Verify validation set.
...@@ -778,10 +814,10 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pairs_labels(): ...@@ -778,10 +814,10 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pairs_labels():
key = list(item.keys())[0] key = list(item.keys())[0]
assert key in ["paper:cites:paper", "author:writes:paper"] assert key in ["paper:cites:paper", "author:writes:paper"]
node_pair, label = item[key] node_pair, label = item[key]
assert node_pair[0] == validation_node_pairs[i % 1000][0] assert node_pair[0] == validation_seeds[i % 1000][0]
assert node_pair[1] == validation_node_pairs[i % 1000][1] assert node_pair[1] == validation_seeds[i % 1000][1]
assert label == validation_labels[i % 1000] assert label == validation_labels[i % 1000]
assert validation_set.names == ("node_pairs", "labels") assert validation_set.names == ("seeds", "labels")
validation_set = None validation_set = None
# Verify test set. # Verify test set.
...@@ -794,10 +830,10 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pairs_labels(): ...@@ -794,10 +830,10 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_node_pairs_labels():
key = list(item.keys())[0] key = list(item.keys())[0]
assert key in ["paper:cites:paper", "author:writes:paper"] assert key in ["paper:cites:paper", "author:writes:paper"]
node_pair, label = item[key] node_pair, label = item[key]
assert node_pair[0] == test_node_pairs[i % 1000][0] assert node_pair[0] == test_seeds[i % 1000][0]
assert node_pair[1] == test_node_pairs[i % 1000][1] assert node_pair[1] == test_seeds[i % 1000][1]
assert label == test_labels[i % 1000] assert label == test_labels[i % 1000]
assert test_set.names == ("node_pairs", "labels") assert test_set.names == ("seeds", "labels")
test_set = None test_set = None
dataset = None dataset = None
...@@ -1294,21 +1330,21 @@ def test_OnDiskDataset_preprocess_homogeneous_hardcode( ...@@ -1294,21 +1330,21 @@ def test_OnDiskDataset_preprocess_homogeneous_hardcode(
f" train_set:\n" f" train_set:\n"
f" - type: null\n" f" - type: null\n"
f" data:\n" f" data:\n"
f" - name: node_pairs\n" f" - name: seeds\n"
f" format: numpy\n" f" format: numpy\n"
f" in_memory: true\n" f" in_memory: true\n"
f" path: {train_path}\n" f" path: {train_path}\n"
f" validation_set:\n" f" validation_set:\n"
f" - type: null\n" f" - type: null\n"
f" data:\n" f" data:\n"
f" - name: node_pairs\n" f" - name: seeds\n"
f" format: numpy\n" f" format: numpy\n"
f" in_memory: true\n" f" in_memory: true\n"
f" path: {valid_path}\n" f" path: {valid_path}\n"
f" test_set:\n" f" test_set:\n"
f" - type: null\n" f" - type: null\n"
f" data:\n" f" data:\n"
f" - name: node_pairs\n" f" - name: seeds\n"
f" format: numpy\n" f" format: numpy\n"
f" in_memory: true\n" f" in_memory: true\n"
f" path: {test_path}\n" f" path: {test_path}\n"
...@@ -2856,22 +2892,22 @@ def test_OnDiskDataset_auto_force_preprocess(capsys): ...@@ -2856,22 +2892,22 @@ def test_OnDiskDataset_auto_force_preprocess(capsys):
def test_OnDiskTask_repr_homogeneous(): def test_OnDiskTask_repr_homogeneous():
item_set = gb.ItemSet( item_set = gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)), (torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"), names=("seeds", "labels"),
) )
metadata = {"name": "node_classification"} metadata = {"name": "node_classification"}
task = gb.OnDiskTask(metadata, item_set, item_set, item_set) task = gb.OnDiskTask(metadata, item_set, item_set, item_set)
expected_str = ( expected_str = (
"OnDiskTask(validation_set=ItemSet(\n" "OnDiskTask(validation_set=ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n" " items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
" names=('seed_nodes', 'labels'),\n" " names=('seeds', 'labels'),\n"
" ),\n" " ),\n"
" train_set=ItemSet(\n" " train_set=ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n" " items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
" names=('seed_nodes', 'labels'),\n" " names=('seeds', 'labels'),\n"
" ),\n" " ),\n"
" test_set=ItemSet(\n" " test_set=ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n" " items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
" names=('seed_nodes', 'labels'),\n" " names=('seeds', 'labels'),\n"
" ),\n" " ),\n"
" metadata={'name': 'node_classification'},)" " metadata={'name': 'node_classification'},)"
) )
...@@ -2908,8 +2944,8 @@ def test_OnDiskDataset_not_include_eids(): ...@@ -2908,8 +2944,8 @@ def test_OnDiskDataset_not_include_eids():
def test_OnDiskTask_repr_heterogeneous(): def test_OnDiskTask_repr_heterogeneous():
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
{ {
"user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"), "user": gb.ItemSet(torch.arange(0, 5), names="seeds"),
"item": gb.ItemSet(torch.arange(5, 10), names="seed_nodes"), "item": gb.ItemSet(torch.arange(5, 10), names="seeds"),
} }
) )
metadata = {"name": "node_classification"} metadata = {"name": "node_classification"}
...@@ -2918,32 +2954,32 @@ def test_OnDiskTask_repr_heterogeneous(): ...@@ -2918,32 +2954,32 @@ def test_OnDiskTask_repr_heterogeneous():
"OnDiskTask(validation_set=ItemSetDict(\n" "OnDiskTask(validation_set=ItemSetDict(\n"
" itemsets={'user': ItemSet(\n" " itemsets={'user': ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]),),\n" " items=(tensor([0, 1, 2, 3, 4]),),\n"
" names=('seed_nodes',),\n" " names=('seeds',),\n"
" ), 'item': ItemSet(\n" " ), 'item': ItemSet(\n"
" items=(tensor([5, 6, 7, 8, 9]),),\n" " items=(tensor([5, 6, 7, 8, 9]),),\n"
" names=('seed_nodes',),\n" " names=('seeds',),\n"
" )},\n" " )},\n"
" names=('seed_nodes',),\n" " names=('seeds',),\n"
" ),\n" " ),\n"
" train_set=ItemSetDict(\n" " train_set=ItemSetDict(\n"
" itemsets={'user': ItemSet(\n" " itemsets={'user': ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]),),\n" " items=(tensor([0, 1, 2, 3, 4]),),\n"
" names=('seed_nodes',),\n" " names=('seeds',),\n"
" ), 'item': ItemSet(\n" " ), 'item': ItemSet(\n"
" items=(tensor([5, 6, 7, 8, 9]),),\n" " items=(tensor([5, 6, 7, 8, 9]),),\n"
" names=('seed_nodes',),\n" " names=('seeds',),\n"
" )},\n" " )},\n"
" names=('seed_nodes',),\n" " names=('seeds',),\n"
" ),\n" " ),\n"
" test_set=ItemSetDict(\n" " test_set=ItemSetDict(\n"
" itemsets={'user': ItemSet(\n" " itemsets={'user': ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]),),\n" " items=(tensor([0, 1, 2, 3, 4]),),\n"
" names=('seed_nodes',),\n" " names=('seeds',),\n"
" ), 'item': ItemSet(\n" " ), 'item': ItemSet(\n"
" items=(tensor([5, 6, 7, 8, 9]),),\n" " items=(tensor([5, 6, 7, 8, 9]),),\n"
" names=('seed_nodes',),\n" " names=('seeds',),\n"
" )},\n" " )},\n"
" names=('seed_nodes',),\n" " names=('seeds',),\n"
" ),\n" " ),\n"
" metadata={'name': 'node_classification'},)" " metadata={'name': 'node_classification'},)"
) )
......
...@@ -25,7 +25,7 @@ def test_FeatureFetcher_invoke(): ...@@ -25,7 +25,7 @@ def test_FeatureFetcher_invoke():
features[keys[1]] = gb.TorchBasedFeature(b) features[keys[1]] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features) feature_store = gb.BasicFeatureStore(features)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes") itemset = gb.ItemSet(torch.arange(10), names="seeds")
item_sampler = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
...@@ -58,7 +58,7 @@ def test_FeatureFetcher_homo(): ...@@ -58,7 +58,7 @@ def test_FeatureFetcher_homo():
features[keys[1]] = gb.TorchBasedFeature(b) features[keys[1]] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features) feature_store = gb.BasicFeatureStore(features)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes") itemset = gb.ItemSet(torch.arange(10), names="seeds")
item_sampler = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
num_layer = 2 num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)] fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
...@@ -104,7 +104,7 @@ def test_FeatureFetcher_with_edges_homo(): ...@@ -104,7 +104,7 @@ def test_FeatureFetcher_with_edges_homo():
features[keys[1]] = gb.TorchBasedFeature(b) features[keys[1]] = gb.TorchBasedFeature(b)
feature_store = gb.BasicFeatureStore(features) feature_store = gb.BasicFeatureStore(features)
itemset = gb.ItemSet(torch.arange(10), names="seed_nodes") itemset = gb.ItemSet(torch.arange(10), names="seeds")
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids) converter_dp = Mapper(item_sampler_dp, add_node_and_edge_ids)
fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, ["a"], ["b"]) fetcher_dp = gb.FeatureFetcher(converter_dp, feature_store, ["a"], ["b"])
...@@ -152,8 +152,8 @@ def test_FeatureFetcher_hetero(): ...@@ -152,8 +152,8 @@ def test_FeatureFetcher_hetero():
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
"n1": gb.ItemSet(torch.LongTensor([0, 1]), names="seed_nodes"), "n1": gb.ItemSet(torch.LongTensor([0, 1]), names="seeds"),
"n2": gb.ItemSet(torch.LongTensor([0, 1, 2]), names="seed_nodes"), "n2": gb.ItemSet(torch.LongTensor([0, 1, 2]), names="seeds"),
} }
) )
item_sampler = gb.ItemSampler(itemset, batch_size=2) item_sampler = gb.ItemSampler(itemset, batch_size=2)
...@@ -215,7 +215,7 @@ def test_FeatureFetcher_with_edges_hetero(): ...@@ -215,7 +215,7 @@ def test_FeatureFetcher_with_edges_hetero():
itemset = gb.ItemSetDict( itemset = gb.ItemSetDict(
{ {
"n1": gb.ItemSet(torch.randint(0, 20, (10,)), names="seed_nodes"), "n1": gb.ItemSet(torch.randint(0, 20, (10,)), names="seeds"),
} }
) )
item_sampler_dp = gb.ItemSampler(itemset, batch_size=2) item_sampler_dp = gb.ItemSampler(itemset, batch_size=2)
......
...@@ -1161,7 +1161,7 @@ def test_DistributedItemSampler( ...@@ -1161,7 +1161,7 @@ def test_DistributedItemSampler(
): ):
nprocs = 4 nprocs = 4
batch_size = 4 batch_size = 4
item_set = gb.ItemSet(torch.arange(0, num_ids), names="seed_nodes") item_set = gb.ItemSet(torch.arange(0, num_ids), names="seeds")
# On Windows, if the process group initialization file already exists, # On Windows, if the process group initialization file already exists,
# the program may hang. So we need to delete it if it exists. # the program may hang. So we need to delete it if it exists.
......
...@@ -8,15 +8,15 @@ from dgl import graphbolt as gb ...@@ -8,15 +8,15 @@ from dgl import graphbolt as gb
def test_ItemSet_names(): def test_ItemSet_names():
# ItemSet with single name. # ItemSet with single name.
item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes") item_set = gb.ItemSet(torch.arange(0, 5), names="seeds")
assert item_set.names == ("seed_nodes",) assert item_set.names == ("seeds",)
# ItemSet with multiple names. # ItemSet with multiple names.
item_set = gb.ItemSet( item_set = gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)), (torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"), names=("seeds", "labels"),
) )
assert item_set.names == ("seed_nodes", "labels") assert item_set.names == ("seeds", "labels")
# ItemSet without name. # ItemSet without name.
item_set = gb.ItemSet(torch.arange(0, 5)) item_set = gb.ItemSet(torch.arange(0, 5))
...@@ -27,19 +27,19 @@ def test_ItemSet_names(): ...@@ -27,19 +27,19 @@ def test_ItemSet_names():
AssertionError, AssertionError,
match=re.escape("Number of items (1) and names (2) must match."), match=re.escape("Number of items (1) and names (2) must match."),
): ):
_ = gb.ItemSet(5, names=("seed_nodes", "labels")) _ = gb.ItemSet(5, names=("seeds", "labels"))
# ItemSet with mismatched items and names. # ItemSet with mismatched items and names.
with pytest.raises( with pytest.raises(
AssertionError, AssertionError,
match=re.escape("Number of items (1) and names (2) must match."), match=re.escape("Number of items (1) and names (2) must match."),
): ):
_ = gb.ItemSet(torch.arange(0, 5), names=("seed_nodes", "labels")) _ = gb.ItemSet(torch.arange(0, 5), names=("seeds", "labels"))
@pytest.mark.parametrize("dtype", [torch.int32, torch.int64]) @pytest.mark.parametrize("dtype", [torch.int32, torch.int64])
def test_ItemSet_scalar_dtype(dtype): def test_ItemSet_scalar_dtype(dtype):
item_set = gb.ItemSet(torch.tensor(5, dtype=dtype), names="seed_nodes") item_set = gb.ItemSet(torch.tensor(5, dtype=dtype), names="seeds")
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert i == item assert i == item
assert item.dtype == dtype assert item.dtype == dtype
...@@ -106,8 +106,8 @@ def test_ItemSet_length(): ...@@ -106,8 +106,8 @@ def test_ItemSet_length():
def test_ItemSet_seed_nodes(): def test_ItemSet_seed_nodes():
# Node IDs with tensor. # Node IDs with tensor.
item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes") item_set = gb.ItemSet(torch.arange(0, 5), names="seeds")
assert item_set.names == ("seed_nodes",) assert item_set.names == ("seeds",)
# Iterating over ItemSet and indexing one by one. # Iterating over ItemSet and indexing one by one.
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert i == item.item() assert i == item.item()
...@@ -118,8 +118,8 @@ def test_ItemSet_seed_nodes(): ...@@ -118,8 +118,8 @@ def test_ItemSet_seed_nodes():
assert torch.equal(item_set[torch.arange(0, 5)], torch.arange(0, 5)) assert torch.equal(item_set[torch.arange(0, 5)], torch.arange(0, 5))
# Node IDs with single integer. # Node IDs with single integer.
item_set = gb.ItemSet(5, names="seed_nodes") item_set = gb.ItemSet(5, names="seeds")
assert item_set.names == ("seed_nodes",) assert item_set.names == ("seeds",)
# Iterating over ItemSet and indexing one by one. # Iterating over ItemSet and indexing one by one.
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert i == item.item() assert i == item.item()
...@@ -145,8 +145,8 @@ def test_ItemSet_seed_nodes_labels(): ...@@ -145,8 +145,8 @@ def test_ItemSet_seed_nodes_labels():
# Node IDs and labels. # Node IDs and labels.
seed_nodes = torch.arange(0, 5) seed_nodes = torch.arange(0, 5)
labels = torch.randint(0, 3, (5,)) labels = torch.randint(0, 3, (5,))
item_set = gb.ItemSet((seed_nodes, labels), names=("seed_nodes", "labels")) item_set = gb.ItemSet((seed_nodes, labels), names=("seeds", "labels"))
assert item_set.names == ("seed_nodes", "labels") assert item_set.names == ("seeds", "labels")
# Iterating over ItemSet and indexing one by one. # Iterating over ItemSet and indexing one by one.
for i, (seed_node, label) in enumerate(item_set): for i, (seed_node, label) in enumerate(item_set):
assert seed_node == seed_nodes[i] assert seed_node == seed_nodes[i]
...@@ -164,8 +164,8 @@ def test_ItemSet_seed_nodes_labels(): ...@@ -164,8 +164,8 @@ def test_ItemSet_seed_nodes_labels():
def test_ItemSet_node_pairs(): def test_ItemSet_node_pairs():
# Node pairs. # Node pairs.
node_pairs = torch.arange(0, 10).reshape(-1, 2) node_pairs = torch.arange(0, 10).reshape(-1, 2)
item_set = gb.ItemSet(node_pairs, names="node_pairs") item_set = gb.ItemSet(node_pairs, names="seeds")
assert item_set.names == ("node_pairs",) assert item_set.names == ("seeds",)
# Iterating over ItemSet and indexing one by one. # Iterating over ItemSet and indexing one by one.
for i, (src, dst) in enumerate(item_set): for i, (src, dst) in enumerate(item_set):
assert node_pairs[i][0] == src assert node_pairs[i][0] == src
...@@ -182,8 +182,8 @@ def test_ItemSet_node_pairs_labels(): ...@@ -182,8 +182,8 @@ def test_ItemSet_node_pairs_labels():
# Node pairs and labels # Node pairs and labels
node_pairs = torch.arange(0, 10).reshape(-1, 2) node_pairs = torch.arange(0, 10).reshape(-1, 2)
labels = torch.randint(0, 3, (5,)) labels = torch.randint(0, 3, (5,))
item_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels")) item_set = gb.ItemSet((node_pairs, labels), names=("seeds", "labels"))
assert item_set.names == ("node_pairs", "labels") assert item_set.names == ("seeds", "labels")
# Iterating over ItemSet and indexing one by one. # Iterating over ItemSet and indexing one by one.
for i, (node_pair, label) in enumerate(item_set): for i, (node_pair, label) in enumerate(item_set):
assert torch.equal(node_pairs[i], node_pair) assert torch.equal(node_pairs[i], node_pair)
...@@ -198,26 +198,31 @@ def test_ItemSet_node_pairs_labels(): ...@@ -198,26 +198,31 @@ def test_ItemSet_node_pairs_labels():
assert torch.equal(item_set[torch.arange(0, 5)][1], labels) assert torch.equal(item_set[torch.arange(0, 5)][1], labels)
def test_ItemSet_node_pairs_neg_dsts(): def test_ItemSet_node_pairs_labels_indexes():
# Node pairs and negative destinations. # Node pairs and negative destinations.
node_pairs = torch.arange(0, 10).reshape(-1, 2) node_pairs = torch.arange(0, 10).reshape(-1, 2)
neg_dsts = torch.arange(10, 25).reshape(-1, 3) labels = torch.tensor([1, 1, 0, 0, 0])
indexes = torch.tensor([0, 1, 0, 0, 1])
item_set = gb.ItemSet( item_set = gb.ItemSet(
(node_pairs, neg_dsts), names=("node_pairs", "negative_dsts") (node_pairs, labels, indexes), names=("seeds", "labels", "indexes")
) )
assert item_set.names == ("node_pairs", "negative_dsts") assert item_set.names == ("seeds", "labels", "indexes")
# Iterating over ItemSet and indexing one by one. # Iterating over ItemSet and indexing one by one.
for i, (node_pair, neg_dst) in enumerate(item_set): for i, (node_pair, label, index) in enumerate(item_set):
assert torch.equal(node_pairs[i], node_pair) assert torch.equal(node_pairs[i], node_pair)
assert torch.equal(neg_dsts[i], neg_dst) assert torch.equal(labels[i], label)
assert torch.equal(indexes[i], index)
assert torch.equal(node_pairs[i], item_set[i][0]) assert torch.equal(node_pairs[i], item_set[i][0])
assert torch.equal(neg_dsts[i], item_set[i][1]) assert torch.equal(labels[i], item_set[i][1])
assert torch.equal(indexes[i], item_set[i][2])
# Indexing with a slice. # Indexing with a slice.
assert torch.equal(item_set[:][0], node_pairs) assert torch.equal(item_set[:][0], node_pairs)
assert torch.equal(item_set[:][1], neg_dsts) assert torch.equal(item_set[:][1], labels)
assert torch.equal(item_set[:][2], indexes)
# Indexing with an Iterable. # Indexing with an Iterable.
assert torch.equal(item_set[torch.arange(0, 5)][0], node_pairs) assert torch.equal(item_set[torch.arange(0, 5)][0], node_pairs)
assert torch.equal(item_set[torch.arange(0, 5)][1], neg_dsts) assert torch.equal(item_set[torch.arange(0, 5)][1], labels)
assert torch.equal(item_set[torch.arange(0, 5)][2], indexes)
def test_ItemSet_graphs(): def test_ItemSet_graphs():
...@@ -237,26 +242,26 @@ def test_ItemSetDict_names(): ...@@ -237,26 +242,26 @@ def test_ItemSetDict_names():
# ItemSetDict with single name. # ItemSetDict with single name.
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
{ {
"user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"), "user": gb.ItemSet(torch.arange(0, 5), names="seeds"),
"item": gb.ItemSet(torch.arange(5, 10), names="seed_nodes"), "item": gb.ItemSet(torch.arange(5, 10), names="seeds"),
} }
) )
assert item_set.names == ("seed_nodes",) assert item_set.names == ("seeds",)
# ItemSetDict with multiple names. # ItemSetDict with multiple names.
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
{ {
"user": gb.ItemSet( "user": gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)), (torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"), names=("seeds", "labels"),
), ),
"item": gb.ItemSet( "item": gb.ItemSet(
(torch.arange(5, 10), torch.arange(10, 15)), (torch.arange(5, 10), torch.arange(10, 15)),
names=("seed_nodes", "labels"), names=("seeds", "labels"),
), ),
} }
) )
assert item_set.names == ("seed_nodes", "labels") assert item_set.names == ("seeds", "labels")
# ItemSetDict with no name. # ItemSetDict with no name.
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
...@@ -276,11 +281,9 @@ def test_ItemSetDict_names(): ...@@ -276,11 +281,9 @@ def test_ItemSetDict_names():
{ {
"user": gb.ItemSet( "user": gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)), (torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"), names=("seeds", "labels"),
),
"item": gb.ItemSet(
(torch.arange(5, 10),), names=("seed_nodes",)
), ),
"item": gb.ItemSet((torch.arange(5, 10),), names=("seeds",)),
} }
) )
...@@ -354,14 +357,14 @@ def test_ItemSetDict_iteration_seed_nodes(): ...@@ -354,14 +357,14 @@ def test_ItemSetDict_iteration_seed_nodes():
user_ids = torch.arange(0, 5) user_ids = torch.arange(0, 5)
item_ids = torch.arange(5, 10) item_ids = torch.arange(5, 10)
ids = { ids = {
"user": gb.ItemSet(user_ids, names="seed_nodes"), "user": gb.ItemSet(user_ids, names="seeds"),
"item": gb.ItemSet(item_ids, names="seed_nodes"), "item": gb.ItemSet(item_ids, names="seeds"),
} }
chained_ids = [] chained_ids = []
for key, value in ids.items(): for key, value in ids.items():
chained_ids += [(key, v) for v in value] chained_ids += [(key, v) for v in value]
item_set = gb.ItemSetDict(ids) item_set = gb.ItemSetDict(ids)
assert item_set.names == ("seed_nodes",) assert item_set.names == ("seeds",)
# Iterating over ItemSetDict and indexing one by one. # Iterating over ItemSetDict and indexing one by one.
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
...@@ -413,18 +416,14 @@ def test_ItemSetDict_iteration_seed_nodes_labels(): ...@@ -413,18 +416,14 @@ def test_ItemSetDict_iteration_seed_nodes_labels():
item_ids = torch.arange(5, 10) item_ids = torch.arange(5, 10)
item_labels = torch.randint(0, 3, (5,)) item_labels = torch.randint(0, 3, (5,))
ids_labels = { ids_labels = {
"user": gb.ItemSet( "user": gb.ItemSet((user_ids, user_labels), names=("seeds", "labels")),
(user_ids, user_labels), names=("seed_nodes", "labels") "item": gb.ItemSet((item_ids, item_labels), names=("seeds", "labels")),
),
"item": gb.ItemSet(
(item_ids, item_labels), names=("seed_nodes", "labels")
),
} }
chained_ids = [] chained_ids = []
for key, value in ids_labels.items(): for key, value in ids_labels.items():
chained_ids += [(key, v) for v in value] chained_ids += [(key, v) for v in value]
item_set = gb.ItemSetDict(ids_labels) item_set = gb.ItemSetDict(ids_labels)
assert item_set.names == ("seed_nodes", "labels") assert item_set.names == ("seeds", "labels")
# Iterating over ItemSetDict and indexing one by one. # Iterating over ItemSetDict and indexing one by one.
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
...@@ -443,14 +442,14 @@ def test_ItemSetDict_iteration_node_pairs(): ...@@ -443,14 +442,14 @@ def test_ItemSetDict_iteration_node_pairs():
# Node pairs. # Node pairs.
node_pairs = torch.arange(0, 10).reshape(-1, 2) node_pairs = torch.arange(0, 10).reshape(-1, 2)
node_pairs_dict = { node_pairs_dict = {
"user:like:item": gb.ItemSet(node_pairs, names="node_pairs"), "user:like:item": gb.ItemSet(node_pairs, names="seeds"),
"user:follow:user": gb.ItemSet(node_pairs, names="node_pairs"), "user:follow:user": gb.ItemSet(node_pairs, names="seeds"),
} }
expected_data = [] expected_data = []
for key, value in node_pairs_dict.items(): for key, value in node_pairs_dict.items():
expected_data += [(key, v) for v in value] expected_data += [(key, v) for v in value]
item_set = gb.ItemSetDict(node_pairs_dict) item_set = gb.ItemSetDict(node_pairs_dict)
assert item_set.names == ("node_pairs",) assert item_set.names == ("seeds",)
# Iterating over ItemSetDict and indexing one by one. # Iterating over ItemSetDict and indexing one by one.
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
...@@ -471,17 +470,17 @@ def test_ItemSetDict_iteration_node_pairs_labels(): ...@@ -471,17 +470,17 @@ def test_ItemSetDict_iteration_node_pairs_labels():
labels = torch.randint(0, 3, (5,)) labels = torch.randint(0, 3, (5,))
node_pairs_labels = { node_pairs_labels = {
"user:like:item": gb.ItemSet( "user:like:item": gb.ItemSet(
(node_pairs, labels), names=("node_pairs", "labels") (node_pairs, labels), names=("seeds", "labels")
), ),
"user:follow:user": gb.ItemSet( "user:follow:user": gb.ItemSet(
(node_pairs, labels), names=("node_pairs", "labels") (node_pairs, labels), names=("seeds", "labels")
), ),
} }
expected_data = [] expected_data = []
for key, value in node_pairs_labels.items(): for key, value in node_pairs_labels.items():
expected_data += [(key, v) for v in value] expected_data += [(key, v) for v in value]
item_set = gb.ItemSetDict(node_pairs_labels) item_set = gb.ItemSetDict(node_pairs_labels)
assert item_set.names == ("node_pairs", "labels") assert item_set.names == ("seeds", "labels")
# Iterating over ItemSetDict and indexing one by one. # Iterating over ItemSetDict and indexing one by one.
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
...@@ -501,23 +500,24 @@ def test_ItemSetDict_iteration_node_pairs_labels(): ...@@ -501,23 +500,24 @@ def test_ItemSetDict_iteration_node_pairs_labels():
assert torch.equal(item_set[:]["user:follow:user"][1], labels) assert torch.equal(item_set[:]["user:follow:user"][1], labels)
def test_ItemSetDict_iteration_node_pairs_neg_dsts(): def test_ItemSetDict_iteration_node_pairs_labels_indexes():
# Node pairs and negative destinations. # Node pairs and negative destinations.
node_pairs = torch.arange(0, 10).reshape(-1, 2) node_pairs = torch.arange(0, 10).reshape(-1, 2)
neg_dsts = torch.arange(10, 25).reshape(-1, 3) labels = torch.tensor([1, 1, 0, 0, 0])
indexes = torch.tensor([0, 1, 0, 0, 1])
node_pairs_neg_dsts = { node_pairs_neg_dsts = {
"user:like:item": gb.ItemSet( "user:like:item": gb.ItemSet(
(node_pairs, neg_dsts), names=("node_pairs", "negative_dsts") (node_pairs, labels, indexes), names=("seeds", "labels", "indexes")
), ),
"user:follow:user": gb.ItemSet( "user:follow:user": gb.ItemSet(
(node_pairs, neg_dsts), names=("node_pairs", "negative_dsts") (node_pairs, labels, indexes), names=("seeds", "labels", "indexes")
), ),
} }
expected_data = [] expected_data = []
for key, value in node_pairs_neg_dsts.items(): for key, value in node_pairs_neg_dsts.items():
expected_data += [(key, v) for v in value] expected_data += [(key, v) for v in value]
item_set = gb.ItemSetDict(node_pairs_neg_dsts) item_set = gb.ItemSetDict(node_pairs_neg_dsts)
assert item_set.names == ("node_pairs", "negative_dsts") assert item_set.names == ("seeds", "labels", "indexes")
# Iterating over ItemSetDict and indexing one by one. # Iterating over ItemSetDict and indexing one by one.
for i, item in enumerate(item_set): for i, item in enumerate(item_set):
assert len(item) == 1 assert len(item) == 1
...@@ -526,24 +526,28 @@ def test_ItemSetDict_iteration_node_pairs_neg_dsts(): ...@@ -526,24 +526,28 @@ def test_ItemSetDict_iteration_node_pairs_neg_dsts():
assert key in item assert key in item
assert torch.equal(item[key][0], value[0]) assert torch.equal(item[key][0], value[0])
assert torch.equal(item[key][1], value[1]) assert torch.equal(item[key][1], value[1])
assert torch.equal(item[key][2], value[2])
assert item_set[i].keys() == item.keys() assert item_set[i].keys() == item.keys()
key = list(item.keys())[0] key = list(item.keys())[0]
assert torch.equal(item_set[i][key][0], item[key][0]) assert torch.equal(item_set[i][key][0], item[key][0])
assert torch.equal(item_set[i][key][1], item[key][1]) assert torch.equal(item_set[i][key][1], item[key][1])
assert torch.equal(item_set[i][key][2], item[key][2])
# Indexing with a slice. # Indexing with a slice.
assert torch.equal(item_set[:]["user:like:item"][0], node_pairs) assert torch.equal(item_set[:]["user:like:item"][0], node_pairs)
assert torch.equal(item_set[:]["user:like:item"][1], neg_dsts) assert torch.equal(item_set[:]["user:like:item"][1], labels)
assert torch.equal(item_set[:]["user:like:item"][2], indexes)
assert torch.equal(item_set[:]["user:follow:user"][0], node_pairs) assert torch.equal(item_set[:]["user:follow:user"][0], node_pairs)
assert torch.equal(item_set[:]["user:follow:user"][1], neg_dsts) assert torch.equal(item_set[:]["user:follow:user"][1], labels)
assert torch.equal(item_set[:]["user:follow:user"][2], indexes)
def test_ItemSet_repr(): def test_ItemSet_repr():
# ItemSet with single name. # ItemSet with single name.
item_set = gb.ItemSet(torch.arange(0, 5), names="seed_nodes") item_set = gb.ItemSet(torch.arange(0, 5), names="seeds")
expected_str = ( expected_str = (
"ItemSet(\n" "ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]),),\n" " items=(tensor([0, 1, 2, 3, 4]),),\n"
" names=('seed_nodes',),\n" " names=('seeds',),\n"
")" ")"
) )
...@@ -552,12 +556,12 @@ def test_ItemSet_repr(): ...@@ -552,12 +556,12 @@ def test_ItemSet_repr():
# ItemSet with multiple names. # ItemSet with multiple names.
item_set = gb.ItemSet( item_set = gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)), (torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"), names=("seeds", "labels"),
) )
expected_str = ( expected_str = (
"ItemSet(\n" "ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n" " items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
" names=('seed_nodes', 'labels'),\n" " names=('seeds', 'labels'),\n"
")" ")"
) )
assert str(item_set) == expected_str, item_set assert str(item_set) == expected_str, item_set
...@@ -567,20 +571,20 @@ def test_ItemSetDict_repr(): ...@@ -567,20 +571,20 @@ def test_ItemSetDict_repr():
# ItemSetDict with single name. # ItemSetDict with single name.
item_set = gb.ItemSetDict( item_set = gb.ItemSetDict(
{ {
"user": gb.ItemSet(torch.arange(0, 5), names="seed_nodes"), "user": gb.ItemSet(torch.arange(0, 5), names="seeds"),
"item": gb.ItemSet(torch.arange(5, 10), names="seed_nodes"), "item": gb.ItemSet(torch.arange(5, 10), names="seeds"),
} }
) )
expected_str = ( expected_str = (
"ItemSetDict(\n" "ItemSetDict(\n"
" itemsets={'user': ItemSet(\n" " itemsets={'user': ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]),),\n" " items=(tensor([0, 1, 2, 3, 4]),),\n"
" names=('seed_nodes',),\n" " names=('seeds',),\n"
" ), 'item': ItemSet(\n" " ), 'item': ItemSet(\n"
" items=(tensor([5, 6, 7, 8, 9]),),\n" " items=(tensor([5, 6, 7, 8, 9]),),\n"
" names=('seed_nodes',),\n" " names=('seeds',),\n"
" )},\n" " )},\n"
" names=('seed_nodes',),\n" " names=('seeds',),\n"
")" ")"
) )
assert str(item_set) == expected_str, item_set assert str(item_set) == expected_str, item_set
...@@ -590,11 +594,11 @@ def test_ItemSetDict_repr(): ...@@ -590,11 +594,11 @@ def test_ItemSetDict_repr():
{ {
"user": gb.ItemSet( "user": gb.ItemSet(
(torch.arange(0, 5), torch.arange(5, 10)), (torch.arange(0, 5), torch.arange(5, 10)),
names=("seed_nodes", "labels"), names=("seeds", "labels"),
), ),
"item": gb.ItemSet( "item": gb.ItemSet(
(torch.arange(5, 10), torch.arange(10, 15)), (torch.arange(5, 10), torch.arange(10, 15)),
names=("seed_nodes", "labels"), names=("seeds", "labels"),
), ),
} }
) )
...@@ -602,12 +606,12 @@ def test_ItemSetDict_repr(): ...@@ -602,12 +606,12 @@ def test_ItemSetDict_repr():
"ItemSetDict(\n" "ItemSetDict(\n"
" itemsets={'user': ItemSet(\n" " itemsets={'user': ItemSet(\n"
" items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n" " items=(tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8, 9])),\n"
" names=('seed_nodes', 'labels'),\n" " names=('seeds', 'labels'),\n"
" ), 'item': ItemSet(\n" " ), 'item': ItemSet(\n"
" items=(tensor([5, 6, 7, 8, 9]), tensor([10, 11, 12, 13, 14])),\n" " items=(tensor([5, 6, 7, 8, 9]), tensor([10, 11, 12, 13, 14])),\n"
" names=('seed_nodes', 'labels'),\n" " names=('seeds', 'labels'),\n"
" )},\n" " )},\n"
" names=('seed_nodes', 'labels'),\n" " names=('seeds', 'labels'),\n"
")" ")"
) )
assert str(item_set) == expected_str, item_set assert str(item_set) == expected_str, item_set
...@@ -563,11 +563,10 @@ def test_dgl_link_predication_homo(): ...@@ -563,11 +563,10 @@ def test_dgl_link_predication_homo():
check_dgl_blocks_homo(minibatch, dgl_blocks) check_dgl_blocks_homo(minibatch, dgl_blocks)
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"]) def test_dgl_link_predication_hetero():
def test_dgl_link_predication_hetero(mode):
# Arrange # Arrange
minibatch = create_hetero_minibatch() minibatch = create_hetero_minibatch()
minibatch.compacted_node_pairs = { minibatch.compacted_seeds = {
relation: (torch.tensor([[1, 1, 2, 0, 1, 2], [1, 0, 1, 1, 0, 0]]).T,), relation: (torch.tensor([[1, 1, 2, 0, 1, 2], [1, 0, 1, 1, 0, 0]]).T,),
reverse_relation: ( reverse_relation: (
torch.tensor([[0, 1, 1, 2, 0, 2], [1, 0, 1, 1, 0, 0]]).T, torch.tensor([[0, 1, 1, 2, 0, 2], [1, 0, 1, 1, 0, 0]]).T,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment