"tests/python/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "34af2d6889a8bdf36b8520d6e46409d9281093f6"
Unverified Commit b6f774cd authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[Examples] fix lint (#5838)

parent 5ada3fc9
...@@ -3,14 +3,16 @@ import socket ...@@ -3,14 +3,16 @@ import socket
import time import time
from contextlib import contextmanager from contextlib import contextmanager
import dgl
import dgl.nn.pytorch as dglnn
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import tqdm import tqdm
import dgl
import dgl.nn.pytorch as dglnn
def load_subtensor(g, seeds, input_nodes, device, load_feat=True): def load_subtensor(g, seeds, input_nodes, device, load_feat=True):
""" """
...@@ -83,9 +85,7 @@ class DistSAGE(nn.Module): ...@@ -83,9 +85,7 @@ class DistSAGE(nn.Module):
"h_last", "h_last",
persistent=True, persistent=True,
) )
print( print(f"|V|={g.num_nodes()}, eval batch size: {batch_size}")
f"|V|={g.num_nodes()}, eval batch size: {batch_size}"
)
sampler = dgl.dataloading.NeighborSampler([-1]) sampler = dgl.dataloading.NeighborSampler([-1])
dataloader = dgl.dataloading.DistNodeDataLoader( dataloader = dgl.dataloading.DistNodeDataLoader(
...@@ -249,7 +249,7 @@ def run(args, device, data): ...@@ -249,7 +249,7 @@ def run(args, device, data):
acc.item(), acc.item(),
np.mean(iter_tput[3:]), np.mean(iter_tput[3:]),
gpu_mem_alloc, gpu_mem_alloc,
np.sum(step_time[-args.log_every:]), np.sum(step_time[-args.log_every :]),
) )
) )
start = time.time() start = time.time()
...@@ -284,8 +284,7 @@ def run(args, device, data): ...@@ -284,8 +284,7 @@ def run(args, device, data):
device, device,
) )
print( print(
"Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}".format "Part {}, Val Acc {:.4f}, Test Acc {:.4f}, time: {:.4f}".format(
(
g.rank(), val_acc, test_acc, time.time() - start g.rank(), val_acc, test_acc, time.time() - start
) )
) )
...@@ -298,10 +297,7 @@ def main(args): ...@@ -298,10 +297,7 @@ def main(args):
print(socket.gethostname(), "Initializing DGL process group") print(socket.gethostname(), "Initializing DGL process group")
th.distributed.init_process_group(backend=args.backend) th.distributed.init_process_group(backend=args.backend)
print(socket.gethostname(), "Initializing DistGraph") print(socket.gethostname(), "Initializing DistGraph")
g = dgl.distributed.DistGraph( g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)
args.graph_name,
part_config=args.part_config
)
print(socket.gethostname(), "rank:", g.rank()) print(socket.gethostname(), "rank:", g.rank())
pb = g.get_partition_book() pb = g.get_partition_book()
...@@ -413,7 +409,7 @@ if __name__ == "__main__": ...@@ -413,7 +409,7 @@ if __name__ == "__main__":
default=False, default=False,
action="store_true", action="store_true",
help="Pad train nid to the same length across machine, to ensure num " help="Pad train nid to the same length across machine, to ensure num "
"of batches to be the same.", "of batches to be the same.",
) )
parser.add_argument( parser.add_argument(
"--net_type", "--net_type",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment