Unverified Commit c2ffe3bf authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Update notebook examples. (#7266)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent d3483fe1
......@@ -77,7 +77,7 @@
},
"outputs": [],
"source": [
"dataset = gb.BuiltinDataset(\"cora\").load()"
"dataset = gb.BuiltinDataset(\"cora-seeds\").load()"
]
},
{
......@@ -255,7 +255,8 @@
" total_loss = 0\n",
" for step, data in tqdm(enumerate(create_train_dataloader())):\n",
" # Get node pairs with labels for loss calculation.\n",
" compacted_pairs, labels = data.node_pairs_with_labels\n",
" compacted_seeds = data.compacted_seeds.T\n",
" labels = data.labels\n",
" node_feature = data.node_features[\"feat\"]\n",
" # Convert sampled subgraphs to DGL blocks.\n",
" blocks = data.blocks\n",
......@@ -263,7 +264,7 @@
" # Get the embeddings of the input nodes.\n",
" y = model(blocks, node_feature)\n",
" logits = model.predictor(\n",
" y[compacted_pairs[0]] * y[compacted_pairs[1]]\n",
" y[compacted_seeds[0]] * y[compacted_seeds[1]]\n",
" ).squeeze()\n",
"\n",
" # Compute loss.\n",
......@@ -308,7 +309,8 @@
"labels = []\n",
"for step, data in tqdm(enumerate(eval_dataloader)):\n",
" # Get node pairs with labels for loss calculation.\n",
" compacted_pairs, label = data.node_pairs_with_labels\n",
" compacted_seeds = data.compacted_seeds.T\n",
" label = data.labels\n",
"\n",
" # The features of sampled nodes.\n",
" x = data.node_features[\"feat\"]\n",
......@@ -316,7 +318,7 @@
" # Forward.\n",
" y = model(data.blocks, x)\n",
" logit = (\n",
" model.predictor(y[compacted_pairs[0]] * y[compacted_pairs[1]])\n",
" model.predictor(y[compacted_seeds[0]] * y[compacted_seeds[1]])\n",
" .squeeze()\n",
" .detach()\n",
" )\n",
......
......@@ -78,7 +78,7 @@
},
"outputs": [],
"source": [
"dataset = gb.BuiltinDataset(\"ogbn-arxiv\").load()"
"dataset = gb.BuiltinDataset(\"ogbn-arxiv-seeds\").load()"
]
},
{
......@@ -143,7 +143,7 @@
"source": [
"def create_dataloader(itemset, shuffle):\n",
" datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=shuffle)\n",
" datapipe = datapipe.copy_to(device, extra_attrs=[\"seed_nodes\"])\n",
" datapipe = datapipe.copy_to(device, extra_attrs=[\"seeds\"])\n",
" datapipe = datapipe.sample_neighbor(graph, [4, 4])\n",
" datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n",
" return gb.DataLoader(datapipe)"
......@@ -375,4 +375,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
\ No newline at end of file
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment