Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
ba61a566
Unverified
Commit
ba61a566
authored
Aug 29, 2023
by
Muhammed Fatih BALIN
Committed by
GitHub
Aug 29, 2023
Browse files
[Bug] fix inference for labor example (#6148)
parent
3fb81fca
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
24 deletions
+21
-24
examples/pytorch/labor/model.py
examples/pytorch/labor/model.py
+13
-8
examples/pytorch/labor/train_lightning.py
examples/pytorch/labor/train_lightning.py
+8
-16
No files found.
examples/pytorch/labor/model.py
View file @
ba61a566
...
...
@@ -47,13 +47,14 @@ class SAGE(nn.Module):
h
=
self
.
dropout
(
h
)
return
h
def
inference
(
self
,
g
,
device
,
batch_size
,
num_workers
,
buffer_device
=
None
):
def
inference
(
self
,
g
,
device
,
batch_size
,
use_uva
,
num_workers
):
# The difference between this inference function and the one in the official
# example is that the intermediate results can also benefit from prefetching.
g
.
ndata
[
"h"
]
=
g
.
ndata
[
"features"
]
sampler
=
dgl
.
dataloading
.
MultiLayerFullNeighborSampler
(
1
,
prefetch_node_feats
=
[
"h"
]
)
pin_memory
=
g
.
device
!=
device
and
use_uva
dataloader
=
dgl
.
dataloading
.
DataLoader
(
g
,
th
.
arange
(
g
.
num_nodes
(),
dtype
=
g
.
idtype
,
device
=
g
.
device
),
...
...
@@ -62,26 +63,30 @@ class SAGE(nn.Module):
batch_size
=
batch_size
,
shuffle
=
False
,
drop_last
=
False
,
use_uva
=
use_uva
,
num_workers
=
num_workers
,
persistent_workers
=
(
num_workers
>
0
),
)
if
buffer_device
is
None
:
buffer_device
=
device
self
.
train
(
False
)
self
.
eval
(
)
for
l
,
layer
in
enumerate
(
self
.
layers
):
y
=
th
.
zeros
(
y
=
th
.
empty
(
g
.
num_nodes
(),
self
.
n_hidden
if
l
!=
len
(
self
.
layers
)
-
1
else
self
.
n_classes
,
device
=
buffer_device
,
dtype
=
g
.
ndata
[
"h"
].
dtype
,
device
=
g
.
device
,
pin_memory
=
pin_memory
,
)
for
input_nodes
,
output_nodes
,
blocks
in
tqdm
.
tqdm
(
dataloader
):
x
=
blocks
[
0
].
srcdata
[
"h"
]
h
=
layer
(
blocks
[
0
],
x
)
if
l
!=
len
(
self
.
layers
)
-
1
:
if
l
<
len
(
self
.
layers
)
-
1
:
h
=
self
.
activation
(
h
)
h
=
self
.
dropout
(
h
)
y
[
output_nodes
]
=
h
.
to
(
buffer_device
)
# by design, our output nodes are contiguous
y
[
output_nodes
[
0
].
item
()
:
output_nodes
[
-
1
].
item
()
+
1
]
=
h
.
to
(
y
.
device
)
g
.
ndata
[
"h"
]
=
y
return
y
examples/pytorch/labor/train_lightning.py
View file @
ba61a566
...
...
@@ -269,18 +269,9 @@ class DataModule(LightningDataModule):
dataloader_device
=
device
self
.
g
=
g
if
cast_to_int
:
self
.
train_nid
,
self
.
val_nid
,
self
.
test_nid
=
(
train_nid
.
int
(),
val_nid
.
int
(),
test_nid
.
int
(),
)
else
:
self
.
train_nid
,
self
.
val_nid
,
self
.
test_nid
=
(
train_nid
,
val_nid
,
test_nid
,
)
self
.
train_nid
=
train_nid
.
to
(
g
.
idtype
)
self
.
val_nid
=
val_nid
.
to
(
g
.
idtype
)
self
.
test_nid
=
test_nid
.
to
(
g
.
idtype
)
self
.
sampler
=
sampler
self
.
device
=
dataloader_device
self
.
use_uva
=
use_uva
...
...
@@ -385,7 +376,7 @@ if __name__ == "__main__":
argparser
.
add_argument
(
"--gpu"
,
type
=
int
,
default
=
0
,
default
=
0
if
th
.
cuda
.
is_available
()
else
-
1
,
help
=
"GPU device ID. Use -1 for CPU training"
,
)
argparser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
"reddit"
)
...
...
@@ -493,7 +484,7 @@ if __name__ == "__main__":
logger
=
TensorBoardLogger
(
args
.
logdir
,
name
=
subdir
)
trainer
=
Trainer
(
accelerator
=
"gpu"
if
args
.
gpu
!=
-
1
else
"cpu"
,
devices
=
[
args
.
gpu
],
devices
=
[
args
.
gpu
]
if
args
.
gpu
!=
-
1
else
"auto"
,
max_epochs
=
args
.
num_epochs
,
max_steps
=
args
.
num_steps
,
min_steps
=
args
.
min_steps
,
...
...
@@ -521,15 +512,16 @@ if __name__ == "__main__":
graph
,
f
"cuda:
{
args
.
gpu
}
"
if
args
.
gpu
!=
-
1
else
"cpu"
,
4096
,
args
.
use_uva
,
args
.
num_workers
,
graph
.
device
,
)
for
nid
,
split_name
in
zip
(
[
datamodule
.
train_nid
,
datamodule
.
val_nid
,
datamodule
.
test_nid
],
[
"Train"
,
"Validation"
,
"Test"
],
):
nid
=
nid
.
to
(
pred
.
device
).
long
()
pred_nid
=
pred
[
nid
]
label
=
graph
.
ndata
[
"labels"
][
nid
]
f1score
=
model
.
f1score_class
().
to
(
pred
.
device
)
acc
=
f1score
(
pred_nid
,
label
)
print
(
f
"
{
split_name
}
accuracy:
"
,
acc
.
item
())
print
(
f
"
{
split_name
}
accuracy:
{
acc
.
item
()
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment