Unverified Commit 0ec43924 authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Examples] re-locate load_graph for share (#3784)

parent e0f0fa2a
......@@ -10,7 +10,8 @@ import argparse
import tqdm
import glob
import os
import sys
sys.path.append('../')
from load_graph import load_reddit, inductive_split, load_ogb
from torchmetrics import Accuracy
......
......@@ -14,10 +14,11 @@ import os
from negative_sampler import NegativeSampler
from pytorch_lightning.metrics import Accuracy
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from model import SAGE, compute_acc_unsupervised as compute_acc
import sys
sys.path.append('../')
from load_graph import load_reddit, inductive_split, load_ogb
class CrossEntropyLoss(nn.Module):
......
......@@ -14,6 +14,8 @@ from torch.nn.parallel import DistributedDataParallel
from model import SAGE, compute_acc_unsupervised as compute_acc
from negative_sampler import NegativeSampler
import sys
sys.path.append('../')
from load_graph import load_reddit, load_ogb
class CrossEntropyLoss(nn.Module):
......
......@@ -116,7 +116,7 @@ The command below launches one process per machine for both sampling and trainin
```bash
python3 ~/workspace/dgl/tools/launch.py \
--workspace ~/workspace/dgl/examples/pytorch/graphsage/experimental/ \
--workspace ~/workspace/dgl/examples/pytorch/graphsage/dist/ \
--num_trainers 1 \
--num_samplers 0 \
--num_servers 1 \
......@@ -129,7 +129,7 @@ To run unsupervised training:
```bash
python3 ~/workspace/dgl/tools/launch.py \
--workspace ~/workspace/dgl/examples/pytorch/graphsage/experimental/ \
--workspace ~/workspace/dgl/examples/pytorch/graphsage/dist/ \
--num_trainers 1 \
--num_samplers 0 \
--num_servers 1 \
......@@ -143,7 +143,7 @@ By default, this code will run on CPU. If you have GPU support, you can just add
```bash
python3 ~/workspace/dgl/tools/launch.py \
--workspace ~/workspace/dgl/examples/pytorch/graphsage/experimental/ \
--workspace ~/workspace/dgl/examples/pytorch/graphsage/dist/ \
--num_trainers 4 \
--num_samplers 0 \
--num_servers 1 \
......@@ -154,7 +154,7 @@ python3 ~/workspace/dgl/tools/launch.py \
To run supervised with transductive setting (nodes are initialized with node embedding)
```bash
python3 ~/workspace/dgl/tools/launch.py --workspace ~/workspace/dgl/examples/pytorch/graphsage/experimental/ \
python3 ~/workspace/dgl/tools/launch.py --workspace ~/workspace/dgl/examples/pytorch/graphsage/dist/ \
--num_trainers 4 \
--num_samplers 4 \
--num_servers 1 \
......@@ -166,7 +166,7 @@ python3 ~/workspace/dgl/tools/launch.py --workspace ~/workspace/dgl/examples/pyt
To run supervised with transductive setting using dgl distributed DistEmbedding
```bash
python3 ~/workspace/dgl/tools/launch.py --workspace ~/workspace/dgl/examples/pytorch/graphsage/experimental/ \
python3 ~/workspace/dgl/tools/launch.py --workspace ~/workspace/dgl/examples/pytorch/graphsage/dist/ \
--num_trainers 4 \
--num_samplers 4 \
--num_servers 1 \
......@@ -178,7 +178,7 @@ python3 ~/workspace/dgl/tools/launch.py --workspace ~/workspace/dgl/examples/pyt
To run unsupervised with transductive setting (nodes are initialized with node embedding)
```bash
python3 ~/workspace/dgl/tools/launch.py --workspace ~/workspace/dgl/examples/pytorch/graphsage/experimental/ \
python3 ~/workspace/dgl/tools/launch.py --workspace ~/workspace/dgl/examples/pytorch/graphsage/dist/ \
--num_trainers 4 \
--num_samplers 0 \
--num_servers 1 \
......@@ -190,7 +190,7 @@ python3 ~/workspace/dgl/tools/launch.py --workspace ~/workspace/dgl/examples/pyt
To run unsupervised with transductive setting using dgl distributed DistEmbedding
```bash
python3 ~/workspace/dgl/tools/launch.py --workspace ~/workspace/dgl/examples/pytorch/graphsage/experimental/ \
python3 ~/workspace/dgl/tools/launch.py --workspace ~/workspace/dgl/examples/pytorch/graphsage/dist/ \
--num_trainers 4 \
--num_samplers 0 \
--num_servers 1 \
......
......@@ -3,7 +3,8 @@ import numpy as np
import torch as th
import argparse
import time
import sys
sys.path.append('../')
from load_graph import load_reddit, load_ogb
if __name__ == '__main__':
......
......@@ -19,6 +19,7 @@ from statistics import mean
import random
import time
import argparse
sys.path.append('../')
from load_graph import load_ogb
import dgl
from dgl.data import load_data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment