Unverified Commit 9090a879 authored by Xinyu Yao's avatar Xinyu Yao Committed by GitHub
Browse files

[GraphBolt] Hyperlink support in `subgraph_sampler`. (#7354)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent ce37a934
......@@ -25,6 +25,7 @@ __all__ = [
"expand_indptr",
"CSCFormatBase",
"seed",
"seed_type_str_to_ntypes",
]
CANONICAL_ETYPE_DELIMITER = ":"
......@@ -185,6 +186,37 @@ def etype_str_to_tuple(c_etype):
return ret
def seed_type_str_to_ntypes(seed_type, seed_size):
"""Convert seeds type to node types from string to list.
Examples
--------
1. node pairs
>>> seed_type = "user:like:item"
>>> seed_size = 2
>>> node_type = seed_type_str_to_ntypes(seed_type, seed_size)
>>> print(node_type)
["user", "item"]
2. hyperlink
>>> seed_type = "query:user:item"
>>> seed_size = 3
>>> node_type = seed_type_str_to_ntypes(seed_type, seed_size)
>>> print(node_type)
["query", "user", "item"]
"""
assert isinstance(
seed_type, str
), f"Passed-in seed type should be string, but got {type(seed_type)}"
ntypes = seed_type.split(CANONICAL_ETYPE_DELIMITER)
is_hyperlink = len(ntypes) == seed_size
if not is_hyperlink:
ntypes = ntypes[::2]
return ntypes
def apply_to(x, device):
"""Apply `to` function to object x only if it has `to`."""
......
......@@ -110,6 +110,21 @@ class ItemSet:
tensor([1, 1, 0, 0, 0]))
>>> item_set.names
('seeds', 'labels')
6. Tuple of iterables with different shape: hyperlink and labels.
>>> seeds = torch.arange(0, 10).reshape(-1, 5)
>>> labels = torch.tensor([1, 0])
>>> item_set = gb.ItemSet(
... (seeds, labels), names=("seeds", "lables"))
>>> list(item_set)
[(tensor([0, 1, 2, 3, 4]), tensor([1])),
(tensor([5, 6, 7, 8, 9]), tensor([0]))]
>>> item_set[:]
(tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]),
tensor([1, 0]))
>>> item_set.names
('seeds', 'labels')
"""
def __init__(
......@@ -315,6 +330,31 @@ class ItemSetDict:
tensor([1, 1, 0]))}
>>> item_set.names
('seeds', 'labels')
4. Tuple of iterables with different shape: hyperlink and labels.
>>> first_seeds = torch.arange(0, 6).reshape(-1, 3)
>>> first_labels = torch.tensor([1, 0])
>>> second_seeds = torch.arange(0, 2).reshape(-1, 1)
>>> second_labels = torch.tensor([1, 0])
>>> item_set = gb.ItemSetDict({
... "query:user:item": gb.ItemSet(
... (first_seeds, first_labels),
... names=("seeds", "labels")),
... "user": gb.ItemSet(
... (second_seeds, second_labels),
... names=("seeds", "labels"))})
>>> list(item_set)
[{'query:user:item': (tensor([0, 1, 2]), tensor(1))},
{'query:user:item': (tensor([3, 4, 5]), tensor(0))},
{'user': (tensor([0]), tensor(1))},
{'user': (tensor([1]), tensor(0))}]
>>> item_set[:]
{'query:user:item': (tensor([[0, 1, 2], [3, 4, 5]]),
tensor([1, 0])),
'user': (tensor([[0], [1]]),tensor([1, 0]))}
>>> item_set.names
('seeds', 'labels')
"""
def __init__(self, itemsets: Dict[str, ItemSet]) -> None:
......
......@@ -6,7 +6,7 @@ from typing import Dict
import torch
from torch.utils.data import functional_datapipe
from .base import etype_str_to_tuple
from .base import seed_type_str_to_ntypes
from .internal import compact_temporal_nodes, unique_and_compact
from .minibatch_transformer import MiniBatchTransformer
......@@ -93,7 +93,8 @@ class SubgraphSampler(MiniBatchTransformer):
"""Preprocess `seeds` in a minibatch to construct `unique_seeds`,
`node_timestamp` and `compacted_seeds` for further sampling. It
optionally incorporates timestamps for temporal graphs, organizing and
compacting seeds based on their types and timestamps.
compacting seeds based on their types and timestamps. In heterogeneous
graph, `seeds` with same node type will be unqiued together.
Parameters
----------
......@@ -121,7 +122,7 @@ class SubgraphSampler(MiniBatchTransformer):
nodes_timestamp = None
if use_timestamp:
nodes_timestamp = defaultdict(list)
for etype, typed_seeds in seeds.items():
for seed_type, typed_seeds in seeds.items():
# When typed_seeds is a one-dimensional tensor, it represents
# seed nodes, which does not need to do unique and compact.
if typed_seeds.ndim == 1:
......@@ -131,25 +132,27 @@ class SubgraphSampler(MiniBatchTransformer):
else None
)
return seeds, nodes_timestamp, None
assert typed_seeds.ndim == 2 and typed_seeds.shape[1] == 2, (
"Only tensor with shape 1*N and N*2 is "
assert typed_seeds.ndim == 2, (
"Only tensor with shape 1*N and N*M is "
+ f"supported now, but got {typed_seeds.shape}."
)
ntypes = etype[:].split(":")[::2]
ntypes = seed_type_str_to_ntypes(
seed_type, typed_seeds.shape[1]
)
if use_timestamp:
negative_ratio = (
typed_seeds.shape[0]
// minibatch.timestamp[etype].shape[0]
// minibatch.timestamp[seed_type].shape[0]
- 1
)
neg_timestamp = minibatch.timestamp[
etype
seed_type
].repeat_interleave(negative_ratio)
for i, ntype in enumerate(ntypes):
nodes[ntype].append(typed_seeds[:, i])
if use_timestamp:
nodes_timestamp[ntype].append(
minibatch.timestamp[etype]
minibatch.timestamp[seed_type]
)
nodes_timestamp[ntype].append(neg_timestamp)
# Unique and compact the collected nodes.
......@@ -164,11 +167,16 @@ class SubgraphSampler(MiniBatchTransformer):
nodes_timestamp = None
compacted_seeds = {}
# Map back in same order as collect.
for etype, typed_seeds in seeds.items():
src_type, _, dst_type = etype_str_to_tuple(etype)
src = compacted[src_type].pop(0)
dst = compacted[dst_type].pop(0)
compacted_seeds[etype] = torch.cat((src, dst)).view(2, -1).T
for seed_type, typed_seeds in seeds.items():
ntypes = seed_type_str_to_ntypes(
seed_type, typed_seeds.shape[1]
)
compacted_seed = []
for ntype in ntypes:
compacted_seed.append(compacted[ntype].pop(0))
compacted_seeds[seed_type] = (
torch.cat(compacted_seed).view(len(ntypes), -1).T
)
else:
# When seeds is a one-dimensional tensor, it represents seed nodes,
# which does not need to do unique and compact.
......@@ -193,7 +201,9 @@ class SubgraphSampler(MiniBatchTransformer):
seeds_timestamp = torch.cat(
(minibatch.timestamp, neg_timestamp)
)
nodes_timestamp = [seeds_timestamp for _ in range(seeds.ndim)]
nodes_timestamp = [
seeds_timestamp for _ in range(seeds.shape[1])
]
# Unique and compact the collected nodes.
if use_timestamp:
(
......
......@@ -169,6 +169,31 @@ def test_etype_str_to_tuple():
_ = gb.etype_str_to_tuple(c_etype_str)
def test_seed_type_str_to_ntypes():
"""Convert etype from string to tuple."""
# Test for node pairs.
seed_type_str = "user:like:item"
seed_size = 2
node_type = gb.seed_type_str_to_ntypes(seed_type_str, seed_size)
assert node_type == ["user", "item"]
# Test for node pairs.
seed_type_str = "user:item:user"
seed_size = 3
node_type = gb.seed_type_str_to_ntypes(seed_type_str, seed_size)
assert node_type == ["user", "item", "user"]
# Test for unexpected input: list.
seed_type_str = ["user", "item"]
with pytest.raises(
AssertionError,
match=re.escape(
"Passed-in seed type should be string, but got <class 'list'>"
),
):
_ = gb.seed_type_str_to_ntypes(seed_type_str, 2)
def test_isin():
elements = torch.tensor([2, 3, 5, 5, 20, 13, 11], device=F.ctx())
test_elements = torch.tensor([2, 5], device=F.ctx())
......
......@@ -265,6 +265,38 @@ def test_SubgraphSampler_Link_With_Negative(sampler_type):
_check_sampler_len(datapipe, 5)
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_HyperLink(sampler_type):
_check_sampler_type(sampler_type)
graph = gb_test_utils.rand_csc_graph(20, 0.15, bidirection_edge=True).to(
F.ctx()
)
items = torch.arange(20).reshape(-1, 5)
names = "seeds"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {"timestamp": torch.arange(20).to(F.ctx())}
graph.edge_attributes = {
"timestamp": torch.arange(len(graph.indices)).to(F.ctx())
}
items = (items, torch.arange(4))
names = (names, "timestamp")
itemset = gb.ItemSet(items, names=names)
datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts)
_check_sampler_len(datapipe, 2)
for data in datapipe:
assert torch.equal(
data.compacted_seeds,
torch.tensor([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]).to(F.ctx()),
)
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
......@@ -487,6 +519,57 @@ def test_SubgraphSampler_Link_Hetero_With_Negative_Unknown_Etype(sampler_type):
_check_sampler_len(datapipe, 5)
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_HyperLink_Hetero(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
items = torch.LongTensor([[2, 0, 1, 1, 2], [0, 1, 1, 0, 0]])
names = "seeds"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.arange(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.arange(graph.indices.numel()).to(F.ctx())
}
items = (items, torch.randint(0, 10, (2,)))
names = (names, "timestamp")
itemset = gb.ItemSetDict(
{
"n2:n1:n2:n1:n2": gb.ItemSet(
items,
names=names,
),
}
)
datapipe = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type)
datapipe = sampler(datapipe, graph, fanouts)
_check_sampler_len(datapipe, 1)
for data in datapipe:
for compacted_seeds in data.compacted_seeds.values():
if sampler_type == SamplerType.Temporal:
assert torch.equal(
compacted_seeds,
torch.tensor([[0, 0, 2, 2, 4], [1, 1, 3, 3, 5]]).to(
F.ctx()
),
)
else:
assert torch.equal(
compacted_seeds,
torch.tensor([[0, 0, 2, 1, 0], [1, 1, 2, 0, 1]]).to(
F.ctx()
),
)
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
......@@ -1353,3 +1436,374 @@ def test_SubgraphSampler_unique_csc_format_Hetero_Link(labor):
sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr.to(F.ctx()),
)
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_without_deduplication_Homo_HyperLink(sampler_type):
_check_sampler_type(sampler_type)
graph = dgl.graph(
([5, 0, 1, 5, 6, 7, 2, 2, 4], [0, 1, 2, 2, 2, 2, 3, 4, 4])
)
graph = gb.from_dglgraph(graph, True).to(F.ctx())
items = torch.LongTensor([[0, 1, 4], [3, 5, 6]])
names = "seeds"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
}
items = (items, torch.randint(1, 10, (2,)))
names = (names, "timestamp")
itemset = gb.ItemSet(items, names=names)
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type)
if sampler_type == SamplerType.Temporal:
datapipe = sampler(item_sampler, graph, fanouts)
else:
datapipe = sampler(item_sampler, graph, fanouts, deduplicate=False)
length = [23, 11]
compacted_indices = [
(torch.arange(0, 12) + 11).to(F.ctx()),
(torch.arange(0, 5) + 6).to(F.ctx()),
]
indptr = [
torch.tensor([0, 1, 2, 4, 5, 5, 5, 5, 6, 8, 10, 12]).to(F.ctx()),
torch.tensor([0, 1, 2, 4, 5, 5, 5]).to(F.ctx()),
]
seeds = [
torch.tensor([0, 0, 1, 2, 2, 3, 4, 4, 5, 5, 6]).to(F.ctx()),
torch.tensor([0, 1, 3, 4, 5, 6]).to(F.ctx()),
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
assert len(sampled_subgraph.original_row_node_ids) == length[step]
assert torch.equal(
sampled_subgraph.sampled_csc.indices, compacted_indices[step]
)
assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step]
)
assert torch.equal(
torch.sort(sampled_subgraph.original_column_node_ids)[0],
seeds[step],
)
@pytest.mark.parametrize(
"sampler_type",
[SamplerType.Normal, SamplerType.Layer, SamplerType.Temporal],
)
def test_SubgraphSampler_without_deduplication_Hetero_HyperLink(sampler_type):
_check_sampler_type(sampler_type)
graph = get_hetero_graph().to(F.ctx())
items = torch.arange(3).view(1, 3)
names = "seeds"
if sampler_type == SamplerType.Temporal:
graph.node_attributes = {
"timestamp": torch.zeros(graph.csc_indptr.numel() - 1).to(F.ctx())
}
graph.edge_attributes = {
"timestamp": torch.zeros(graph.indices.numel()).to(F.ctx())
}
items = (items, torch.randint(1, 10, (1,)))
names = (names, "timestamp")
itemset = gb.ItemSetDict({"n2:n1:n2": gb.ItemSet(items, names=names)})
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
sampler = _get_sampler(sampler_type)
if sampler_type == SamplerType.Temporal:
datapipe = sampler(item_sampler, graph, fanouts)
else:
datapipe = sampler(item_sampler, graph, fanouts, deduplicate=False)
csc_formats = [
{
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4, 6, 8]),
indices=torch.tensor([5, 6, 7, 8, 9, 10, 11, 12]),
),
"n2:e2:n1": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4, 6, 8, 10]),
indices=torch.tensor([4, 5, 6, 7, 8, 9, 10, 11, 12, 13]),
),
},
{
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4]),
indices=torch.tensor([1, 2, 3, 4]),
),
"n2:e2:n1": gb.CSCFormatBase(
indptr=torch.tensor([0, 2]),
indices=torch.tensor([2, 3], dtype=torch.int64),
),
},
]
original_column_node_ids = [
{
"n1": torch.tensor([1, 0, 1, 0, 1]),
"n2": torch.tensor([0, 2, 0, 1]),
},
{
"n1": torch.tensor([1]),
"n2": torch.tensor([0, 2]),
},
]
original_row_node_ids = [
{
"n1": torch.tensor([1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0]),
"n2": torch.tensor([0, 2, 0, 1, 0, 1, 0, 2, 0, 1, 0, 2, 0, 1]),
},
{
"n1": torch.tensor([1, 0, 1, 0, 1]),
"n2": torch.tensor([0, 2, 0, 1]),
},
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
for ntype in ["n1", "n2"]:
assert torch.equal(
sampled_subgraph.original_row_node_ids[ntype],
original_row_node_ids[step][ntype].to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.original_column_node_ids[ntype],
original_column_node_ids[step][ntype].to(F.ctx()),
)
for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indices,
csc_formats[step][etype].indices.to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr.to(F.ctx()),
)
@unittest.skipIf(
F._default_context_str == "gpu",
reason="Fails due to different result on the GPU.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Homo_HyperLink_cpu(labor):
torch.manual_seed(1205)
graph = dgl.graph(([5, 0, 6, 7, 2, 2, 4], [0, 1, 2, 2, 3, 4, 4]))
graph = gb.from_dglgraph(graph, True).to(F.ctx())
seed_nodes = torch.LongTensor([[0, 3, 3], [4, 4, 4]])
itemset = gb.ItemSet(seed_nodes, names="seeds")
item_sampler = gb.ItemSampler(itemset, batch_size=4).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(
item_sampler,
graph,
fanouts,
deduplicate=True,
)
original_row_node_ids = [
torch.tensor([0, 3, 4, 5, 2, 6, 7]).to(F.ctx()),
torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),
]
compacted_indices = [
torch.tensor([3, 4, 4, 2, 5, 6]).to(F.ctx()),
torch.tensor([3, 4, 4, 2]).to(F.ctx()),
]
indptr = [
torch.tensor([0, 1, 2, 4, 4, 6]).to(F.ctx()),
torch.tensor([0, 1, 2, 4]).to(F.ctx()),
]
seeds = [
torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),
torch.tensor([0, 3, 4]).to(F.ctx()),
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
assert torch.equal(
sampled_subgraph.original_row_node_ids,
original_row_node_ids[step],
)
assert torch.equal(
sampled_subgraph.sampled_csc.indices, compacted_indices[step]
)
assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step]
)
assert torch.equal(
sampled_subgraph.original_column_node_ids, seeds[step]
)
@unittest.skipIf(
F._default_context_str == "cpu",
reason="Fails due to different result on the CPU.",
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Homo_HyperLink_gpu(labor):
torch.manual_seed(1205)
graph = dgl.graph(([5, 0, 7, 7, 2, 4], [0, 1, 2, 2, 3, 4]))
graph = gb.from_dglgraph(graph, is_homogeneous=True).to(F.ctx())
seed_nodes = torch.LongTensor([[0, 3, 4], [4, 4, 3]])
itemset = gb.ItemSet(seed_nodes, names="seeds")
item_sampler = gb.ItemSampler(itemset, batch_size=4).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([-1]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(
item_sampler,
graph,
fanouts,
deduplicate=True,
)
if torch.cuda.get_device_capability()[0] < 7:
original_row_node_ids = [
torch.tensor([0, 3, 4, 2, 5, 7]).to(F.ctx()),
torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()),
]
compacted_indices = [
torch.tensor([4, 3, 2, 5, 5]).to(F.ctx()),
torch.tensor([4, 3, 2]).to(F.ctx()),
]
indptr = [
torch.tensor([0, 1, 2, 3, 5, 5]).to(F.ctx()),
torch.tensor([0, 1, 2, 3]).to(F.ctx()),
]
seeds = [
torch.tensor([0, 3, 4, 2, 5]).to(F.ctx()),
torch.tensor([0, 3, 4]).to(F.ctx()),
]
else:
original_row_node_ids = [
torch.tensor([0, 3, 4, 5, 2, 7]).to(F.ctx()),
torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),
]
compacted_indices = [
torch.tensor([3, 4, 2, 5, 5]).to(F.ctx()),
torch.tensor([3, 4, 2]).to(F.ctx()),
]
indptr = [
torch.tensor([0, 1, 2, 3, 3, 5]).to(F.ctx()),
torch.tensor([0, 1, 2, 3]).to(F.ctx()),
]
seeds = [
torch.tensor([0, 3, 4, 5, 2]).to(F.ctx()),
torch.tensor([0, 3, 4]).to(F.ctx()),
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
assert torch.equal(
sampled_subgraph.original_row_node_ids,
original_row_node_ids[step],
)
assert torch.equal(
sampled_subgraph.sampled_csc.indices, compacted_indices[step]
)
assert torch.equal(
sampled_subgraph.sampled_csc.indptr, indptr[step]
)
assert torch.equal(
sampled_subgraph.original_column_node_ids, seeds[step]
)
@pytest.mark.parametrize("labor", [False, True])
def test_SubgraphSampler_unique_csc_format_Hetero_HyperLink(labor):
graph = get_hetero_graph().to(F.ctx())
itemset = gb.ItemSetDict(
{"n1:n2:n1": gb.ItemSet(torch.tensor([[0, 1, 0]]), names="seeds")}
)
item_sampler = gb.ItemSampler(itemset, batch_size=2).copy_to(F.ctx())
num_layer = 2
fanouts = [torch.LongTensor([2]) for _ in range(num_layer)]
Sampler = gb.LayerNeighborSampler if labor else gb.NeighborSampler
datapipe = Sampler(
item_sampler,
graph,
fanouts,
deduplicate=True,
)
csc_formats = [
{
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4, 6]),
indices=torch.tensor([1, 0, 0, 1, 0, 1]),
),
"n2:e2:n1": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4]),
indices=torch.tensor([1, 2, 1, 0]),
),
},
{
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 2]),
indices=torch.tensor([1, 0]),
),
"n2:e2:n1": gb.CSCFormatBase(
indptr=torch.tensor([0, 2]),
indices=torch.tensor([1, 2], dtype=torch.int64),
),
},
]
original_column_node_ids = [
{
"n1": torch.tensor([0, 1]),
"n2": torch.tensor([0, 1, 2]),
},
{
"n1": torch.tensor([0]),
"n2": torch.tensor([1]),
},
]
original_row_node_ids = [
{
"n1": torch.tensor([0, 1]),
"n2": torch.tensor([0, 1, 2]),
},
{
"n1": torch.tensor([0, 1]),
"n2": torch.tensor([0, 1, 2]),
},
]
for data in datapipe:
for step, sampled_subgraph in enumerate(data.sampled_subgraphs):
for ntype in ["n1", "n2"]:
assert torch.equal(
torch.sort(sampled_subgraph.original_row_node_ids[ntype])[
0
],
original_row_node_ids[step][ntype].to(F.ctx()),
)
assert torch.equal(
torch.sort(
sampled_subgraph.original_column_node_ids[ntype]
)[0],
original_column_node_ids[step][ntype].to(F.ctx()),
)
for etype in ["n1:e1:n2", "n2:e2:n1"]:
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indices,
csc_formats[step][etype].indices.to(F.ctx()),
)
assert torch.equal(
sampled_subgraph.sampled_csc[etype].indptr,
csc_formats[step][etype].indptr.to(F.ctx()),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment