Unverified Commit a6b44e72 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Model] Fix GCMC broken code (#2001)

* [Model] Fix GCMC debugging code

* use the one with apply_edges
parent 4efa3320
...@@ -304,7 +304,9 @@ class BiDecoder(nn.Module): ...@@ -304,7 +304,9 @@ class BiDecoder(nn.Module):
super(BiDecoder, self).__init__() super(BiDecoder, self).__init__()
self._num_basis = num_basis self._num_basis = num_basis
self.dropout = nn.Dropout(dropout_rate) self.dropout = nn.Dropout(dropout_rate)
self.P = nn.Parameter(th.randn(num_basis, in_units, in_units)) self.Ps = nn.ParameterList(
nn.Parameter(th.randn(in_units, in_units))
for _ in range(num_basis))
self.combine_basis = nn.Linear(self._num_basis, num_classes, bias=False) self.combine_basis = nn.Linear(self._num_basis, num_classes, bias=False)
self.reset_parameters() self.reset_parameters()
...@@ -343,7 +345,7 @@ class BiDecoder(nn.Module): ...@@ -343,7 +345,7 @@ class BiDecoder(nn.Module):
out = self.combine_basis(out) out = self.combine_basis(out)
return out return out
class DenseBiDecoder(BiDecoder): class DenseBiDecoder(nn.Module):
r"""Dense bi-linear decoder. r"""Dense bi-linear decoder.
Dense implementation of the bi-linear decoder used in GCMC. Suitable when Dense implementation of the bi-linear decoder used in GCMC. Suitable when
...@@ -366,10 +368,17 @@ class DenseBiDecoder(BiDecoder): ...@@ -366,10 +368,17 @@ class DenseBiDecoder(BiDecoder):
num_classes, num_classes,
num_basis=2, num_basis=2,
dropout_rate=0.0): dropout_rate=0.0):
super(DenseBiDecoder, self).__init__(in_units, super().__init__()
num_classes, self._num_basis = num_basis
num_basis, self.dropout = nn.Dropout(dropout_rate)
dropout_rate) self.P = nn.Parameter(th.randn(num_basis, in_units, in_units))
self.combine_basis = nn.Linear(self._num_basis, num_classes, bias=False)
self.reset_parameters()
def reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, ufeat, ifeat): def forward(self, ufeat, ifeat):
"""Forward function. """Forward function.
......
...@@ -11,6 +11,7 @@ import random ...@@ -11,6 +11,7 @@ import random
import string import string
import traceback import traceback
import numpy as np import numpy as np
import tqdm
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.multiprocessing as mp import torch.multiprocessing as mp
...@@ -20,7 +21,7 @@ from torch.nn.parallel import DistributedDataParallel ...@@ -20,7 +21,7 @@ from torch.nn.parallel import DistributedDataParallel
from _thread import start_new_thread from _thread import start_new_thread
from functools import wraps from functools import wraps
from data import MovieLens from data import MovieLens
from model import GCMCLayer, DenseBiDecoder from model import GCMCLayer, DenseBiDecoder, BiDecoder
from utils import get_activation, get_optimizer, torch_total_param_num, torch_net_info, MetricLogger, to_etype_name from utils import get_activation, get_optimizer, torch_total_param_num, torch_net_info, MetricLogger, to_etype_name
import dgl import dgl
...@@ -45,19 +46,14 @@ class Net(nn.Module): ...@@ -45,19 +46,14 @@ class Net(nn.Module):
else: else:
self.encoder.to(dev_id) self.encoder.to(dev_id)
self.decoder = DenseBiDecoder(in_units=args.gcn_out_units, self.decoder = BiDecoder(in_units=args.gcn_out_units,
num_classes=len(args.rating_vals), num_classes=len(args.rating_vals),
num_basis=args.gen_r_num_basis_func) num_basis=args.gen_r_num_basis_func)
self.decoder.to(dev_id) self.decoder.to(dev_id)
def forward(self, compact_g, frontier, ufeat, ifeat, possible_rating_values): def forward(self, compact_g, frontier, ufeat, ifeat, possible_rating_values):
user_out, movie_out = self.encoder(frontier, ufeat, ifeat) user_out, movie_out = self.encoder(frontier, ufeat, ifeat)
pred_ratings = self.decoder(compact_g, user_out, movie_out)
head, tail = compact_g.edges(order='eid')
head_emb = user_out[head]
tail_emb = movie_out[tail]
pred_ratings = self.decoder(head_emb, tail_emb)
return pred_ratings return pred_ratings
def load_subtensor(input_nodes, pair_graph, blocks, dataset, parent_graph): def load_subtensor(input_nodes, pair_graph, blocks, dataset, parent_graph):
...@@ -289,48 +285,42 @@ def run(proc_id, n_gpus, args, devices, dataset): ...@@ -289,48 +285,42 @@ def run(proc_id, n_gpus, args, devices, dataset):
if epoch > 1: if epoch > 1:
t0 = time.time() t0 = time.time()
net.train() net.train()
for step, (input_nodes, pair_graph, blocks) in enumerate(dataloader): with tqdm.tqdm(dataloader) as tq:
head_feat, tail_feat, blocks = load_subtensor( for step, (input_nodes, pair_graph, blocks) in enumerate(tq):
input_nodes, pair_graph, blocks, dataset, dataset.train_enc_graph) head_feat, tail_feat, blocks = load_subtensor(
frontier = blocks[0] input_nodes, pair_graph, blocks, dataset, dataset.train_enc_graph)
compact_g = flatten_etypes(pair_graph, dataset, 'train').to(dev_id) frontier = blocks[0]
true_relation_labels = compact_g.edata['label'] compact_g = flatten_etypes(pair_graph, dataset, 'train').to(dev_id)
true_relation_ratings = compact_g.edata['rating'] true_relation_labels = compact_g.edata['label']
true_relation_ratings = compact_g.edata['rating']
head_feat = head_feat.to(dev_id)
tail_feat = tail_feat.to(dev_id) head_feat = head_feat.to(dev_id)
frontier = frontier.to(dev_id) tail_feat = tail_feat.to(dev_id)
frontier = frontier.to(dev_id)
pred_ratings = net(compact_g, frontier, head_feat, tail_feat, dataset.possible_rating_values)
loss = rating_loss_net(pred_ratings, true_relation_labels.to(dev_id)).mean() pred_ratings = net(compact_g, frontier, head_feat, tail_feat, dataset.possible_rating_values)
count_loss += loss.item() loss = rating_loss_net(pred_ratings, true_relation_labels.to(dev_id)).mean()
optimizer.zero_grad() count_loss += loss.item()
loss.backward() optimizer.zero_grad()
nn.utils.clip_grad_norm_(net.parameters(), args.train_grad_clip) loss.backward()
optimizer.step() nn.utils.clip_grad_norm_(net.parameters(), args.train_grad_clip)
optimizer.step()
if proc_id == 0 and iter_idx == 1:
print("Total #Param of net: %d" % (torch_total_param_num(net))) if proc_id == 0 and iter_idx == 1:
print("Total #Param of net: %d" % (torch_total_param_num(net)))
real_pred_ratings = (th.softmax(pred_ratings, dim=1) *
nd_possible_rating_values.view(1, -1)).sum(dim=1) real_pred_ratings = (th.softmax(pred_ratings, dim=1) *
rmse = ((real_pred_ratings - true_relation_ratings.to(dev_id)) ** 2).sum() nd_possible_rating_values.view(1, -1)).sum(dim=1)
count_rmse += rmse.item() rmse = ((real_pred_ratings - true_relation_ratings.to(dev_id)) ** 2).sum()
count_num += pred_ratings.shape[0] count_rmse += rmse.item()
count_num += pred_ratings.shape[0]
if iter_idx % args.train_log_interval == 0:
logging_str = "Iter={}, loss={:.4f}, rmse={:.4f}".format( tq.set_postfix({'loss': '{:.4f}'.format(count_loss / iter_idx),
iter_idx, count_loss/iter_idx, count_rmse/count_num) 'rmse': '{:.4f}'.format(count_rmse / count_num)},
count_rmse = 0 refresh=False)
count_num = 0
iter_idx += 1
if iter_idx % args.train_log_interval == 0:
print("[{}] {}".format(proc_id, logging_str))
iter_idx += 1
if step == 20:
return
if epoch > 1: if epoch > 1:
epoch_time = time.time() - t0 epoch_time = time.time() - t0
print("Epoch {} time {}".format(epoch, epoch_time)) print("Epoch {} time {}".format(epoch, epoch_time))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment