Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
a6b44e72
Unverified
Commit
a6b44e72
authored
Aug 12, 2020
by
Quan (Andy) Gan
Committed by
GitHub
Aug 12, 2020
Browse files
[Model] Fix GCMC broken code (#2001)
* [Model] Fix GCMC debugging code * use the one with apply_edges
parent
4efa3320
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
57 additions
and
58 deletions
+57
-58
examples/pytorch/gcmc/model.py
examples/pytorch/gcmc/model.py
+15
-6
examples/pytorch/gcmc/train_sampling.py
examples/pytorch/gcmc/train_sampling.py
+42
-52
No files found.
examples/pytorch/gcmc/model.py
View file @
a6b44e72
...
@@ -304,7 +304,9 @@ class BiDecoder(nn.Module):
...
@@ -304,7 +304,9 @@ class BiDecoder(nn.Module):
super
(
BiDecoder
,
self
).
__init__
()
super
(
BiDecoder
,
self
).
__init__
()
self
.
_num_basis
=
num_basis
self
.
_num_basis
=
num_basis
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
self
.
P
=
nn
.
Parameter
(
th
.
randn
(
num_basis
,
in_units
,
in_units
))
self
.
Ps
=
nn
.
ParameterList
(
nn
.
Parameter
(
th
.
randn
(
in_units
,
in_units
))
for
_
in
range
(
num_basis
))
self
.
combine_basis
=
nn
.
Linear
(
self
.
_num_basis
,
num_classes
,
bias
=
False
)
self
.
combine_basis
=
nn
.
Linear
(
self
.
_num_basis
,
num_classes
,
bias
=
False
)
self
.
reset_parameters
()
self
.
reset_parameters
()
...
@@ -343,7 +345,7 @@ class BiDecoder(nn.Module):
...
@@ -343,7 +345,7 @@ class BiDecoder(nn.Module):
out
=
self
.
combine_basis
(
out
)
out
=
self
.
combine_basis
(
out
)
return
out
return
out
class
DenseBiDecoder
(
BiDecoder
):
class
DenseBiDecoder
(
nn
.
Module
):
r
"""Dense bi-linear decoder.
r
"""Dense bi-linear decoder.
Dense implementation of the bi-linear decoder used in GCMC. Suitable when
Dense implementation of the bi-linear decoder used in GCMC. Suitable when
...
@@ -366,10 +368,17 @@ class DenseBiDecoder(BiDecoder):
...
@@ -366,10 +368,17 @@ class DenseBiDecoder(BiDecoder):
num_classes
,
num_classes
,
num_basis
=
2
,
num_basis
=
2
,
dropout_rate
=
0.0
):
dropout_rate
=
0.0
):
super
(
DenseBiDecoder
,
self
).
__init__
(
in_units
,
super
().
__init__
()
num_classes
,
self
.
_num_basis
=
num_basis
num_basis
,
self
.
dropout
=
nn
.
Dropout
(
dropout_rate
)
dropout_rate
)
self
.
P
=
nn
.
Parameter
(
th
.
randn
(
num_basis
,
in_units
,
in_units
))
self
.
combine_basis
=
nn
.
Linear
(
self
.
_num_basis
,
num_classes
,
bias
=
False
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
for
p
in
self
.
parameters
():
if
p
.
dim
()
>
1
:
nn
.
init
.
xavier_uniform_
(
p
)
def
forward
(
self
,
ufeat
,
ifeat
):
def
forward
(
self
,
ufeat
,
ifeat
):
"""Forward function.
"""Forward function.
...
...
examples/pytorch/gcmc/train_sampling.py
View file @
a6b44e72
...
@@ -11,6 +11,7 @@ import random
...
@@ -11,6 +11,7 @@ import random
import
string
import
string
import
traceback
import
traceback
import
numpy
as
np
import
numpy
as
np
import
tqdm
import
torch
as
th
import
torch
as
th
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
...
@@ -20,7 +21,7 @@ from torch.nn.parallel import DistributedDataParallel
...
@@ -20,7 +21,7 @@ from torch.nn.parallel import DistributedDataParallel
from
_thread
import
start_new_thread
from
_thread
import
start_new_thread
from
functools
import
wraps
from
functools
import
wraps
from
data
import
MovieLens
from
data
import
MovieLens
from
model
import
GCMCLayer
,
DenseBiDecoder
from
model
import
GCMCLayer
,
DenseBiDecoder
,
BiDecoder
from
utils
import
get_activation
,
get_optimizer
,
torch_total_param_num
,
torch_net_info
,
MetricLogger
,
to_etype_name
from
utils
import
get_activation
,
get_optimizer
,
torch_total_param_num
,
torch_net_info
,
MetricLogger
,
to_etype_name
import
dgl
import
dgl
...
@@ -45,19 +46,14 @@ class Net(nn.Module):
...
@@ -45,19 +46,14 @@ class Net(nn.Module):
else
:
else
:
self
.
encoder
.
to
(
dev_id
)
self
.
encoder
.
to
(
dev_id
)
self
.
decoder
=
Dense
BiDecoder
(
in_units
=
args
.
gcn_out_units
,
self
.
decoder
=
BiDecoder
(
in_units
=
args
.
gcn_out_units
,
num_classes
=
len
(
args
.
rating_vals
),
num_classes
=
len
(
args
.
rating_vals
),
num_basis
=
args
.
gen_r_num_basis_func
)
num_basis
=
args
.
gen_r_num_basis_func
)
self
.
decoder
.
to
(
dev_id
)
self
.
decoder
.
to
(
dev_id
)
def
forward
(
self
,
compact_g
,
frontier
,
ufeat
,
ifeat
,
possible_rating_values
):
def
forward
(
self
,
compact_g
,
frontier
,
ufeat
,
ifeat
,
possible_rating_values
):
user_out
,
movie_out
=
self
.
encoder
(
frontier
,
ufeat
,
ifeat
)
user_out
,
movie_out
=
self
.
encoder
(
frontier
,
ufeat
,
ifeat
)
pred_ratings
=
self
.
decoder
(
compact_g
,
user_out
,
movie_out
)
head
,
tail
=
compact_g
.
edges
(
order
=
'eid'
)
head_emb
=
user_out
[
head
]
tail_emb
=
movie_out
[
tail
]
pred_ratings
=
self
.
decoder
(
head_emb
,
tail_emb
)
return
pred_ratings
return
pred_ratings
def
load_subtensor
(
input_nodes
,
pair_graph
,
blocks
,
dataset
,
parent_graph
):
def
load_subtensor
(
input_nodes
,
pair_graph
,
blocks
,
dataset
,
parent_graph
):
...
@@ -289,48 +285,42 @@ def run(proc_id, n_gpus, args, devices, dataset):
...
@@ -289,48 +285,42 @@ def run(proc_id, n_gpus, args, devices, dataset):
if
epoch
>
1
:
if
epoch
>
1
:
t0
=
time
.
time
()
t0
=
time
.
time
()
net
.
train
()
net
.
train
()
for
step
,
(
input_nodes
,
pair_graph
,
blocks
)
in
enumerate
(
dataloader
):
with
tqdm
.
tqdm
(
dataloader
)
as
tq
:
head_feat
,
tail_feat
,
blocks
=
load_subtensor
(
for
step
,
(
input_nodes
,
pair_graph
,
blocks
)
in
enumerate
(
tq
):
input_nodes
,
pair_graph
,
blocks
,
dataset
,
dataset
.
train_enc_graph
)
head_feat
,
tail_feat
,
blocks
=
load_subtensor
(
frontier
=
blocks
[
0
]
input_nodes
,
pair_graph
,
blocks
,
dataset
,
dataset
.
train_enc_graph
)
compact_g
=
flatten_etypes
(
pair_graph
,
dataset
,
'train'
).
to
(
dev_id
)
frontier
=
blocks
[
0
]
true_relation_labels
=
compact_g
.
edata
[
'label'
]
compact_g
=
flatten_etypes
(
pair_graph
,
dataset
,
'train'
).
to
(
dev_id
)
true_relation_ratings
=
compact_g
.
edata
[
'rating'
]
true_relation_labels
=
compact_g
.
edata
[
'label'
]
true_relation_ratings
=
compact_g
.
edata
[
'rating'
]
head_feat
=
head_feat
.
to
(
dev_id
)
tail_feat
=
tail_feat
.
to
(
dev_id
)
head_feat
=
head_feat
.
to
(
dev_id
)
frontier
=
frontier
.
to
(
dev_id
)
tail_feat
=
tail_feat
.
to
(
dev_id
)
frontier
=
frontier
.
to
(
dev_id
)
pred_ratings
=
net
(
compact_g
,
frontier
,
head_feat
,
tail_feat
,
dataset
.
possible_rating_values
)
loss
=
rating_loss_net
(
pred_ratings
,
true_relation_labels
.
to
(
dev_id
)).
mean
()
pred_ratings
=
net
(
compact_g
,
frontier
,
head_feat
,
tail_feat
,
dataset
.
possible_rating_values
)
count_loss
+=
loss
.
item
()
loss
=
rating_loss_net
(
pred_ratings
,
true_relation_labels
.
to
(
dev_id
)).
mean
()
optimizer
.
zero_grad
()
count_loss
+=
loss
.
item
()
loss
.
backward
()
optimizer
.
zero_grad
()
nn
.
utils
.
clip_grad_norm_
(
net
.
parameters
(),
args
.
train_grad_clip
)
loss
.
backward
()
optimizer
.
step
()
nn
.
utils
.
clip_grad_norm_
(
net
.
parameters
(),
args
.
train_grad_clip
)
optimizer
.
step
()
if
proc_id
==
0
and
iter_idx
==
1
:
print
(
"Total #Param of net: %d"
%
(
torch_total_param_num
(
net
)))
if
proc_id
==
0
and
iter_idx
==
1
:
print
(
"Total #Param of net: %d"
%
(
torch_total_param_num
(
net
)))
real_pred_ratings
=
(
th
.
softmax
(
pred_ratings
,
dim
=
1
)
*
nd_possible_rating_values
.
view
(
1
,
-
1
)).
sum
(
dim
=
1
)
real_pred_ratings
=
(
th
.
softmax
(
pred_ratings
,
dim
=
1
)
*
rmse
=
((
real_pred_ratings
-
true_relation_ratings
.
to
(
dev_id
))
**
2
).
sum
()
nd_possible_rating_values
.
view
(
1
,
-
1
)).
sum
(
dim
=
1
)
count_rmse
+=
rmse
.
item
()
rmse
=
((
real_pred_ratings
-
true_relation_ratings
.
to
(
dev_id
))
**
2
).
sum
()
count_num
+=
pred_ratings
.
shape
[
0
]
count_rmse
+=
rmse
.
item
()
count_num
+=
pred_ratings
.
shape
[
0
]
if
iter_idx
%
args
.
train_log_interval
==
0
:
logging_str
=
"Iter={}, loss={:.4f}, rmse={:.4f}"
.
format
(
tq
.
set_postfix
({
'loss'
:
'{:.4f}'
.
format
(
count_loss
/
iter_idx
),
iter_idx
,
count_loss
/
iter_idx
,
count_rmse
/
count_num
)
'rmse'
:
'{:.4f}'
.
format
(
count_rmse
/
count_num
)},
count_rmse
=
0
refresh
=
False
)
count_num
=
0
iter_idx
+=
1
if
iter_idx
%
args
.
train_log_interval
==
0
:
print
(
"[{}] {}"
.
format
(
proc_id
,
logging_str
))
iter_idx
+=
1
if
step
==
20
:
return
if
epoch
>
1
:
if
epoch
>
1
:
epoch_time
=
time
.
time
()
-
t0
epoch_time
=
time
.
time
()
-
t0
print
(
"Epoch {} time {}"
.
format
(
epoch
,
epoch_time
))
print
(
"Epoch {} time {}"
.
format
(
epoch
,
epoch_time
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment