Unverified Commit a6b44e72 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Model] Fix GCMC broken code (#2001)

* [Model] Fix GCMC debugging code

* use the one with apply_edges
parent 4efa3320
......@@ -304,7 +304,9 @@ class BiDecoder(nn.Module):
super(BiDecoder, self).__init__()
self._num_basis = num_basis
self.dropout = nn.Dropout(dropout_rate)
self.P = nn.Parameter(th.randn(num_basis, in_units, in_units))
self.Ps = nn.ParameterList(
nn.Parameter(th.randn(in_units, in_units))
for _ in range(num_basis))
self.combine_basis = nn.Linear(self._num_basis, num_classes, bias=False)
self.reset_parameters()
......@@ -343,7 +345,7 @@ class BiDecoder(nn.Module):
out = self.combine_basis(out)
return out
class DenseBiDecoder(BiDecoder):
class DenseBiDecoder(nn.Module):
r"""Dense bi-linear decoder.
Dense implementation of the bi-linear decoder used in GCMC. Suitable when
......@@ -366,10 +368,17 @@ class DenseBiDecoder(BiDecoder):
num_classes,
num_basis=2,
dropout_rate=0.0):
super(DenseBiDecoder, self).__init__(in_units,
num_classes,
num_basis,
dropout_rate)
super().__init__()
self._num_basis = num_basis
self.dropout = nn.Dropout(dropout_rate)
self.P = nn.Parameter(th.randn(num_basis, in_units, in_units))
self.combine_basis = nn.Linear(self._num_basis, num_classes, bias=False)
self.reset_parameters()
def reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, ufeat, ifeat):
"""Forward function.
......
......@@ -11,6 +11,7 @@ import random
import string
import traceback
import numpy as np
import tqdm
import torch as th
import torch.nn as nn
import torch.multiprocessing as mp
......@@ -20,7 +21,7 @@ from torch.nn.parallel import DistributedDataParallel
from _thread import start_new_thread
from functools import wraps
from data import MovieLens
from model import GCMCLayer, DenseBiDecoder
from model import GCMCLayer, DenseBiDecoder, BiDecoder
from utils import get_activation, get_optimizer, torch_total_param_num, torch_net_info, MetricLogger, to_etype_name
import dgl
......@@ -45,19 +46,14 @@ class Net(nn.Module):
else:
self.encoder.to(dev_id)
self.decoder = DenseBiDecoder(in_units=args.gcn_out_units,
num_classes=len(args.rating_vals),
num_basis=args.gen_r_num_basis_func)
self.decoder = BiDecoder(in_units=args.gcn_out_units,
num_classes=len(args.rating_vals),
num_basis=args.gen_r_num_basis_func)
self.decoder.to(dev_id)
def forward(self, compact_g, frontier, ufeat, ifeat, possible_rating_values):
user_out, movie_out = self.encoder(frontier, ufeat, ifeat)
head, tail = compact_g.edges(order='eid')
head_emb = user_out[head]
tail_emb = movie_out[tail]
pred_ratings = self.decoder(head_emb, tail_emb)
pred_ratings = self.decoder(compact_g, user_out, movie_out)
return pred_ratings
def load_subtensor(input_nodes, pair_graph, blocks, dataset, parent_graph):
......@@ -289,48 +285,42 @@ def run(proc_id, n_gpus, args, devices, dataset):
if epoch > 1:
t0 = time.time()
net.train()
for step, (input_nodes, pair_graph, blocks) in enumerate(dataloader):
head_feat, tail_feat, blocks = load_subtensor(
input_nodes, pair_graph, blocks, dataset, dataset.train_enc_graph)
frontier = blocks[0]
compact_g = flatten_etypes(pair_graph, dataset, 'train').to(dev_id)
true_relation_labels = compact_g.edata['label']
true_relation_ratings = compact_g.edata['rating']
head_feat = head_feat.to(dev_id)
tail_feat = tail_feat.to(dev_id)
frontier = frontier.to(dev_id)
pred_ratings = net(compact_g, frontier, head_feat, tail_feat, dataset.possible_rating_values)
loss = rating_loss_net(pred_ratings, true_relation_labels.to(dev_id)).mean()
count_loss += loss.item()
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(net.parameters(), args.train_grad_clip)
optimizer.step()
if proc_id == 0 and iter_idx == 1:
print("Total #Param of net: %d" % (torch_total_param_num(net)))
real_pred_ratings = (th.softmax(pred_ratings, dim=1) *
nd_possible_rating_values.view(1, -1)).sum(dim=1)
rmse = ((real_pred_ratings - true_relation_ratings.to(dev_id)) ** 2).sum()
count_rmse += rmse.item()
count_num += pred_ratings.shape[0]
if iter_idx % args.train_log_interval == 0:
logging_str = "Iter={}, loss={:.4f}, rmse={:.4f}".format(
iter_idx, count_loss/iter_idx, count_rmse/count_num)
count_rmse = 0
count_num = 0
if iter_idx % args.train_log_interval == 0:
print("[{}] {}".format(proc_id, logging_str))
iter_idx += 1
if step == 20:
return
with tqdm.tqdm(dataloader) as tq:
for step, (input_nodes, pair_graph, blocks) in enumerate(tq):
head_feat, tail_feat, blocks = load_subtensor(
input_nodes, pair_graph, blocks, dataset, dataset.train_enc_graph)
frontier = blocks[0]
compact_g = flatten_etypes(pair_graph, dataset, 'train').to(dev_id)
true_relation_labels = compact_g.edata['label']
true_relation_ratings = compact_g.edata['rating']
head_feat = head_feat.to(dev_id)
tail_feat = tail_feat.to(dev_id)
frontier = frontier.to(dev_id)
pred_ratings = net(compact_g, frontier, head_feat, tail_feat, dataset.possible_rating_values)
loss = rating_loss_net(pred_ratings, true_relation_labels.to(dev_id)).mean()
count_loss += loss.item()
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(net.parameters(), args.train_grad_clip)
optimizer.step()
if proc_id == 0 and iter_idx == 1:
print("Total #Param of net: %d" % (torch_total_param_num(net)))
real_pred_ratings = (th.softmax(pred_ratings, dim=1) *
nd_possible_rating_values.view(1, -1)).sum(dim=1)
rmse = ((real_pred_ratings - true_relation_ratings.to(dev_id)) ** 2).sum()
count_rmse += rmse.item()
count_num += pred_ratings.shape[0]
tq.set_postfix({'loss': '{:.4f}'.format(count_loss / iter_idx),
'rmse': '{:.4f}'.format(count_rmse / count_num)},
refresh=False)
iter_idx += 1
if epoch > 1:
epoch_time = time.time() - t0
print("Epoch {} time {}".format(epoch, epoch_time))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment