Unverified Commit 74684bbe authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Fix minibatch layout. (#6616)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent 683e1f45
...@@ -146,3 +146,23 @@ class CSCFormatBase: ...@@ -146,3 +146,23 @@ class CSCFormatBase:
""" """
indptr: torch.Tensor = None indptr: torch.Tensor = None
indices: torch.Tensor = None indices: torch.Tensor = None
def __repr__(self) -> str:
return _csc_format_base_str(self)
def _csc_format_base_str(csc_format_base: CSCFormatBase) -> str:
final_str = "CSCFormatBase("
def _add_indent(_str, indent):
lines = _str.split("\n")
lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
return "\n".join(lines)
final_str += (
f"indptr={_add_indent(str(csc_format_base.indptr), 21)},\n" + " " * 14
)
final_str += (
f"indices={_add_indent(str(csc_format_base.indices), 22)},\n" + ")"
)
return final_str
...@@ -68,6 +68,9 @@ class FusedSampledSubgraphImpl(SampledSubgraph): ...@@ -68,6 +68,9 @@ class FusedSampledSubgraphImpl(SampledSubgraph):
isinstance(item, torch.Tensor) for item in self.node_pairs isinstance(item, torch.Tensor) for item in self.node_pairs
), "Nodes in pairs should be of type torch.Tensor." ), "Nodes in pairs should be of type torch.Tensor."
def __repr__(self) -> str:
return _sampled_subgraph_str(self, "FusedSampledSubgraphImpl")
@dataclass @dataclass
class SampledSubgraphImpl(SampledSubgraph): class SampledSubgraphImpl(SampledSubgraph):
...@@ -129,3 +132,38 @@ class SampledSubgraphImpl(SampledSubgraph): ...@@ -129,3 +132,38 @@ class SampledSubgraphImpl(SampledSubgraph):
) and isinstance( ) and isinstance(
self.node_pairs.indices, torch.Tensor self.node_pairs.indices, torch.Tensor
), "Nodes in pairs should be of type torch.Tensor." ), "Nodes in pairs should be of type torch.Tensor."
def __repr__(self) -> str:
return _sampled_subgraph_str(self, "SampledSubgraphImpl")
def _sampled_subgraph_str(sampled_subgraph: SampledSubgraph, classname) -> str:
final_str = classname + "("
def _get_attributes(_obj) -> list:
attributes = [
attribute
for attribute in dir(_obj)
if not attribute.startswith("__")
and not callable(getattr(_obj, attribute))
]
return attributes
attributes = _get_attributes(sampled_subgraph)
attributes.reverse()
for name in attributes:
val = getattr(sampled_subgraph, name)
def _add_indent(_str, indent):
lines = _str.split("\n")
lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
return "\n".join(lines)
val = str(val)
final_str = (
final_str
+ f"{name}={_add_indent(val, len(name) + len(classname) + 1)},\n"
+ " " * len(classname)
)
return final_str[: -len(classname)] + ")"
...@@ -577,35 +577,11 @@ def _minibatch_str(minibatch: MiniBatch) -> str: ...@@ -577,35 +577,11 @@ def _minibatch_str(minibatch: MiniBatch) -> str:
# indentation on top of the original if the original data output has # indentation on top of the original if the original data output has
# line feeds. # line feeds.
if isinstance(val, list): if isinstance(val, list):
if len(val) == 0: val = [str(val_str) for val_str in val]
val = "[]" val = "[" + ",\n".join(val) + "]"
# Special handling of FusedSampledSubgraphImpl data. Each element of elif isinstance(val, tuple):
# the data occupies one row and is further structured. val = [str(val_str) for val_str in val]
elif isinstance( val = "(" + ",\n".join(val) + ")"
val[0],
dgl.graphbolt.impl.sampled_subgraph_impl.FusedSampledSubgraphImpl,
):
sampledsubgraph_strs = []
for sampledsubgraph in val:
ss_attributes = _get_attributes(sampledsubgraph)
sampledsubgraph_str = "FusedSampledSubgraphImpl("
for ss_name in ss_attributes:
ss_val = str(getattr(sampledsubgraph, ss_name))
sampledsubgraph_str = (
sampledsubgraph_str
+ f"{ss_name}={_add_indent(ss_val, len(ss_name)+1)},\n"
+ " " * 20
)
sampledsubgraph_strs.append(sampledsubgraph_str[:-21] + ")")
val = "[" + ",\n".join(sampledsubgraph_strs) + "]"
else:
val = [
_add_indent(
str(val_str), len(str(val_str).split("': ")[0]) - 6
)
for val_str in val
]
val = "[" + ",\n".join(val) + "]"
else: else:
val = str(val) val = str(val)
final_str = ( final_str = (
......
...@@ -111,15 +111,15 @@ def create_hetero_minibatch(): ...@@ -111,15 +111,15 @@ def create_hetero_minibatch():
) )
def test_minibatch_representation(): def test_minibatch_representation_homo():
node_pairs = [ csc_formats = [
( gb.CSCFormatBase(
torch.tensor([0, 1, 2, 2, 2, 1]), indptr=torch.tensor([0, 1, 3, 5, 6]),
torch.tensor([0, 1, 1, 2, 3, 2]), indices=torch.tensor([0, 1, 2, 2, 1, 2]),
), ),
( gb.CSCFormatBase(
torch.tensor([0, 1, 2]), indptr=torch.tensor([0, 2, 3]),
torch.tensor([1, 0, 0]), indices=torch.tensor([1, 2, 0]),
), ),
] ]
original_column_node_ids = [ original_column_node_ids = [
...@@ -134,16 +134,16 @@ def test_minibatch_representation(): ...@@ -134,16 +134,16 @@ def test_minibatch_representation():
torch.tensor([19, 20, 21, 22, 25, 30]), torch.tensor([19, 20, 21, 22, 25, 30]),
torch.tensor([10, 15, 17]), torch.tensor([10, 15, 17]),
] ]
node_features = {"x": torch.tensor([7, 6, 2, 2])} node_features = {"x": torch.tensor([5, 0, 2, 1])}
edge_features = [ edge_features = [
{"x": torch.tensor([[8], [1], [6]])}, {"x": torch.tensor([9, 0, 1, 1, 7, 4])},
{"x": torch.tensor([[2], [8], [8]])}, {"x": torch.tensor([0, 2, 2])},
] ]
subgraphs = [] subgraphs = []
for i in range(2): for i in range(2):
subgraphs.append( subgraphs.append(
gb.FusedSampledSubgraphImpl( gb.SampledSubgraphImpl(
node_pairs=node_pairs[i], node_pairs=csc_formats[i],
original_column_node_ids=original_column_node_ids[i], original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i], original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i], original_edge_ids=original_edge_ids[i],
...@@ -152,7 +152,9 @@ def test_minibatch_representation(): ...@@ -152,7 +152,9 @@ def test_minibatch_representation():
negative_srcs = torch.tensor([[8], [1], [6]]) negative_srcs = torch.tensor([[8], [1], [6]])
negative_dsts = torch.tensor([[2], [8], [8]]) negative_dsts = torch.tensor([[2], [8], [8]])
input_nodes = torch.tensor([8, 1, 6, 5, 9, 0, 2, 4]) input_nodes = torch.tensor([8, 1, 6, 5, 9, 0, 2, 4])
compacted_node_pairs = (torch.tensor([0, 1, 2]), torch.tensor([3, 4, 5])) compacted_csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 3]), indices=torch.tensor([3, 4, 5])
)
compacted_negative_srcs = torch.tensor([[0], [1], [2]]) compacted_negative_srcs = torch.tensor([[0], [1], [2]])
compacted_negative_dsts = torch.tensor([[6], [0], [0]]) compacted_negative_dsts = torch.tensor([[6], [0], [0]])
labels = torch.tensor([0.0, 1.0, 2.0]) labels = torch.tensor([0.0, 1.0, 2.0])
...@@ -177,31 +179,41 @@ def test_minibatch_representation(): ...@@ -177,31 +179,41 @@ def test_minibatch_representation():
assert result == expect_result, print(len(expect_result), len(result)) assert result == expect_result, print(len(expect_result), len(result))
# Test minibatch with all attributes. # Test minibatch with all attributes.
minibatch = gb.MiniBatch( minibatch = gb.MiniBatch(
node_pairs=node_pairs, node_pairs=csc_formats,
sampled_subgraphs=subgraphs, sampled_subgraphs=subgraphs,
labels=labels, labels=labels,
node_features=node_features, node_features=node_features,
edge_features=edge_features, edge_features=edge_features,
negative_srcs=negative_srcs, negative_srcs=negative_srcs,
negative_dsts=negative_dsts, negative_dsts=negative_dsts,
compacted_node_pairs=compacted_node_pairs, compacted_node_pairs=compacted_csc_formats,
input_nodes=input_nodes, input_nodes=input_nodes,
compacted_negative_srcs=compacted_negative_srcs, compacted_negative_srcs=compacted_negative_srcs,
compacted_negative_dsts=compacted_negative_dsts, compacted_negative_dsts=compacted_negative_dsts,
) )
expect_result = str( expect_result = str(
"""MiniBatch(seed_nodes=None, """MiniBatch(seed_nodes=None,
sampled_subgraphs=[FusedSampledSubgraphImpl(node_pairs=(tensor([0, 1, 2, 2, 2, 1]), tensor([0, 1, 1, 2, 3, 2])), sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids=tensor([10, 11, 12, 13]),
original_column_node_ids=tensor([10, 11, 12, 13]), original_edge_ids=tensor([19, 20, 21, 22, 25, 30]),
original_edge_ids=tensor([19, 20, 21, 22, 25, 30]), original_column_node_ids=tensor([10, 11, 12, 13]),
original_row_node_ids=tensor([10, 11, 12, 13]),), node_pairs=CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6]),
FusedSampledSubgraphImpl(node_pairs=(tensor([0, 1, 2]), tensor([1, 0, 0])), indices=tensor([0, 1, 2, 2, 1, 2]),
original_column_node_ids=tensor([10, 11]), ),
original_edge_ids=tensor([10, 15, 17]), ),
original_row_node_ids=tensor([10, 11, 12]),)], SampledSubgraphImpl(original_row_node_ids=tensor([10, 11, 12]),
node_pairs=[(tensor([0, 1, 2, 2, 2, 1]), tensor([0, 1, 1, 2, 3, 2])), original_edge_ids=tensor([10, 15, 17]),
(tensor([0, 1, 2]), tensor([1, 0, 0]))], original_column_node_ids=tensor([10, 11]),
node_features={'x': tensor([7, 6, 2, 2])}, node_pairs=CSCFormatBase(indptr=tensor([0, 2, 3]),
indices=tensor([1, 2, 0]),
),
)],
node_pairs=[CSCFormatBase(indptr=tensor([0, 1, 3, 5, 6]),
indices=tensor([0, 1, 2, 2, 1, 2]),
),
CSCFormatBase(indptr=tensor([0, 2, 3]),
indices=tensor([1, 2, 0]),
)],
node_features={'x': tensor([5, 0, 2, 1])},
negative_srcs=tensor([[8], negative_srcs=tensor([[8],
[1], [1],
[6]]), [6]]),
...@@ -210,13 +222,11 @@ def test_minibatch_representation(): ...@@ -210,13 +222,11 @@ def test_minibatch_representation():
[8]]), [8]]),
labels=tensor([0., 1., 2.]), labels=tensor([0., 1., 2.]),
input_nodes=tensor([8, 1, 6, 5, 9, 0, 2, 4]), input_nodes=tensor([8, 1, 6, 5, 9, 0, 2, 4]),
edge_features=[{'x': tensor([[8], edge_features=[{'x': tensor([9, 0, 1, 1, 7, 4])},
[1], {'x': tensor([0, 2, 2])}],
[6]])}, compacted_node_pairs=CSCFormatBase(indptr=tensor([0, 2, 3]),
{'x': tensor([[2], indices=tensor([3, 4, 5]),
[8], ),
[8]])}],
compacted_node_pairs=(tensor([0, 1, 2]), tensor([3, 4, 5])),
compacted_negative_srcs=tensor([[0], compacted_negative_srcs=tensor([[0],
[1], [1],
[2]]), [2]]),
...@@ -229,6 +239,146 @@ def test_minibatch_representation(): ...@@ -229,6 +239,146 @@ def test_minibatch_representation():
assert result == expect_result, print(expect_result, result) assert result == expect_result, print(expect_result, result)
def test_minibatch_representation_hetero():
csc_formats = [
{
relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2, 3]),
indices=torch.tensor([0, 1, 1]),
),
reverse_relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 0, 1, 2]),
indices=torch.tensor([1, 0]),
),
},
{
relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2]), indices=torch.tensor([1, 0])
)
},
]
original_column_node_ids = [
{"B": torch.tensor([10, 11, 12]), "A": torch.tensor([5, 7, 9, 11])},
{"B": torch.tensor([10, 11])},
]
original_row_node_ids = [
{
"A": torch.tensor([5, 7, 9, 11]),
"B": torch.tensor([10, 11, 12]),
},
{
"A": torch.tensor([5, 7]),
"B": torch.tensor([10, 11]),
},
]
original_edge_ids = [
{
relation: torch.tensor([19, 20, 21]),
reverse_relation: torch.tensor([23, 26]),
},
{relation: torch.tensor([10, 12])},
]
node_features = {
("A", "x"): torch.tensor([6, 4, 0, 1]),
}
edge_features = [
{(relation, "x"): torch.tensor([4, 2, 4])},
{(relation, "x"): torch.tensor([0, 6])},
]
subgraphs = []
for i in range(2):
subgraphs.append(
gb.SampledSubgraphImpl(
node_pairs=csc_formats[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i],
)
)
negative_srcs = {"B": torch.tensor([[8], [1], [6]])}
negative_dsts = {"B": torch.tensor([[2], [8], [8]])}
compacted_csc_formats = {
relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 2, 3]), indices=torch.tensor([3, 4, 5])
),
reverse_relation: gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 0, 1, 2]), indices=torch.tensor([0, 1])
),
}
compacted_negative_srcs = {relation: torch.tensor([[0], [1], [2]])}
compacted_negative_dsts = {relation: torch.tensor([[6], [0], [0]])}
# Test dglminibatch with all attributes.
minibatch = gb.MiniBatch(
seed_nodes={"B": torch.tensor([10, 15])},
node_pairs=csc_formats,
sampled_subgraphs=subgraphs,
node_features=node_features,
edge_features=edge_features,
labels={"B": torch.tensor([2, 5])},
negative_srcs=negative_srcs,
negative_dsts=negative_dsts,
compacted_node_pairs=compacted_csc_formats,
input_nodes={
"A": torch.tensor([5, 7, 9, 11]),
"B": torch.tensor([10, 11, 12]),
},
compacted_negative_srcs=compacted_negative_srcs,
compacted_negative_dsts=compacted_negative_dsts,
)
expect_result = str(
"""MiniBatch(seed_nodes={'B': tensor([10, 15])},
sampled_subgraphs=[SampledSubgraphImpl(original_row_node_ids={'A': tensor([ 5, 7, 9, 11]), 'B': tensor([10, 11, 12])},
original_edge_ids={'A:r:B': tensor([19, 20, 21]), 'B:rr:A': tensor([23, 26])},
original_column_node_ids={'B': tensor([10, 11, 12]), 'A': tensor([ 5, 7, 9, 11])},
node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([0, 1, 1]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
indices=tensor([1, 0]),
)},
),
SampledSubgraphImpl(original_row_node_ids={'A': tensor([5, 7]), 'B': tensor([10, 11])},
original_edge_ids={'A:r:B': tensor([10, 12])},
original_column_node_ids={'B': tensor([10, 11])},
node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([1, 0]),
)},
)],
node_pairs=[{'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([0, 1, 1]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
indices=tensor([1, 0]),
)},
{'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2]),
indices=tensor([1, 0]),
)}],
node_features={('A', 'x'): tensor([6, 4, 0, 1])},
negative_srcs={'B': tensor([[8],
[1],
[6]])},
negative_dsts={'B': tensor([[2],
[8],
[8]])},
labels={'B': tensor([2, 5])},
input_nodes={'A': tensor([ 5, 7, 9, 11]), 'B': tensor([10, 11, 12])},
edge_features=[{('A:r:B', 'x'): tensor([4, 2, 4])},
{('A:r:B', 'x'): tensor([0, 6])}],
compacted_node_pairs={'A:r:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([3, 4, 5]),
), 'B:rr:A': CSCFormatBase(indptr=tensor([0, 0, 0, 1, 2]),
indices=tensor([0, 1]),
)},
compacted_negative_srcs={'A:r:B': tensor([[0],
[1],
[2]])},
compacted_negative_dsts={'A:r:B': tensor([[6],
[0],
[0]])},
)"""
)
result = str(minibatch)
assert result == expect_result, print(result)
def test_dgl_minibatch_representation_homo(): def test_dgl_minibatch_representation_homo():
node_pairs = [ node_pairs = [
( (
......
import unittest import unittest
import backend as F import backend as F
import dgl.graphbolt as gb
import pytest import pytest
import torch import torch
from dgl.graphbolt.impl.sampled_subgraph_impl import (
from dgl.graphbolt.impl.sampled_subgraph_impl import FusedSampledSubgraphImpl FusedSampledSubgraphImpl,
SampledSubgraphImpl,
)
def _assert_container_equal(lhs, rhs): def _assert_container_equal(lhs, rhs):
...@@ -185,3 +189,77 @@ def test_sampled_subgraph_to_device(): ...@@ -185,3 +189,77 @@ def test_sampled_subgraph_to_device():
assert graph.original_row_node_ids[key].device.type == "cuda" assert graph.original_row_node_ids[key].device.type == "cuda"
for key in graph.original_edge_ids: for key in graph.original_edge_ids:
assert graph.original_edge_ids[key].device.type == "cuda" assert graph.original_edge_ids[key].device.type == "cuda"
def test_sampled_subgraph_impl_representation_homo():
sampled_subgraph_impl = SampledSubgraphImpl(
node_pairs=gb.CSCFormatBase(
indptr=torch.arange(0, 101, 10),
indices=torch.arange(10, 110),
),
original_column_node_ids=torch.arange(0, 10),
original_row_node_ids=torch.arange(0, 110),
original_edge_ids=None,
)
expected_result = str(
"""SampledSubgraphImpl(original_row_node_ids=tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,
42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,
56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,
70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,
98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109]),
original_edge_ids=None,
original_column_node_ids=tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
node_pairs=CSCFormatBase(indptr=tensor([ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]),
indices=tensor([ 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37,
38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51,
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65,
66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,
80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93,
94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107,
108, 109]),
),
)"""
)
assert str(sampled_subgraph_impl) == expected_result, print(
sampled_subgraph_impl
)
def test_sampled_subgraph_impl_representation_hetero():
sampled_subgraph_impl = SampledSubgraphImpl(
node_pairs={
"n1:e1:n2": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4]),
indices=torch.tensor([4, 5, 6, 7]),
),
"n2:e2:n1": gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4, 6, 8]),
indices=torch.tensor([2, 3, 4, 5, 6, 7, 8, 9]),
),
},
original_column_node_ids={
"n1": torch.tensor([1, 0, 0, 1]),
"n2": torch.tensor([1, 2]),
},
original_row_node_ids={
"n1": torch.tensor([1, 0, 0, 1, 1, 0, 0, 1]),
"n2": torch.tensor([1, 2, 0, 1, 0, 2, 0, 2, 0, 1]),
},
original_edge_ids=None,
)
expected_result = str(
"""SampledSubgraphImpl(original_row_node_ids={'n1': tensor([1, 0, 0, 1, 1, 0, 0, 1]), 'n2': tensor([1, 2, 0, 1, 0, 2, 0, 2, 0, 1])},
original_edge_ids=None,
original_column_node_ids={'n1': tensor([1, 0, 0, 1]), 'n2': tensor([1, 2])},
node_pairs={'n1:e1:n2': CSCFormatBase(indptr=tensor([0, 2, 4]),
indices=tensor([4, 5, 6, 7]),
), 'n2:e2:n1': CSCFormatBase(indptr=tensor([0, 2, 4, 6, 8]),
indices=tensor([2, 3, 4, 5, 6, 7, 8, 9]),
)},
)"""
)
assert str(sampled_subgraph_impl) == expected_result, print(expected_result)
...@@ -150,3 +150,16 @@ def test_isin_non_1D_dim(): ...@@ -150,3 +150,16 @@ def test_isin_non_1D_dim():
test_elements = torch.tensor([[2, 5]]) test_elements = torch.tensor([[2, 5]])
with pytest.raises(Exception): with pytest.raises(Exception):
gb.isin(elements, test_elements) gb.isin(elements, test_elements)
def test_csc_format_base_representation():
csc_format_base = gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4]),
indices=torch.tensor([4, 5, 6, 7]),
)
expected_result = str(
"""CSCFormatBase(indptr=tensor([0, 2, 4]),
indices=tensor([4, 5, 6, 7]),
)"""
)
assert str(csc_format_base) == expected_result, print(csc_format_base)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment