"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "43927851378b0cd3b3f68699262a1e6d887e5f08"
Unverified Commit 9b62e8d0 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

Revert "[Misc] Isort tutorials/large/L1_large_node_classification.py (#4721)" (#4741)

This reverts commit 9962b7bd.
parent 743516f3
...@@ -25,6 +25,9 @@ Sampling for GNN Training <L0_neighbor_sampling_overview>`. ...@@ -25,6 +25,9 @@ Sampling for GNN Training <L0_neighbor_sampling_overview>`.
# OGB already prepared the data as DGL graph. # OGB already prepared the data as DGL graph.
# #
import dgl
import torch
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
dataset = DglNodePropPredDataset("ogbn-arxiv") dataset = DglNodePropPredDataset("ogbn-arxiv")
...@@ -37,8 +40,6 @@ device = "cpu" # change to 'cuda' for GPU ...@@ -37,8 +40,6 @@ device = "cpu" # change to 'cuda' for GPU
# simply get the graph and its node labels like this: # simply get the graph and its node labels like this:
# #
import dgl
graph, node_labels = dataset[0] graph, node_labels = dataset[0]
# Add reverse edges since ogbn-arxiv is unidirectional. # Add reverse edges since ogbn-arxiv is unidirectional.
graph = dgl.add_reverse_edges(graph) graph = dgl.add_reverse_edges(graph)
...@@ -165,8 +166,6 @@ print( ...@@ -165,8 +166,6 @@ print(
# the computation of the new features. # the computation of the new features.
# #
import torch
mfg_0_src = mfgs[0].srcdata[dgl.NID] mfg_0_src = mfgs[0].srcdata[dgl.NID]
mfg_0_dst = mfgs[0].dstdata[dgl.NID] mfg_0_dst = mfgs[0].dstdata[dgl.NID]
print(mfg_0_src) print(mfg_0_src)
...@@ -184,7 +183,6 @@ print(torch.equal(mfg_0_src[: mfgs[0].num_dst_nodes()], mfg_0_dst)) ...@@ -184,7 +183,6 @@ print(torch.equal(mfg_0_src[: mfgs[0].num_dst_nodes()], mfg_0_dst))
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.nn import SAGEConv from dgl.nn import SAGEConv
...@@ -289,9 +287,8 @@ valid_dataloader = dgl.dataloading.DataLoader( ...@@ -289,9 +287,8 @@ valid_dataloader = dgl.dataloading.DataLoader(
# It also saves the model with the best validation accuracy into a file. # It also saves the model with the best validation accuracy into a file.
# #
import numpy as np
import sklearn.metrics
import tqdm import tqdm
import sklearn.metrics
best_accuracy = 0 best_accuracy = 0
best_model_path = "model.pt" best_model_path = "model.pt"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment