Unverified Commit 91cfcaf8 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent a5d21c2b
...@@ -30,8 +30,8 @@ import torch ...@@ -30,8 +30,8 @@ import torch
import numpy as np import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
dataset = DglNodePropPredDataset('ogbn-arxiv') dataset = DglNodePropPredDataset("ogbn-arxiv")
device = 'cpu' # change to 'cuda' for GPU device = "cpu" # change to 'cuda' for GPU
###################################################################### ######################################################################
...@@ -43,14 +43,14 @@ device = 'cpu' # change to 'cuda' for GPU ...@@ -43,14 +43,14 @@ device = 'cpu' # change to 'cuda' for GPU
graph, node_labels = dataset[0] graph, node_labels = dataset[0]
# Add reverse edges since ogbn-arxiv is unidirectional. # Add reverse edges since ogbn-arxiv is unidirectional.
graph = dgl.add_reverse_edges(graph) graph = dgl.add_reverse_edges(graph)
graph.ndata['label'] = node_labels[:, 0] graph.ndata["label"] = node_labels[:, 0]
print(graph) print(graph)
print(node_labels) print(node_labels)
node_features = graph.ndata['feat'] node_features = graph.ndata["feat"]
num_features = node_features.shape[1] num_features = node_features.shape[1]
num_classes = (node_labels.max() + 1).item() num_classes = (node_labels.max() + 1).item()
print('Number of classes:', num_classes) print("Number of classes:", num_classes)
###################################################################### ######################################################################
...@@ -59,9 +59,9 @@ print('Number of classes:', num_classes) ...@@ -59,9 +59,9 @@ print('Number of classes:', num_classes)
# #
idx_split = dataset.get_idx_split() idx_split = dataset.get_idx_split()
train_nids = idx_split['train'] train_nids = idx_split["train"]
valid_nids = idx_split['valid'] valid_nids = idx_split["valid"]
test_nids = idx_split['test'] test_nids = idx_split["test"]
###################################################################### ######################################################################
...@@ -118,7 +118,7 @@ train_dataloader = dgl.dataloading.DataLoader( ...@@ -118,7 +118,7 @@ train_dataloader = dgl.dataloading.DataLoader(
batch_size=1024, # Batch size batch_size=1024, # Batch size
shuffle=True, # Whether to shuffle the nodes for every epoch shuffle=True, # Whether to shuffle the nodes for every epoch
drop_last=False, # Whether to drop the last incomplete batch drop_last=False, # Whether to drop the last incomplete batch
num_workers=0 # Number of sampler processes num_workers=0, # Number of sampler processes
) )
...@@ -135,9 +135,15 @@ train_dataloader = dgl.dataloading.DataLoader( ...@@ -135,9 +135,15 @@ train_dataloader = dgl.dataloading.DataLoader(
# You can iterate over the data loader and see what it yields. # You can iterate over the data loader and see what it yields.
# #
input_nodes, output_nodes, mfgs = example_minibatch = next(iter(train_dataloader)) input_nodes, output_nodes, mfgs = example_minibatch = next(
iter(train_dataloader)
)
print(example_minibatch) print(example_minibatch)
print("To compute {} nodes' outputs, we need {} nodes' input features".format(len(output_nodes), len(input_nodes))) print(
"To compute {} nodes' outputs, we need {} nodes' input features".format(
len(output_nodes), len(input_nodes)
)
)
###################################################################### ######################################################################
...@@ -164,7 +170,7 @@ mfg_0_src = mfgs[0].srcdata[dgl.NID] ...@@ -164,7 +170,7 @@ mfg_0_src = mfgs[0].srcdata[dgl.NID]
mfg_0_dst = mfgs[0].dstdata[dgl.NID] mfg_0_dst = mfgs[0].dstdata[dgl.NID]
print(mfg_0_src) print(mfg_0_src)
print(mfg_0_dst) print(mfg_0_dst)
print(torch.equal(mfg_0_src[:mfgs[0].num_dst_nodes()], mfg_0_dst)) print(torch.equal(mfg_0_src[: mfgs[0].num_dst_nodes()], mfg_0_dst))
###################################################################### ######################################################################
...@@ -179,23 +185,25 @@ import torch.nn as nn ...@@ -179,23 +185,25 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.nn import SAGEConv from dgl.nn import SAGEConv
class Model(nn.Module): class Model(nn.Module):
def __init__(self, in_feats, h_feats, num_classes): def __init__(self, in_feats, h_feats, num_classes):
super(Model, self).__init__() super(Model, self).__init__()
self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type='mean') self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type="mean")
self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type='mean') self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type="mean")
self.h_feats = h_feats self.h_feats = h_feats
def forward(self, mfgs, x): def forward(self, mfgs, x):
# Lines that are changed are marked with an arrow: "<---" # Lines that are changed are marked with an arrow: "<---"
h_dst = x[:mfgs[0].num_dst_nodes()] # <--- h_dst = x[: mfgs[0].num_dst_nodes()] # <---
h = self.conv1(mfgs[0], (x, h_dst)) # <--- h = self.conv1(mfgs[0], (x, h_dst)) # <---
h = F.relu(h) h = F.relu(h)
h_dst = h[:mfgs[1].num_dst_nodes()] # <--- h_dst = h[: mfgs[1].num_dst_nodes()] # <---
h = self.conv2(mfgs[1], (h, h_dst)) # <--- h = self.conv2(mfgs[1], (h, h_dst)) # <---
return h return h
model = Model(num_features, 128, num_classes).to(device) model = Model(num_features, 128, num_classes).to(device)
...@@ -263,12 +271,14 @@ opt = torch.optim.Adam(model.parameters()) ...@@ -263,12 +271,14 @@ opt = torch.optim.Adam(model.parameters())
# #
valid_dataloader = dgl.dataloading.DataLoader( valid_dataloader = dgl.dataloading.DataLoader(
graph, valid_nids, sampler, graph,
valid_nids,
sampler,
batch_size=1024, batch_size=1024,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
num_workers=0, num_workers=0,
device=device device=device,
) )
...@@ -281,15 +291,15 @@ import tqdm ...@@ -281,15 +291,15 @@ import tqdm
import sklearn.metrics import sklearn.metrics
best_accuracy = 0 best_accuracy = 0
best_model_path = 'model.pt' best_model_path = "model.pt"
for epoch in range(10): for epoch in range(10):
model.train() model.train()
with tqdm.tqdm(train_dataloader) as tq: with tqdm.tqdm(train_dataloader) as tq:
for step, (input_nodes, output_nodes, mfgs) in enumerate(tq): for step, (input_nodes, output_nodes, mfgs) in enumerate(tq):
# feature copy from CPU to GPU takes place here # feature copy from CPU to GPU takes place here
inputs = mfgs[0].srcdata['feat'] inputs = mfgs[0].srcdata["feat"]
labels = mfgs[-1].dstdata['label'] labels = mfgs[-1].dstdata["label"]
predictions = model(mfgs, inputs) predictions = model(mfgs, inputs)
...@@ -298,9 +308,15 @@ for epoch in range(10): ...@@ -298,9 +308,15 @@ for epoch in range(10):
loss.backward() loss.backward()
opt.step() opt.step()
accuracy = sklearn.metrics.accuracy_score(labels.cpu().numpy(), predictions.argmax(1).detach().cpu().numpy()) accuracy = sklearn.metrics.accuracy_score(
labels.cpu().numpy(),
predictions.argmax(1).detach().cpu().numpy(),
)
tq.set_postfix({'loss': '%.03f' % loss.item(), 'acc': '%.03f' % accuracy}, refresh=False) tq.set_postfix(
{"loss": "%.03f" % loss.item(), "acc": "%.03f" % accuracy},
refresh=False,
)
model.eval() model.eval()
...@@ -308,13 +324,13 @@ for epoch in range(10): ...@@ -308,13 +324,13 @@ for epoch in range(10):
labels = [] labels = []
with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad(): with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():
for input_nodes, output_nodes, mfgs in tq: for input_nodes, output_nodes, mfgs in tq:
inputs = mfgs[0].srcdata['feat'] inputs = mfgs[0].srcdata["feat"]
labels.append(mfgs[-1].dstdata['label'].cpu().numpy()) labels.append(mfgs[-1].dstdata["label"].cpu().numpy())
predictions.append(model(mfgs, inputs).argmax(1).cpu().numpy()) predictions.append(model(mfgs, inputs).argmax(1).cpu().numpy())
predictions = np.concatenate(predictions) predictions = np.concatenate(predictions)
labels = np.concatenate(labels) labels = np.concatenate(labels)
accuracy = sklearn.metrics.accuracy_score(labels, predictions) accuracy = sklearn.metrics.accuracy_score(labels, predictions)
print('Epoch {} Validation Accuracy {}'.format(epoch, accuracy)) print("Epoch {} Validation Accuracy {}".format(epoch, accuracy))
if best_accuracy < accuracy: if best_accuracy < accuracy:
best_accuracy = accuracy best_accuracy = accuracy
torch.save(model.state_dict(), best_model_path) torch.save(model.state_dict(), best_model_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment