Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
91cfcaf8
Unverified
Commit
91cfcaf8
authored
Oct 14, 2022
by
Hongzhi (Steve), Chen
Committed by
GitHub
Oct 14, 2022
Browse files
black (#4707)
Co-authored-by:
Steve
<
ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal
>
parent
a5d21c2b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
49 additions
and
33 deletions
+49
-33
tutorials/large/L1_large_node_classification.py
tutorials/large/L1_large_node_classification.py
+49
-33
No files found.
tutorials/large/L1_large_node_classification.py
View file @
91cfcaf8
...
@@ -30,8 +30,8 @@ import torch
...
@@ -30,8 +30,8 @@ import torch
import
numpy
as
np
import
numpy
as
np
from
ogb.nodeproppred
import
DglNodePropPredDataset
from
ogb.nodeproppred
import
DglNodePropPredDataset
dataset
=
DglNodePropPredDataset
(
'
ogbn-arxiv
'
)
dataset
=
DglNodePropPredDataset
(
"
ogbn-arxiv
"
)
device
=
'
cpu
'
# change to 'cuda' for GPU
device
=
"
cpu
"
# change to 'cuda' for GPU
######################################################################
######################################################################
...
@@ -43,14 +43,14 @@ device = 'cpu' # change to 'cuda' for GPU
...
@@ -43,14 +43,14 @@ device = 'cpu' # change to 'cuda' for GPU
graph
,
node_labels
=
dataset
[
0
]
graph
,
node_labels
=
dataset
[
0
]
# Add reverse edges since ogbn-arxiv is unidirectional.
# Add reverse edges since ogbn-arxiv is unidirectional.
graph
=
dgl
.
add_reverse_edges
(
graph
)
graph
=
dgl
.
add_reverse_edges
(
graph
)
graph
.
ndata
[
'
label
'
]
=
node_labels
[:,
0
]
graph
.
ndata
[
"
label
"
]
=
node_labels
[:,
0
]
print
(
graph
)
print
(
graph
)
print
(
node_labels
)
print
(
node_labels
)
node_features
=
graph
.
ndata
[
'
feat
'
]
node_features
=
graph
.
ndata
[
"
feat
"
]
num_features
=
node_features
.
shape
[
1
]
num_features
=
node_features
.
shape
[
1
]
num_classes
=
(
node_labels
.
max
()
+
1
).
item
()
num_classes
=
(
node_labels
.
max
()
+
1
).
item
()
print
(
'
Number of classes:
'
,
num_classes
)
print
(
"
Number of classes:
"
,
num_classes
)
######################################################################
######################################################################
...
@@ -59,9 +59,9 @@ print('Number of classes:', num_classes)
...
@@ -59,9 +59,9 @@ print('Number of classes:', num_classes)
#
#
idx_split
=
dataset
.
get_idx_split
()
idx_split
=
dataset
.
get_idx_split
()
train_nids
=
idx_split
[
'
train
'
]
train_nids
=
idx_split
[
"
train
"
]
valid_nids
=
idx_split
[
'
valid
'
]
valid_nids
=
idx_split
[
"
valid
"
]
test_nids
=
idx_split
[
'
test
'
]
test_nids
=
idx_split
[
"
test
"
]
######################################################################
######################################################################
...
@@ -110,15 +110,15 @@ test_nids = idx_split['test']
...
@@ -110,15 +110,15 @@ test_nids = idx_split['test']
sampler
=
dgl
.
dataloading
.
NeighborSampler
([
4
,
4
])
sampler
=
dgl
.
dataloading
.
NeighborSampler
([
4
,
4
])
train_dataloader
=
dgl
.
dataloading
.
DataLoader
(
train_dataloader
=
dgl
.
dataloading
.
DataLoader
(
# The following arguments are specific to DGL's DataLoader.
# The following arguments are specific to DGL's DataLoader.
graph
,
# The graph
graph
,
# The graph
train_nids
,
# The node IDs to iterate over in minibatches
train_nids
,
# The node IDs to iterate over in minibatches
sampler
,
# The neighbor sampler
sampler
,
# The neighbor sampler
device
=
device
,
# Put the sampled MFGs on CPU or GPU
device
=
device
,
# Put the sampled MFGs on CPU or GPU
# The following arguments are inherited from PyTorch DataLoader.
# The following arguments are inherited from PyTorch DataLoader.
batch_size
=
1024
,
# Batch size
batch_size
=
1024
,
# Batch size
shuffle
=
True
,
# Whether to shuffle the nodes for every epoch
shuffle
=
True
,
# Whether to shuffle the nodes for every epoch
drop_last
=
False
,
# Whether to drop the last incomplete batch
drop_last
=
False
,
# Whether to drop the last incomplete batch
num_workers
=
0
# Number of sampler processes
num_workers
=
0
,
# Number of sampler processes
)
)
...
@@ -135,9 +135,15 @@ train_dataloader = dgl.dataloading.DataLoader(
...
@@ -135,9 +135,15 @@ train_dataloader = dgl.dataloading.DataLoader(
# You can iterate over the data loader and see what it yields.
# You can iterate over the data loader and see what it yields.
#
#
input_nodes
,
output_nodes
,
mfgs
=
example_minibatch
=
next
(
iter
(
train_dataloader
))
input_nodes
,
output_nodes
,
mfgs
=
example_minibatch
=
next
(
iter
(
train_dataloader
)
)
print
(
example_minibatch
)
print
(
example_minibatch
)
print
(
"To compute {} nodes' outputs, we need {} nodes' input features"
.
format
(
len
(
output_nodes
),
len
(
input_nodes
)))
print
(
"To compute {} nodes' outputs, we need {} nodes' input features"
.
format
(
len
(
output_nodes
),
len
(
input_nodes
)
)
)
######################################################################
######################################################################
...
@@ -164,7 +170,7 @@ mfg_0_src = mfgs[0].srcdata[dgl.NID]
...
@@ -164,7 +170,7 @@ mfg_0_src = mfgs[0].srcdata[dgl.NID]
mfg_0_dst
=
mfgs
[
0
].
dstdata
[
dgl
.
NID
]
mfg_0_dst
=
mfgs
[
0
].
dstdata
[
dgl
.
NID
]
print
(
mfg_0_src
)
print
(
mfg_0_src
)
print
(
mfg_0_dst
)
print
(
mfg_0_dst
)
print
(
torch
.
equal
(
mfg_0_src
[:
mfgs
[
0
].
num_dst_nodes
()],
mfg_0_dst
))
print
(
torch
.
equal
(
mfg_0_src
[:
mfgs
[
0
].
num_dst_nodes
()],
mfg_0_dst
))
######################################################################
######################################################################
...
@@ -179,23 +185,25 @@ import torch.nn as nn
...
@@ -179,23 +185,25 @@ import torch.nn as nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
dgl.nn
import
SAGEConv
from
dgl.nn
import
SAGEConv
class
Model
(
nn
.
Module
):
class
Model
(
nn
.
Module
):
def
__init__
(
self
,
in_feats
,
h_feats
,
num_classes
):
def
__init__
(
self
,
in_feats
,
h_feats
,
num_classes
):
super
(
Model
,
self
).
__init__
()
super
(
Model
,
self
).
__init__
()
self
.
conv1
=
SAGEConv
(
in_feats
,
h_feats
,
aggregator_type
=
'
mean
'
)
self
.
conv1
=
SAGEConv
(
in_feats
,
h_feats
,
aggregator_type
=
"
mean
"
)
self
.
conv2
=
SAGEConv
(
h_feats
,
num_classes
,
aggregator_type
=
'
mean
'
)
self
.
conv2
=
SAGEConv
(
h_feats
,
num_classes
,
aggregator_type
=
"
mean
"
)
self
.
h_feats
=
h_feats
self
.
h_feats
=
h_feats
def
forward
(
self
,
mfgs
,
x
):
def
forward
(
self
,
mfgs
,
x
):
# Lines that are changed are marked with an arrow: "<---"
# Lines that are changed are marked with an arrow: "<---"
h_dst
=
x
[:
mfgs
[
0
].
num_dst_nodes
()]
# <---
h_dst
=
x
[:
mfgs
[
0
].
num_dst_nodes
()]
# <---
h
=
self
.
conv1
(
mfgs
[
0
],
(
x
,
h_dst
))
# <---
h
=
self
.
conv1
(
mfgs
[
0
],
(
x
,
h_dst
))
# <---
h
=
F
.
relu
(
h
)
h
=
F
.
relu
(
h
)
h_dst
=
h
[:
mfgs
[
1
].
num_dst_nodes
()]
# <---
h_dst
=
h
[:
mfgs
[
1
].
num_dst_nodes
()]
# <---
h
=
self
.
conv2
(
mfgs
[
1
],
(
h
,
h_dst
))
# <---
h
=
self
.
conv2
(
mfgs
[
1
],
(
h
,
h_dst
))
# <---
return
h
return
h
model
=
Model
(
num_features
,
128
,
num_classes
).
to
(
device
)
model
=
Model
(
num_features
,
128
,
num_classes
).
to
(
device
)
...
@@ -263,12 +271,14 @@ opt = torch.optim.Adam(model.parameters())
...
@@ -263,12 +271,14 @@ opt = torch.optim.Adam(model.parameters())
#
#
valid_dataloader
=
dgl
.
dataloading
.
DataLoader
(
valid_dataloader
=
dgl
.
dataloading
.
DataLoader
(
graph
,
valid_nids
,
sampler
,
graph
,
valid_nids
,
sampler
,
batch_size
=
1024
,
batch_size
=
1024
,
shuffle
=
False
,
shuffle
=
False
,
drop_last
=
False
,
drop_last
=
False
,
num_workers
=
0
,
num_workers
=
0
,
device
=
device
device
=
device
,
)
)
...
@@ -281,15 +291,15 @@ import tqdm
...
@@ -281,15 +291,15 @@ import tqdm
import
sklearn.metrics
import
sklearn.metrics
best_accuracy
=
0
best_accuracy
=
0
best_model_path
=
'
model.pt
'
best_model_path
=
"
model.pt
"
for
epoch
in
range
(
10
):
for
epoch
in
range
(
10
):
model
.
train
()
model
.
train
()
with
tqdm
.
tqdm
(
train_dataloader
)
as
tq
:
with
tqdm
.
tqdm
(
train_dataloader
)
as
tq
:
for
step
,
(
input_nodes
,
output_nodes
,
mfgs
)
in
enumerate
(
tq
):
for
step
,
(
input_nodes
,
output_nodes
,
mfgs
)
in
enumerate
(
tq
):
# feature copy from CPU to GPU takes place here
# feature copy from CPU to GPU takes place here
inputs
=
mfgs
[
0
].
srcdata
[
'
feat
'
]
inputs
=
mfgs
[
0
].
srcdata
[
"
feat
"
]
labels
=
mfgs
[
-
1
].
dstdata
[
'
label
'
]
labels
=
mfgs
[
-
1
].
dstdata
[
"
label
"
]
predictions
=
model
(
mfgs
,
inputs
)
predictions
=
model
(
mfgs
,
inputs
)
...
@@ -298,9 +308,15 @@ for epoch in range(10):
...
@@ -298,9 +308,15 @@ for epoch in range(10):
loss
.
backward
()
loss
.
backward
()
opt
.
step
()
opt
.
step
()
accuracy
=
sklearn
.
metrics
.
accuracy_score
(
labels
.
cpu
().
numpy
(),
predictions
.
argmax
(
1
).
detach
().
cpu
().
numpy
())
accuracy
=
sklearn
.
metrics
.
accuracy_score
(
labels
.
cpu
().
numpy
(),
predictions
.
argmax
(
1
).
detach
().
cpu
().
numpy
(),
)
tq
.
set_postfix
({
'loss'
:
'%.03f'
%
loss
.
item
(),
'acc'
:
'%.03f'
%
accuracy
},
refresh
=
False
)
tq
.
set_postfix
(
{
"loss"
:
"%.03f"
%
loss
.
item
(),
"acc"
:
"%.03f"
%
accuracy
},
refresh
=
False
,
)
model
.
eval
()
model
.
eval
()
...
@@ -308,13 +324,13 @@ for epoch in range(10):
...
@@ -308,13 +324,13 @@ for epoch in range(10):
labels
=
[]
labels
=
[]
with
tqdm
.
tqdm
(
valid_dataloader
)
as
tq
,
torch
.
no_grad
():
with
tqdm
.
tqdm
(
valid_dataloader
)
as
tq
,
torch
.
no_grad
():
for
input_nodes
,
output_nodes
,
mfgs
in
tq
:
for
input_nodes
,
output_nodes
,
mfgs
in
tq
:
inputs
=
mfgs
[
0
].
srcdata
[
'
feat
'
]
inputs
=
mfgs
[
0
].
srcdata
[
"
feat
"
]
labels
.
append
(
mfgs
[
-
1
].
dstdata
[
'
label
'
].
cpu
().
numpy
())
labels
.
append
(
mfgs
[
-
1
].
dstdata
[
"
label
"
].
cpu
().
numpy
())
predictions
.
append
(
model
(
mfgs
,
inputs
).
argmax
(
1
).
cpu
().
numpy
())
predictions
.
append
(
model
(
mfgs
,
inputs
).
argmax
(
1
).
cpu
().
numpy
())
predictions
=
np
.
concatenate
(
predictions
)
predictions
=
np
.
concatenate
(
predictions
)
labels
=
np
.
concatenate
(
labels
)
labels
=
np
.
concatenate
(
labels
)
accuracy
=
sklearn
.
metrics
.
accuracy_score
(
labels
,
predictions
)
accuracy
=
sklearn
.
metrics
.
accuracy_score
(
labels
,
predictions
)
print
(
'
Epoch {} Validation Accuracy {}
'
.
format
(
epoch
,
accuracy
))
print
(
"
Epoch {} Validation Accuracy {}
"
.
format
(
epoch
,
accuracy
))
if
best_accuracy
<
accuracy
:
if
best_accuracy
<
accuracy
:
best_accuracy
=
accuracy
best_accuracy
=
accuracy
torch
.
save
(
model
.
state_dict
(),
best_model_path
)
torch
.
save
(
model
.
state_dict
(),
best_model_path
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment