Unverified Commit 46a566af authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Remove coo from blocks in MiniBatch. (#6855)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 59bb5406
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import dgl import dgl
from dgl.utils import recursive_apply from dgl.utils import recursive_apply
from .base import CSCFormatBase, etype_str_to_tuple from .base import etype_str_to_tuple
from .internal import get_attributes from .internal import get_attributes
from .sampled_subgraph import SampledSubgraph from .sampled_subgraph import SampledSubgraph
...@@ -194,30 +194,22 @@ class MiniBatch: ...@@ -194,30 +194,22 @@ class MiniBatch:
original_column_node_ids is not None original_column_node_ids is not None
), "Missing `original_column_node_ids` in sampled subgraph." ), "Missing `original_column_node_ids` in sampled subgraph."
if is_heterogeneous: if is_heterogeneous:
if isinstance( sampled_csc = {
list(subgraph.sampled_csc.values())[0], CSCFormatBase etype_str_to_tuple(etype): (
): "csc",
sampled_csc = { (
etype_str_to_tuple(etype): ( v.indptr,
"csc", v.indices,
( torch.arange(
v.indptr, 0,
v.indices, v.indptr[-1],
torch.arange( device=v.indptr.device,
0, dtype=v.indptr.dtype,
v.indptr[-1],
device=v.indptr.device,
dtype=v.indptr.dtype,
),
), ),
) ),
for etype, v in subgraph.sampled_csc.items() )
} for etype, v in subgraph.sampled_csc.items()
else: }
sampled_csc = {
etype_str_to_tuple(etype): v
for etype, v in subgraph.sampled_csc.items()
}
num_src_nodes = { num_src_nodes = {
ntype: nodes.size(0) ntype: nodes.size(0)
for ntype, nodes in original_row_node_ids.items() for ntype, nodes in original_row_node_ids.items()
...@@ -228,20 +220,19 @@ class MiniBatch: ...@@ -228,20 +220,19 @@ class MiniBatch:
} }
else: else:
sampled_csc = subgraph.sampled_csc sampled_csc = subgraph.sampled_csc
if isinstance(subgraph.sampled_csc, CSCFormatBase): sampled_csc = (
sampled_csc = ( "csc",
"csc", (
( sampled_csc.indptr,
sampled_csc.indptr, sampled_csc.indices,
sampled_csc.indices, torch.arange(
torch.arange( 0,
0, sampled_csc.indptr[-1],
sampled_csc.indptr[-1], device=sampled_csc.indptr.device,
device=sampled_csc.indptr.device, dtype=sampled_csc.indptr.dtype,
dtype=sampled_csc.indptr.dtype,
),
), ),
) ),
)
num_src_nodes = original_row_node_ids.size(0) num_src_nodes = original_row_node_ids.size(0)
num_dst_nodes = original_column_node_ids.size(0) num_dst_nodes = original_column_node_ids.size(0)
blocks.append( blocks.append(
......
...@@ -8,109 +8,6 @@ relation = "A:r:B" ...@@ -8,109 +8,6 @@ relation = "A:r:B"
reverse_relation = "B:rr:A" reverse_relation = "B:rr:A"
def create_homo_minibatch():
node_pairs = [
(
torch.tensor([0, 1, 2, 2, 2, 1]),
torch.tensor([0, 1, 1, 2, 3, 2]),
),
(
torch.tensor([0, 1, 2]),
torch.tensor([1, 0, 0]),
),
]
original_column_node_ids = [
torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11]),
]
original_row_node_ids = [
torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11, 12]),
]
original_edge_ids = [
torch.tensor([19, 20, 21, 22, 25, 30]),
torch.tensor([10, 15, 17]),
]
node_features = {"x": torch.randint(0, 10, (4,))}
edge_features = [
{"x": torch.randint(0, 10, (6,))},
{"x": torch.randint(0, 10, (3,))},
]
subgraphs = []
for i in range(2):
subgraphs.append(
gb.FusedSampledSubgraphImpl(
sampled_csc=node_pairs[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i],
)
)
return gb.MiniBatch(
sampled_subgraphs=subgraphs,
node_features=node_features,
edge_features=edge_features,
input_nodes=torch.tensor([10, 11, 12, 13]),
)
def create_hetero_minibatch():
node_pairs = [
{
relation: (torch.tensor([0, 1, 1]), torch.tensor([0, 1, 2])),
reverse_relation: (torch.tensor([1, 0]), torch.tensor([2, 3])),
},
{relation: (torch.tensor([0, 1]), torch.tensor([1, 0]))},
]
original_column_node_ids = [
{"B": torch.tensor([10, 11, 12]), "A": torch.tensor([5, 7, 9, 11])},
{"B": torch.tensor([10, 11])},
]
original_row_node_ids = [
{
"A": torch.tensor([5, 7, 9, 11]),
"B": torch.tensor([10, 11, 12]),
},
{
"A": torch.tensor([5, 7]),
"B": torch.tensor([10, 11]),
},
]
original_edge_ids = [
{
relation: torch.tensor([19, 20, 21]),
reverse_relation: torch.tensor([23, 26]),
},
{relation: torch.tensor([10, 12])},
]
node_features = {
("A", "x"): torch.randint(0, 10, (4,)),
}
edge_features = [
{(relation, "x"): torch.randint(0, 10, (3,))},
{(relation, "x"): torch.randint(0, 10, (2,))},
]
subgraphs = []
for i in range(2):
subgraphs.append(
gb.FusedSampledSubgraphImpl(
sampled_csc=node_pairs[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i],
)
)
return gb.MiniBatch(
sampled_subgraphs=subgraphs,
node_features=node_features,
edge_features=edge_features,
input_nodes={
"A": torch.tensor([5, 7, 9, 11]),
"B": torch.tensor([10, 11, 12]),
},
)
def test_minibatch_representation_homo(): def test_minibatch_representation_homo():
csc_formats = [ csc_formats = [
gb.CSCFormatBase( gb.CSCFormatBase(
...@@ -425,6 +322,16 @@ def test_get_dgl_blocks_homo(): ...@@ -425,6 +322,16 @@ def test_get_dgl_blocks_homo():
torch.tensor([1, 0, 0]), torch.tensor([1, 0, 0]),
), ),
] ]
csc_formats = [
gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3, 5, 6]),
indices=torch.tensor([0, 1, 2, 2, 1, 2]),
),
gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3]),
indices=torch.tensor([0, 1, 2]),
),
]
original_column_node_ids = [ original_column_node_ids = [
torch.tensor([10, 11, 12, 13]), torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11]), torch.tensor([10, 11]),
...@@ -445,8 +352,8 @@ def test_get_dgl_blocks_homo(): ...@@ -445,8 +352,8 @@ def test_get_dgl_blocks_homo():
subgraphs = [] subgraphs = []
for i in range(2): for i in range(2):
subgraphs.append( subgraphs.append(
gb.FusedSampledSubgraphImpl( gb.SampledSubgraphImpl(
sampled_csc=node_pairs[i], sampled_csc=csc_formats[i],
original_column_node_ids=original_column_node_ids[i], original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i], original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i], original_edge_ids=original_edge_ids[i],
...@@ -489,6 +396,23 @@ def test_get_dgl_blocks_hetero(): ...@@ -489,6 +396,23 @@ def test_get_dgl_blocks_hetero():
}, },
{relation: (torch.tensor([0, 1]), torch.tensor([1, 0]))}, {relation: (torch.tensor([0, 1]), torch.tensor([1, 0]))},
] ]
csc_formats = [
{
relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2, 3]),
indices=torch.tensor([0, 1, 1]),
),
reverse_relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 0, 1, 2]),
indices=torch.tensor([1, 0]),
),
},
{
relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2]), indices=torch.tensor([1, 0])
)
},
]
original_column_node_ids = [ original_column_node_ids = [
{"B": torch.tensor([10, 11, 12]), "A": torch.tensor([5, 7, 9, 11])}, {"B": torch.tensor([10, 11, 12]), "A": torch.tensor([5, 7, 9, 11])},
{"B": torch.tensor([10, 11])}, {"B": torch.tensor([10, 11])},
...@@ -520,8 +444,8 @@ def test_get_dgl_blocks_hetero(): ...@@ -520,8 +444,8 @@ def test_get_dgl_blocks_hetero():
subgraphs = [] subgraphs = []
for i in range(2): for i in range(2):
subgraphs.append( subgraphs.append(
gb.FusedSampledSubgraphImpl( gb.SampledSubgraphImpl(
sampled_csc=node_pairs[i], sampled_csc=csc_formats[i],
original_column_node_ids=original_column_node_ids[i], original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i], original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i], original_edge_ids=original_edge_ids[i],
...@@ -610,227 +534,7 @@ def test_minibatch_node_pairs_with_labels(mode): ...@@ -610,227 +534,7 @@ def test_minibatch_node_pairs_with_labels(mode):
assert torch.equal(labels, expect_labels) assert torch.equal(labels, expect_labels)
def check_dgl_blocks_hetero(minibatch, blocks): def create_homo_minibatch():
etype = gb.etype_str_to_tuple(relation)
node_pairs = [
subgraph.sampled_csc for subgraph in minibatch.sampled_subgraphs
]
original_edge_ids = [
subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs
]
original_row_node_ids = [
subgraph.original_row_node_ids
for subgraph in minibatch.sampled_subgraphs
]
for i, block in enumerate(blocks):
edges = block.edges(etype=etype)
assert torch.equal(edges[0], node_pairs[i][relation][0])
assert torch.equal(edges[1], node_pairs[i][relation][1])
assert torch.equal(
block.edges[etype].data[dgl.EID], original_edge_ids[i][relation]
)
edges = blocks[0].edges(etype=gb.etype_str_to_tuple(reverse_relation))
assert torch.equal(edges[0], node_pairs[0][reverse_relation][0])
assert torch.equal(edges[1], node_pairs[0][reverse_relation][1])
assert torch.equal(
blocks[0].srcdata[dgl.NID]["A"], original_row_node_ids[0]["A"]
)
assert torch.equal(
blocks[0].srcdata[dgl.NID]["B"], original_row_node_ids[0]["B"]
)
def check_dgl_blocks_homo(minibatch, blocks):
node_pairs = [
subgraph.sampled_csc for subgraph in minibatch.sampled_subgraphs
]
original_edge_ids = [
subgraph.original_edge_ids for subgraph in minibatch.sampled_subgraphs
]
original_row_node_ids = [
subgraph.original_row_node_ids
for subgraph in minibatch.sampled_subgraphs
]
for i, block in enumerate(blocks):
assert torch.equal(block.edges()[0], node_pairs[i][0])
assert torch.equal(block.edges()[1], node_pairs[i][1])
assert torch.equal(block.edata[dgl.EID], original_edge_ids[i])
assert torch.equal(blocks[0].srcdata[dgl.NID], original_row_node_ids[0])
def test_get_dgl_blocks_node_classification_without_feature():
# Arrange
minibatch = create_homo_minibatch()
minibatch.node_features = None
minibatch.labels = None
minibatch.seed_nodes = torch.tensor([10, 15])
# Act
dgl_blocks = minibatch.blocks
# Assert
assert len(dgl_blocks) == 2
assert minibatch.node_features is None
assert minibatch.labels is None
check_dgl_blocks_homo(minibatch, dgl_blocks)
def test_get_dgl_blocks_node_classification_homo():
# Arrange
minibatch = create_homo_minibatch()
minibatch.seed_nodes = torch.tensor([10, 15])
minibatch.labels = torch.tensor([2, 5])
# Act
dgl_blocks = minibatch.blocks
# Assert
assert len(dgl_blocks) == 2
check_dgl_blocks_homo(minibatch, dgl_blocks)
def test_to_dgl_node_classification_hetero():
minibatch = create_hetero_minibatch()
minibatch.labels = {"B": torch.tensor([2, 5])}
minibatch.seed_nodes = {"B": torch.tensor([10, 15])}
dgl_blocks = minibatch.blocks
# Assert
assert len(dgl_blocks) == 2
check_dgl_blocks_hetero(minibatch, dgl_blocks)
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"])
def test_dgl_link_predication_homo(mode):
# Arrange
minibatch = create_homo_minibatch()
minibatch.compacted_node_pairs = (
torch.tensor([0, 1]),
torch.tensor([1, 0]),
)
if mode == "neg_graph" or mode == "neg_src":
minibatch.compacted_negative_srcs = torch.tensor([[0, 0], [1, 1]])
if mode == "neg_graph" or mode == "neg_dst":
minibatch.compacted_negative_dsts = torch.tensor([[1, 0], [0, 1]])
# Act
dgl_blocks = minibatch.blocks
# Assert
assert len(dgl_blocks) == 2
check_dgl_blocks_homo(minibatch, dgl_blocks)
if mode == "neg_graph" or mode == "neg_src":
assert torch.equal(
minibatch.negative_node_pairs[0],
minibatch.compacted_negative_srcs.view(-1),
)
if mode == "neg_graph" or mode == "neg_dst":
assert torch.equal(
minibatch.negative_node_pairs[1],
minibatch.compacted_negative_dsts.view(-1),
)
node_pairs, labels = minibatch.node_pairs_with_labels
if mode == "neg_src":
expect_node_pairs = (
torch.tensor([0, 1, 0, 0, 1, 1]),
torch.tensor([1, 0, 1, 1, 0, 0]),
)
else:
expect_node_pairs = (
torch.tensor([0, 1, 0, 0, 1, 1]),
torch.tensor([1, 0, 1, 0, 0, 1]),
)
expect_labels = torch.tensor([1, 1, 0, 0, 0, 0]).float()
assert torch.equal(node_pairs[0], expect_node_pairs[0])
assert torch.equal(node_pairs[1], expect_node_pairs[1])
assert torch.equal(labels, expect_labels)
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"])
def test_dgl_link_predication_hetero(mode):
# Arrange
minibatch = create_hetero_minibatch()
minibatch.compacted_node_pairs = {
relation: (
torch.tensor([1, 1]),
torch.tensor([1, 0]),
),
reverse_relation: (
torch.tensor([0, 1]),
torch.tensor([1, 0]),
),
}
if mode == "neg_graph" or mode == "neg_src":
minibatch.compacted_negative_srcs = {
relation: torch.tensor([[2, 0], [1, 2]]),
reverse_relation: torch.tensor([[1, 2], [0, 2]]),
}
if mode == "neg_graph" or mode == "neg_dst":
minibatch.compacted_negative_dsts = {
relation: torch.tensor([[1, 3], [2, 1]]),
reverse_relation: torch.tensor([[2, 1], [3, 1]]),
}
# Act
dgl_blocks = minibatch.blocks
# Assert
assert len(dgl_blocks) == 2
check_dgl_blocks_hetero(minibatch, dgl_blocks)
if mode == "neg_graph" or mode == "neg_src":
for etype, src in minibatch.compacted_negative_srcs.items():
assert torch.equal(
minibatch.negative_node_pairs[etype][0],
src.view(-1),
)
if mode == "neg_graph" or mode == "neg_dst":
for etype, dst in minibatch.compacted_negative_dsts.items():
assert torch.equal(
minibatch.negative_node_pairs[etype][1],
minibatch.compacted_negative_dsts[etype].view(-1),
)
node_pairs, labels = minibatch.node_pairs_with_labels
if mode == "neg_src":
expect_node_pairs = {
"A:r:B": (
torch.tensor([1, 1, 2, 0, 1, 2]),
torch.tensor([1, 0, 1, 1, 0, 0]),
),
"B:rr:A": (
torch.tensor([0, 1, 1, 2, 0, 2]),
torch.tensor([1, 0, 1, 1, 0, 0]),
),
}
elif mode == "neg_dst":
expect_node_pairs = {
"A:r:B": (
torch.tensor([1, 1, 1, 1, 1, 1]),
torch.tensor([1, 0, 1, 3, 2, 1]),
),
"B:rr:A": (
torch.tensor([0, 1, 0, 0, 1, 1]),
torch.tensor([1, 0, 2, 1, 3, 1]),
),
}
else:
expect_node_pairs = {
"A:r:B": (
torch.tensor([1, 1, 2, 0, 1, 2]),
torch.tensor([1, 0, 1, 3, 2, 1]),
),
"B:rr:A": (
torch.tensor([0, 1, 1, 2, 0, 2]),
torch.tensor([1, 0, 2, 1, 3, 1]),
),
}
expect_labels = {
"A:r:B": torch.tensor([1, 1, 0, 0, 0, 0]),
"B:rr:A": torch.tensor([1, 1, 0, 0, 0, 0]),
}
for etype in node_pairs:
assert torch.equal(node_pairs[etype][0], expect_node_pairs[etype][0])
assert torch.equal(node_pairs[etype][1], expect_node_pairs[etype][1])
assert torch.equal(labels[etype], expect_labels[etype])
def create_homo_minibatch_csc_format():
csc_formats = [ csc_formats = [
gb.CSCFormatBase( gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3, 5, 6]), indptr=torch.tensor([0, 1, 3, 5, 6]),
...@@ -876,7 +580,7 @@ def create_homo_minibatch_csc_format(): ...@@ -876,7 +580,7 @@ def create_homo_minibatch_csc_format():
) )
def create_hetero_minibatch_csc_format(): def create_hetero_minibatch():
sampled_csc = [ sampled_csc = [
{ {
relation: gb.CSCFormatBase( relation: gb.CSCFormatBase(
...@@ -943,7 +647,7 @@ def create_hetero_minibatch_csc_format(): ...@@ -943,7 +647,7 @@ def create_hetero_minibatch_csc_format():
) )
def check_dgl_blocks_hetero_csc_format(minibatch, blocks): def check_dgl_blocks_hetero(minibatch, blocks):
etype = gb.etype_str_to_tuple(relation) etype = gb.etype_str_to_tuple(relation)
sampled_csc = [ sampled_csc = [
subgraph.sampled_csc for subgraph in minibatch.sampled_subgraphs subgraph.sampled_csc for subgraph in minibatch.sampled_subgraphs
...@@ -986,7 +690,7 @@ def check_dgl_blocks_hetero_csc_format(minibatch, blocks): ...@@ -986,7 +690,7 @@ def check_dgl_blocks_hetero_csc_format(minibatch, blocks):
) )
def check_dgl_blocks_homo_csc_format(minibatch, blocks): def check_dgl_blocks_homo(minibatch, blocks):
sampled_csc = [ sampled_csc = [
subgraph.sampled_csc for subgraph in minibatch.sampled_subgraphs subgraph.sampled_csc for subgraph in minibatch.sampled_subgraphs
] ]
...@@ -1015,9 +719,9 @@ def check_dgl_blocks_homo_csc_format(minibatch, blocks): ...@@ -1015,9 +719,9 @@ def check_dgl_blocks_homo_csc_format(minibatch, blocks):
), print(blocks[0].srcdata[dgl.NID]) ), print(blocks[0].srcdata[dgl.NID])
def test_dgl_node_classification_without_feature_csc_format(): def test_dgl_node_classification_without_feature():
# Arrange # Arrange
minibatch = create_homo_minibatch_csc_format() minibatch = create_homo_minibatch()
minibatch.node_features = None minibatch.node_features = None
minibatch.labels = None minibatch.labels = None
minibatch.seed_nodes = torch.tensor([10, 15]) minibatch.seed_nodes = torch.tensor([10, 15])
...@@ -1028,12 +732,12 @@ def test_dgl_node_classification_without_feature_csc_format(): ...@@ -1028,12 +732,12 @@ def test_dgl_node_classification_without_feature_csc_format():
assert len(dgl_blocks) == 2 assert len(dgl_blocks) == 2
assert minibatch.node_features is None assert minibatch.node_features is None
assert minibatch.labels is None assert minibatch.labels is None
check_dgl_blocks_homo_csc_format(minibatch, dgl_blocks) check_dgl_blocks_homo(minibatch, dgl_blocks)
def test_dgl_node_classification_homo_csc_format(): def test_dgl_node_classification_homo():
# Arrange # Arrange
minibatch = create_homo_minibatch_csc_format() minibatch = create_homo_minibatch()
minibatch.seed_nodes = torch.tensor([10, 15]) minibatch.seed_nodes = torch.tensor([10, 15])
minibatch.labels = torch.tensor([2, 5]) minibatch.labels = torch.tensor([2, 5])
# Act # Act
...@@ -1041,11 +745,11 @@ def test_dgl_node_classification_homo_csc_format(): ...@@ -1041,11 +745,11 @@ def test_dgl_node_classification_homo_csc_format():
# Assert # Assert
assert len(dgl_blocks) == 2 assert len(dgl_blocks) == 2
check_dgl_blocks_homo_csc_format(minibatch, dgl_blocks) check_dgl_blocks_homo(minibatch, dgl_blocks)
def test_dgl_node_classification_hetero_csc_format(): def test_dgl_node_classification_hetero():
minibatch = create_hetero_minibatch_csc_format() minibatch = create_hetero_minibatch()
minibatch.labels = {"B": torch.tensor([2, 5])} minibatch.labels = {"B": torch.tensor([2, 5])}
minibatch.seed_nodes = {"B": torch.tensor([10, 15])} minibatch.seed_nodes = {"B": torch.tensor([10, 15])}
# Act # Act
...@@ -1053,13 +757,13 @@ def test_dgl_node_classification_hetero_csc_format(): ...@@ -1053,13 +757,13 @@ def test_dgl_node_classification_hetero_csc_format():
# Assert # Assert
assert len(dgl_blocks) == 2 assert len(dgl_blocks) == 2
check_dgl_blocks_hetero_csc_format(minibatch, dgl_blocks) check_dgl_blocks_hetero(minibatch, dgl_blocks)
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"]) @pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"])
def test_dgl_link_predication_homo_csc_format(mode): def test_dgl_link_predication_homo(mode):
# Arrange # Arrange
minibatch = create_homo_minibatch_csc_format() minibatch = create_homo_minibatch()
minibatch.compacted_node_pairs = ( minibatch.compacted_node_pairs = (
torch.tensor([0, 1]), torch.tensor([0, 1]),
torch.tensor([1, 0]), torch.tensor([1, 0]),
...@@ -1073,7 +777,7 @@ def test_dgl_link_predication_homo_csc_format(mode): ...@@ -1073,7 +777,7 @@ def test_dgl_link_predication_homo_csc_format(mode):
# Assert # Assert
assert len(dgl_blocks) == 2 assert len(dgl_blocks) == 2
check_dgl_blocks_homo_csc_format(minibatch, dgl_blocks) check_dgl_blocks_homo(minibatch, dgl_blocks)
if mode == "neg_graph" or mode == "neg_src": if mode == "neg_graph" or mode == "neg_src":
assert torch.equal( assert torch.equal(
minibatch.negative_node_pairs[0], minibatch.negative_node_pairs[0],
...@@ -1105,9 +809,9 @@ def test_dgl_link_predication_homo_csc_format(mode): ...@@ -1105,9 +809,9 @@ def test_dgl_link_predication_homo_csc_format(mode):
@pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"]) @pytest.mark.parametrize("mode", ["neg_graph", "neg_src", "neg_dst"])
def test_dgl_link_predication_hetero_csc_format(mode): def test_dgl_link_predication_hetero(mode):
# Arrange # Arrange
minibatch = create_hetero_minibatch_csc_format() minibatch = create_hetero_minibatch()
minibatch.compacted_node_pairs = { minibatch.compacted_node_pairs = {
relation: ( relation: (
torch.tensor([1, 1]), torch.tensor([1, 1]),
...@@ -1133,7 +837,7 @@ def test_dgl_link_predication_hetero_csc_format(mode): ...@@ -1133,7 +837,7 @@ def test_dgl_link_predication_hetero_csc_format(mode):
# Assert # Assert
assert len(dgl_blocks) == 2 assert len(dgl_blocks) == 2
check_dgl_blocks_hetero_csc_format(minibatch, dgl_blocks) check_dgl_blocks_hetero(minibatch, dgl_blocks)
if mode == "neg_graph" or mode == "neg_src": if mode == "neg_graph" or mode == "neg_src":
for etype, src in minibatch.compacted_negative_srcs.items(): for etype, src in minibatch.compacted_negative_srcs.items():
assert torch.equal( assert torch.equal(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment