"src/vscode:/vscode.git/clone" did not exist on "c6714fc3bfc4b8ccba08ea68cebb095f2af1d75e"
Unverified Commit 12fab559 authored by Ramon Zhou's avatar Ramon Zhou Committed by GitHub
Browse files

[GraphBolt][PyG] Add more attributes in `to_pyg_data` (#7196)

parent a6505e86
...@@ -88,18 +88,18 @@ class GraphSAGE(torch.nn.Module): ...@@ -88,18 +88,18 @@ class GraphSAGE(torch.nn.Module):
x = F.dropout(x, p=0.5, training=self.training) x = F.dropout(x, p=0.5, training=self.training)
return x return x
def inference(self, args, dataloader, x_all, device): def inference(self, dataloader, x_all, device):
"""Conduct layer-wise inference to get all the node embeddings.""" """Conduct layer-wise inference to get all the node embeddings."""
for i, layer in tqdm(enumerate(self.layers), "inference"): for i, layer in tqdm(enumerate(self.layers), "inference"):
xs = [] xs = []
for minibatch in dataloader: for minibatch in dataloader:
# Call `to_pyg_data` to convert GB Minibatch to PyG Data. # Call `to_pyg_data` to convert GB Minibatch to PyG Data.
pyg_data = minibatch.to_pyg_data() pyg_data = minibatch.to_pyg_data()
n_ids = minibatch.node_ids().to("cpu") n_id = pyg_data.n_id.to("cpu")
x = x_all[n_ids].to(device) x = x_all[n_id].to(device)
edge_index = pyg_data.edge_index edge_index = pyg_data.edge_index
x = layer(x, edge_index) x = layer(x, edge_index)
x = x[: 4 * args.batch_size] x = x[: pyg_data.batch_size]
if i != len(self.layers) - 1: if i != len(self.layers) - 1:
x = x.relu() x = x.relu()
xs.append(x.cpu()) xs.append(x.cpu())
...@@ -185,11 +185,11 @@ def evaluate(model, dataloader, num_classes): ...@@ -185,11 +185,11 @@ def evaluate(model, dataloader, num_classes):
@torch.no_grad() @torch.no_grad()
def layerwise_infer( def layerwise_infer(
model, args, infer_dataloader, test_set, feature, num_classes, device model, infer_dataloader, test_set, feature, num_classes, device
): ):
model.eval() model.eval()
features = feature.read("node", None, "feat") features = feature.read("node", None, "feat")
pred = model.inference(args, infer_dataloader, features, device) pred = model.inference(infer_dataloader, features, device)
pred = pred[test_set._items[0]] pred = pred[test_set._items[0]]
label = test_set._items[1].to(pred.device) label = test_set._items[1].to(pred.device)
...@@ -271,7 +271,7 @@ def main(): ...@@ -271,7 +271,7 @@ def main():
f"Valid Accuracy: {valid_accuracy:.4f}" f"Valid Accuracy: {valid_accuracy:.4f}"
) )
test_accuracy = layerwise_infer( test_accuracy = layerwise_infer(
model, args, infer_dataloader, test_set, feature, num_classes, device model, infer_dataloader, test_set, feature, num_classes, device
) )
print(f"Test Accuracy: {test_accuracy:.4f}") print(f"Test Accuracy: {test_accuracy:.4f}")
......
...@@ -526,10 +526,24 @@ class MiniBatch: ...@@ -526,10 +526,24 @@ class MiniBatch:
), "`to_pyg_data` only supports single feature homogeneous graph." ), "`to_pyg_data` only supports single feature homogeneous graph."
node_features = next(iter(self.node_features.values())) node_features = next(iter(self.node_features.values()))
if self.seed_nodes is not None:
if isinstance(self.seed_nodes, Dict):
batch_size = len(next(iter(self.seed_nodes.values())))
else:
batch_size = len(self.seed_nodes)
elif self.node_pairs is not None:
if isinstance(self.node_pairs, Dict):
batch_size = len(next(iter(self.node_pairs.values()))[0])
else:
batch_size = len(self.node_pairs[0])
else:
batch_size = None
pyg_data = Data( pyg_data = Data(
x=node_features, x=node_features,
edge_index=edge_index, edge_index=edge_index,
y=self.labels, y=self.labels,
batch_size=batch_size,
n_id=self.node_ids(),
) )
return pyg_data return pyg_data
......
...@@ -869,40 +869,27 @@ def test_dgl_link_predication_hetero(mode): ...@@ -869,40 +869,27 @@ def test_dgl_link_predication_hetero(mode):
def test_to_pyg_data(): def test_to_pyg_data():
test_subgraph_a = gb.SampledSubgraphImpl( test_minibatch = create_homo_minibatch()
sampled_csc=gb.CSCFormatBase( test_minibatch.seed_nodes = torch.tensor([0, 1])
indptr=torch.tensor([0, 1, 3, 5, 6]), test_minibatch.labels = torch.tensor([7, 8])
indices=torch.tensor([0, 1, 2, 2, 1, 2]),
),
original_column_node_ids=torch.tensor([10, 11, 12, 13]),
original_row_node_ids=torch.tensor([19, 20, 21, 22, 25, 30]),
original_edge_ids=torch.tensor([10, 11, 12, 13]),
)
test_subgraph_b = gb.SampledSubgraphImpl(
sampled_csc=gb.CSCFormatBase(
indptr=torch.tensor([0, 1, 3]),
indices=torch.tensor([1, 2, 0]),
),
original_row_node_ids=torch.tensor([10, 11, 12]),
original_edge_ids=torch.tensor([10, 15, 17]),
original_column_node_ids=torch.tensor([10, 11]),
)
expected_edge_index = torch.tensor( expected_edge_index = torch.tensor(
[[0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 0, 1, 2, 1, 2, 3]] [[0, 0, 1, 1, 1, 2, 2, 2, 2], [0, 1, 0, 1, 2, 0, 1, 2, 3]]
)
expected_node_features = torch.tensor([[1], [2], [3], [4]])
expected_labels = torch.tensor([0, 1])
test_minibatch = gb.MiniBatch(
sampled_subgraphs=[test_subgraph_a, test_subgraph_b],
node_features={"feat": expected_node_features},
labels=expected_labels,
) )
expected_node_features = next(iter(test_minibatch.node_features.values()))
expected_labels = torch.tensor([7, 8])
expected_batch_size = 2
expected_n_id = torch.tensor([10, 11, 12, 13])
pyg_data = test_minibatch.to_pyg_data() pyg_data = test_minibatch.to_pyg_data()
pyg_data.validate() pyg_data.validate()
assert torch.equal(pyg_data.edge_index, expected_edge_index) assert torch.equal(pyg_data.edge_index, expected_edge_index)
assert torch.equal(pyg_data.x, expected_node_features) assert torch.equal(pyg_data.x, expected_node_features)
assert torch.equal(pyg_data.y, expected_labels) assert torch.equal(pyg_data.y, expected_labels)
assert pyg_data.batch_size == expected_batch_size
assert torch.equal(pyg_data.n_id, expected_n_id)
subgraph = test_minibatch.sampled_subgraphs[0]
# Test with sampled_csc as None. # Test with sampled_csc as None.
test_minibatch = gb.MiniBatch( test_minibatch = gb.MiniBatch(
sampled_subgraphs=None, sampled_subgraphs=None,
...@@ -914,7 +901,7 @@ def test_to_pyg_data(): ...@@ -914,7 +901,7 @@ def test_to_pyg_data():
# Test with node_features as None. # Test with node_features as None.
test_minibatch = gb.MiniBatch( test_minibatch = gb.MiniBatch(
sampled_subgraphs=[test_subgraph_a], sampled_subgraphs=[subgraph],
node_features=None, node_features=None,
labels=expected_labels, labels=expected_labels,
) )
...@@ -923,7 +910,7 @@ def test_to_pyg_data(): ...@@ -923,7 +910,7 @@ def test_to_pyg_data():
# Test with labels as None. # Test with labels as None.
test_minibatch = gb.MiniBatch( test_minibatch = gb.MiniBatch(
sampled_subgraphs=[test_subgraph_a], sampled_subgraphs=[subgraph],
node_features={"feat": expected_node_features}, node_features={"feat": expected_node_features},
labels=None, labels=None,
) )
...@@ -932,7 +919,7 @@ def test_to_pyg_data(): ...@@ -932,7 +919,7 @@ def test_to_pyg_data():
# Test with multiple features. # Test with multiple features.
test_minibatch = gb.MiniBatch( test_minibatch = gb.MiniBatch(
sampled_subgraphs=[test_subgraph_a], sampled_subgraphs=[subgraph],
node_features={ node_features={
"feat": expected_node_features, "feat": expected_node_features,
"extra_feat": torch.tensor([[3], [4]]), "extra_feat": torch.tensor([[3], [4]]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment