Unverified Commit de14619c authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[GraphBolt] fix several issues in doc and notebook (#6775)

parent 1db71cfe
......@@ -5,12 +5,12 @@
Implementing custom samplers involves subclassing the
:class:`dgl.graphbolt.SubgraphSampler` base class and implementing its abstract
:attr:`_sample_subgraphs` method. The :attr:`_sample_subgraphs` method should
:attr:`sample_subgraphs` method. The :attr:`sample_subgraphs` method should
take in seed nodes which are the nodes to sample neighbors from:
.. code:: python
def _sample_subgraphs(self, seed_nodes):
def sample_subgraphs(self, seed_nodes):
return input_nodes, sampled_subgraphs
The method should return the input node IDs list and a list of subgraphs. Each
......@@ -31,7 +31,7 @@ The code below implements a classical neighbor sampler:
self.graph = graph
self.fanouts = fanouts
def _sample_subgraphs(self, seed_nodes):
def sample_subgraphs(self, seed_nodes):
subgs = []
for fanout in reversed(self.fanouts):
# Sample a fixed number of neighbors of the current seed nodes.
......
......@@ -29,7 +29,7 @@ will customize another sampler with DGL sparse library as shown below.
fanout = torch.LongTensor([int(fanout)])
self.fanouts.insert(0, fanout)
def _sample_subgraphs(self, seeds):
def sample_subgraphs(self, seeds):
sampled_matrices = []
src = seeds
......
......@@ -54,7 +54,8 @@
"os.environ['TORCH'] = torch.__version__\n",
"os.environ['DGLBACKEND'] = \"pytorch\"\n",
"\n",
"# Install the CPU version.\n",
"# Install the CPU version. If you want to install CUDA version, please\n",
"# refer to https://www.dgl.ai/pages/start.html.\n",
"device = torch.device(\"cpu\")\n",
"!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html\n",
"\n",
......@@ -339,7 +340,7 @@
"# Compute the AUROC score.\n",
"from sklearn.metrics import roc_auc_score\n",
"\n",
"auc = roc_auc_score(labels, logits)\n",
"auc = roc_auc_score(labels.cpu(), logits.cpu())\n",
"print(\"Link Prediction AUC:\", auc)"
],
"metadata": {
......
......@@ -54,7 +54,8 @@
"os.environ['TORCH'] = torch.__version__\n",
"os.environ['DGLBACKEND'] = \"pytorch\"\n",
"\n",
"# Install the CPU version.\n",
"# Install the CPU version. If you want to install CUDA version, please\n",
"# refer to https://www.dgl.ai/pages/start.html.\n",
"device = torch.device(\"cpu\")\n",
"!pip install --pre dgl -f https://data.dgl.ai/wheels-test/repo.html\n",
"\n",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment