Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
7f50a6da
Unverified
Commit
7f50a6da
authored
Sep 27, 2022
by
Chang Liu
Committed by
GitHub
Sep 27, 2022
Browse files
[Example][Bugfix] Fix mlp example (#4631)
* Fix mlp example * Remove unused file
parent
188bc2bf
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
50 deletions
+13
-50
examples/pytorch/ogb/ogbn-products/mlp/mlp.py
examples/pytorch/ogb/ogbn-products/mlp/mlp.py
+13
-17
examples/pytorch/ogb/ogbn-products/mlp/utils.py
examples/pytorch/ogb/ogbn-products/mlp/utils.py
+0
-33
No files found.
examples/pytorch/ogb/ogbn-products/mlp/mlp.py
View file @
7f50a6da
...
@@ -21,7 +21,6 @@ from torch import nn
...
@@ -21,7 +21,6 @@ from torch import nn
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
models
import
MLP
from
models
import
MLP
from
utils
import
BatchSampler
,
DataLoaderWrapper
epsilon
=
1
-
math
.
log
(
2
)
epsilon
=
1
-
math
.
log
(
2
)
...
@@ -107,6 +106,7 @@ def train(args, model, dataloader, labels, train_idx, criterion, optimizer, eval
...
@@ -107,6 +106,7 @@ def train(args, model, dataloader, labels, train_idx, criterion, optimizer, eval
loss_sum
+=
loss
.
item
()
*
count
loss_sum
+=
loss
.
item
()
*
count
total
+=
count
total
+=
count
preds
=
preds
.
to
(
train_idx
.
device
)
return
(
return
(
loss_sum
/
total
,
loss_sum
/
total
,
evaluator
(
preds
[
train_idx
],
labels
[
train_idx
]),
evaluator
(
preds
[
train_idx
],
labels
[
train_idx
]),
...
@@ -152,14 +152,13 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
...
@@ -152,14 +152,13 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
train_batch_size
=
4096
train_batch_size
=
4096
train_sampler
=
MultiLayerNeighborSampler
([
0
for
_
in
range
(
args
.
n_layers
)])
# no not sample neighbors
train_sampler
=
MultiLayerNeighborSampler
([
0
for
_
in
range
(
args
.
n_layers
)])
# no not sample neighbors
train_dataloader
=
DataLoaderWrapper
(
train_dataloader
=
DataLoader
(
DataLoader
(
graph
.
cpu
(),
graph
.
cpu
(),
train_idx
.
cpu
(),
train_idx
.
cpu
(),
train_sampler
,
train_sampler
,
batch_sampler
=
BatchSampler
(
len
(
train_idx
),
batch_size
=
train_batch_size
,
shuffle
=
True
),
batch_size
=
train_batch_size
,
num_workers
=
4
,
shuffle
=
True
,
)
num_workers
=
4
)
)
eval_batch_size
=
4096
eval_batch_size
=
4096
...
@@ -168,14 +167,13 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
...
@@ -168,14 +167,13 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
eval_idx
=
torch
.
cat
([
train_idx
.
cpu
(),
val_idx
.
cpu
()])
eval_idx
=
torch
.
cat
([
train_idx
.
cpu
(),
val_idx
.
cpu
()])
else
:
else
:
eval_idx
=
torch
.
cat
([
train_idx
.
cpu
(),
val_idx
.
cpu
(),
test_idx
.
cpu
()])
eval_idx
=
torch
.
cat
([
train_idx
.
cpu
(),
val_idx
.
cpu
(),
test_idx
.
cpu
()])
eval_dataloader
=
DataLoaderWrapper
(
eval_dataloader
=
DataLoader
(
DataLoader
(
graph
.
cpu
(),
graph
.
cpu
(),
eval_idx
,
eval_idx
,
eval_sampler
,
eval_sampler
,
batch_sampler
=
BatchSampler
(
len
(
eval_idx
),
batch_size
=
eval_batch_size
,
shuffle
=
False
),
batch_size
=
eval_batch_size
,
num_workers
=
4
,
shuffle
=
False
,
)
num_workers
=
4
)
)
model
=
gen_model
(
args
).
to
(
device
)
model
=
gen_model
(
args
).
to
(
device
)
...
@@ -195,7 +193,6 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
...
@@ -195,7 +193,6 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
for
epoch
in
range
(
1
,
args
.
n_epochs
+
1
):
for
epoch
in
range
(
1
,
args
.
n_epochs
+
1
):
tic
=
time
.
time
()
tic
=
time
.
time
()
loss
,
score
=
train
(
args
,
model
,
train_dataloader
,
labels
,
train_idx
,
criterion
,
optimizer
,
evaluator_wrapper
)
loss
,
score
=
train
(
args
,
model
,
train_dataloader
,
labels
,
train_idx
,
criterion
,
optimizer
,
evaluator_wrapper
)
toc
=
time
.
time
()
toc
=
time
.
time
()
...
@@ -233,14 +230,13 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
...
@@ -233,14 +230,13 @@ def run(args, graph, labels, train_idx, val_idx, test_idx, evaluator, n_running)
if
args
.
eval_last
:
if
args
.
eval_last
:
model
.
load_state_dict
(
best_model_state_dict
)
model
.
load_state_dict
(
best_model_state_dict
)
eval_dataloader
=
DataLoaderWrapper
(
eval_dataloader
=
DataLoader
(
DataLoader
(
graph
.
cpu
(),
graph
.
cpu
(),
test_idx
.
cpu
(),
test_idx
.
cpu
(),
eval_sampler
,
eval_sampler
,
batch_sampler
=
BatchSampler
(
len
(
test_idx
),
batch_size
=
eval_batch_size
,
shuffle
=
False
),
batch_size
=
eval_batch_size
,
num_workers
=
4
,
shuffle
=
False
,
)
num_workers
=
4
)
)
final_test_score
=
evaluate
(
final_test_score
=
evaluate
(
args
,
model
,
eval_dataloader
,
labels
,
train_idx
,
val_idx
,
test_idx
,
criterion
,
evaluator_wrapper
args
,
model
,
eval_dataloader
,
labels
,
train_idx
,
val_idx
,
test_idx
,
criterion
,
evaluator_wrapper
...
...
examples/pytorch/ogb/ogbn-products/mlp/utils.py
deleted
100644 → 0
View file @
188bc2bf
import
torch
class
DataLoaderWrapper
(
object
):
def
__init__
(
self
,
dataloader
):
self
.
iter
=
iter
(
dataloader
)
def
__iter__
(
self
):
return
self
def
__next__
(
self
):
try
:
return
next
(
self
.
iter
)
except
Exception
:
raise
StopIteration
()
from
None
class
BatchSampler
(
object
):
def
__init__
(
self
,
n
,
batch_size
,
shuffle
=
False
):
self
.
n
=
n
self
.
batch_size
=
batch_size
self
.
shuffle
=
shuffle
def
__iter__
(
self
):
while
True
:
if
self
.
shuffle
:
perm
=
torch
.
randperm
(
self
.
n
)
else
:
perm
=
torch
.
arange
(
start
=
0
,
end
=
self
.
n
)
shuf
=
perm
.
split
(
self
.
batch_size
)
for
shuf_batch
in
shuf
:
yield
shuf_batch
yield
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment