Unverified Commit 3742d5ff authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] rename data attributes in MiniBatch (#6274)

parent 72683697
......@@ -54,7 +54,7 @@ class NeighborSampler(SubgraphSampler):
>>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper
>>> def minibatch_link_collator(data):
... minibatch = gb.MiniBatch(node_pair=data)
... minibatch = gb.MiniBatch(node_pairs=data)
... return minibatch
...
>>> from dgl import graphbolt as gb
......@@ -76,7 +76,7 @@ class NeighborSampler(SubgraphSampler):
>>> subgraph_sampler = gb.NeighborSampler(
...neg_sampler, graph, fanouts)
>>> for data in subgraph_sampler:
... print(data.compacted_node_pair)
... print(data.compacted_node_pairs)
... print(len(data.sampled_subgraphs))
(tensor([0, 0, 0]), tensor([1, 0, 2]))
3
......@@ -166,7 +166,7 @@ class LayerNeighborSampler(NeighborSampler):
>>> import dgl.graphbolt as gb
>>> from torchdata.datapipes.iter import Mapper
>>> def minibatch_link_collator(data):
... minibatch = gb.MiniBatch(node_pair=data)
... minibatch = gb.MiniBatch(node_pairs=data)
... return minibatch
...
>>> from dgl import graphbolt as gb
......@@ -188,7 +188,7 @@ class LayerNeighborSampler(NeighborSampler):
>>> subgraph_sampler = gb.LayerNeighborSampler(
...neg_sampler, graph, fanouts)
>>> for data in subgraph_sampler:
... print(data.compacted_node_pair)
... print(data.compacted_node_pairs)
... print(len(data.sampled_subgraphs))
(tensor([0, 0, 0]), tensor([1, 0, 2]))
3
......
......@@ -17,9 +17,9 @@ __all__ = ["ItemSampler"]
class ItemSampler(IterDataPipe):
"""Item Sampler.
Creates item subset of data which could be node/edge IDs, node pairs with
or without labels, head/tail/negative_tails, DGLGraphs and heterogeneous
counterparts.
Creates item subset of data which could be node IDs, node pairs with or
without labels, node pairs with negative sources/destinations, DGLGraphs
and heterogeneous counterparts.
Note: This class `ItemSampler` is not decorated with
`torchdata.datapipes.functional_datapipe` on purpose. This indicates it
......
......@@ -55,74 +55,74 @@ class MiniBatch:
value should be corresponding heterogeneous node id.
"""
seed_node: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
seed_nodes: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of seed nodes used for sampling in the graph.
- If `seed_node` is a tensor: It indicates the graph is homogeneous.
- If `seed_node` is a dictionary: The keys should be node type and the
- If `seed_nodes` is a tensor: It indicates the graph is homogeneous.
- If `seed_nodes` is a dictionary: The keys should be node type and the
value should be corresponding heterogeneous node ids.
"""
label: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
labels: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Labels associated with seed nodes in the graph.
- If `label` is a tensor: It indicates the graph is homogeneous. The value
should be corresponding labels to given 'seed_node' or 'node_pair'.
- If `label` is a dictionary: The keys should be node or edge type and the
value should be corresponding labels to given 'seed_node' or 'node_pair'.
labelss associated with seed nodes in the graph.
- If `labels` is a tensor: It indicates the graph is homogeneous. The value
should be corresponding labelss to given 'seed_nodes' or 'node_pairs'.
- If `labels` is a dictionary: The keys should be node or edge type and the
value should be corresponding labelss to given 'seed_nodes' or 'node_pairs'.
"""
node_pair: Union[
node_pairs: Union[
Tuple[torch.Tensor, torch.Tensor],
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
] = None
"""
Representation of seed node pairs utilized in link prediction tasks.
- If `node_pair` is a tuple: It indicates a homogeneous graph where each
- If `node_pairs` is a tuple: It indicates a homogeneous graph where each
tuple contains two tensors representing source-destination node pairs.
- If `node_pair` is a dictionary: The keys should be edge type, and the
- If `node_pairs` is a dictionary: The keys should be edge type, and the
value should be a tuple of tensors representing node pairs of the given
type.
"""
negative_head: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
negative_srcs: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of negative samples for the head nodes in the link
prediction task.
- If `negative_head` is a tensor: It indicates a homogeneous graph.
- If `negative_head` is a dictionary: The key should be edge type, and the
- If `negative_srcs` is a tensor: It indicates a homogeneous graph.
- If `negative_srcs` is a dictionary: The key should be edge type, and the
value should correspond to the negative samples for head nodes of the
given type.
"""
negative_tail: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
negative_dsts: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of negative samples for the tail nodes in the link
prediction task.
- If `negative_tail` is a tensor: It indicates a homogeneous graph.
- If `negative_tail` is a dictionary: The key should be edge type, and the
- If `negative_dsts` is a tensor: It indicates a homogeneous graph.
- If `negative_dsts` is a dictionary: The key should be edge type, and the
value should correspond to the negative samples for head nodes of the
given type.
"""
compacted_node_pair: Union[
compacted_node_pairs: Union[
Tuple[torch.Tensor, torch.Tensor],
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
] = None
"""
Representation of compacted node pairs corresponding to 'node_pair', where
Representation of compacted node pairs corresponding to 'node_pairs', where
all node ids inside are compacted.
"""
compacted_negative_head: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
compacted_negative_srcs: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of compacted nodes corresponding to 'negative_head', where
Representation of compacted nodes corresponding to 'negative_srcs', where
all node ids inside are compacted.
"""
compacted_negative_tail: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
compacted_negative_dsts: Union[torch.Tensor, Dict[str, torch.Tensor]] = None
"""
Representation of compacted nodes corresponding to 'negative_tail', where
Representation of compacted nodes corresponding to 'negative_dsts', where
all node ids inside are compacted.
"""
......
......@@ -44,9 +44,9 @@ class NegativeSampler(Mapper):
Parameters
----------
minibatch : MiniBatch
An instance of 'MiniBatch' class requires the 'node_pair' field.
An instance of 'MiniBatch' class requires the 'node_pairs' field.
This function is responsible for generating negative edges
corresponding to the positive edges defined by the 'node_pair'. In
corresponding to the positive edges defined by the 'node_pairs'. In
cases where negative edges already exist, this function will
overwrite them.
......@@ -56,21 +56,21 @@ class NegativeSampler(Mapper):
An instance of 'MiniBatch' encompasses both positive and negative
samples.
"""
node_pairs = minibatch.node_pair
node_pairs = minibatch.node_pairs
assert node_pairs is not None
if isinstance(node_pairs, Mapping):
if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT:
minibatch.label = {}
minibatch.labels = {}
else:
minibatch.negative_head, minibatch.negative_tail = {}, {}
minibatch.negative_srcs, minibatch.negative_dsts = {}, {}
for etype, pos_pairs in node_pairs.items():
self._collate(
minibatch, self._sample_with_etype(pos_pairs, etype), etype
)
if self.output_format == LinkPredictionEdgeFormat.HEAD_CONDITIONED:
minibatch.negative_tail = None
minibatch.negative_dsts = None
if self.output_format == LinkPredictionEdgeFormat.TAIL_CONDITIONED:
minibatch.negative_head = None
minibatch.negative_srcs = None
else:
self._collate(minibatch, self._sample_with_etype(node_pairs))
return minibatch
......@@ -111,23 +111,23 @@ class NegativeSampler(Mapper):
Canonical edge type.
"""
pos_src, pos_dst = (
minibatch.node_pair[etype]
minibatch.node_pairs[etype]
if etype is not None
else minibatch.node_pair
else minibatch.node_pairs
)
neg_src, neg_dst = neg_pairs
if self.output_format == LinkPredictionEdgeFormat.INDEPENDENT:
pos_label = torch.ones_like(pos_src)
neg_label = torch.zeros_like(neg_src)
pos_labels = torch.ones_like(pos_src)
neg_labels = torch.zeros_like(neg_src)
src = torch.cat([pos_src, neg_src])
dst = torch.cat([pos_dst, neg_dst])
label = torch.cat([pos_label, neg_label])
labels = torch.cat([pos_labels, neg_labels])
if etype is not None:
minibatch.node_pair[etype] = (src, dst)
minibatch.label[etype] = label
minibatch.node_pairs[etype] = (src, dst)
minibatch.labels[etype] = labels
else:
minibatch.node_pair = (src, dst)
minibatch.label = label
minibatch.node_pairs = (src, dst)
minibatch.labels = labels
else:
if self.output_format == LinkPredictionEdgeFormat.CONDITIONED:
neg_src = neg_src.view(-1, self.negative_ratio)
......@@ -147,8 +147,8 @@ class NegativeSampler(Mapper):
f"Unsupported output format {self.output_format}."
)
if etype is not None:
minibatch.negative_head[etype] = neg_src
minibatch.negative_tail[etype] = neg_dst
minibatch.negative_srcs[etype] = neg_src
minibatch.negative_dsts[etype] = neg_dst
else:
minibatch.negative_head = neg_src
minibatch.negative_tail = neg_dst
minibatch.negative_srcs = neg_src
minibatch.negative_dsts = neg_dst
......@@ -28,19 +28,19 @@ class SubgraphSampler(Mapper):
super().__init__(datapipe, self._sample)
def _sample(self, minibatch):
if minibatch.node_pair is not None:
if minibatch.node_pairs is not None:
(
seeds,
minibatch.compacted_node_pair,
minibatch.compacted_negative_head,
minibatch.compacted_negative_tail,
) = self._node_pair_preprocess(minibatch)
elif minibatch.seed_node is not None:
seeds = minibatch.seed_node
minibatch.compacted_node_pairs,
minibatch.compacted_negative_srcs,
minibatch.compacted_negative_dsts,
) = self._node_pairs_preprocess(minibatch)
elif minibatch.seed_nodes is not None:
seeds = minibatch.seed_nodes
else:
raise ValueError(
f"Invalid minibatch {minibatch}: Either 'node_pair' or \
'seed_node' should have a value."
f"Invalid minibatch {minibatch}: Either 'node_pairs' or \
'seed_nodes' should have a value."
)
(
minibatch.input_nodes,
......@@ -48,16 +48,16 @@ class SubgraphSampler(Mapper):
) = self._sample_subgraphs(seeds)
return minibatch
def _node_pair_preprocess(self, minibatch):
node_pair = minibatch.node_pair
neg_src, neg_dst = minibatch.negative_head, minibatch.negative_tail
def _node_pairs_preprocess(self, minibatch):
node_pairs = minibatch.node_pairs
neg_src, neg_dst = minibatch.negative_srcs, minibatch.negative_dsts
has_neg_src = neg_src is not None
has_neg_dst = neg_dst is not None
is_heterogeneous = isinstance(node_pair, Dict)
is_heterogeneous = isinstance(node_pairs, Dict)
if is_heterogeneous:
# Collect nodes from all types of input.
nodes = defaultdict(list)
for etype, (src, dst) in node_pair.items():
for etype, (src, dst) in node_pairs.items():
src_type, _, dst_type = etype_str_to_tuple(etype)
nodes[src_type].append(src)
nodes[dst_type].append(dst)
......@@ -72,27 +72,27 @@ class SubgraphSampler(Mapper):
# Unique and compact the collected nodes.
seeds, compacted = unique_and_compact(nodes)
(
compacted_node_pair,
compacted_negative_head,
compacted_negative_tail,
compacted_node_pairs,
compacted_negative_srcs,
compacted_negative_dsts,
) = ({}, {}, {})
# Map back in same order as collect.
for etype, _ in node_pair.items():
for etype, _ in node_pairs.items():
src_type, _, dst_type = etype_str_to_tuple(etype)
src = compacted[src_type].pop(0)
dst = compacted[dst_type].pop(0)
compacted_node_pair[etype] = (src, dst)
compacted_node_pairs[etype] = (src, dst)
if has_neg_src:
for etype, _ in neg_src.items():
src_type, _, _ = etype_str_to_tuple(etype)
compacted_negative_head[etype] = compacted[src_type].pop(0)
compacted_negative_srcs[etype] = compacted[src_type].pop(0)
if has_neg_dst:
for etype, _ in neg_dst.items():
_, _, dst_type = etype_str_to_tuple(etype)
compacted_negative_tail[etype] = compacted[dst_type].pop(0)
compacted_negative_dsts[etype] = compacted[dst_type].pop(0)
else:
# Collect nodes from all types of input.
nodes = list(node_pair)
nodes = list(node_pairs)
if has_neg_src:
nodes.append(neg_src.view(-1))
if has_neg_dst:
......@@ -100,17 +100,17 @@ class SubgraphSampler(Mapper):
# Unique and compact the collected nodes.
seeds, compacted = unique_and_compact(nodes)
# Map back in same order as collect.
compacted_node_pair = tuple(compacted[:2])
compacted_node_pairs = tuple(compacted[:2])
compacted = compacted[2:]
if has_neg_src:
compacted_negative_head = compacted.pop(0)
compacted_negative_srcs = compacted.pop(0)
if has_neg_dst:
compacted_negative_tail = compacted.pop(0)
compacted_negative_dsts = compacted.pop(0)
return (
seeds,
compacted_node_pair,
compacted_negative_head if has_neg_src else None,
compacted_negative_tail if has_neg_dst else None,
compacted_node_pairs,
compacted_negative_srcs if has_neg_src else None,
compacted_negative_dsts if has_neg_dst else None,
)
def _sample_subgraphs(self, seeds):
......
......@@ -9,12 +9,12 @@ import torch
def minibatch_node_collator(data):
minibatch = gb.MiniBatch(seed_node=data)
minibatch = gb.MiniBatch(seed_nodes=data)
return minibatch
def minibatch_link_collator(data):
minibatch = gb.MiniBatch(node_pair=data)
minibatch = gb.MiniBatch(node_pairs=data)
return minibatch
......
......@@ -30,14 +30,14 @@ def test_NegativeSampler_Independent_Format(negative_ratio):
)
# Perform Negative sampling.
for data in negative_sampler:
src, dst = data.node_pair
label = data.label
src, dst = data.node_pairs
labels = data.labels
# Assertation
assert len(src) == batch_size * (negative_ratio + 1)
assert len(dst) == batch_size * (negative_ratio + 1)
assert len(label) == batch_size * (negative_ratio + 1)
assert torch.all(torch.eq(label[:batch_size], 1))
assert torch.all(torch.eq(label[batch_size:], 0))
assert len(labels) == batch_size * (negative_ratio + 1)
assert torch.all(torch.eq(labels[:batch_size], 1))
assert torch.all(torch.eq(labels[batch_size:], 0))
@pytest.mark.parametrize("negative_ratio", [1, 5, 10, 20])
......@@ -65,8 +65,8 @@ def test_NegativeSampler_Conditioned_Format(negative_ratio):
)
# Perform Negative sampling.
for data in negative_sampler:
pos_src, pos_dst = data.node_pair
neg_src, neg_dst = data.negative_head, data.negative_tail
pos_src, pos_dst = data.node_pairs
neg_src, neg_dst = data.negative_srcs, data.negative_dsts
# Assertation
assert len(pos_src) == batch_size
assert len(pos_dst) == batch_size
......@@ -103,8 +103,8 @@ def test_NegativeSampler_Head_Conditioned_Format(negative_ratio):
)
# Perform Negative sampling.
for data in negative_sampler:
pos_src, pos_dst = data.node_pair
neg_src = data.negative_head
pos_src, pos_dst = data.node_pairs
neg_src = data.negative_srcs
# Assertation
assert len(pos_src) == batch_size
assert len(pos_dst) == batch_size
......@@ -139,8 +139,8 @@ def test_NegativeSampler_Tail_Conditioned_Format(negative_ratio):
)
# Perform Negative sampling.
for data in negative_sampler:
pos_src, pos_dst = data.node_pair
neg_dst = data.negative_tail
pos_src, pos_dst = data.node_pairs
neg_dst = data.negative_dsts
# Assertation
assert len(pos_src) == batch_size
assert len(pos_dst) == batch_size
......
......@@ -78,11 +78,11 @@ def test_OnDiskDataset_TVTSet_ItemSet_names():
train_set:
- type: null
data:
- name: seed_node
- name: seed_nodes
format: numpy
in_memory: true
path: {train_ids_path}
- name: label
- name: labels
format: numpy
in_memory: true
path: {train_labels_path}
......@@ -104,7 +104,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_names():
for i, (id, label, _) in enumerate(train_set):
assert id == train_ids[i]
assert label == train_labels[i]
assert train_set.names == ("seed_node", "label", None)
assert train_set.names == ("seed_nodes", "labels", None)
train_set = None
......@@ -125,11 +125,11 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_names():
train_set:
- type: "author:writes:paper"
data:
- name: seed_node
- name: seed_nodes
format: numpy
in_memory: true
path: {train_ids_path}
- name: label
- name: labels
format: numpy
in_memory: true
path: {train_labels_path}
......@@ -154,7 +154,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_names():
id, label, _ = item["author:writes:paper"]
assert id == train_ids[i]
assert label == train_labels[i]
assert train_set.names == ("seed_node", "label", None)
assert train_set.names == ("seed_nodes", "labels", None)
train_set = None
......@@ -193,32 +193,32 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
train_set:
- type: null
data:
- name: seed_node
- name: seed_nodes
format: numpy
in_memory: true
path: {train_ids_path}
- name: label
- name: labels
format: numpy
in_memory: true
path: {train_labels_path}
validation_set:
- data:
- name: seed_node
- name: seed_nodes
format: numpy
in_memory: true
path: {validation_ids_path}
- name: label
- name: labels
format: numpy
in_memory: true
path: {validation_labels_path}
test_set:
- type: null
data:
- name: seed_node
- name: seed_nodes
format: numpy
in_memory: true
path: {test_ids_path}
- name: label
- name: labels
format: numpy
in_memory: true
path: {test_labels_path}
......@@ -242,7 +242,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
for i, (id, label) in enumerate(train_set):
assert id == train_ids[i]
assert label == train_labels[i]
assert train_set.names == ("seed_node", "label")
assert train_set.names == ("seed_nodes", "labels")
train_set = None
# Verify validation set.
......@@ -252,7 +252,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
for i, (id, label) in enumerate(validation_set):
assert id == validation_ids[i]
assert label == validation_labels[i]
assert validation_set.names == ("seed_node", "label")
assert validation_set.names == ("seed_nodes", "labels")
validation_set = None
# Verify test set.
......@@ -262,7 +262,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_id_label():
for i, (id, label) in enumerate(test_set):
assert id == test_ids[i]
assert label == test_labels[i]
assert test_set.names == ("seed_node", "label")
assert test_set.names == ("seed_nodes", "labels")
test_set = None
dataset = None
......@@ -334,7 +334,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
format: numpy
in_memory: true
path: {train_dst_path}
- name: label
- name: labels
format: numpy
in_memory: true
path: {train_labels_path}
......@@ -348,7 +348,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
format: numpy
in_memory: true
path: {validation_dst_path}
- name: label
- name: labels
format: numpy
in_memory: true
path: {validation_labels_path}
......@@ -363,7 +363,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
format: numpy
in_memory: true
path: {test_dst_path}
- name: label
- name: labels
format: numpy
in_memory: true
path: {test_labels_path}
......@@ -383,7 +383,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
assert src == train_src[i]
assert dst == train_dst[i]
assert label == train_labels[i]
assert train_set.names == ("src", "dst", "label")
assert train_set.names == ("src", "dst", "labels")
train_set = None
# Verify validation set.
......@@ -394,7 +394,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
assert src == validation_src[i]
assert dst == validation_dst[i]
assert label == validation_labels[i]
assert validation_set.names == ("src", "dst", "label")
assert validation_set.names == ("src", "dst", "labels")
validation_set = None
# Verify test set.
......@@ -405,7 +405,7 @@ def test_OnDiskDataset_TVTSet_ItemSet_node_pair_label():
assert src == test_src[i]
assert dst == test_dst[i]
assert label == test_labels[i]
assert test_set.names == ("src", "dst", "label")
assert test_set.names == ("src", "dst", "labels")
test_set = None
dataset = None
......@@ -564,36 +564,36 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
train_set:
- type: paper
data:
- name: seed_node
- name: seed_nodes
format: numpy
in_memory: true
path: {train_path}
- type: author
data:
- name: seed_node
- name: seed_nodes
format: numpy
path: {train_path}
validation_set:
- type: paper
data:
- name: seed_node
- name: seed_nodes
format: numpy
path: {validation_path}
- type: author
data:
- name: seed_node
- name: seed_nodes
format: numpy
path: {validation_path}
test_set:
- type: paper
data:
- name: seed_node
- name: seed_nodes
format: numpy
in_memory: false
path: {test_path}
- type: author
data:
- name: seed_node
- name: seed_nodes
format: numpy
path: {test_path}
"""
......@@ -616,7 +616,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
id, label = item[key]
assert id == train_ids[i % 1000]
assert label == train_labels[i % 1000]
assert train_set.names == ("seed_node",)
assert train_set.names == ("seed_nodes",)
train_set = None
# Verify validation set.
......@@ -631,7 +631,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
id, label = item[key]
assert id == validation_ids[i % 1000]
assert label == validation_labels[i % 1000]
assert validation_set.names == ("seed_node",)
assert validation_set.names == ("seed_nodes",)
validation_set = None
# Verify test set.
......@@ -646,7 +646,7 @@ def test_OnDiskDataset_TVTSet_ItemSetDict_id_label():
id, label = item[key]
assert id == test_ids[i % 1000]
assert label == test_labels[i % 1000]
assert test_set.names == ("seed_node",)
assert test_set.names == ("seed_nodes",)
test_set = None
dataset = None
......@@ -798,7 +798,7 @@ def test_OnDiskDataset_Feature_heterograph():
path: {node_data_paper_path}
- domain: node
type: paper
name: label
name: labels
format: numpy
in_memory: true
path: {node_data_label_path}
......@@ -810,7 +810,7 @@ def test_OnDiskDataset_Feature_heterograph():
path: {edge_data_writes_path}
- domain: edge
type: "author:writes:paper"
name: label
name: labels
format: numpy
in_memory: true
path: {edge_data_label_path}
......@@ -832,7 +832,7 @@ def test_OnDiskDataset_Feature_heterograph():
torch.tensor(node_data_paper),
)
assert torch.equal(
feature_data.read("node", "paper", "label"),
feature_data.read("node", "paper", "labels"),
torch.tensor(node_data_label),
)
......@@ -842,7 +842,7 @@ def test_OnDiskDataset_Feature_heterograph():
torch.tensor(edge_data_writes),
)
assert torch.equal(
feature_data.read("edge", "author:writes:paper", "label"),
feature_data.read("edge", "author:writes:paper", "labels"),
torch.tensor(edge_data_label),
)
......@@ -879,7 +879,7 @@ def test_OnDiskDataset_Feature_homograph():
in_memory: false
path: {node_data_feat_path}
- domain: node
name: label
name: labels
format: numpy
in_memory: true
path: {node_data_label_path}
......@@ -889,7 +889,7 @@ def test_OnDiskDataset_Feature_homograph():
in_memory: false
path: {edge_data_feat_path}
- domain: edge
name: label
name: labels
format: numpy
in_memory: true
path: {edge_data_label_path}
......@@ -911,7 +911,7 @@ def test_OnDiskDataset_Feature_homograph():
torch.tensor(node_data_feat),
)
assert torch.equal(
feature_data.read("node", None, "label"),
feature_data.read("node", None, "labels"),
torch.tensor(node_data_label),
)
......@@ -921,7 +921,7 @@ def test_OnDiskDataset_Feature_homograph():
torch.tensor(edge_data_feat),
)
assert torch.equal(
feature_data.read("edge", None, "label"),
feature_data.read("edge", None, "labels"),
torch.tensor(edge_data_label),
)
......
......@@ -22,7 +22,7 @@ def test_SubgraphSampler_Node(labor):
def to_link_batch(data):
block = gb.MiniBatch(node_pair=data)
block = gb.MiniBatch(node_pairs=data)
return block
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment