"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "fb7f9a16628cb0813ac958da4525247e325cc3d2"
Unverified Commit 25209120 authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Improve MiniBatch repr layout (#6356)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-8-56.us-west-2.compute.internal>
parent c0ac2f60
...@@ -225,3 +225,71 @@ class MiniBatch: ...@@ -225,3 +225,71 @@ class MiniBatch:
block.edata[dgl.EID] = subgraph.reverse_edge_ids block.edata[dgl.EID] = subgraph.reverse_edge_ids
return blocks return blocks
def __repr__(self) -> str:
return _minibatch_str(self)
def _minibatch_str(minibatch: MiniBatch) -> str:
final_str = ""
# Get all attributes in the class except methods.
def _get_attributes(_obj) -> list:
attributes = [
attribute
for attribute in dir(_obj)
if not attribute.startswith("__")
and not callable(getattr(_obj, attribute))
]
return attributes
attributes = _get_attributes(minibatch)
attributes.reverse()
# Insert key with its value into the string.
for name in attributes:
val = getattr(minibatch, name)
def _add_indent(_str, indent):
lines = _str.split("\n")
lines = [lines[0]] + [
" " * (indent + 10) + line for line in lines[1:]
]
return "\n".join(lines)
# Let the variables in the list occupy one line each,
# and adjust the indentation on top of the original
# if the original data output has line feeds.
if isinstance(val, list):
# Special handling SampledSubgraphImpl data.
# Line feeds variables within this type.
if isinstance(
val[0],
dgl.graphbolt.impl.sampled_subgraph_impl.SampledSubgraphImpl,
):
sampledsubgraph_strs = []
for sampledsubgraph in val:
ss_attributes = _get_attributes(sampledsubgraph)
sampledsubgraph_str = "SampledSubgraphImpl("
for ss_name in ss_attributes:
ss_val = str(getattr(sampledsubgraph, ss_name))
sampledsubgraph_str = (
sampledsubgraph_str
+ f"{ss_name}={_add_indent(ss_val, len(ss_name)+1)},\n"
+ " " * 20
)
sampledsubgraph_strs.append(sampledsubgraph_str[:-21] + ")")
val = "[" + ",\n".join(sampledsubgraph_strs) + "]"
else:
val = [
_add_indent(
str(val_str), len(str(val_str).split("': ")[0]) - 6
)
for val_str in val
]
val = "[" + ",\n".join(val) + "]"
else:
val = str(val)
final_str = (
final_str + f"{name}={_add_indent(val, len(name)+1)},\n" + " " * 10
)
return "MiniBatch(" + final_str[:-3] + ")"
...@@ -137,3 +137,117 @@ def test_to_dgl_blocks_homo(): ...@@ -137,3 +137,117 @@ def test_to_dgl_blocks_homo():
assert torch.equal(block.edata["x"], edge_features[i]["x"]) assert torch.equal(block.edata["x"], edge_features[i]["x"])
assert torch.equal(blocks[0].srcdata[dgl.NID], reverse_row_node_ids[0]) assert torch.equal(blocks[0].srcdata[dgl.NID], reverse_row_node_ids[0])
assert torch.equal(blocks[0].srcdata["x"], node_features["x"]) assert torch.equal(blocks[0].srcdata["x"], node_features["x"])
def test_representation():
node_pairs = [
(
torch.tensor([0, 1, 2, 2, 2, 1]),
torch.tensor([0, 1, 1, 2, 3, 2]),
),
(
torch.tensor([0, 1, 2]),
torch.tensor([1, 0, 0]),
),
]
reverse_column_node_ids = [
torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11]),
]
reverse_row_node_ids = [
torch.tensor([10, 11, 12, 13]),
torch.tensor([10, 11, 12]),
]
reverse_edge_ids = [
torch.tensor([19, 20, 21, 22, 25, 30]),
torch.tensor([10, 15, 17]),
]
node_features = {"x": torch.tensor([7, 6, 2, 2])}
edge_features = [
{"x": torch.tensor([[8], [1], [6]])},
{"x": torch.tensor([[2], [8], [8]])},
]
subgraphs = []
for i in range(2):
subgraphs.append(
gb.SampledSubgraphImpl(
node_pairs=node_pairs[i],
reverse_column_node_ids=reverse_column_node_ids[i],
reverse_row_node_ids=reverse_row_node_ids[i],
reverse_edge_ids=reverse_edge_ids[i],
)
)
negative_srcs = torch.tensor([[8], [1], [6]])
negative_dsts = torch.tensor([[2], [8], [8]])
input_nodes = torch.tensor([8, 1, 6, 5, 9, 0, 2, 4])
compacted_node_pairs = (torch.tensor([0, 1, 2]), torch.tensor([3, 4, 5]))
compacted_negative_srcs = torch.tensor([0, 1, 2])
compacted_negative_dsts = torch.tensor([6, 0, 0])
labels = torch.tensor([0.0, 1.0, 2.0])
# Test minibatch without data.
minibatch = gb.MiniBatch()
expect_result = str(
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=None,
node_pairs=None,
node_features=None,
negative_srcs=None,
negative_dsts=None,
labels=None,
input_nodes=None,
edge_features=None,
compacted_node_pairs=None,
compacted_negative_srcs=None,
compacted_negative_dsts=None,
)"""
)
result = str(minibatch)
assert result == expect_result, print(len(expect_result), len(result))
# Test minibatch with all attributes.
minibatch = gb.MiniBatch(
node_pairs=node_pairs,
sampled_subgraphs=subgraphs,
labels=labels,
node_features=node_features,
edge_features=edge_features,
negative_srcs=negative_srcs,
negative_dsts=negative_dsts,
compacted_node_pairs=compacted_node_pairs,
input_nodes=input_nodes,
compacted_negative_srcs=compacted_negative_srcs,
compacted_negative_dsts=compacted_negative_dsts,
)
expect_result = str(
"""MiniBatch(seed_nodes=None,
sampled_subgraphs=[SampledSubgraphImpl(node_pairs=(tensor([0, 1, 2, 2, 2, 1]), tensor([0, 1, 1, 2, 3, 2])),
reverse_column_node_ids=tensor([10, 11, 12, 13]),
reverse_edge_ids=tensor([19, 20, 21, 22, 25, 30]),
reverse_row_node_ids=tensor([10, 11, 12, 13]),),
SampledSubgraphImpl(node_pairs=(tensor([0, 1, 2]), tensor([1, 0, 0])),
reverse_column_node_ids=tensor([10, 11]),
reverse_edge_ids=tensor([10, 15, 17]),
reverse_row_node_ids=tensor([10, 11, 12]),)],
node_pairs=[(tensor([0, 1, 2, 2, 2, 1]), tensor([0, 1, 1, 2, 3, 2])),
(tensor([0, 1, 2]), tensor([1, 0, 0]))],
node_features={'x': tensor([7, 6, 2, 2])},
negative_srcs=tensor([[8],
[1],
[6]]),
negative_dsts=tensor([[2],
[8],
[8]]),
labels=tensor([0., 1., 2.]),
input_nodes=tensor([8, 1, 6, 5, 9, 0, 2, 4]),
edge_features=[{'x': tensor([[8],
[1],
[6]])},
{'x': tensor([[2],
[8],
[8]])}],
compacted_node_pairs=(tensor([0, 1, 2]), tensor([3, 4, 5])),
compacted_negative_srcs=tensor([0, 1, 2]),
compacted_negative_dsts=tensor([6, 0, 0]),
)"""
)
result = str(minibatch)
assert result == expect_result, print(expect_result, result)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment