Unverified Commit de14619c authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] fix several issues in doc and notebook (#6775)

parent 1db71cfe
...@@ -5,12 +5,12 @@ ...@@ -5,12 +5,12 @@
Implementing custom samplers involves subclassing the Implementing custom samplers involves subclassing the
:class:`dgl.graphbolt.SubgraphSampler` base class and implementing its abstract :class:`dgl.graphbolt.SubgraphSampler` base class and implementing its abstract
:attr:`_sample_subgraphs` method. The :attr:`_sample_subgraphs` method should :attr:`sample_subgraphs` method. The :attr:`sample_subgraphs` method should
take in seed nodes which are the nodes to sample neighbors from: take in seed nodes which are the nodes to sample neighbors from:
.. code:: python .. code:: python
def _sample_subgraphs(self, seed_nodes): def sample_subgraphs(self, seed_nodes):
return input_nodes, sampled_subgraphs return input_nodes, sampled_subgraphs
The method should return the input node IDs list and a list of subgraphs. Each The method should return the input node IDs list and a list of subgraphs. Each
...@@ -31,7 +31,7 @@ The code below implements a classical neighbor sampler: ...@@ -31,7 +31,7 @@ The code below implements a classical neighbor sampler:
self.graph = graph self.graph = graph
self.fanouts = fanouts self.fanouts = fanouts
def _sample_subgraphs(self, seed_nodes): def sample_subgraphs(self, seed_nodes):
subgs = [] subgs = []
for fanout in reversed(self.fanouts): for fanout in reversed(self.fanouts):
# Sample a fixed number of neighbors of the current seed nodes. # Sample a fixed number of neighbors of the current seed nodes.
......
...@@ -29,7 +29,7 @@ will customize another sampler with DGL sparse library as shown below. ...@@ -29,7 +29,7 @@ will customize another sampler with DGL sparse library as shown below.
fanout = torch.LongTensor([int(fanout)]) fanout = torch.LongTensor([int(fanout)])
self.fanouts.insert(0, fanout) self.fanouts.insert(0, fanout)
def _sample_subgraphs(self, seeds): def sample_subgraphs(self, seeds):
sampled_matrices = [] sampled_matrices = []
src = seeds src = seeds
......
...@@ -54,7 +54,8 @@ ...@@ -54,7 +54,8 @@
"os.environ['TORCH'] = torch.__version__\n", "os.environ['TORCH'] = torch.__version__\n",
"os.environ['DGLBACKEND'] = \"pytorch\"\n", "os.environ['DGLBACKEND'] = \"pytorch\"\n",
"\n", "\n",
"# Install the CPU version.\n", "# Install the CPU version. If you want to install CUDA version, please\n",
"# refer to https://www.dgl.ai/pages/start.html.\n",
"device = torch.device(\"cpu\")\n", "device = torch.device(\"cpu\")\n",
"!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html\n", "!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html\n",
"\n", "\n",
...@@ -339,7 +340,7 @@ ...@@ -339,7 +340,7 @@
"# Compute the AUROC score.\n", "# Compute the AUROC score.\n",
"from sklearn.metrics import roc_auc_score\n", "from sklearn.metrics import roc_auc_score\n",
"\n", "\n",
"auc = roc_auc_score(labels, logits)\n", "auc = roc_auc_score(labels.cpu(), logits.cpu())\n",
"print(\"Link Prediction AUC:\", auc)" "print(\"Link Prediction AUC:\", auc)"
], ],
"metadata": { "metadata": {
......
...@@ -54,7 +54,8 @@ ...@@ -54,7 +54,8 @@
"os.environ['TORCH'] = torch.__version__\n", "os.environ['TORCH'] = torch.__version__\n",
"os.environ['DGLBACKEND'] = \"pytorch\"\n", "os.environ['DGLBACKEND'] = \"pytorch\"\n",
"\n", "\n",
"# Install the CPU version.\n", "# Install the CPU version. If you want to install CUDA version, please\n",
"# refer to https://www.dgl.ai/pages/start.html.\n",
"device = torch.device(\"cpu\")\n", "device = torch.device(\"cpu\")\n",
"!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html\n", "!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html\n",
"\n", "\n",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment