Unverified Commit 12fab559 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt][PyG] Add more attributes in `to_pyg_data` (#7196)

parent a6505e86
......@@ -88,18 +88,18 @@ class GraphSAGE(torch.nn.Module):
x = F.dropout(x, p=0.5, training=self.training)
return x
def inference(self, args, dataloader, x_all, device):
def inference(self, dataloader, x_all, device):
"""Conduct layer-wise inference to get all the node embeddings."""
for i, layer in tqdm(enumerate(self.layers), "inference"):
xs = []
for minibatch in dataloader:
# Call `to_pyg_data` to convert GB Minibatch to PyG Data.
pyg_data = minibatch.to_pyg_data()
n_ids = minibatch.node_ids().to("cpu")
x = x_all[n_ids].to(device)
n_id = pyg_data.n_id.to("cpu")
x = x_all[n_id].to(device)
edge_index = pyg_data.edge_index
x = layer(x, edge_index)
x = x[: 4 * args.batch_size]
x = x[: pyg_data.batch_size]
if i != len(self.layers) - 1:
x = x.relu()
xs.append(x.cpu())
......@@ -185,11 +185,11 @@ def evaluate(model, dataloader, num_classes):
@torch.no_grad()
def layerwise_infer(
model, args, infer_dataloader, test_set, feature, num_classes, device
model, infer_dataloader, test_set, feature, num_classes, device
):
model.eval()
features = feature.read("node", None, "feat")
pred = model.inference(args, infer_dataloader, features, device)
pred = model.inference(infer_dataloader, features, device)
pred = pred[test_set._items[0]]
label = test_set._items[1].to(pred.device)
......@@ -271,7 +271,7 @@ def main():
f"Valid Accuracy: {valid_accuracy:.4f}"
)
test_accuracy = layerwise_infer(
model, args, infer_dataloader, test_set, feature, num_classes, device
model, infer_dataloader, test_set, feature, num_classes, device
)
print(f"Test Accuracy: {test_accuracy:.4f}")
......
......@@ -526,10 +526,24 @@ class MiniBatch:
), "`to_pyg_data` only supports single feature homogeneous graph."
node_features = next(iter(self.node_features.values()))
if self.seed_nodes is not None:
if isinstance(self.seed_nodes, Dict):
batch_size = len(next(iter(self.seed_nodes.values())))
else:
batch_size = len(self.seed_nodes)
elif self.node_pairs is not None:
if isinstance(self.node_pairs, Dict):
batch_size = len(next(iter(self.node_pairs.values()))[0])
else:
batch_size = len(self.node_pairs[0])
else:
batch_size = None
pyg_data = Data(
x=node_features,
edge_index=edge_index,
y=self.labels,
batch_size=batch_size,
n_id=self.node_ids(),
)
return pyg_data
......
......@@ -869,40 +869,27 @@ def test_dgl_link_predication_hetero(mode):
def test_to_pyg_data():
test_subgraph_a = gb.SampledSubgraphImpl(
sampled_csc=gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3, 5, 6]),
indices=torch.tensor([0, 1, 2, 2, 1, 2]),
),
original_column_node_ids=torch.tensor([10, 11, 12, 13]),
original_row_node_ids=torch.tensor([19, 20, 21, 22, 25, 30]),
original_edge_ids=torch.tensor([10, 11, 12, 13]),
)
test_subgraph_b = gb.SampledSubgraphImpl(
sampled_csc=gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3]),
indices=torch.tensor([1, 2, 0]),
),
original_row_node_ids=torch.tensor([10, 11, 12]),
original_edge_ids=torch.tensor([10, 15, 17]),
original_column_node_ids=torch.tensor([10, 11]),
)
test_minibatch = create_homo_minibatch()
test_minibatch.seed_nodes = torch.tensor([0, 1])
test_minibatch.labels = torch.tensor([7, 8])
expected_edge_index = torch.tensor(
[[0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 0, 1, 2, 1, 2, 3]]
)
expected_node_features = torch.tensor([[1], [2], [3], [4]])
expected_labels = torch.tensor([0, 1])
test_minibatch = gb.MiniBatch(
sampled_subgraphs=[test_subgraph_a, test_subgraph_b],
node_features={"feat": expected_node_features},
labels=expected_labels,
[[0, 0, 1, 1, 1, 2, 2, 2, 2], [0, 1, 0, 1, 2, 0, 1, 2, 3]]
)
expected_node_features = next(iter(test_minibatch.node_features.values()))
expected_labels = torch.tensor([7, 8])
expected_batch_size = 2
expected_n_id = torch.tensor([10, 11, 12, 13])
pyg_data = test_minibatch.to_pyg_data()
pyg_data.validate()
assert torch.equal(pyg_data.edge_index, expected_edge_index)
assert torch.equal(pyg_data.x, expected_node_features)
assert torch.equal(pyg_data.y, expected_labels)
assert pyg_data.batch_size == expected_batch_size
assert torch.equal(pyg_data.n_id, expected_n_id)
subgraph = test_minibatch.sampled_subgraphs[0]
# Test with sampled_csc as None.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=None,
......@@ -914,7 +901,7 @@ def test_to_pyg_data():
# Test with node_features as None.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=[test_subgraph_a],
sampled_subgraphs=[subgraph],
node_features=None,
labels=expected_labels,
)
......@@ -923,7 +910,7 @@ def test_to_pyg_data():
# Test with labels as None.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=[test_subgraph_a],
sampled_subgraphs=[subgraph],
node_features={"feat": expected_node_features},
labels=None,
)
......@@ -932,7 +919,7 @@ def test_to_pyg_data():
# Test with multiple features.
test_minibatch = gb.MiniBatch(
sampled_subgraphs=[test_subgraph_a],
sampled_subgraphs=[subgraph],
node_features={
"feat": expected_node_features,
"extra_feat": torch.tensor([[3], [4]]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment