Unverified Commit c2ffe3bf authored by yxy235's avatar yxy235 Committed by GitHub
Browse files

[GraphBolt] Update notebook examples. (#7266)


Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
parent d3483fe1
...@@ -77,7 +77,7 @@ ...@@ -77,7 +77,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"dataset = gb.BuiltinDataset(\"cora\").load()" "dataset = gb.BuiltinDataset(\"cora-seeds\").load()"
] ]
}, },
{ {
...@@ -255,7 +255,8 @@ ...@@ -255,7 +255,8 @@
" total_loss = 0\n", " total_loss = 0\n",
" for step, data in tqdm(enumerate(create_train_dataloader())):\n", " for step, data in tqdm(enumerate(create_train_dataloader())):\n",
" # Get node pairs with labels for loss calculation.\n", " # Get node pairs with labels for loss calculation.\n",
" compacted_pairs, labels = data.node_pairs_with_labels\n", " compacted_seeds = data.compacted_seeds.T\n",
" labels = data.labels\n",
" node_feature = data.node_features[\"feat\"]\n", " node_feature = data.node_features[\"feat\"]\n",
" # Convert sampled subgraphs to DGL blocks.\n", " # Convert sampled subgraphs to DGL blocks.\n",
" blocks = data.blocks\n", " blocks = data.blocks\n",
...@@ -263,7 +264,7 @@ ...@@ -263,7 +264,7 @@
" # Get the embeddings of the input nodes.\n", " # Get the embeddings of the input nodes.\n",
" y = model(blocks, node_feature)\n", " y = model(blocks, node_feature)\n",
" logits = model.predictor(\n", " logits = model.predictor(\n",
" y[compacted_pairs[0]] * y[compacted_pairs[1]]\n", " y[compacted_seeds[0]] * y[compacted_seeds[1]]\n",
" ).squeeze()\n", " ).squeeze()\n",
"\n", "\n",
" # Compute loss.\n", " # Compute loss.\n",
...@@ -308,7 +309,8 @@ ...@@ -308,7 +309,8 @@
"labels = []\n", "labels = []\n",
"for step, data in tqdm(enumerate(eval_dataloader)):\n", "for step, data in tqdm(enumerate(eval_dataloader)):\n",
" # Get node pairs with labels for loss calculation.\n", " # Get node pairs with labels for loss calculation.\n",
" compacted_pairs, label = data.node_pairs_with_labels\n", " compacted_seeds = data.compacted_seeds.T\n",
" label = data.labels\n",
"\n", "\n",
" # The features of sampled nodes.\n", " # The features of sampled nodes.\n",
" x = data.node_features[\"feat\"]\n", " x = data.node_features[\"feat\"]\n",
...@@ -316,7 +318,7 @@ ...@@ -316,7 +318,7 @@
" # Forward.\n", " # Forward.\n",
" y = model(data.blocks, x)\n", " y = model(data.blocks, x)\n",
" logit = (\n", " logit = (\n",
" model.predictor(y[compacted_pairs[0]] * y[compacted_pairs[1]])\n", " model.predictor(y[compacted_seeds[0]] * y[compacted_seeds[1]])\n",
" .squeeze()\n", " .squeeze()\n",
" .detach()\n", " .detach()\n",
" )\n", " )\n",
......
...@@ -78,7 +78,7 @@ ...@@ -78,7 +78,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"dataset = gb.BuiltinDataset(\"ogbn-arxiv\").load()" "dataset = gb.BuiltinDataset(\"ogbn-arxiv-seeds\").load()"
] ]
}, },
{ {
...@@ -143,7 +143,7 @@ ...@@ -143,7 +143,7 @@
"source": [ "source": [
"def create_dataloader(itemset, shuffle):\n", "def create_dataloader(itemset, shuffle):\n",
" datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=shuffle)\n", " datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=shuffle)\n",
" datapipe = datapipe.copy_to(device, extra_attrs=[\"seed_nodes\"])\n", " datapipe = datapipe.copy_to(device, extra_attrs=[\"seeds\"])\n",
" datapipe = datapipe.sample_neighbor(graph, [4, 4])\n", " datapipe = datapipe.sample_neighbor(graph, [4, 4])\n",
" datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n", " datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n",
" return gb.DataLoader(datapipe)" " return gb.DataLoader(datapipe)"
...@@ -375,4 +375,4 @@ ...@@ -375,4 +375,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 0 "nbformat_minor": 0
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment