Unverified Commit d57564ac authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

Gb to dgl nb (#6764)

parent 9d417346
...@@ -218,29 +218,6 @@ ...@@ -218,29 +218,6 @@
"print(next(iter(datapipe)))" "print(next(iter(datapipe)))"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {
"id": "Gt059n1xrmj-"
},
"source": [
"After retrieving the required data, Graphbolt provides helper methods to convert it to the output format needed for subsequent GNN training.\n",
"\n",
"* Convert to **DGLMiniBatch** format for training with DGL."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "o8Yoi8BeqSdu"
},
"outputs": [],
"source": [
"datapipe = datapipe.to_dgl()\n",
"print(next(iter(datapipe)))"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
......
...@@ -172,27 +172,6 @@ ...@@ -172,27 +172,6 @@
"execution_count": null, "execution_count": null,
"outputs": [] "outputs": []
}, },
{
"cell_type": "markdown",
"source": [
"In order to train with DGL, you need to convert `MiniBatch` to `DGLMiniBatch` like below:"
],
"metadata": {
"id": "IpAgrEp_cdEP"
}
},
{
"cell_type": "code",
"source": [
"data = data.to_dgl()\n",
"print(f\"DGLMiniBatch: {data}\")"
],
"metadata": {
"id": "KQgxFUyCcjVT"
},
"execution_count": null,
"outputs": []
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [ "source": [
...@@ -263,38 +242,6 @@ ...@@ -263,38 +242,6 @@
"execution_count": null, "execution_count": null,
"outputs": [] "outputs": []
}, },
{
"cell_type": "markdown",
"source": [
"Define utility function to vonvert the minibatch to a training pair and a label tensor.\n",
"\n"
],
"metadata": {
"id": "J9K1GUs4ZDYw"
}
},
{
"cell_type": "code",
"source": [
"def to_binary_link_dgl_computing_pack(data: gb.DGLMiniBatch):\n",
" \"\"\"Convert the minibatch to a training pair and a label tensor.\"\"\"\n",
" pos_src, pos_dst = data.positive_node_pairs\n",
" neg_src, neg_dst = data.negative_node_pairs\n",
" node_pairs = (\n",
" torch.cat((pos_src, neg_src), dim=0),\n",
" torch.cat((pos_dst, neg_dst), dim=0),\n",
" )\n",
" pos_label = torch.ones_like(pos_src)\n",
" neg_label = torch.zeros_like(neg_src)\n",
" labels = torch.cat([pos_label, neg_label], dim=0)\n",
" return (node_pairs, labels.float())"
],
"metadata": {
"id": "wvIJBPb7ZNUv"
},
"execution_count": null,
"outputs": []
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [ "source": [
...@@ -313,11 +260,8 @@ ...@@ -313,11 +260,8 @@
" model.train()\n", " model.train()\n",
" total_loss = 0\n", " total_loss = 0\n",
" for step, data in tqdm.tqdm(enumerate(train_dataloader)):\n", " for step, data in tqdm.tqdm(enumerate(train_dataloader)):\n",
" # Convert to DGL format.\n", " # Get node pairs with labels for loss calculation.\n",
" data = data.to_dgl()\n", " compacted_pairs, labels = data.node_pairs_with_labels\n",
"\n",
" # Unpack DGLMiniBatch.\n",
" compacted_pairs, labels = to_binary_link_dgl_computing_pack(data)\n",
" node_feature = data.node_features[\"feat\"]\n", " node_feature = data.node_features[\"feat\"]\n",
" # Convert sampled subgraphs to DGL blocks.\n", " # Convert sampled subgraphs to DGL blocks.\n",
" blocks = data.blocks\n", " blocks = data.blocks\n",
...@@ -369,11 +313,8 @@ ...@@ -369,11 +313,8 @@
"logits = []\n", "logits = []\n",
"labels = []\n", "labels = []\n",
"for step, data in enumerate(eval_dataloader):\n", "for step, data in enumerate(eval_dataloader):\n",
" # Convert to DGL format.\n", " # Get node pairs with labels for loss calculation.\n",
" data = data.to_dgl()\n", " compacted_pairs, label = data.node_pairs_with_labels\n",
"\n",
" # Unpack MiniBatch.\n",
" compacted_pairs, label = to_binary_link_dgl_computing_pack(data)\n",
"\n", "\n",
" # The features of sampled nodes.\n", " # The features of sampled nodes.\n",
" x = data.node_features[\"feat\"]\n", " x = data.node_features[\"feat\"]\n",
......
...@@ -183,27 +183,6 @@ ...@@ -183,27 +183,6 @@
"execution_count": null, "execution_count": null,
"outputs": [] "outputs": []
}, },
{
"cell_type": "markdown",
"source": [
"In order to train with DGL, you need to convert `MiniBatch` to `DGLMiniBatch` like below:"
],
"metadata": {
"id": "FwDJf1AJbNtt"
}
},
{
"cell_type": "code",
"source": [
"data = data.to_dgl()\n",
"print(data)"
],
"metadata": {
"id": "3Tzfp6A8bdWv"
},
"execution_count": null,
"outputs": []
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [ "source": [
...@@ -336,7 +315,6 @@ ...@@ -336,7 +315,6 @@
"\n", "\n",
" with tqdm.tqdm(train_dataloader) as tq:\n", " with tqdm.tqdm(train_dataloader) as tq:\n",
" for step, data in enumerate(tq):\n", " for step, data in enumerate(tq):\n",
" data = data.to_dgl()\n",
" x = data.node_features[\"feat\"]\n", " x = data.node_features[\"feat\"]\n",
" labels = data.labels\n", " labels = data.labels\n",
"\n", "\n",
...@@ -363,7 +341,6 @@ ...@@ -363,7 +341,6 @@
" labels = []\n", " labels = []\n",
" with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():\n", " with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():\n",
" for data in tq:\n", " for data in tq:\n",
" data = data.to_dgl()\n",
" x = data.node_features[\"feat\"]\n", " x = data.node_features[\"feat\"]\n",
" labels.append(data.labels.cpu().numpy())\n", " labels.append(data.labels.cpu().numpy())\n",
" predictions.append(model(data.blocks, x).argmax(1).cpu().numpy())\n", " predictions.append(model(data.blocks, x).argmax(1).cpu().numpy())\n",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment