Unverified Commit 46a566af authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Remove coo from blocks in MiniBatch. (#6855)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 59bb5406
......@@ -8,7 +8,7 @@ import torch
import dgl
from dgl.utils import recursive_apply
from .base import CSCFormatBase, etype_str_to_tuple
from .base import etype_str_to_tuple
from .internal import get_attributes
from .sampled_subgraph import SampledSubgraph
......@@ -194,9 +194,6 @@ class MiniBatch:
original_column_node_ids is not None
), "Missing `original_column_node_ids` in sampled subgraph."
if is_heterogeneous:
if isinstance(
list(subgraph.sampled_csc.values())[0], CSCFormatBase
):
sampled_csc = {
etype_str_to_tuple(etype): (
"csc",
......@@ -213,11 +210,6 @@ class MiniBatch:
)
for etype, v in subgraph.sampled_csc.items()
}
else:
sampled_csc = {
etype_str_to_tuple(etype): v
for etype, v in subgraph.sampled_csc.items()
}
num_src_nodes = {
ntype: nodes.size(0)
for ntype, nodes in original_row_node_ids.items()
......@@ -228,7 +220,6 @@ class MiniBatch:
}
else:
sampled_csc = subgraph.sampled_csc
if isinstance(subgraph.sampled_csc, CSCFormatBase):
sampled_csc = (
"csc",
(
......
......@@ -8,109 +8,6 @@ relation = "A:r:B"
reverse_relation = "B:rr:A"
def create_homo_minibatch():
node_pairs = [
(
torch.tensor([0, 1, 2, 2, 2, 1]),
torch.tensor([0, 1, 1, 2, 3, 2]),
),
(
torch.tensor([0, 1, 2]),
torch.tensor([1, 0, 0]),
),
]
original_column_node_ids = [
torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11]),
]
original_row_node_ids = [
torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11, 12]),
]
original_edge_ids = [
torch.tensor([19, 20, 21, 22, 25, 30]),
torch.tensor([10, 15, 17]),
]
node_features = {"x": torch.randint(0, 10, (4,))}
edge_features = [
{"x": torch.randint(0, 10, (6,))},
{"x": torch.randint(0, 10, (3,))},
]
subgraphs = []
for i in range(2):
subgraphs.append(
gb.FusedSampledSubgraphImpl(
sampled_csc=node_pairs[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i],
)
)
return gb.MiniBatch(
sampled_subgraphs=subgraphs,
node_features=node_features,
edge_features=edge_features,
input_nodes=torch.tensor([10, 11, 12, 13]),
)
def create_hetero_minibatch():
node_pairs = [
{
relation: (torch.tensor([0, 1, 1]), torch.tensor([0, 1, 2])),
reverse_relation: (torch.tensor([1, 0]), torch.tensor([2, 3])),
},
{relation: (torch.tensor([0, 1]), torch.tensor([1, 0]))},
]
original_column_node_ids = [
{"B": torch.tensor([10, 11, 12]), "A": torch.tensor([5, 7, 9, 11])},
{"B": torch.tensor([10, 11])},
]
original_row_node_ids = [
{
"A": torch.tensor([5, 7, 9, 11]),
"B": torch.tensor([10, 11, 12]),
},
{
"A": torch.tensor([5, 7]),
"B": torch.tensor([10, 11]),
},
]
original_edge_ids = [
{
relation: torch.tensor([19, 20, 21]),
reverse_relation: torch.tensor([23, 26]),
},
{relation: torch.tensor([10, 12])},
]
node_features = {
("A", "x"): torch.randint(0, 10, (4,)),
}
edge_features = [
{(relation, "x"): torch.randint(0, 10, (3,))},
{(relation, "x"): torch.randint(0, 10, (2,))},
]
subgraphs = []
for i in range(2):
subgraphs.append(
gb.FusedSampledSubgraphImpl(
sampled_csc=node_pairs[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i],
)
)
return gb.MiniBatch(
sampled_subgraphs=subgraphs,
node_features=node_features,
edge_features=edge_features,
input_nodes={
"A": torch.tensor([5, 7, 9, 11]),
"B": torch.tensor([10, 11, 12]),
},
)
def test_minibatch_representation_homo():
csc_formats = [
gb.CSCFormatBase(
......@@ -425,6 +322,16 @@ def test_get_dgl_blocks_homo():
torch.tensor([1, 0, 0]),
),
]
csc_formats = [
gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3, 5, 6]),
indices=torch.tensor([0, 1, 2, 2, 1, 2]),
),
gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3]),
indices=torch.tensor([0, 1, 2]),
),
]
original_column_node_ids = [
torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11]),
......@@ -445,8 +352,8 @@ def test_get_dgl_blocks_homo():
subgraphs = []
for i in range(2):
subgraphs.append(
gb.FusedSampledSubgraphImpl(
sampled_csc=node_pairs[i],
gb.SampledSubgraphImpl(
sampled_csc=csc_formats[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i],
......@@ -489,6 +396,23 @@ def test_get_dgl_blocks_hetero():
},
{relation: (torch.tensor([0, 1]), torch.tensor([1, 0]))},
]
csc_formats = [
{
relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2, 3]),
indices=torch.tensor([0, 1, 1]),
),
reverse_relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 0, 1, 2]),
indices=torch.tensor([1, 0]),
),
},
{
relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2]), indices=torch.tensor([1, 0])
)
},
]
original_column_node_ids = [
{"B": torch.tensor([10, 11, 12]), "A": torch.tensor([5, 7, 9, 11])},
{"B": torch.tensor([10, 11])},
......@@ -520,8 +444,8 @@ def test_get_dgl_blocks_hetero():
subgraphs = []
for i in range(2):
subgraphs.append(
gb.FusedSampledSubgraphImpl(
sampled_csc=node_pairs[i],
gb.SampledSubgraphImpl(
sampled_csc=csc_formats[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i],
......@@ -610,227 +534,7 @@ def test_minibatch_node_pairs_with_labels(mode):
assert torch.equal(labels, expect_labels)
def check_dgl_blocks_hetero(minibatch, blocks):
etype = gb.etype_str_to_tuple(relation)
node_pairs = [
subgraph.sampled_csc for subgraph in minibatch.sampled_subgraphs
]
original_edge_ids = [
subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs
]
original_row_node_ids = [
subgraph.original_row_node_ids
for subgraph in minibatch.sampled_subgraphs
]
for i, block in enumerate(blocks):
edges = block.edges(etype=etype)
assert torch.equal(edges[0], node_pairs[i][relation][0])
assert torch.equal(edges[1], node_pairs[i][relation][1])
assert torch.equal(
block.edges[etype].data[dgl.EID], original_edge_ids[i][relation]
)
edges = blocks[0].edges(etype=gb.etype_str_to_tuple(reverse_relation))
assert torch.equal(edges[0], node_pairs[0][reverse_relation][0])
assert torch.equal(edges[1], node_pairs[0][reverse_relation][1])
assert torch.equal(
blocks[0].srcdata[dgl.NID]["A"], original_row_node_ids[0]["A"]
)
assert torch.equal(
blocks[0].srcdata[dgl.NID]["B"], original_row_node_ids[0]["B"]
)
def check_dgl_blocks_homo(minibatch, blocks):
node_pairs = [
subgraph.sampled_csc for subgraph in minibatch.sampled_subgraphs
]
original_edge_ids = [
subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs
]
original_row_node_ids = [
subgraph.original_row_node_ids
for subgraph in minibatch.sampled_subgraphs
]
for i, block in enumerate(blocks):
assert torch.equal(block.edges()[0], node_pairs[i][0])
assert torch.equal(block.edges()[1], node_pairs[i][1])
assert torch.equal(block.edata[dgl.EID], original_edge_ids[i])
assert torch.equal(blocks[0].srcdata[dgl.NID], original_row_node_ids[0])
def test_get_dgl_blocks_node_classification_without_feature():
# Arrange
minibatch = create_homo_minibatch()
minibatch.node_features = None
minibatch.labels = None
minibatch.seed_nodes = torch.tensor([10, 15])
# Act
dgl_blocks = minibatch.blocks
# Assert
assert len(dgl_blocks) == 2
assert minibatch.node_features is None
assert minibatch.labels is None
check_dgl_blocks_homo(minibatch, dgl_blocks)
def test_get_dgl_blocks_node_classification_homo():
# Arrange
minibatch = create_homo_minibatch()
minibatch.seed_nodes = torch.tensor([10, 15])
minibatch.labels = torch.tensor([2, 5])
# Act
dgl_blocks = minibatch.blocks
# Assert
assert len(dgl_blocks) == 2
check_dgl_blocks_homo(minibatch, dgl_blocks)
def test_to_dgl_node_classification_hetero():
minibatch = create_hetero_minibatch()
minibatch.labels = {"B": torch.tensor([2, 5])}
minibatch.seed_nodes = {"B": torch.tensor([10, 15])}
dgl_blocks = minibatch.blocks
# Assert
assert len(dgl_blocks) == 2
check_dgl_blocks_hetero(minibatch, dgl_blocks)
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"])
def test_dgl_link_predication_homo(mode):
# Arrange
minibatch = create_homo_minibatch()
minibatch.compacted_node_pairs = (
torch.tensor([0, 1]),
torch.tensor([1, 0]),
)
if mode == "neg_graph" or mode == "neg_src":
minibatch.compacted_negative_srcs = torch.tensor([[0, 0], [1, 1]])
if mode == "neg_graph" or mode == "neg_dst":
minibatch.compacted_negative_dsts = torch.tensor([[1, 0], [0, 1]])
# Act
dgl_blocks = minibatch.blocks
# Assert
assert len(dgl_blocks) == 2
check_dgl_blocks_homo(minibatch, dgl_blocks)
if mode == "neg_graph" or mode == "neg_src":
assert torch.equal(
minibatch.negative_node_pairs[0],
minibatch.compacted_negative_srcs.view(-1),
)
if mode == "neg_graph" or mode == "neg_dst":
assert torch.equal(
minibatch.negative_node_pairs[1],
minibatch.compacted_negative_dsts.view(-1),
)
node_pairs, labels = minibatch.node_pairs_with_labels
if mode == "neg_src":
expect_node_pairs = (
torch.tensor([0, 1, 0, 0, 1, 1]),
torch.tensor([1, 0, 1, 1, 0, 0]),
)
else:
expect_node_pairs = (
torch.tensor([0, 1, 0, 0, 1, 1]),
torch.tensor([1, 0, 1, 0, 0, 1]),
)
expect_labels = torch.tensor([1, 1, 0, 0, 0, 0]).float()
assert torch.equal(node_pairs[0], expect_node_pairs[0])
assert torch.equal(node_pairs[1], expect_node_pairs[1])
assert torch.equal(labels, expect_labels)
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"])
def test_dgl_link_predication_hetero(mode):
# Arrange
minibatch = create_hetero_minibatch()
minibatch.compacted_node_pairs = {
relation: (
torch.tensor([1, 1]),
torch.tensor([1, 0]),
),
reverse_relation: (
torch.tensor([0, 1]),
torch.tensor([1, 0]),
),
}
if mode == "neg_graph" or mode == "neg_src":
minibatch.compacted_negative_srcs = {
relation: torch.tensor([[2, 0], [1, 2]]),
reverse_relation: torch.tensor([[1, 2], [0, 2]]),
}
if mode == "neg_graph" or mode == "neg_dst":
minibatch.compacted_negative_dsts = {
relation: torch.tensor([[1, 3], [2, 1]]),
reverse_relation: torch.tensor([[2, 1], [3, 1]]),
}
# Act
dgl_blocks = minibatch.blocks
# Assert
assert len(dgl_blocks) == 2
check_dgl_blocks_hetero(minibatch, dgl_blocks)
if mode == "neg_graph" or mode == "neg_src":
for etype, src in minibatch.compacted_negative_srcs.items():
assert torch.equal(
minibatch.negative_node_pairs[etype][0],
src.view(-1),
)
if mode == "neg_graph" or mode == "neg_dst":
for etype, dst in minibatch.compacted_negative_dsts.items():
assert torch.equal(
minibatch.negative_node_pairs[etype][1],
minibatch.compacted_negative_dsts[etype].view(-1),
)
node_pairs, labels = minibatch.node_pairs_with_labels
if mode == "neg_src":
expect_node_pairs = {
"A:r:B": (
torch.tensor([1, 1, 2, 0, 1, 2]),
torch.tensor([1, 0, 1, 1, 0, 0]),
),
"B:rr:A": (
torch.tensor([0, 1, 1, 2, 0, 2]),
torch.tensor([1, 0, 1, 1, 0, 0]),
),
}
elif mode == "neg_dst":
expect_node_pairs = {
"A:r:B": (
torch.tensor([1, 1, 1, 1, 1, 1]),
torch.tensor([1, 0, 1, 3, 2, 1]),
),
"B:rr:A": (
torch.tensor([0, 1, 0, 0, 1, 1]),
torch.tensor([1, 0, 2, 1, 3, 1]),
),
}
else:
expect_node_pairs = {
"A:r:B": (
torch.tensor([1, 1, 2, 0, 1, 2]),
torch.tensor([1, 0, 1, 3, 2, 1]),
),
"B:rr:A": (
torch.tensor([0, 1, 1, 2, 0, 2]),
torch.tensor([1, 0, 2, 1, 3, 1]),
),
}
expect_labels = {
"A:r:B": torch.tensor([1, 1, 0, 0, 0, 0]),
"B:rr:A": torch.tensor([1, 1, 0, 0, 0, 0]),
}
for etype in node_pairs:
assert torch.equal(node_pairs[etype][0], expect_node_pairs[etype][0])
assert torch.equal(node_pairs[etype][1], expect_node_pairs[etype][1])
assert torch.equal(labels[etype], expect_labels[etype])
def create_homo_minibatch_csc_format():
def create_homo_minibatch():
csc_formats = [
gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3, 5, 6]),
......@@ -876,7 +580,7 @@ def create_homo_minibatch_csc_format():
)
def create_hetero_minibatch_csc_format():
def create_hetero_minibatch():
sampled_csc = [
{
relation: gb.CSCFormatBase(
......@@ -943,7 +647,7 @@ def create_hetero_minibatch_csc_format():
)
def check_dgl_blocks_hetero_csc_format(minibatch, blocks):
def check_dgl_blocks_hetero(minibatch, blocks):
etype = gb.etype_str_to_tuple(relation)
sampled_csc = [
subgraph.sampled_csc for subgraph in minibatch.sampled_subgraphs
......@@ -986,7 +690,7 @@ def check_dgl_blocks_hetero_csc_format(minibatch, blocks):
)
def check_dgl_blocks_homo_csc_format(minibatch, blocks):
def check_dgl_blocks_homo(minibatch, blocks):
sampled_csc = [
subgraph.sampled_csc for subgraph in minibatch.sampled_subgraphs
]
......@@ -1015,9 +719,9 @@ def check_dgl_blocks_homo_csc_format(minibatch, blocks):
), print(blocks[0].srcdata[dgl.NID])
def test_dgl_node_classification_without_feature_csc_format():
def test_dgl_node_classification_without_feature():
# Arrange
minibatch = create_homo_minibatch_csc_format()
minibatch = create_homo_minibatch()
minibatch.node_features = None
minibatch.labels = None
minibatch.seed_nodes = torch.tensor([10, 15])
......@@ -1028,12 +732,12 @@ def test_dgl_node_classification_without_feature_csc_format():
assert len(dgl_blocks) == 2
assert minibatch.node_features is None
assert minibatch.labels is None
check_dgl_blocks_homo_csc_format(minibatch, dgl_blocks)
check_dgl_blocks_homo(minibatch, dgl_blocks)
def test_dgl_node_classification_homo_csc_format():
def test_dgl_node_classification_homo():
# Arrange
minibatch = create_homo_minibatch_csc_format()
minibatch = create_homo_minibatch()
minibatch.seed_nodes = torch.tensor([10, 15])
minibatch.labels = torch.tensor([2, 5])
# Act
......@@ -1041,11 +745,11 @@ def test_dgl_node_classification_homo_csc_format():
# Assert
assert len(dgl_blocks) == 2
check_dgl_blocks_homo_csc_format(minibatch, dgl_blocks)
check_dgl_blocks_homo(minibatch, dgl_blocks)
def test_dgl_node_classification_hetero_csc_format():
minibatch = create_hetero_minibatch_csc_format()
def test_dgl_node_classification_hetero():
minibatch = create_hetero_minibatch()
minibatch.labels = {"B": torch.tensor([2, 5])}
minibatch.seed_nodes = {"B": torch.tensor([10, 15])}
# Act
......@@ -1053,13 +757,13 @@ def test_dgl_node_classification_hetero_csc_format():
# Assert
assert len(dgl_blocks) == 2
check_dgl_blocks_hetero_csc_format(minibatch, dgl_blocks)
check_dgl_blocks_hetero(minibatch, dgl_blocks)
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"])
def test_dgl_link_predication_homo_csc_format(mode):
def test_dgl_link_predication_homo(mode):
# Arrange
minibatch = create_homo_minibatch_csc_format()
minibatch = create_homo_minibatch()
minibatch.compacted_node_pairs = (
torch.tensor([0, 1]),
torch.tensor([1, 0]),
......@@ -1073,7 +777,7 @@ def test_dgl_link_predication_homo_csc_format(mode):
# Assert
assert len(dgl_blocks) == 2
check_dgl_blocks_homo_csc_format(minibatch, dgl_blocks)
check_dgl_blocks_homo(minibatch, dgl_blocks)
if mode == "neg_graph" or mode == "neg_src":
assert torch.equal(
minibatch.negative_node_pairs[0],
......@@ -1105,9 +809,9 @@ def test_dgl_link_predication_homo_csc_format(mode):
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"])
def test_dgl_link_predication_hetero_csc_format(mode):
def test_dgl_link_predication_hetero(mode):
# Arrange
minibatch = create_hetero_minibatch_csc_format()
minibatch = create_hetero_minibatch()
minibatch.compacted_node_pairs = {
relation: (
torch.tensor([1, 1]),
......@@ -1133,7 +837,7 @@ def test_dgl_link_predication_hetero_csc_format(mode):
# Assert
assert len(dgl_blocks) == 2
check_dgl_blocks_hetero_csc_format(minibatch, dgl_blocks)
check_dgl_blocks_hetero(minibatch, dgl_blocks)
if mode == "neg_graph" or mode == "neg_src":
for etype, src in minibatch.compacted_negative_srcs.items():
assert torch.equal(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment