Unverified Commit 9962b7bd authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Isort tutorials/large/L1_large_node_classification.py (#4721)



* iosrt

* sort
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 98792a8a
...@@ -25,9 +25,6 @@ Sampling for GNN Training <L0_neighbor_sampling_overview>`. ...@@ -25,9 +25,6 @@ Sampling for GNN Training <L0_neighbor_sampling_overview>`.
# OGB already prepared the data as DGL graph. # OGB already prepared the data as DGL graph.
# #
import dgl
import torch
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
dataset = DglNodePropPredDataset("ogbn-arxiv") dataset = DglNodePropPredDataset("ogbn-arxiv")
...@@ -40,6 +37,8 @@ device = "cpu" # change to 'cuda' for GPU ...@@ -40,6 +37,8 @@ device = "cpu" # change to 'cuda' for GPU
# simply get the graph and its node labels like this: # simply get the graph and its node labels like this:
# #
import dgl
graph, node_labels = dataset[0] graph, node_labels = dataset[0]
# Add reverse edges since ogbn-arxiv is unidirectional. # Add reverse edges since ogbn-arxiv is unidirectional.
graph = dgl.add_reverse_edges(graph) graph = dgl.add_reverse_edges(graph)
...@@ -166,6 +165,8 @@ print( ...@@ -166,6 +165,8 @@ print(
# the computation of the new features. # the computation of the new features.
# #
import torch
mfg_0_src = mfgs[0].srcdata[dgl.NID] mfg_0_src = mfgs[0].srcdata[dgl.NID]
mfg_0_dst = mfgs[0].dstdata[dgl.NID] mfg_0_dst = mfgs[0].dstdata[dgl.NID]
print(mfg_0_src) print(mfg_0_src)
...@@ -183,6 +184,7 @@ print(torch.equal(mfg_0_src[: mfgs[0].num_dst_nodes()], mfg_0_dst)) ...@@ -183,6 +184,7 @@ print(torch.equal(mfg_0_src[: mfgs[0].num_dst_nodes()], mfg_0_dst))
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.nn import SAGEConv from dgl.nn import SAGEConv
...@@ -287,8 +289,9 @@ valid_dataloader = dgl.dataloading.DataLoader( ...@@ -287,8 +289,9 @@ valid_dataloader = dgl.dataloading.DataLoader(
# It also saves the model with the best validation accuracy into a file. # It also saves the model with the best validation accuracy into a file.
# #
import tqdm import numpy as np
import sklearn.metrics import sklearn.metrics
import tqdm
best_accuracy = 0 best_accuracy = 0
best_model_path = "model.pt" best_model_path = "model.pt"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment