Unverified Commit d63130a7 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Modify `to_dgl` to support Sampled Subgraph Impl (csc formats) (#6606)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 9aca3092
......@@ -9,7 +9,7 @@ import dgl
from dgl.heterograph import DGLBlock
from dgl.utils import recursive_apply
from .base import etype_str_to_tuple
from .base import CSCFormatBase, etype_str_to_tuple
from .sampled_subgraph import SampledSubgraph
__all__ = ["DGLMiniBatch", "MiniBatch"]
......@@ -384,6 +384,21 @@ class MiniBatch:
original_column_node_ids is not None
), "Missing `original_column_node_ids` in sampled subgraph."
if is_heterogeneous:
if isinstance(
list(subgraph.node_pairs.values())[0], CSCFormatBase
):
node_pairs = {
etype_str_to_tuple(etype): (
"csc",
(
v.indptr,
v.indices,
torch.tensor([]),
),
)
for etype, v in subgraph.node_pairs.items()
}
else:
node_pairs = {
etype_str_to_tuple(etype): v
for etype, v in subgraph.node_pairs.items()
......@@ -398,6 +413,15 @@ class MiniBatch:
}
else:
node_pairs = subgraph.node_pairs
if isinstance(subgraph.node_pairs, CSCFormatBase):
node_pairs = (
"csc",
(
node_pairs.indptr,
node_pairs.indices,
torch.tensor([]),
),
)
num_src_nodes = original_row_node_ids.size(0)
num_dst_nodes = original_column_node_ids.size(0)
blocks.append(
......
......@@ -738,3 +738,321 @@ def test_to_dgl_link_predication_hetero(mode):
dgl_minibatch.negative_node_pairs[etype][1],
minibatch.compacted_negative_dsts[etype].view(-1),
)
def create_homo_minibatch_csc_format():
csc_formats = [
gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3, 5, 6]),
indices=torch.tensor([0, 1, 2, 2, 1, 2]),
),
gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 3]),
indices=torch.tensor([1, 2, 0]),
),
]
original_column_node_ids = [
torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11]),
]
original_row_node_ids = [
torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11, 12]),
]
original_edge_ids = [
torch.tensor([19, 20, 21, 22, 25, 30]),
torch.tensor([10, 15, 17]),
]
node_features = {"x": torch.randint(0, 10, (4,))}
edge_features = [
{"x": torch.randint(0, 10, (6,))},
{"x": torch.randint(0, 10, (3,))},
]
subgraphs = []
for i in range(2):
subgraphs.append(
gb.SampledSubgraphImpl(
node_pairs=csc_formats[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i],
)
)
return gb.MiniBatch(
sampled_subgraphs=subgraphs,
node_features=node_features,
edge_features=edge_features,
input_nodes=torch.tensor([10, 11, 12, 13]),
)
def create_hetero_minibatch_csc_format():
node_pairs = [
{
relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2, 3]),
indices=torch.tensor([0, 1, 1]),
),
reverse_relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 0, 1, 2]),
indices=torch.tensor([1, 0]),
),
},
{
relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2]), indices=torch.tensor([1, 0])
)
},
]
original_column_node_ids = [
{"B": torch.tensor([10, 11, 12]), "A": torch.tensor([5, 7, 9, 11])},
{"B": torch.tensor([10, 11])},
]
original_row_node_ids = [
{
"A": torch.tensor([5, 7, 9, 11]),
"B": torch.tensor([10, 11, 12]),
},
{
"A": torch.tensor([5, 7]),
"B": torch.tensor([10, 11]),
},
]
original_edge_ids = [
{
relation: torch.tensor([19, 20, 21]),
reverse_relation: torch.tensor([23, 26]),
},
{relation: torch.tensor([10, 12])},
]
node_features = {
("A", "x"): torch.randint(0, 10, (4,)),
}
edge_features = [
{(relation, "x"): torch.randint(0, 10, (3,))},
{(relation, "x"): torch.randint(0, 10, (2,))},
]
subgraphs = []
for i in range(2):
subgraphs.append(
gb.SampledSubgraphImpl(
node_pairs=node_pairs[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i],
)
)
return gb.MiniBatch(
sampled_subgraphs=subgraphs,
node_features=node_features,
edge_features=edge_features,
input_nodes={
"A": torch.tensor([5, 7, 9, 11]),
"B": torch.tensor([10, 11, 12]),
},
)
def check_dgl_blocks_hetero_csc_format(minibatch, blocks):
etype = gb.etype_str_to_tuple(relation)
node_pairs = [
subgraph.node_pairs for subgraph in minibatch.sampled_subgraphs
]
original_edge_ids = [
subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs
]
original_row_node_ids = [
subgraph.original_row_node_ids
for subgraph in minibatch.sampled_subgraphs
]
for i, block in enumerate(blocks):
edges = block.edges(etype=etype)
dst_ndoes = torch.arange(
0, len(node_pairs[i][relation].indptr) - 1
).repeat_interleave(
node_pairs[i][relation].indptr[1:]
- node_pairs[i][relation].indptr[:-1]
)
assert torch.equal(edges[0], node_pairs[i][relation].indices)
assert torch.equal(edges[1], dst_ndoes)
assert torch.equal(
block.edges[etype].data[dgl.EID], original_edge_ids[i][relation]
)
edges = blocks[0].edges(etype=gb.etype_str_to_tuple(reverse_relation))
dst_ndoes = torch.arange(
0, len(node_pairs[0][reverse_relation].indptr) - 1
).repeat_interleave(
node_pairs[0][reverse_relation].indptr[1:]
- node_pairs[0][reverse_relation].indptr[:-1]
)
assert torch.equal(edges[0], node_pairs[0][reverse_relation].indices)
assert torch.equal(edges[1], dst_ndoes)
assert torch.equal(
blocks[0].srcdata[dgl.NID]["A"], original_row_node_ids[0]["A"]
)
assert torch.equal(
blocks[0].srcdata[dgl.NID]["B"], original_row_node_ids[0]["B"]
)
def check_dgl_blocks_homo_csc_format(minibatch, blocks):
node_pairs = [
subgraph.node_pairs for subgraph in minibatch.sampled_subgraphs
]
original_edge_ids = [
subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs
]
original_row_node_ids = [
subgraph.original_row_node_ids
for subgraph in minibatch.sampled_subgraphs
]
for i, block in enumerate(blocks):
dst_ndoes = torch.arange(
0, len(node_pairs[i].indptr) - 1
).repeat_interleave(
node_pairs[i].indptr[1:] - node_pairs[i].indptr[:-1]
)
assert torch.equal(block.edges()[0], node_pairs[i].indices), print(
block.edges()
)
assert torch.equal(block.edges()[1], dst_ndoes), print(block.edges())
assert torch.equal(block.edata[dgl.EID], original_edge_ids[i]), print(
block.edata[dgl.EID]
)
assert torch.equal(
blocks[0].srcdata[dgl.NID], original_row_node_ids[0]
), print(blocks[0].srcdata[dgl.NID])
def test_to_dgl_node_classification_without_feature_csc_format():
# Arrange
minibatch = create_homo_minibatch_csc_format()
minibatch.node_features = None
minibatch.labels = None
minibatch.seed_nodes = torch.tensor([10, 15])
# Act
dgl_minibatch = minibatch.to_dgl()
# Assert
assert len(dgl_minibatch.blocks) == 2
assert dgl_minibatch.node_features is None
assert minibatch.edge_features is dgl_minibatch.edge_features
assert dgl_minibatch.labels is None
assert minibatch.input_nodes is dgl_minibatch.input_nodes
assert minibatch.seed_nodes is dgl_minibatch.output_nodes
check_dgl_blocks_homo_csc_format(minibatch, dgl_minibatch.blocks)
def test_to_dgl_node_classification_homo_csc_format():
# Arrange
minibatch = create_homo_minibatch_csc_format()
minibatch.seed_nodes = torch.tensor([10, 15])
minibatch.labels = torch.tensor([2, 5])
# Act
dgl_minibatch = minibatch.to_dgl()
# Assert
assert len(dgl_minibatch.blocks) == 2
assert minibatch.node_features is dgl_minibatch.node_features
assert minibatch.edge_features is dgl_minibatch.edge_features
assert minibatch.labels is dgl_minibatch.labels
assert dgl_minibatch.input_nodes is None
assert dgl_minibatch.output_nodes is None
check_dgl_blocks_homo_csc_format(minibatch, dgl_minibatch.blocks)
def test_to_dgl_node_classification_hetero_csc_format():
minibatch = create_hetero_minibatch_csc_format()
minibatch.labels = {"B": torch.tensor([2, 5])}
minibatch.seed_nodes = {"B": torch.tensor([10, 15])}
dgl_minibatch = minibatch.to_dgl()
# Assert
assert len(dgl_minibatch.blocks) == 2
assert minibatch.node_features is dgl_minibatch.node_features
assert minibatch.edge_features is dgl_minibatch.edge_features
assert minibatch.labels is dgl_minibatch.labels
assert dgl_minibatch.input_nodes is None
assert dgl_minibatch.output_nodes is None
check_dgl_blocks_hetero_csc_format(minibatch, dgl_minibatch.blocks)
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"])
def test_to_dgl_link_predication_homo_csc_format(mode):
# Arrange
minibatch = create_homo_minibatch_csc_format()
minibatch.compacted_node_pairs = (
torch.tensor([0, 1]),
torch.tensor([1, 0]),
)
if mode == "neg_graph" or mode == "neg_src":
minibatch.compacted_negative_srcs = torch.tensor([[0, 0], [1, 1]])
if mode == "neg_graph" or mode == "neg_dst":
minibatch.compacted_negative_dsts = torch.tensor([[1, 0], [0, 1]])
# Act
dgl_minibatch = minibatch.to_dgl()
# Assert
assert len(dgl_minibatch.blocks) == 2
assert minibatch.node_features is dgl_minibatch.node_features
assert minibatch.edge_features is dgl_minibatch.edge_features
assert minibatch.compacted_node_pairs is dgl_minibatch.positive_node_pairs
check_dgl_blocks_homo_csc_format(minibatch, dgl_minibatch.blocks)
if mode == "neg_graph" or mode == "neg_src":
assert torch.equal(
dgl_minibatch.negative_node_pairs[0],
minibatch.compacted_negative_srcs.view(-1),
)
if mode == "neg_graph" or mode == "neg_dst":
assert torch.equal(
dgl_minibatch.negative_node_pairs[1],
minibatch.compacted_negative_dsts.view(-1),
)
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"])
def test_to_dgl_link_predication_hetero_csc_format(mode):
# Arrange
minibatch = create_hetero_minibatch_csc_format()
minibatch.compacted_node_pairs = {
relation: (
torch.tensor([1, 1]),
torch.tensor([1, 0]),
),
reverse_relation: (
torch.tensor([0, 1]),
torch.tensor([1, 0]),
),
}
if mode == "neg_graph" or mode == "neg_src":
minibatch.compacted_negative_srcs = {
relation: torch.tensor([[2, 0], [1, 2]]),
reverse_relation: torch.tensor([[1, 2], [0, 2]]),
}
if mode == "neg_graph" or mode == "neg_dst":
minibatch.compacted_negative_dsts = {
relation: torch.tensor([[1, 3], [2, 1]]),
reverse_relation: torch.tensor([[2, 1], [3, 1]]),
}
# Act
dgl_minibatch = minibatch.to_dgl()
# Assert
assert len(dgl_minibatch.blocks) == 2
assert minibatch.node_features is dgl_minibatch.node_features
assert minibatch.edge_features is dgl_minibatch.edge_features
assert minibatch.compacted_node_pairs is dgl_minibatch.positive_node_pairs
check_dgl_blocks_hetero_csc_format(minibatch, dgl_minibatch.blocks)
if mode == "neg_graph" or mode == "neg_src":
for etype, src in minibatch.compacted_negative_srcs.items():
assert torch.equal(
dgl_minibatch.negative_node_pairs[etype][0],
src.view(-1),
)
if mode == "neg_graph" or mode == "neg_dst":
for etype, dst in minibatch.compacted_negative_dsts.items():
assert torch.equal(
dgl_minibatch.negative_node_pairs[etype][1],
minibatch.compacted_negative_dsts[etype].view(-1),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment