Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
d20db1ec
Unverified
Commit
d20db1ec
authored
Aug 03, 2023
by
Andrei Ivanov
Committed by
GitHub
Aug 04, 2023
Browse files
Improving the RGCN_HETERO example (#6060)
parent
566719b1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
55 deletions
+73
-55
examples/pytorch/rgcn-hetero/entity_classify_mb.py
examples/pytorch/rgcn-hetero/entity_classify_mb.py
+61
-44
examples/pytorch/rgcn-hetero/model.py
examples/pytorch/rgcn-hetero/model.py
+12
-11
No files found.
examples/pytorch/rgcn-hetero/entity_classify_mb.py
View file @
d20db1ec
...
...
@@ -28,18 +28,19 @@ def evaluate(model, loader, node_embed, labels, category, device):
total_loss
=
0
total_acc
=
0
count
=
0
for
input_nodes
,
seeds
,
blocks
in
loader
:
blocks
=
[
blk
.
to
(
device
)
for
blk
in
blocks
]
seeds
=
seeds
[
category
]
emb
=
extract_embed
(
node_embed
,
input_nodes
)
emb
=
{
k
:
e
.
to
(
device
)
for
k
,
e
in
emb
.
items
()}
lbl
=
labels
[
seeds
].
to
(
device
)
logits
=
model
(
emb
,
blocks
)[
category
]
loss
=
F
.
cross_entropy
(
logits
,
lbl
)
acc
=
th
.
sum
(
logits
.
argmax
(
dim
=
1
)
==
lbl
).
item
()
total_loss
+=
loss
.
item
()
*
len
(
seeds
)
total_acc
+=
acc
count
+=
len
(
seeds
)
with
loader
.
enable_cpu_affinity
():
for
input_nodes
,
seeds
,
blocks
in
loader
:
blocks
=
[
blk
.
to
(
device
)
for
blk
in
blocks
]
seeds
=
seeds
[
category
]
emb
=
extract_embed
(
node_embed
,
input_nodes
)
emb
=
{
k
:
e
.
to
(
device
)
for
k
,
e
in
emb
.
items
()}
lbl
=
labels
[
seeds
].
to
(
device
)
logits
=
model
(
emb
,
blocks
)[
category
]
loss
=
F
.
cross_entropy
(
logits
,
lbl
)
acc
=
th
.
sum
(
logits
.
argmax
(
dim
=
1
)
==
lbl
).
item
()
total_loss
+=
loss
.
item
()
*
len
(
seeds
)
total_acc
+=
acc
count
+=
len
(
seeds
)
return
total_loss
/
count
,
total_acc
/
count
...
...
@@ -86,6 +87,12 @@ def main(args):
labels
=
labels
.
to
(
device
)
embed_layer
=
embed_layer
.
to
(
device
)
if
args
.
num_workers
<=
0
:
raise
ValueError
(
"The '--num_workers' parameter value is expected "
"to be >0, but got {}."
.
format
(
args
.
num_workers
)
)
node_embed
=
embed_layer
()
# create model
model
=
EntityClassify
(
...
...
@@ -111,7 +118,7 @@ def main(args):
sampler
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
num_workers
=
0
,
num_workers
=
args
.
num_workers
,
)
# validation sampler
...
...
@@ -125,7 +132,7 @@ def main(args):
val_sampler
,
batch_size
=
args
.
batch_size
,
shuffle
=
True
,
num_workers
=
0
,
num_workers
=
args
.
num_workers
,
)
# optimizer
...
...
@@ -134,53 +141,59 @@ def main(args):
# training loop
print
(
"start training..."
)
dur
=
[]
mean
=
0
for
epoch
in
range
(
args
.
n_epochs
):
model
.
train
()
optimizer
.
zero_grad
()
if
epoch
>
3
:
t0
=
time
.
time
()
for
i
,
(
input_nodes
,
seeds
,
blocks
)
in
enumerate
(
loader
):
blocks
=
[
blk
.
to
(
device
)
for
blk
in
blocks
]
seeds
=
seeds
[
category
]
# we only predict the nodes with type "category"
batch_tic
=
time
.
time
()
emb
=
extract_embed
(
node_embed
,
input_nodes
)
lbl
=
labels
[
seeds
]
if
use_cuda
:
emb
=
{
k
:
e
.
cuda
()
for
k
,
e
in
emb
.
items
()}
lbl
=
lbl
.
cuda
()
logits
=
model
(
emb
,
blocks
)[
category
]
loss
=
F
.
cross_entropy
(
logits
,
lbl
)
loss
.
backward
()
optimizer
.
step
()
with
loader
.
enable_cpu_affinity
():
for
i
,
(
input_nodes
,
seeds
,
blocks
)
in
enumerate
(
loader
):
blocks
=
[
blk
.
to
(
device
)
for
blk
in
blocks
]
seeds
=
seeds
[
category
]
# we only predict the nodes with type "category"
batch_tic
=
time
.
time
()
emb
=
extract_embed
(
node_embed
,
input_nodes
)
lbl
=
labels
[
seeds
]
if
use_cuda
:
emb
=
{
k
:
e
.
cuda
()
for
k
,
e
in
emb
.
items
()}
lbl
=
lbl
.
cuda
()
logits
=
model
(
emb
,
blocks
)[
category
]
loss
=
F
.
cross_entropy
(
logits
,
lbl
)
loss
.
backward
()
optimizer
.
step
()
train_acc
=
th
.
sum
(
logits
.
argmax
(
dim
=
1
)
==
lbl
).
item
()
/
len
(
seeds
)
print
(
"Epoch {:05d} | Batch {:03d} | Train Acc: {:.4f} | Train Loss: {:.4f} | Time: {:.4f}"
.
format
(
epoch
,
i
,
train_acc
,
loss
.
item
(),
time
.
time
()
-
batch_tic
train_acc
=
th
.
sum
(
logits
.
argmax
(
dim
=
1
)
==
lbl
).
item
()
/
len
(
seeds
)
print
(
f
"Epoch
{
epoch
:
05
d
}
| Batch
{
i
:
03
d
}
| Train Acc: "
"{train_acc:.4f} | Train Loss: {loss.item():.4f} | Time: "
"{time.time() - batch_tic:.4f}"
)
)
if
epoch
>
3
:
dur
.
append
(
time
.
time
()
-
t0
)
mean
=
(
mean
*
(
epoch
-
3
)
+
(
time
.
time
()
-
t0
)
)
/
(
epoch
-
2
)
val_loss
,
val_acc
=
evaluate
(
model
,
val_loader
,
node_embed
,
labels
,
category
,
device
)
print
(
"Epoch {:05d} | Valid Acc: {:.4f} | Valid loss:
{:.4f} | Time: {:.4f}"
.
format
(
epoch
,
val_acc
,
val_loss
,
np
.
average
(
dur
)
val_loss
,
val_acc
=
evaluate
(
model
,
val_loader
,
node_embed
,
labels
,
category
,
device
)
print
(
f
"Epoch
{
epoch
:
05
d
}
| Valid Acc:
{
val_acc
:.
4
f
}
| Valid loss:
"
"{val_loss:.4f} | Time: {mean:.4f}"
)
)
print
()
if
args
.
model_path
is
not
None
:
th
.
save
(
model
.
state_dict
(),
args
.
model_path
)
output
=
model
.
inference
(
g
,
args
.
batch_size
,
"cuda"
if
use_cuda
else
"cpu"
,
0
,
node_embed
g
,
args
.
batch_size
,
"cuda"
if
use_cuda
else
"cpu"
,
args
.
num_workers
,
node_embed
,
)
test_pred
=
output
[
category
][
test_idx
]
test_labels
=
labels
[
test_idx
].
to
(
test_pred
.
device
)
...
...
@@ -245,6 +258,10 @@ if __name__ == "__main__":
"be undesired if they cannot fit in GPU memory at once. "
"This flag disables that."
,
)
parser
.
add_argument
(
"--num_workers"
,
type
=
int
,
default
=
4
,
help
=
"Number of node dataloader"
)
fp
=
parser
.
add_mutually_exclusive_group
(
required
=
False
)
fp
.
add_argument
(
"--validation"
,
dest
=
"validation"
,
action
=
"store_true"
)
fp
.
add_argument
(
"--testing"
,
dest
=
"validation"
,
action
=
"store_false"
)
...
...
examples/pytorch/rgcn-hetero/model.py
View file @
d20db1ec
...
...
@@ -423,17 +423,18 @@ class EntityClassify(nn.Module):
num_workers
=
num_workers
,
)
for
input_nodes
,
output_nodes
,
blocks
in
tqdm
.
tqdm
(
dataloader
):
block
=
blocks
[
0
].
to
(
device
)
h
=
{
k
:
x
[
k
][
input_nodes
[
k
]].
to
(
device
)
for
k
in
input_nodes
.
keys
()
}
h
=
layer
(
block
,
h
)
for
k
in
output_nodes
.
keys
():
y
[
k
][
output_nodes
[
k
]]
=
h
[
k
].
cpu
()
with
dataloader
.
enable_cpu_affinity
():
for
input_nodes
,
output_nodes
,
blocks
in
tqdm
.
tqdm
(
dataloader
):
block
=
blocks
[
0
].
to
(
device
)
h
=
{
k
:
x
[
k
][
input_nodes
[
k
]].
to
(
device
)
for
k
in
input_nodes
.
keys
()
}
h
=
layer
(
block
,
h
)
for
k
in
output_nodes
.
keys
():
y
[
k
][
output_nodes
[
k
]]
=
h
[
k
].
cpu
()
x
=
y
return
y
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment