Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
a6b44e72
Unverified
Commit
a6b44e72
authored
Aug 12, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Aug 12, 2020
Browse files
[Model] Fix GCMC broken code (#2001)
* [Model] Fix GCMC debugging code * use the one with apply_edges
parent
4efa3320
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
57 additions
and
58 deletions
+57
-58
examples/pytorch/gcmc/model.py
examples/pytorch/gcmc/model.py
+15
-6
examples/pytorch/gcmc/train_sampling.py
examples/pytorch/gcmc/train_sampling.py
+42
-52
No files found.
examples/pytorch/gcmc/model.py
View file @
a6b44e72
...
...
@@ -304,7 +304,9 @@ class BiDecoder(nn.Module):
super
(
BiDecoder
,
self
).
__init__
()
self
.
_num_basis
=
num_basis
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
P
=
nn
.
Parameter
(
th
.
randn
(
num_basis
,
in_units
,
in_units
))
self
.
Ps
=
nn
.
ParameterList
(
nn
.
Parameter
(
th
.
randn
(
in_units
,
in_units
))
for
_
in
range
(
num_basis
))
self
.
combine_basis
=
nn
.
Linear
(
self
.
_num_basis
,
num_classes
,
bias
=
False
)
self
.
reset_parameters
()
...
...
@@ -343,7 +345,7 @@ class BiDecoder(nn.Module):
out
=
self
.
combine_basis
(
out
)
return
out
class
DenseBiDecoder
(
BiDecoder
):
class
DenseBiDecoder
(
nn
.
Module
):
r
"""Dense bi-linear decoder.
Dense implementation of the bi-linear decoder used in GCMC. Suitable when
...
...
@@ -366,10 +368,17 @@ class DenseBiDecoder(BiDecoder):
num_classes
,
num_basis
=
2
,
dropout_rate
=
0.0
):
super
(
DenseBiDecoder
,
self
).
__init__
(
in_units
,
num_classes
,
num_basis
,
dropout_rate
)
super
().
__init__
()
self
.
_num_basis
=
num_basis
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
P
=
nn
.
Parameter
(
th
.
randn
(
num_basis
,
in_units
,
in_units
))
self
.
combine_basis
=
nn
.
Linear
(
self
.
_num_basis
,
num_classes
,
bias
=
False
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
for
p
in
self
.
parameters
():
if
p
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
p
)
def
forward
(
self
,
ufeat
,
ifeat
):
"""Forward function.
...
...
examples/pytorch/gcmc/train_sampling.py
View file @
a6b44e72
...
...
@@ -11,6 +11,7 @@ import random
import
string
import
traceback
import
numpy
as
np
import
tqdm
import
torch
as
th
import
torch.nn
as
nn
import
torch.multiprocessing
as
mp
...
...
@@ -20,7 +21,7 @@ from torch.nn.parallel import DistributedDataParallel
from
_thread
import
start_new_thread
from
functools
import
wraps
from
data
import
MovieLens
from
model
import
GCMCLayer
,
DenseBiDecoder
from
model
import
GCMCLayer
,
DenseBiDecoder
,
BiDecoder
from
utils
import
get_activation
,
get_optimizer
,
torch_total_param_num
,
torch_net_info
,
MetricLogger
,
to_etype_name
import
dgl
...
...
@@ -45,19 +46,14 @@ class Net(nn.Module):
else
:
self
.
encoder
.
to
(
dev_id
)
self
.
decoder
=
Dense
BiDecoder
(
in_units
=
args
.
gcn_out_units
,
num_classes
=
len
(
args
.
rating_vals
),
num_basis
=
args
.
gen_r_num_basis_func
)
self
.
decoder
=
BiDecoder
(
in_units
=
args
.
gcn_out_units
,
num_classes
=
len
(
args
.
rating_vals
),
num_basis
=
args
.
gen_r_num_basis_func
)
self
.
decoder
.
to
(
dev_id
)
def
forward
(
self
,
compact_g
,
frontier
,
ufeat
,
ifeat
,
possible_rating_values
):
user_out
,
movie_out
=
self
.
encoder
(
frontier
,
ufeat
,
ifeat
)
head
,
tail
=
compact_g
.
edges
(
order
=
'eid'
)
head_emb
=
user_out
[
head
]
tail_emb
=
movie_out
[
tail
]
pred_ratings
=
self
.
decoder
(
head_emb
,
tail_emb
)
pred_ratings
=
self
.
decoder
(
compact_g
,
user_out
,
movie_out
)
return
pred_ratings
def
load_subtensor
(
input_nodes
,
pair_graph
,
blocks
,
dataset
,
parent_graph
):
...
...
@@ -289,48 +285,42 @@ def run(proc_id, n_gpus, args, devices, dataset):
if
epoch
>
1
:
t0
=
time
.
time
()
net
.
train
()
for
step
,
(
input_nodes
,
pair_graph
,
blocks
)
in
enumerate
(
dataloader
):
head_feat
,
tail_feat
,
blocks
=
load_subtensor
(
input_nodes
,
pair_graph
,
blocks
,
dataset
,
dataset
.
train_enc_graph
)
frontier
=
blocks
[
0
]
compact_g
=
flatten_etypes
(
pair_graph
,
dataset
,
'train'
).
to
(
dev_id
)
true_relation_labels
=
compact_g
.
edata
[
'label'
]
true_relation_ratings
=
compact_g
.
edata
[
'rating'
]
head_feat
=
head_feat
.
to
(
dev_id
)
tail_feat
=
tail_feat
.
to
(
dev_id
)
frontier
=
frontier
.
to
(
dev_id
)
pred_ratings
=
net
(
compact_g
,
frontier
,
head_feat
,
tail_feat
,
dataset
.
possible_rating_values
)
loss
=
rating_loss_net
(
pred_ratings
,
true_relation_labels
.
to
(
dev_id
)).
mean
()
count_loss
+=
loss
.
item
()
optimizer
.
zero_grad
()
loss
.
backward
()
nn
.
utils
.
clip_grad_norm_
(
net
.
parameters
(),
args
.
train_grad_clip
)
optimizer
.
step
()
if
proc_id
==
0
and
iter_idx
==
1
:
print
(
"Total #Param of net: %d"
%
(
torch_total_param_num
(
net
)))
real_pred_ratings
=
(
th
.
softmax
(
pred_ratings
,
dim
=
1
)
*
nd_possible_rating_values
.
view
(
1
,
-
1
)).
sum
(
dim
=
1
)
rmse
=
((
real_pred_ratings
-
true_relation_ratings
.
to
(
dev_id
))
**
2
).
sum
()
count_rmse
+=
rmse
.
item
()
count_num
+=
pred_ratings
.
shape
[
0
]
if
iter_idx
%
args
.
train_log_interval
==
0
:
logging_str
=
"Iter={}, loss={:.4f}, rmse={:.4f}"
.
format
(
iter_idx
,
count_loss
/
iter_idx
,
count_rmse
/
count_num
)
count_rmse
=
0
count_num
=
0
if
iter_idx
%
args
.
train_log_interval
==
0
:
print
(
"[{}] {}"
.
format
(
proc_id
,
logging_str
))
iter_idx
+=
1
if
step
==
20
:
return
with
tqdm
.
tqdm
(
dataloader
)
as
tq
:
for
step
,
(
input_nodes
,
pair_graph
,
blocks
)
in
enumerate
(
tq
):
head_feat
,
tail_feat
,
blocks
=
load_subtensor
(
input_nodes
,
pair_graph
,
blocks
,
dataset
,
dataset
.
train_enc_graph
)
frontier
=
blocks
[
0
]
compact_g
=
flatten_etypes
(
pair_graph
,
dataset
,
'train'
).
to
(
dev_id
)
true_relation_labels
=
compact_g
.
edata
[
'label'
]
true_relation_ratings
=
compact_g
.
edata
[
'rating'
]
head_feat
=
head_feat
.
to
(
dev_id
)
tail_feat
=
tail_feat
.
to
(
dev_id
)
frontier
=
frontier
.
to
(
dev_id
)
pred_ratings
=
net
(
compact_g
,
frontier
,
head_feat
,
tail_feat
,
dataset
.
possible_rating_values
)
loss
=
rating_loss_net
(
pred_ratings
,
true_relation_labels
.
to
(
dev_id
)).
mean
()
count_loss
+=
loss
.
item
()
optimizer
.
zero_grad
()
loss
.
backward
()
nn
.
utils
.
clip_grad_norm_
(
net
.
parameters
(),
args
.
train_grad_clip
)
optimizer
.
step
()
if
proc_id
==
0
and
iter_idx
==
1
:
print
(
"Total #Param of net: %d"
%
(
torch_total_param_num
(
net
)))
real_pred_ratings
=
(
th
.
softmax
(
pred_ratings
,
dim
=
1
)
*
nd_possible_rating_values
.
view
(
1
,
-
1
)).
sum
(
dim
=
1
)
rmse
=
((
real_pred_ratings
-
true_relation_ratings
.
to
(
dev_id
))
**
2
).
sum
()
count_rmse
+=
rmse
.
item
()
count_num
+=
pred_ratings
.
shape
[
0
]
tq
.
set_postfix
({
'loss'
:
'{:.4f}'
.
format
(
count_loss
/
iter_idx
),
'rmse'
:
'{:.4f}'
.
format
(
count_rmse
/
count_num
)},
refresh
=
False
)
iter_idx
+=
1
if
epoch
>
1
:
epoch_time
=
time
.
time
()
-
t0
print
(
"Epoch {} time {}"
.
format
(
epoch
,
epoch_time
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment