Unverified Commit 3742d5ff authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] rename data attributes in MiniBatch (#6274)

parent 72683697
...@@ -54,7 +54,7 @@ class NeighborSampler(SubgraphSampler): ...@@ -54,7 +54,7 @@ class NeighborSampler(SubgraphSampler):
>>> import dgl.graphbolt as gb >>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper >>> from torchdata.datapipes.iter import Mapper
>>> def minibatch_link_collator(data): >>> def minibatch_link_collator(data):
... minibatch = gb.MiniBatch(node_pair=data) ... minibatch = gb.MiniBatch(node_pairs=data)
... return minibatch ... return minibatch
... ...
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
...@@ -76,7 +76,7 @@ class NeighborSampler(SubgraphSampler): ...@@ -76,7 +76,7 @@ class NeighborSampler(SubgraphSampler):
>>> subgraph_sampler = gb.NeighborSampler( >>> subgraph_sampler = gb.NeighborSampler(
...neg_sampler, graph, fanouts) ...neg_sampler, graph, fanouts)
>>> for data in subgraph_sampler: >>> for data in subgraph_sampler:
... print(data.compacted_node_pair) ... print(data.compacted_node_pairs)
... print(len(data.sampled_subgraphs)) ... print(len(data.sampled_subgraphs))
(tensor([0, 0, 0]), tensor([1, 0, 2])) (tensor([0, 0, 0]), tensor([1, 0, 2]))
3 3
...@@ -166,7 +166,7 @@ class LayerNeighborSampler(NeighborSampler): ...@@ -166,7 +166,7 @@ class LayerNeighborSampler(NeighborSampler):
>>> import dgl.graphbolt as gb >>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper >>> from torchdata.datapipes.iter import Mapper
>>> def minibatch_link_collator(data): >>> def minibatch_link_collator(data):
... minibatch = gb.MiniBatch(node_pair=data) ... minibatch = gb.MiniBatch(node_pairs=data)
... return minibatch ... return minibatch
... ...
>>> from dgl import graphbolt as gb >>> from dgl import graphbolt as gb
...@@ -188,7 +188,7 @@ class LayerNeighborSampler(NeighborSampler): ...@@ -188,7 +188,7 @@ class LayerNeighborSampler(NeighborSampler):
>>> subgraph_sampler = gb.LayerNeighborSampler( >>> subgraph_sampler = gb.LayerNeighborSampler(
...neg_sampler, graph, fanouts) ...neg_sampler, graph, fanouts)
>>> for data in subgraph_sampler: >>> for data in subgraph_sampler:
... print(data.compacted_node_pair) ... print(data.compacted_node_pairs)
... print(len(data.sampled_subgraphs)) ... print(len(data.sampled_subgraphs))
(tensor([0, 0, 0]), tensor([1, 0, 2])) (tensor([0, 0, 0]), tensor([1, 0, 2]))
3 3
......
...@@ -17,9 +17,9 @@ __all__ = ["ItemSampler"] ...@@ -17,9 +17,9 @@ __all__ = ["ItemSampler"]
class ItemSampler(IterDataPipe): class ItemSampler(IterDataPipe):
"""Item Sampler. """Item Sampler.
Creates item subset of data which could be node/edge IDs, node pairs with Creates item subset of data which could be node IDs, node pairs with or
or without labels, head/tail/negative_tails, DGLGraphs and heterogeneous without labels, node pairs with negative sources/destinations, DGLGraphs
counterparts. and heterogeneous counterparts.
Note: This class `ItemSampler` is not decorated with Note: This class `ItemSampler` is not decorated with
`torchdata.datapipes.functional_datapipe` on purpose. This indicates it `torchdata.datapipes.functional_datapipe` on purpose. This indicates it
......
...@@ -55,74 +55,74 @@ class MiniBatch: ...@@ -55,74 +55,74 @@ class MiniBatch:
value should be corresponding heterogeneous node id. value should be corresponding heterogeneous node id.
""" """
seed_node: Union[torch.Tensor, Dict[str, torch.Tensor]] = None seed_nodes: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
""" """
Representation of seed nodes used for sampling in the graph. Representation of seed nodes used for sampling in the graph.
- If `seed_node` is a tensor: It indicates the graph is homogeneous. - If `seed_nodes` is a tensor: It indicates the graph is homogeneous.
- If `seed_node` is a dictionary: The keys should be node type and the - If `seed_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node ids. value should be corresponding heterogeneous node ids.
""" """
label: Union[torch.Tensor, Dict[str, torch.Tensor]] = None labels: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
""" """
Labels associated with seed nodes in the graph. labelss associated with seed nodes in the graph.
- If `label` is a tensor: It indicates the graph is homogeneous. The value - If `labels` is a tensor: It indicates the graph is homogeneous. The value
should be corresponding labels to given 'seed_node' or 'node_pair'. should be corresponding labelss to given 'seed_nodes' or 'node_pairs'.
- If `label` is a dictionary: The keys should be node or edge type and the - If `labels` is a dictionary: The keys should be node or edge type and the
value should be corresponding labels to given 'seed_node' or 'node_pair'. value should be corresponding labelss to given 'seed_nodes' or 'node_pairs'.
""" """
node_pair: Union[ node_pairs: Union[
Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor],
Dict[str, Tuple[torch.Tensor, torch.Tensor]], Dict[str, Tuple[torch.Tensor, torch.Tensor]],
] = None ] = None
""" """
Representation of seed node pairs utilized in link prediction tasks. Representation of seed node pairs utilized in link prediction tasks.
- If `node_pair` is a tuple: It indicates a homogeneous graph where each - If `node_pairs` is a tuple: It indicates a homogeneous graph where each
tuple contains two tensors representing source-destination node pairs. tuple contains two tensors representing source-destination node pairs.
- If `node_pair` is a dictionary: The keys should be edge type, and the - If `node_pairs` is a dictionary: The keys should be edge type, and the
value should be a tuple of tensors representing node pairs of the given value should be a tuple of tensors representing node pairs of the given
type. type.
""" """
negative_head: Union[torch.Tensor, Dict[str, torch.Tensor]] = None negative_srcs: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
""" """
Representation of negative samples for the head nodes in the link Representation of negative samples for the head nodes in the link
prediction task. prediction task.
- If `negative_head` is a tensor: It indicates a homogeneous graph. - If `negative_srcs` is a tensor: It indicates a homogeneous graph.
- If `negative_head` is a dictionary: The key should be edge type, and the - If `negative_srcs` is a dictionary: The key should be edge type, and the
value should correspond to the negative samples for head nodes of the value should correspond to the negative samples for head nodes of the
given type. given type.
""" """
negative_tail: Union[torch.Tensor, Dict[str, torch.Tensor]] = None negative_dsts: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
""" """
Representation of negative samples for the tail nodes in the link Representation of negative samples for the tail nodes in the link
prediction task. prediction task.
- If `negative_tail` is a tensor: It indicates a homogeneous graph. - If `negative_dsts` is a tensor: It indicates a homogeneous graph.
- If `negative_tail` is a dictionary: The key should be edge type, and the - If `negative_dsts` is a dictionary: The key should be edge type, and the
value should correspond to the negative samples for head nodes of the value should correspond to the negative samples for head nodes of the
given type. given type.
""" """
compacted_node_pair: Union[ compacted_node_pairs: Union[
Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor],
Dict[str, Tuple[torch.Tensor, torch.Tensor]], Dict[str, Tuple[torch.Tensor, torch.Tensor]],
] = None ] = None
""" """
Representation of compacted node pairs corresponding to 'node_pair', where Representation of compacted node pairs corresponding to 'node_pairs', where
all node ids inside are compacted. all node ids inside are compacted.
""" """
compacted_negative_head: Union[torch.Tensor, Dict[str, torch.Tensor]] = None compacted_negative_srcs: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
""" """
Representation of compacted nodes corresponding to 'negative_head', where Representation of compacted nodes corresponding to 'negative_srcs', where
all node ids inside are compacted. all node ids inside are compacted.
""" """
compacted_negative_tail: Union[torch.Tensor, Dict[str, torch.Tensor]] = None compacted_negative_dsts: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
""" """
Representation of compacted nodes corresponding to 'negative_tail', where Representation of compacted nodes corresponding to 'negative_dsts', where
all node ids inside are compacted. all node ids inside are compacted.
""" """
......
...@@ -44,9 +44,9 @@ class NegativeSampler(Mapper): ...@@ -44,9 +44,9 @@ class NegativeSampler(Mapper):
Parameters Parameters
---------- ----------
minibatch : MiniBatch minibatch : MiniBatch
An instance of 'MiniBatch' class requires the 'node_pair' field. An instance of 'MiniBatch' class requires the 'node_pairs' field.
This function is responsible for generating negative edges This function is responsible for generating negative edges
corresponding to the positive edges defined by the 'node_pair'. In corresponding to the positive edges defined by the 'node_pairs'. In
cases where negative edges already exist, this function will cases where negative edges already exist, this function will
overwrite them. overwrite them.
...@@ -56,21 +56,21 @@ class NegativeSampler(Mapper): ...@@ -56,21 +56,21 @@ class NegativeSampler(Mapper):
An instance of 'MiniBatch' encompasses both positive and negative An instance of 'MiniBatch' encompasses both positive and negative
samples. samples.
""" """
node_pairs = minibatch.node_pair node_pairs = minibatch.node_pairs
assert node_pairs is not None assert node_pairs is not None
if isinstance(node_pairs, Mapping): if isinstance(node_pairs, Mapping):
if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT: if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT:
minibatch.label = {} minibatch.labels = {}
else: else:
minibatch.negative_head, minibatch.negative_tail = {}, {} minibatch.negative_srcs, minibatch.negative_dsts = {}, {}
for etype, pos_pairs in node_pairs.items(): for etype, pos_pairs in node_pairs.items():
self._collate( self._collate(
minibatch, self._sample_with_etype(pos_pairs, etype), etype minibatch, self._sample_with_etype(pos_pairs, etype), etype
) )
if self.output_format == LinkPredictionEdgeFormat.HEAD_CONDITIONED: if self.output_format == LinkPredictionEdgeFormat.HEAD_CONDITIONED:
minibatch.negative_tail = None minibatch.negative_dsts = None
if self.output_format == LinkPredictionEdgeFormat.TAIL_CONDITIONED: if self.output_format == LinkPredictionEdgeFormat.TAIL_CONDITIONED:
minibatch.negative_head = None minibatch.negative_srcs = None
else: else:
self._collate(minibatch, self._sample_with_etype(node_pairs)) self._collate(minibatch, self._sample_with_etype(node_pairs))
return minibatch return minibatch
...@@ -111,23 +111,23 @@ class NegativeSampler(Mapper): ...@@ -111,23 +111,23 @@ class NegativeSampler(Mapper):
Canonical edge type. Canonical edge type.
""" """
pos_src, pos_dst = ( pos_src, pos_dst = (
minibatch.node_pair[etype] minibatch.node_pairs[etype]
if etype is not None if etype is not None
else minibatch.node_pair else minibatch.node_pairs
) )
neg_src, neg_dst = neg_pairs neg_src, neg_dst = neg_pairs
if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT: if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT:
pos_label = torch.ones_like(pos_src) pos_labels = torch.ones_like(pos_src)
neg_label = torch.zeros_like(neg_src) neg_labels = torch.zeros_like(neg_src)
src = torch.cat([pos_src, neg_src]) src = torch.cat([pos_src, neg_src])
dst = torch.cat([pos_dst, neg_dst]) dst = torch.cat([pos_dst, neg_dst])
label = torch.cat([pos_label, neg_label]) labels = torch.cat([pos_labels, neg_labels])
if etype is not None: if etype is not None:
minibatch.node_pair[etype] = (src, dst) minibatch.node_pairs[etype] = (src, dst)
minibatch.label[etype] = label minibatch.labels[etype] = labels
else: else:
minibatch.node_pair = (src, dst) minibatch.node_pairs = (src, dst)
minibatch.label = label minibatch.labels = labels
else: else:
if self.output_format == LinkPredictionEdgeFormat.CONDITIONED: if self.output_format == LinkPredictionEdgeFormat.CONDITIONED:
neg_src = neg_src.view(-1, self.negative_ratio) neg_src = neg_src.view(-1, self.negative_ratio)
...@@ -147,8 +147,8 @@ class NegativeSampler(Mapper): ...@@ -147,8 +147,8 @@ class NegativeSampler(Mapper):
f"Unsupported output format {self.output_format}." f"Unsupported output format {self.output_format}."
) )
if etype is not None: if etype is not None:
minibatch.negative_head[etype] = neg_src minibatch.negative_srcs[etype] = neg_src
minibatch.negative_tail[etype] = neg_dst minibatch.negative_dsts[etype] = neg_dst
else: else:
minibatch.negative_head = neg_src minibatch.negative_srcs = neg_src
minibatch.negative_tail = neg_dst minibatch.negative_dsts = neg_dst
...@@ -28,19 +28,19 @@ class SubgraphSampler(Mapper): ...@@ -28,19 +28,19 @@ class SubgraphSampler(Mapper):
super().__init__(datapipe, self._sample) super().__init__(datapipe, self._sample)
def _sample(self, minibatch): def _sample(self, minibatch):
if minibatch.node_pair is not None: if minibatch.node_pairs is not None:
( (
seeds, seeds,
minibatch.compacted_node_pair, minibatch.compacted_node_pairs,
minibatch.compacted_negative_head, minibatch.compacted_negative_srcs,
minibatch.compacted_negative_tail, minibatch.compacted_negative_dsts,
) = self._node_pair_preprocess(minibatch) ) = self._node_pairs_preprocess(minibatch)
elif minibatch.seed_node is not None: elif minibatch.seed_nodes is not None:
seeds = minibatch.seed_node seeds = minibatch.seed_nodes
else: else:
raise ValueError( raise ValueError(
f"Invalid minibatch {minibatch}: Either 'node_pair' or \ f"Invalid minibatch {minibatch}: Either 'node_pairs' or \
'seed_node' should have a value." 'seed_nodes' should have a value."
) )
( (
minibatch.input_nodes, minibatch.input_nodes,
...@@ -48,16 +48,16 @@ class SubgraphSampler(Mapper): ...@@ -48,16 +48,16 @@ class SubgraphSampler(Mapper):
) = self._sample_subgraphs(seeds) ) = self._sample_subgraphs(seeds)
return minibatch return minibatch
def _node_pair_preprocess(self, minibatch): def _node_pairs_preprocess(self, minibatch):
node_pair = minibatch.node_pair node_pairs = minibatch.node_pairs
neg_src, neg_dst = minibatch.negative_head, minibatch.negative_tail neg_src, neg_dst = minibatch.negative_srcs, minibatch.negative_dsts
has_neg_src = neg_src is not None has_neg_src = neg_src is not None
has_neg_dst = neg_dst is not None has_neg_dst = neg_dst is not None
is_heterogeneous = isinstance(node_pair, Dict) is_heterogeneous = isinstance(node_pairs, Dict)
if is_heterogeneous: if is_heterogeneous:
# Collect nodes from all types of input. # Collect nodes from all types of input.
nodes = defaultdict(list) nodes = defaultdict(list)
for etype, (src, dst) in node_pair.items(): for etype, (src, dst) in node_pairs.items():
src_type, _, dst_type = etype_str_to_tuple(etype) src_type, _, dst_type = etype_str_to_tuple(etype)
nodes[src_type].append(src) nodes[src_type].append(src)
nodes[dst_type].append(dst) nodes[dst_type].append(dst)
...@@ -72,27 +72,27 @@ class SubgraphSampler(Mapper): ...@@ -72,27 +72,27 @@ class SubgraphSampler(Mapper):
# Unique and compact the collected nodes. # Unique and compact the collected nodes.
seeds, compacted = unique_and_compact(nodes) seeds, compacted = unique_and_compact(nodes)
( (
compacted_node_pair, compacted_node_pairs,
compacted_negative_head, compacted_negative_srcs,
compacted_negative_tail, compacted_negative_dsts,
) = ({}, {}, {}) ) = ({}, {}, {})
# Map back in same order as collect. # Map back in same order as collect.
for etype, _ in node_pair.items(): for etype, _ in node_pairs.items():
src_type, _, dst_type = etype_str_to_tuple(etype) src_type, _, dst_type = etype_str_to_tuple(etype)
src = compacted[src_type].pop(0) src = compacted[src_type].pop(0)
dst = compacted[dst_type].pop(0) dst = compacted[dst_type].pop(0)
compacted_node_pair[etype] = (src, dst) compacted_node_pairs[etype] = (src, dst)
if has_neg_src: if has_neg_src:
for etype, _ in neg_src.items(): for etype, _ in neg_src.items():
src_type, _, _ = etype_str_to_tuple(etype) src_type, _, _ = etype_str_to_tuple(etype)
compacted_negative_head[etype] = compacted[src_type].pop(0) compacted_negative_srcs[etype] = compacted[src_type].pop(0)
if has_neg_dst: if has_neg_dst:
for etype, _ in neg_dst.items(): for etype, _ in neg_dst.items():
_, _, dst_type = etype_str_to_tuple(etype) _, _, dst_type = etype_str_to_tuple(etype)
compacted_negative_tail[etype] = compacted[dst_type].pop(0) compacted_negative_dsts[etype] = compacted[dst_type].pop(0)
else: else:
# Collect nodes from all types of input. # Collect nodes from all types of input.
nodes = list(node_pair) nodes = list(node_pairs)
if has_neg_src: if has_neg_src:
nodes.append(neg_src.view(-1)) nodes.append(neg_src.view(-1))
if has_neg_dst: if has_neg_dst:
...@@ -100,17 +100,17 @@ class SubgraphSampler(Mapper): ...@@ -100,17 +100,17 @@ class SubgraphSampler(Mapper):
# Unique and compact the collected nodes. # Unique and compact the collected nodes.
seeds, compacted = unique_and_compact(nodes) seeds, compacted = unique_and_compact(nodes)
# Map back in same order as collect. # Map back in same order as collect.
compacted_node_pair = tuple(compacted[:2]) compacted_node_pairs = tuple(compacted[:2])
compacted = compacted[2:] compacted = compacted[2:]
if has_neg_src: if has_neg_src:
compacted_negative_head = compacted.pop(0) compacted_negative_srcs = compacted.pop(0)
if has_neg_dst: if has_neg_dst:
compacted_negative_tail = compacted.pop(0) compacted_negative_dsts = compacted.pop(0)
return ( return (
seeds, seeds,
compacted_node_pair, compacted_node_pairs,
compacted_negative_head if has_neg_src else None, compacted_negative_srcs if has_neg_src else None,
compacted_negative_tail if has_neg_dst else None, compacted_negative_dsts if has_neg_dst else None,
) )
def _sample_subgraphs(self, seeds): def _sample_subgraphs(self, seeds):
......
...@@ -9,12 +9,12 @@ import torch ...@@ -9,12 +9,12 @@ import torch
def minibatch_node_collator(data): def minibatch_node_collator(data):
minibatch = gb.MiniBatch(seed_node=data) minibatch = gb.MiniBatch(seed_nodes=data)
return minibatch return minibatch
def minibatch_link_collator(data): def minibatch_link_collator(data):
minibatch = gb.MiniBatch(node_pair=data) minibatch = gb.MiniBatch(node_pairs=data)
return minibatch return minibatch
......
...@@ -30,14 +30,14 @@ def test_NegativeSampler_Independent_Format(negative_ratio): ...@@ -30,14 +30,14 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
) )
# Perform Negative sampling. # Perform Negative sampling.
for data in negative_sampler: for data in negative_sampler:
src, dst = data.node_pair src, dst = data.node_pairs
label = data.label labels = data.labels
# Assertation # Assertation
assert len(src) == batch_size * (negative_ratio + 1) assert len(src) == batch_size * (negative_ratio + 1)
assert len(dst) == batch_size * (negative_ratio + 1) assert len(dst) == batch_size * (negative_ratio + 1)
assert len(label) == batch_size * (negative_ratio + 1) assert len(labels) == batch_size * (negative_ratio + 1)
assert torch.all(torch.eq(label[:batch_size], 1)) assert torch.all(torch.eq(labels[:batch_size], 1))
assert torch.all(torch.eq(label[batch_size:], 0)) assert torch.all(torch.eq(labels[batch_size:], 0))
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20]) @pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
...@@ -65,8 +65,8 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio): ...@@ -65,8 +65,8 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
) )
# Perform Negative sampling. # Perform Negative sampling.
for data in negative_sampler: for data in negative_sampler:
pos_src, pos_dst = data.node_pair pos_src, pos_dst = data.node_pairs
neg_src, neg_dst = data.negative_head, data.negative_tail neg_src, neg_dst = data.negative_srcs, data.negative_dsts
# Assertation # Assertation
assert len(pos_src) == batch_size assert len(pos_src) == batch_size
assert len(pos_dst) == batch_size assert len(pos_dst) == batch_size
...@@ -103,8 +103,8 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio): ...@@ -103,8 +103,8 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
) )
# Perform Negative sampling. # Perform Negative sampling.
for data in negative_sampler: for data in negative_sampler:
pos_src, pos_dst = data.node_pair pos_src, pos_dst = data.node_pairs
neg_src = data.negative_head neg_src = data.negative_srcs
# Assertation # Assertation
assert len(pos_src) == batch_size assert len(pos_src) == batch_size
assert len(pos_dst) == batch_size assert len(pos_dst) == batch_size
...@@ -139,8 +139,8 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio): ...@@ -139,8 +139,8 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
) )
# Perform Negative sampling. # Perform Negative sampling.
for data in negative_sampler: for data in negative_sampler:
pos_src, pos_dst = data.node_pair pos_src, pos_dst = data.node_pairs
neg_dst = data.negative_tail neg_dst = data.negative_dsts
# Assertation # Assertation
assert len(pos_src) == batch_size assert len(pos_src) == batch_size
assert len(pos_dst) == batch_size assert len(pos_dst) == batch_size
......
...@@ -78,11 +78,11 @@ def test_OnDiskDataset_TVTSet_ItemSet_names(): ...@@ -78,11 +78,11 @@ def test_OnDiskDataset_TVTSet_ItemSet_names():
train_set: train_set:
- type: null - type: null
data: data:
- name: seed_node - name: seed_nodes
format: numpy format: numpy
in_memory: true in_memory: true
path: {train_ids_path} path: {train_ids_path}
- name: label - name: labels
format: numpy format: numpy
in_memory: true in_memory: true
path: {train_labels_path} path: {train_labels_path}
...@@ -104,7 +104,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_names(): ...@@ -104,7 +104,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_names():
for i, (id, label, _) in enumerate(train_set): for i, (id, label, _) in enumerate(train_set):
assert id == train_ids[i] assert id == train_ids[i]
assert label == train_labels[i] assert label == train_labels[i]
assert train_set.names == ("seed_node", "label", None) assert train_set.names == ("seed_nodes", "labels", None)
train_set = None train_set = None
...@@ -125,11 +125,11 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_names(): ...@@ -125,11 +125,11 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_names():
train_set: train_set:
- type: "author:writes:paper" - type: "author:writes:paper"
data: data:
- name: seed_node - name: seed_nodes
format: numpy format: numpy
in_memory: true in_memory: true
path: {train_ids_path} path: {train_ids_path}
- name: label - name: labels
format: numpy format: numpy
in_memory: true in_memory: true
path: {train_labels_path} path: {train_labels_path}
...@@ -154,7 +154,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_names(): ...@@ -154,7 +154,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_names():
id, label, _ = item["author:writes:paper"] id, label, _ = item["author:writes:paper"]
assert id == train_ids[i] assert id == train_ids[i]
assert label == train_labels[i] assert label == train_labels[i]
assert train_set.names == ("seed_node", "label", None) assert train_set.names == ("seed_nodes", "labels", None)
train_set = None train_set = None
...@@ -193,32 +193,32 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -193,32 +193,32 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
train_set: train_set:
- type: null - type: null
data: data:
- name: seed_node - name: seed_nodes
format: numpy format: numpy
in_memory: true in_memory: true
path: {train_ids_path} path: {train_ids_path}
- name: label - name: labels
format: numpy format: numpy
in_memory: true in_memory: true
path: {train_labels_path} path: {train_labels_path}
validation_set: validation_set:
- data: - data:
- name: seed_node - name: seed_nodes
format: numpy format: numpy
in_memory: true in_memory: true
path: {validation_ids_path} path: {validation_ids_path}
- name: label - name: labels
format: numpy format: numpy
in_memory: true in_memory: true
path: {validation_labels_path} path: {validation_labels_path}
test_set: test_set:
- type: null - type: null
data: data:
- name: seed_node - name: seed_nodes
format: numpy format: numpy
in_memory: true in_memory: true
path: {test_ids_path} path: {test_ids_path}
- name: label - name: labels
format: numpy format: numpy
in_memory: true in_memory: true
path: {test_labels_path} path: {test_labels_path}
...@@ -242,7 +242,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -242,7 +242,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
for i, (id, label) in enumerate(train_set): for i, (id, label) in enumerate(train_set):
assert id == train_ids[i] assert id == train_ids[i]
assert label == train_labels[i] assert label == train_labels[i]
assert train_set.names == ("seed_node", "label") assert train_set.names == ("seed_nodes", "labels")
train_set = None train_set = None
# Verify validation set. # Verify validation set.
...@@ -252,7 +252,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -252,7 +252,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
for i, (id, label) in enumerate(validation_set): for i, (id, label) in enumerate(validation_set):
assert id == validation_ids[i] assert id == validation_ids[i]
assert label == validation_labels[i] assert label == validation_labels[i]
assert validation_set.names == ("seed_node", "label") assert validation_set.names == ("seed_nodes", "labels")
validation_set = None validation_set = None
# Verify test set. # Verify test set.
...@@ -262,7 +262,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label(): ...@@ -262,7 +262,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
for i, (id, label) in enumerate(test_set): for i, (id, label) in enumerate(test_set):
assert id == test_ids[i] assert id == test_ids[i]
assert label == test_labels[i] assert label == test_labels[i]
assert test_set.names == ("seed_node", "label") assert test_set.names == ("seed_nodes", "labels")
test_set = None test_set = None
dataset = None dataset = None
...@@ -334,7 +334,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label(): ...@@ -334,7 +334,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
format: numpy format: numpy
in_memory: true in_memory: true
path: {train_dst_path} path: {train_dst_path}
- name: label - name: labels
format: numpy format: numpy
in_memory: true in_memory: true
path: {train_labels_path} path: {train_labels_path}
...@@ -348,7 +348,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label(): ...@@ -348,7 +348,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
format: numpy format: numpy
in_memory: true in_memory: true
path: {validation_dst_path} path: {validation_dst_path}
- name: label - name: labels
format: numpy format: numpy
in_memory: true in_memory: true
path: {validation_labels_path} path: {validation_labels_path}
...@@ -363,7 +363,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label(): ...@@ -363,7 +363,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
format: numpy format: numpy
in_memory: true in_memory: true
path: {test_dst_path} path: {test_dst_path}
- name: label - name: labels
format: numpy format: numpy
in_memory: true in_memory: true
path: {test_labels_path} path: {test_labels_path}
...@@ -383,7 +383,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label(): ...@@ -383,7 +383,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
assert src == train_src[i] assert src == train_src[i]
assert dst == train_dst[i] assert dst == train_dst[i]
assert label == train_labels[i] assert label == train_labels[i]
assert train_set.names == ("src", "dst", "label") assert train_set.names == ("src", "dst", "labels")
train_set = None train_set = None
# Verify validation set. # Verify validation set.
...@@ -394,7 +394,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label(): ...@@ -394,7 +394,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
assert src == validation_src[i] assert src == validation_src[i]
assert dst == validation_dst[i] assert dst == validation_dst[i]
assert label == validation_labels[i] assert label == validation_labels[i]
assert validation_set.names == ("src", "dst", "label") assert validation_set.names == ("src", "dst", "labels")
validation_set = None validation_set = None
# Verify test set. # Verify test set.
...@@ -405,7 +405,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label(): ...@@ -405,7 +405,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
assert src == test_src[i] assert src == test_src[i]
assert dst == test_dst[i] assert dst == test_dst[i]
assert label == test_labels[i] assert label == test_labels[i]
assert test_set.names == ("src", "dst", "label") assert test_set.names == ("src", "dst", "labels")
test_set = None test_set = None
dataset = None dataset = None
...@@ -564,36 +564,36 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -564,36 +564,36 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
train_set: train_set:
- type: paper - type: paper
data: data:
- name: seed_node - name: seed_nodes
format: numpy format: numpy
in_memory: true in_memory: true
path: {train_path} path: {train_path}
- type: author - type: author
data: data:
- name: seed_node - name: seed_nodes
format: numpy format: numpy
path: {train_path} path: {train_path}
validation_set: validation_set:
- type: paper - type: paper
data: data:
- name: seed_node - name: seed_nodes
format: numpy format: numpy
path: {validation_path} path: {validation_path}
- type: author - type: author
data: data:
- name: seed_node - name: seed_nodes
format: numpy format: numpy
path: {validation_path} path: {validation_path}
test_set: test_set:
- type: paper - type: paper
data: data:
- name: seed_node - name: seed_nodes
format: numpy format: numpy
in_memory: false in_memory: false
path: {test_path} path: {test_path}
- type: author - type: author
data: data:
- name: seed_node - name: seed_nodes
format: numpy format: numpy
path: {test_path} path: {test_path}
""" """
...@@ -616,7 +616,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -616,7 +616,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
id, label = item[key] id, label = item[key]
assert id == train_ids[i % 1000] assert id == train_ids[i % 1000]
assert label == train_labels[i % 1000] assert label == train_labels[i % 1000]
assert train_set.names == ("seed_node",) assert train_set.names == ("seed_nodes",)
train_set = None train_set = None
# Verify validation set. # Verify validation set.
...@@ -631,7 +631,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -631,7 +631,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
id, label = item[key] id, label = item[key]
assert id == validation_ids[i % 1000] assert id == validation_ids[i % 1000]
assert label == validation_labels[i % 1000] assert label == validation_labels[i % 1000]
assert validation_set.names == ("seed_node",) assert validation_set.names == ("seed_nodes",)
validation_set = None validation_set = None
# Verify test set. # Verify test set.
...@@ -646,7 +646,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label(): ...@@ -646,7 +646,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
id, label = item[key] id, label = item[key]
assert id == test_ids[i % 1000] assert id == test_ids[i % 1000]
assert label == test_labels[i % 1000] assert label == test_labels[i % 1000]
assert test_set.names == ("seed_node",) assert test_set.names == ("seed_nodes",)
test_set = None test_set = None
dataset = None dataset = None
...@@ -798,7 +798,7 @@ def test_OnDiskDataset_Feature_heterograph(): ...@@ -798,7 +798,7 @@ def test_OnDiskDataset_Feature_heterograph():
path: {node_data_paper_path} path: {node_data_paper_path}
- domain: node - domain: node
type: paper type: paper
name: label name: labels
format: numpy format: numpy
in_memory: true in_memory: true
path: {node_data_label_path} path: {node_data_label_path}
...@@ -810,7 +810,7 @@ def test_OnDiskDataset_Feature_heterograph(): ...@@ -810,7 +810,7 @@ def test_OnDiskDataset_Feature_heterograph():
path: {edge_data_writes_path} path: {edge_data_writes_path}
- domain: edge - domain: edge
type: "author:writes:paper" type: "author:writes:paper"
name: label name: labels
format: numpy format: numpy
in_memory: true in_memory: true
path: {edge_data_label_path} path: {edge_data_label_path}
...@@ -832,7 +832,7 @@ def test_OnDiskDataset_Feature_heterograph(): ...@@ -832,7 +832,7 @@ def test_OnDiskDataset_Feature_heterograph():
torch.tensor(node_data_paper), torch.tensor(node_data_paper),
) )
assert torch.equal( assert torch.equal(
feature_data.read("node", "paper", "label"), feature_data.read("node", "paper", "labels"),
torch.tensor(node_data_label), torch.tensor(node_data_label),
) )
...@@ -842,7 +842,7 @@ def test_OnDiskDataset_Feature_heterograph(): ...@@ -842,7 +842,7 @@ def test_OnDiskDataset_Feature_heterograph():
torch.tensor(edge_data_writes), torch.tensor(edge_data_writes),
) )
assert torch.equal( assert torch.equal(
feature_data.read("edge", "author:writes:paper", "label"), feature_data.read("edge", "author:writes:paper", "labels"),
torch.tensor(edge_data_label), torch.tensor(edge_data_label),
) )
...@@ -879,7 +879,7 @@ def test_OnDiskDataset_Feature_homograph(): ...@@ -879,7 +879,7 @@ def test_OnDiskDataset_Feature_homograph():
in_memory: false in_memory: false
path: {node_data_feat_path} path: {node_data_feat_path}
- domain: node - domain: node
name: label name: labels
format: numpy format: numpy
in_memory: true in_memory: true
path: {node_data_label_path} path: {node_data_label_path}
...@@ -889,7 +889,7 @@ def test_OnDiskDataset_Feature_homograph(): ...@@ -889,7 +889,7 @@ def test_OnDiskDataset_Feature_homograph():
in_memory: false in_memory: false
path: {edge_data_feat_path} path: {edge_data_feat_path}
- domain: edge - domain: edge
name: label name: labels
format: numpy format: numpy
in_memory: true in_memory: true
path: {edge_data_label_path} path: {edge_data_label_path}
...@@ -911,7 +911,7 @@ def test_OnDiskDataset_Feature_homograph(): ...@@ -911,7 +911,7 @@ def test_OnDiskDataset_Feature_homograph():
torch.tensor(node_data_feat), torch.tensor(node_data_feat),
) )
assert torch.equal( assert torch.equal(
feature_data.read("node", None, "label"), feature_data.read("node", None, "labels"),
torch.tensor(node_data_label), torch.tensor(node_data_label),
) )
...@@ -921,7 +921,7 @@ def test_OnDiskDataset_Feature_homograph(): ...@@ -921,7 +921,7 @@ def test_OnDiskDataset_Feature_homograph():
torch.tensor(edge_data_feat), torch.tensor(edge_data_feat),
) )
assert torch.equal( assert torch.equal(
feature_data.read("edge", None, "label"), feature_data.read("edge", None, "labels"),
torch.tensor(edge_data_label), torch.tensor(edge_data_label),
) )
......
...@@ -22,7 +22,7 @@ def test_SubgraphSampler_Node(labor): ...@@ -22,7 +22,7 @@ def test_SubgraphSampler_Node(labor):
def to_link_batch(data): def to_link_batch(data):
block = gb.MiniBatch(node_pair=data) block = gb.MiniBatch(node_pairs=data)
return block return block
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment