Unverified Commit a6b44e72 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Model] Fix GCMC broken code (#2001)

* [Model] Fix GCMC debugging code

* use the one with apply_edges
parent 4efa3320
......@@ -304,7 +304,9 @@ class BiDecoder(nn.Module):
super(BiDecoder, self).__init__()
self._num_basis = num_basis
self.dropout = nn.Dropout(dropout_rate)
self.P = nn.Parameter(th.randn(num_basis, in_units, in_units))
self.Ps = nn.ParameterList(
nn.Parameter(th.randn(in_units, in_units))
for _ in range(num_basis))
self.combine_basis = nn.Linear(self._num_basis, num_classes, bias=False)
self.reset_parameters()
......@@ -343,7 +345,7 @@ class BiDecoder(nn.Module):
out = self.combine_basis(out)
return out
class DenseBiDecoder(BiDecoder):
class DenseBiDecoder(nn.Module):
r"""Dense bi-linear decoder.
Dense implementation of the bi-linear decoder used in GCMC. Suitable when
......@@ -366,10 +368,17 @@ class DenseBiDecoder(BiDecoder):
num_classes,
num_basis=2,
dropout_rate=0.0):
super(DenseBiDecoder, self).__init__(in_units,
num_classes,
num_basis,
dropout_rate)
super().__init__()
self._num_basis = num_basis
self.dropout = nn.Dropout(dropout_rate)
self.P = nn.Parameter(th.randn(num_basis, in_units, in_units))
self.combine_basis = nn.Linear(self._num_basis, num_classes, bias=False)
self.reset_parameters()
def reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, ufeat, ifeat):
"""Forward function.
......
......@@ -11,6 +11,7 @@ import random
import string
import traceback
import numpy as np
import tqdm
import torch as th
import torch.nn as nn
import torch.multiprocessing as mp
......@@ -20,7 +21,7 @@ from torch.nn.parallel import DistributedDataParallel
from _thread import start_new_thread
from functools import wraps
from data import MovieLens
from model import GCMCLayer, DenseBiDecoder
from model import GCMCLayer, DenseBiDecoder, BiDecoder
from utils import get_activation, get_optimizer, torch_total_param_num, torch_net_info, MetricLogger, to_etype_name
import dgl
......@@ -45,19 +46,14 @@ class Net(nn.Module):
else:
self.encoder.to(dev_id)
self.decoder = DenseBiDecoder(in_units=args.gcn_out_units,
self.decoder = BiDecoder(in_units=args.gcn_out_units,
num_classes=len(args.rating_vals),
num_basis=args.gen_r_num_basis_func)
self.decoder.to(dev_id)
def forward(self, compact_g, frontier, ufeat, ifeat, possible_rating_values):
user_out, movie_out = self.encoder(frontier, ufeat, ifeat)
head, tail = compact_g.edges(order='eid')
head_emb = user_out[head]
tail_emb = movie_out[tail]
pred_ratings = self.decoder(head_emb, tail_emb)
pred_ratings = self.decoder(compact_g, user_out, movie_out)
return pred_ratings
def load_subtensor(input_nodes, pair_graph, blocks, dataset, parent_graph):
......@@ -289,7 +285,8 @@ def run(proc_id, n_gpus, args, devices, dataset):
if epoch > 1:
t0 = time.time()
net.train()
for step, (input_nodes, pair_graph, blocks) in enumerate(dataloader):
with tqdm.tqdm(dataloader) as tq:
for step, (input_nodes, pair_graph, blocks) in enumerate(tq):
head_feat, tail_feat, blocks = load_subtensor(
input_nodes, pair_graph, blocks, dataset, dataset.train_enc_graph)
frontier = blocks[0]
......@@ -318,19 +315,12 @@ def run(proc_id, n_gpus, args, devices, dataset):
count_rmse += rmse.item()
count_num += pred_ratings.shape[0]
if iter_idx % args.train_log_interval == 0:
logging_str = "Iter={}, loss={:.4f}, rmse={:.4f}".format(
iter_idx, count_loss/iter_idx, count_rmse/count_num)
count_rmse = 0
count_num = 0
if iter_idx % args.train_log_interval == 0:
print("[{}] {}".format(proc_id, logging_str))
tq.set_postfix({'loss': '{:.4f}'.format(count_loss / iter_idx),
'rmse': '{:.4f}'.format(count_rmse / count_num)},
refresh=False)
iter_idx += 1
if step == 20:
return
if epoch > 1:
epoch_time = time.time() - t0
print("Epoch {} time {}".format(epoch, epoch_time))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment