"examples/pytorch/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "d04d59eef9744e7133f5b84938523d20a7321f54"
Unverified Commit 3f958d7c authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] modify DGLMiniBatch layout. (#6410)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent a19e0efd
...@@ -96,6 +96,9 @@ class DGLMiniBatch: ...@@ -96,6 +96,9 @@ class DGLMiniBatch:
given type. given type.
""" """
def __repr__(self) -> str:
return _dgl_minibatch_str(self)
def to(self, device: torch.device) -> None: # pylint: disable=invalid-name def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
"""Copy `DGLMiniBatch` to the specified device using reflection.""" """Copy `DGLMiniBatch` to the specified device using reflection."""
...@@ -440,13 +443,15 @@ def _minibatch_str(minibatch: MiniBatch) -> str: ...@@ -440,13 +443,15 @@ def _minibatch_str(minibatch: MiniBatch) -> str:
] ]
return "\n".join(lines) return "\n".join(lines)
# Let the variables in the list occupy one line each, # Let the variables in the list occupy one line each, and adjust the
# and adjust the indentation on top of the original # indentation on top of the original if the original data output has
# if the original data output has line feeds. # line feeds.
if isinstance(val, list): if isinstance(val, list):
# Special handling SampledSubgraphImpl data. if len(val) == 0:
# Line feeds variables within this type. val = "[]"
if isinstance( # Special handling of SampledSubgraphImpl data. Each element of
# the data occupies one row and is further structured.
elif isinstance(
val[0], val[0],
dgl.graphbolt.impl.sampled_subgraph_impl.SampledSubgraphImpl, dgl.graphbolt.impl.sampled_subgraph_impl.SampledSubgraphImpl,
): ):
...@@ -477,3 +482,58 @@ def _minibatch_str(minibatch: MiniBatch) -> str: ...@@ -477,3 +482,58 @@ def _minibatch_str(minibatch: MiniBatch) -> str:
final_str + f"{name}={_add_indent(val, len(name)+1)},\n" + " " * 10 final_str + f"{name}={_add_indent(val, len(name)+1)},\n" + " " * 10
) )
return "MiniBatch(" + final_str[:-3] + ")" return "MiniBatch(" + final_str[:-3] + ")"
def _dgl_minibatch_str(dglminibatch: DGLMiniBatch) -> str:
final_str = ""
# Get all attributes in the class except methods.
def _get_attributes(_obj) -> list:
attributes = [
attribute
for attribute in dir(_obj)
if not attribute.startswith("__")
and not callable(getattr(_obj, attribute))
]
return attributes
attributes = _get_attributes(dglminibatch)
attributes.reverse()
# Insert key with its value into the string.
for name in attributes:
val = getattr(dglminibatch, name)
def _add_indent(_str, indent):
lines = _str.split("\n")
lines = [lines[0]] + [" " * indent + line for line in lines[1:]]
return "\n".join(lines)
# Let the variables in the list occupy one line each, and adjust the
# indentation on top of the original if the original data output has
# line feeds.
if isinstance(val, list):
if len(val) == 0:
val = "[]"
# Special handling of blocks data. Each element of list occupies
# one row and is further structured.
elif name == "blocks":
blocks_strs = []
for block in val:
block_str = str(block).replace(" ", "\n")
block_str = _add_indent(block_str, len("Block") + 1)
blocks_strs.append(block_str)
val = "[" + ",\n".join(blocks_strs) + "]"
else:
val = [
_add_indent(
str(val_str), len(str(val_str).split("': ")[0]) + 3
)
for val_str in val
]
val = "[" + ",\n".join(val) + "]"
else:
val = str(val)
final_str = (
final_str + f"{name}={_add_indent(val, len(name)+15)},\n" + " " * 13
)
return "DGLMiniBatch(" + final_str[:-3] + ")"
...@@ -106,7 +106,7 @@ def create_hetero_minibatch(): ...@@ -106,7 +106,7 @@ def create_hetero_minibatch():
) )
def test_representation(): def test_minibatch_representation():
node_pairs = [ node_pairs = [
( (
torch.tensor([0, 1, 2, 2, 2, 1]), torch.tensor([0, 1, 2, 2, 2, 1]),
...@@ -220,6 +220,91 @@ def test_representation(): ...@@ -220,6 +220,91 @@ def test_representation():
assert result == expect_result, print(expect_result, result) assert result == expect_result, print(expect_result, result)
def test_dgl_minibatch_representation():
node_pairs = [
(
torch.tensor([0, 1, 2, 2, 2, 1]),
torch.tensor([0, 1, 1, 2, 3, 2]),
),
(
torch.tensor([0, 1, 2]),
torch.tensor([1, 0, 0]),
),
]
original_column_node_ids = [
torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11]),
]
original_row_node_ids = [
torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11, 12]),
]
original_edge_ids = [
torch.tensor([19, 20, 21, 22, 25, 30]),
torch.tensor([10, 15, 17]),
]
node_features = {"x": torch.tensor([7, 6, 2, 2])}
edge_features = [
{"x": torch.tensor([[8], [1], [6]])},
{"x": torch.tensor([[2], [8], [8]])},
]
subgraphs = []
for i in range(2):
subgraphs.append(
gb.SampledSubgraphImpl(
node_pairs=node_pairs[i],
original_column_node_ids=original_column_node_ids[i],
original_row_node_ids=original_row_node_ids[i],
original_edge_ids=original_edge_ids[i],
)
)
negative_srcs = torch.tensor([[8], [1], [6]])
negative_dsts = torch.tensor([[2], [8], [8]])
input_nodes = torch.tensor([8, 1, 6, 5, 9, 0, 2, 4])
compacted_node_pairs = (torch.tensor([0, 1, 2]), torch.tensor([3, 4, 5]))
compacted_negative_srcs = torch.tensor([0, 1, 2])
compacted_negative_dsts = torch.tensor([6, 0, 0])
labels = torch.tensor([0.0, 1.0, 2.0])
# Test dglminibatch with all attributes.
minibatch = gb.MiniBatch(
node_pairs=node_pairs,
sampled_subgraphs=subgraphs,
labels=labels,
node_features=node_features,
edge_features=edge_features,
negative_srcs=negative_srcs,
negative_dsts=negative_dsts,
compacted_node_pairs=compacted_node_pairs,
input_nodes=input_nodes,
compacted_negative_srcs=compacted_negative_srcs,
compacted_negative_dsts=compacted_negative_dsts,
)
dgl_minibatch = minibatch.to_dgl()
expect_result = str(
"""DGLMiniBatch(positive_node_pairs=(tensor([0, 1, 2]), tensor([3, 4, 5])),
output_nodes=None,
node_features={'x': tensor([7, 6, 2, 2])},
negative_node_pairs=(tensor([0, 1, 2]), tensor([6, 0, 0])),
labels=tensor([0., 1., 2.]),
input_nodes=tensor([8, 1, 6, 5, 9, 0, 2, 4]),
edge_features=[{'x': tensor([[8],
[1],
[6]])},
{'x': tensor([[2],
[8],
[8]])}],
blocks=[Block(num_src_nodes=4,
num_dst_nodes=4,
num_edges=6),
Block(num_src_nodes=3,
num_dst_nodes=2,
num_edges=3)],
)"""
)
result = str(dgl_minibatch)
assert result == expect_result, print(result)
def check_dgl_blocks_hetero(minibatch, blocks): def check_dgl_blocks_hetero(minibatch, blocks):
etype = gb.etype_str_to_tuple(relation) etype = gb.etype_str_to_tuple(relation)
node_pairs = [ node_pairs = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment