Unverified Commit 28117cd9 authored by xiang song(charlie.song)'s avatar xiang song(charlie.song) Committed by GitHub
Browse files

[Example] GCMC with sampling (#1296)



* gcmc example

* Update Readme

* Add multiprocess support

* Fix

* Multigpu + dataloader

* Delete some dead code

* Delete more

* upd

* Add README

* upd

* Upd

* combine full batch and sample GCMCLayer, use HeteroCov

* Fix

* Update Readme

* udp

* Fix typo

* Add cpu run

* some fix and docstring
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-51-214.ec2.internal>
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-63-71.ec2.internal>
parent bd1e48a5
......@@ -11,41 +11,217 @@ Credit: Jiani Zhang ([@jennyzhang0215](https://github.com/jennyzhang0215))
## Dependencies
* PyTorch 1.2+
* pandas
* torchtext 0.4+
* torchtext 0.4+ (if using user and item contents as node features)
## Data
Supported datasets: ml-100k, ml-1m, ml-10m
## How to run
### Train with full-graph
ml-100k, no feature
```bash
python train.py --data_name=ml-100k --use_one_hot_fea --gcn_agg_accum=stack
python3 train.py --data_name=ml-100k --use_one_hot_fea --gcn_agg_accum=stack
```
Results: RMSE=0.9088 (0.910 reported)
Speed: 0.0195s/epoch (vanilla implementation: 0.1008s/epoch)
Speed: 0.0410s/epoch (vanilla implementation: 0.1008s/epoch)
ml-100k, with feature
```bash
python train.py --data_name=ml-100k --gcn_agg_accum=stack
python3 train.py --data_name=ml-100k --gcn_agg_accum=stack
```
Results: RMSE=0.9448 (0.905 reported)
ml-1m, no feature
```bash
python train.py --data_name=ml-1m --gcn_agg_accum=sum --use_one_hot_fea
python3 train.py --data_name=ml-1m --gcn_agg_accum=sum --use_one_hot_fea
```
Results: RMSE=0.8377 (0.832 reported)
Speed: 0.0557s/epoch (vanilla implementation: 1.538s/epoch)
Speed: 0.0844s/epoch (vanilla implementation: 1.538s/epoch)
ml-10m, no feature
```bash
python train.py --data_name=ml-10m --gcn_agg_accum=stack --gcn_dropout=0.3 \
python3 train.py --data_name=ml-10m --gcn_agg_accum=stack --gcn_dropout=0.3 \
--train_lr=0.001 --train_min_lr=0.0001 --train_max_iter=15000 \
--use_one_hot_fea --gen_r_num_basis_func=4
```
Results: RMSE=0.7800 (0.777 reported)
Speed: 0.9207/epoch (vanilla implementation: OOM)
Speed: 1.1982/epoch (vanilla implementation: OOM)
Testbed: EC2 p3.2xlarge instance(Amazon Linux 2)
### Train with minibatch on a single GPU
ml-100k, no feature
```bash
python3 train_sampling.py --data_name=ml-100k \
--use_one_hot_fea \
--gcn_agg_accum=stack \
--gpu 0
Testbed: EC2 p3.2xlarge instance(Amazon Linux 2)
\ No newline at end of file
```
ml-100k, no feature with mix_cpu_gpu run, for mix_cpu_gpu run with no feature, the W_r is stored in CPU by default other than in GPU.
```bash
python3 train_sampling.py --data_name=ml-100k \
--use_one_hot_fea \
--gcn_agg_accum=stack \
--mix_cpu_gpu \
--gpu 0
```
Results: RMSE=0.9380
Speed: 1.059s/epoch (Run with 70 epoches)
Speed: 1.046s/epoch (mix_cpu_gpu)
ml-100k, with feature
```bash
python3 train_sampling.py --data_name=ml-100k \
--gcn_agg_accum=stack \
--train_max_epoch 90 \
--gpu 0
```
Results: RMSE=0.9574
ml-1m, no feature
```bash
python3 train_sampling.py --data_name=ml-1m \
--gcn_agg_accum=sum \
--use_one_hot_fea \
--train_max_epoch 160 \
--gpu 0
```
ml-1m, no feature with mix_cpu_gpu run
```bash
python3 train_sampling.py --data_name=ml-1m \
--gcn_agg_accum=sum \
--use_one_hot_fea \
--train_max_epoch 60 \
--mix_cpu_gpu \
--gpu 0
```
Results: RMSE=0.8632
Speed: 7.852s/epoch (Run with 60 epoches)
Speed: 7.788s/epoch (mix_cpu_gpu)
ml-10m, no feature
```bash
python3 train_sampling.py --data_name=ml-10m \
--gcn_agg_accum=stack \
--gcn_dropout=0.3 \
--train_lr=0.001 \
--train_min_lr=0.0001 \
--train_max_epoch=60 \
--use_one_hot_fea \
--gen_r_num_basis_func=4 \
--gpu 0
```
ml-10m, no feature with mix_cpu_gpu run
```bash
python3 train_sampling.py --data_name=ml-10m \
--gcn_agg_accum=stack \
--gcn_dropout=0.3 \
--train_lr=0.001 \
--train_min_lr=0.0001 \
--train_max_epoch=60 \
--use_one_hot_fea \
--gen_r_num_basis_func=4 \
--mix_cpu_gpu \
--gpu 0
```
Results: RMSE=0.8050
Speed: 394.304s/epoch (Run with 60 epoches)
Speed: 408.749s/epoch (mix_cpu_gpu)
Testbed: EC2 p3.2xlarge instance
### Train with minibatch on multi-GPU
ml-100k, no feature
```bash
python train_sampling.py --data_name=ml-100k \
--gcn_agg_accum=stack \
--train_max_epoch 30 \
--train_lr 0.02 \
--use_one_hot_fea \
--gpu 0,1,2,3,4,5,6,7
```
ml-100k, no feature with mix_cpu_gpu run
```bash
python train_sampling.py --data_name=ml-100k \
--gcn_agg_accum=stack \
--train_max_epoch 30 \
--train_lr 0.02 \
--use_one_hot_fea \
--mix_cpu_gpu \
--gpu 0,1,2,3,4,5,6,7
```
Result: RMSE=0.9397
Speed: 1.202s/epoch (Run with only 30 epoches)
Speed: 1.245/epoch (mix_cpu_gpu)
ml-100k, with feature
```bash
python train_sampling.py --data_name=ml-100k \
--gcn_agg_accum=stack \
--train_max_epoch 30 \
--gpu 0,1,2,3,4,5,6,7
```
Result: RMSE=0.9655
Speed: 1.265/epoch (Run with 30 epoches)
ml-1m, no feature
```bash
python train_sampling.py --data_name=ml-1m \
--gcn_agg_accum=sum \
--train_max_epoch 40 \
--use_one_hot_fea \
--gpu 0,1,2,3,4,5,6,7
```
ml-1m, no feature with mix_cpu_gpu run
```bash
python train_sampling.py --data_name=ml-1m \
--gcn_agg_accum=sum \
--train_max_epoch 40 \
--use_one_hot_fea \
--mix_cpu_gpu \
--gpu 0,1,2,3,4,5,6,7
```
Results: RMSE=0.8621
Speed: 11.612s/epoch (Run with 40 epoches)
Speed: 12.483s/epoch (mix_cpu_gpu)
ml-10m, no feature
```bash
python train_sampling.py --data_name=ml-10m \
--gcn_agg_accum=stack \
--gcn_dropout=0.3 \
--train_lr=0.001 \
--train_min_lr=0.0001 \
--train_max_epoch=30 \
--use_one_hot_fea \
--gen_r_num_basis_func=4 \
--gpu 0,1,2,3,4,5,6,7
```
ml-10m, no feature with mix_cpu_gpu run
```bash
python train_sampling.py --data_name=ml-10m \
--gcn_agg_accum=stack \
--gcn_dropout=0.3 \
--train_lr=0.001 \
--train_min_lr=0.0001 \
--train_max_epoch=30 \
--use_one_hot_fea \
--gen_r_num_basis_func=4 \
--mix_cpu_gpu \
--gpu 0,1,2,3,4,5,6,7
```
Results: RMSE=0.8084
Speed: 632.868s/epoch (Run with 30 epoches)
Speed: 633.397s/epoch (mix_cpu_gpu)
Testbed: EC2 p3.16xlarge instance
### Train with minibatch on CPU
ml-100k, no feature
```bash
python3 train_sampling.py --data_name=ml-100k \
--use_one_hot_fea \
--gcn_agg_accum=stack \
--gpu -1
```
Speed 1.591s/epoch
Testbed: EC2 r5.xlarge instance
......@@ -5,8 +5,6 @@ import re
import pandas as pd
import scipy.sparse as sp
import torch as th
from torchtext import data
from torchtext.vocab import GloVe
import dgl
from dgl.data.utils import download, extract_archive, get_download_dir
......@@ -84,6 +82,8 @@ class MovieLens(object):
Dataset name. Could be "ml-100k", "ml-1m", "ml-10m"
device : torch.device
Device context
mix_cpu_gpu : boo, optional
If true, the ``user_feature`` attribute is stored in CPU
use_one_hot_fea : bool, optional
If true, the ``user_feature`` attribute is None, representing an one-hot identity
matrix. (Default: False)
......@@ -96,7 +96,8 @@ class MovieLens(object):
Ratio of validation data
"""
def __init__(self, name, device, use_one_hot_fea=False, symm=True,
def __init__(self, name, device, mix_cpu_gpu=False,
use_one_hot_fea=False, symm=True,
test_ratio=0.1, valid_ratio=0.1):
self._name = name
self._device = device
......@@ -164,8 +165,13 @@ class MovieLens(object):
self.user_feature = None
self.movie_feature = None
else:
self.user_feature = th.FloatTensor(self._process_user_fea()).to(device)
self.movie_feature = th.FloatTensor(self._process_movie_fea()).to(device)
# if mix_cpu_gpu, we put features in CPU
if mix_cpu_gpu:
self.user_feature = th.FloatTensor(self._process_user_fea())
self.movie_feature = th.FloatTensor(self._process_movie_fea())
else:
self.user_feature = th.FloatTensor(self._process_user_fea()).to(self._device)
self.movie_feature = th.FloatTensor(self._process_movie_fea()).to(self._device)
if self.user_feature is None:
self.user_feature_shape = (self.num_user, self.num_user)
self.movie_feature_shape = (self.num_movie, self.num_movie)
......@@ -204,6 +210,7 @@ class MovieLens(object):
def _npairs(graph):
rst = 0
for r in self.possible_rating_values:
r = str(r).replace('.', '_')
rst += graph.number_of_edges(str(r))
return rst
......@@ -245,9 +252,10 @@ class MovieLens(object):
ridx = np.where(rating_values == rating)
rrow = rating_row[ridx]
rcol = rating_col[ridx]
bg = dgl.bipartite((rrow, rcol), 'user', str(rating), 'movie',
rating = str(rating).replace('.', '_')
bg = dgl.bipartite((rrow, rcol), 'user', rating, 'movie',
num_nodes=(self._num_user, self._num_movie))
rev_bg = dgl.bipartite((rcol, rrow), 'movie', 'rev-%s' % str(rating), 'user',
rev_bg = dgl.bipartite((rcol, rrow), 'movie', 'rev-%s' % rating, 'user',
num_nodes=(self._num_movie, self._num_user))
rating_graphs.append(bg)
rating_graphs.append(rev_bg)
......@@ -267,7 +275,7 @@ class MovieLens(object):
movie_ci = []
movie_cj = []
for r in self.possible_rating_values:
r = str(r)
r = str(r).replace('.', '_')
user_ci.append(graph['rev-%s' % r].in_degrees())
movie_ci.append(graph[r].in_degrees())
if self._symm:
......@@ -494,6 +502,8 @@ class MovieLens(object):
Generate movie features by concatenating embedding and the year
"""
import torchtext
if self._name == 'ml-100k':
GENRES = GENRES_ML_100K
elif self._name == 'ml-1m':
......@@ -503,8 +513,8 @@ class MovieLens(object):
else:
raise NotImplementedError
TEXT = data.Field(tokenize='spacy')
embedding = GloVe(name='840B', dim=300)
TEXT = torchtext.data.Field(tokenize='spacy')
embedding = torchtext.vocab.GloVe(name='840B', dim=300)
title_embedding = np.zeros(shape=(self.movie_info.shape[0], 300), dtype=np.float32)
release_years = np.zeros(shape=(self.movie_info.shape[0], 1), dtype=np.float32)
......
"""NN modules"""
import torch as th
import torch.nn as nn
from torch.nn import init
import dgl.function as fn
import dgl.nn.pytorch as dglnn
from utils import get_activation
class GCMCGraphConv(nn.Module):
"""Graph convolution module used in the GCMC model.
Parameters
----------
in_feats : int
Input feature size.
out_feats : int
Output feature size.
weight : bool, optional
If True, apply a linear layer. Otherwise, aggregating the messages
without a weight matrix or with an shared weight provided by caller.
device: str, optional
Which device to put data in. Useful in mix_cpu_gpu training and
multi-gpu training
"""
def __init__(self,
in_feats,
out_feats,
weight=True,
device=None,
dropout_rate=0.0):
super(GCMCGraphConv, self).__init__()
self._in_feats = in_feats
self._out_feats = out_feats
self.device = device
self.dropout = nn.Dropout(dropout_rate)
if weight:
self.weight = nn.Parameter(th.Tensor(in_feats, out_feats))
else:
self.register_parameter('weight', None)
self.reset_parameters()
def reset_parameters(self):
"""Reinitialize learnable parameters."""
if self.weight is not None:
init.xavier_uniform_(self.weight)
def forward(self, graph, feat, weight=None):
"""Compute graph convolution.
Normalizer constant :math:`c_{ij}` is stored as two node data "ci"
and "cj".
Parameters
----------
graph : DGLGraph
The graph.
feat : torch.Tensor
The input feature
weight : torch.Tensor, optional
Optional external weight tensor.
dropout : torch.nn.Dropout, optional
Optional external dropout layer.
Returns
-------
torch.Tensor
The output feature
"""
with graph.local_scope():
cj = graph.srcdata['cj']
ci = graph.dstdata['ci']
if self.device is not None:
cj = cj.to(self.device)
ci = ci.to(self.device)
if weight is not None:
if self.weight is not None:
raise DGLError('External weight is provided while at the same time the'
' module has defined its own weight parameter. Please'
' create the module with flag weight=False.')
else:
weight = self.weight
if weight is not None:
feat = dot_or_identity(feat, weight, self.device)
feat = feat * self.dropout(cj)
graph.srcdata['h'] = feat
graph.update_all(fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'))
rst = graph.dstdata['h']
rst = rst * ci
return rst
class GCMCLayer(nn.Module):
r"""GCMC layer
......@@ -49,6 +138,9 @@ class GCMCLayer(nn.Module):
If true, user node and movie node share the same set of parameters.
Require ``user_in_units`` and ``move_in_units`` to be the same.
(Default: False)
device: str, optional
Which device to put data in. Useful in mix_cpu_gpu training and
multi-gpu training
"""
def __init__(self,
rating_vals,
......@@ -60,7 +152,8 @@ class GCMCLayer(nn.Module):
agg='stack', # or 'sum'
agg_act=None,
out_act=None,
share_user_item_param=False):
share_user_item_param=False,
device=None):
super(GCMCLayer, self).__init__()
self.rating_vals = rating_vals
self.agg = agg
......@@ -77,19 +170,57 @@ class GCMCLayer(nn.Module):
msg_units = msg_units // len(rating_vals)
self.dropout = nn.Dropout(dropout_rate)
self.W_r = nn.ParameterDict()
subConv = {}
for rating in rating_vals:
# PyTorch parameter name can't contain "."
rating = str(rating).replace('.', '_')
rev_rating = 'rev-%s' % rating
if share_user_item_param and user_in_units == movie_in_units:
self.W_r[rating] = nn.Parameter(th.randn(user_in_units, msg_units))
self.W_r['rev-%s' % rating] = self.W_r[rating]
subConv[rating] = GCMCGraphConv(user_in_units,
msg_units,
weight=False,
device=device,
dropout_rate=dropout_rate)
subConv[rev_rating] = GCMCGraphConv(user_in_units,
msg_units,
weight=False,
device=device,
dropout_rate=dropout_rate)
else:
self.W_r[rating] = nn.Parameter(th.randn(user_in_units, msg_units))
self.W_r['rev-%s' % rating] = nn.Parameter(th.randn(movie_in_units, msg_units))
self.W_r = None
subConv[rating] = GCMCGraphConv(user_in_units,
msg_units,
weight=True,
device=device,
dropout_rate=dropout_rate)
subConv[rev_rating] = GCMCGraphConv(movie_in_units,
msg_units,
weight=True,
device=device,
dropout_rate=dropout_rate)
self.conv = dglnn.HeteroGraphConv(subConv, aggregate=agg)
self.agg_act = get_activation(agg_act)
self.out_act = get_activation(out_act)
self.device = device
self.reset_parameters()
def partial_to(self, device):
"""Put parameters into device except W_r
Parameters
----------
device : torch device
Which device the parameters are put in.
"""
assert device == self.device
if device is not None:
self.ufc.cuda(device)
if self.share_user_item_param is False:
self.ifc.cuda(device)
self.dropout.cuda(device)
def reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
......@@ -98,9 +229,6 @@ class GCMCLayer(nn.Module):
def forward(self, graph, ufeat=None, ifeat=None):
"""Forward function
Normalizer constant :math:`c_{ij}` is stored as two node data "ci"
and "cj".
Parameters
----------
graph : DGLHeteroGraph
......@@ -118,28 +246,19 @@ class GCMCLayer(nn.Module):
new_ifeat : torch.Tensor
New movie features
"""
num_u = graph.number_of_nodes('user')
num_i = graph.number_of_nodes('movie')
funcs = {}
in_feats = {'user' : ufeat, 'movie' : ifeat}
mod_args = {}
for i, rating in enumerate(self.rating_vals):
rating = str(rating)
# W_r * x
x_u = dot_or_identity(ufeat, self.W_r[rating.replace('.', '_')])
x_i = dot_or_identity(ifeat, self.W_r['rev-%s' % rating.replace('.', '_')])
# left norm and dropout
x_u = x_u * self.dropout(graph.nodes['user'].data['cj'])
x_i = x_i * self.dropout(graph.nodes['movie'].data['cj'])
graph.nodes['user'].data['h%d' % i] = x_u
graph.nodes['movie'].data['h%d' % i] = x_i
funcs[rating] = (fn.copy_u('h%d' % i, 'm'), fn.sum('m', 'h'))
funcs['rev-%s' % rating] = (fn.copy_u('h%d' % i, 'm'), fn.sum('m', 'h'))
# message passing
graph.multi_update_all(funcs, self.agg)
ufeat = graph.nodes['user'].data.pop('h').view(num_u, -1)
ifeat = graph.nodes['movie'].data.pop('h').view(num_i, -1)
# right norm
ufeat = ufeat * graph.nodes['user'].data['ci']
ifeat = ifeat * graph.nodes['movie'].data['ci']
rating = str(rating).replace('.', '_')
rev_rating = 'rev-%s' % rating
mod_args[rating] = (self.W_r[rating] if self.W_r is not None else None,)
mod_args[rev_rating] = (self.W_r[rev_rating] if self.W_r is not None else None,)
out_feats = self.conv(graph, in_feats, mod_args=mod_args)
ufeat = out_feats['user']
ifeat = out_feats['movie']
ufeat = ufeat.view(ufeat.shape[0], -1)
ifeat = ifeat.view(ifeat.shape[0], -1)
# fc and non-linear
ufeat = self.agg_act(ufeat)
ifeat = self.agg_act(ifeat)
......@@ -150,7 +269,10 @@ class GCMCLayer(nn.Module):
return self.out_act(ufeat), self.out_act(ifeat)
class BiDecoder(nn.Module):
r"""Bilinear decoder.
r"""Bi-linear decoder.
Given a bipartite graph G, for each edge (i, j) ~ G, compute the likelihood
of it being class r by:
.. math::
p(M_{ij}=r) = \text{softmax}(u_i^TQ_rv_j)
......@@ -163,28 +285,27 @@ class BiDecoder(nn.Module):
Parameters
----------
rating_vals : list of int or float
Possible rating values.
in_units : int
Size of input user and movie features
num_basis_functions : int, optional
num_classes : int
Number of classes.
num_basis : int, optional
Number of basis. (Default: 2)
dropout_rate : float, optional
Dropout raite (Default: 0.0)
"""
def __init__(self,
rating_vals,
in_units,
num_basis_functions=2,
num_classes,
num_basis=2,
dropout_rate=0.0):
super(BiDecoder, self).__init__()
self.rating_vals = rating_vals
self._num_basis_functions = num_basis_functions
self._num_basis = num_basis
self.dropout = nn.Dropout(dropout_rate)
self.Ps = nn.ParameterList()
for i in range(num_basis_functions):
for i in range(num_basis):
self.Ps.append(nn.Parameter(th.randn(in_units, in_units)))
self.rate_out = nn.Linear(self._num_basis_functions, len(rating_vals), bias=False)
self.combine_basis = nn.Linear(self._num_basis, num_classes, bias=False)
self.reset_parameters()
def reset_parameters(self):
......@@ -209,22 +330,83 @@ class BiDecoder(nn.Module):
th.Tensor
Predicting scores for each user-movie edge.
"""
graph = graph.local_var()
with graph.local_scope():
ufeat = self.dropout(ufeat)
ifeat = self.dropout(ifeat)
graph.nodes['movie'].data['h'] = ifeat
basis_out = []
for i in range(self._num_basis):
graph.nodes['user'].data['h'] = ufeat @ self.Ps[i]
graph.apply_edges(fn.u_dot_v('h', 'h', 'sr'))
basis_out.append(graph.edata['sr'].unsqueeze(1))
out = th.cat(basis_out, dim=1)
out = self.combine_basis(out)
return out
class DenseBiDecoder(BiDecoder):
r"""Dense bi-linear decoder.
Dense implementation of the bi-linear decoder used in GCMC. Suitable when
the graph can be efficiently represented by a pair of arrays (one for source
nodes; one for destination nodes).
Parameters
----------
in_units : int
Size of input user and movie features
num_classes : int
Number of classes.
num_basis : int, optional
Number of basis. (Default: 2)
dropout_rate : float, optional
Dropout raite (Default: 0.0)
"""
def __init__(self,
in_units,
num_classes,
num_basis=2,
dropout_rate=0.0):
super(DenseBiDecoder, self).__init__(in_units,
num_classes,
num_basis,
dropout_rate)
def forward(self, ufeat, ifeat):
"""Forward function.
Compute logits for each pair ``(ufeat[i], ifeat[i])``.
Parameters
----------
ufeat : th.Tensor
User embeddings. Shape: (B, D)
ifeat : th.Tensor
Movie embeddings. Shape: (B, D)
Returns
-------
th.Tensor
Predicting scores for each user-movie edge. Shape: (B, num_classes)
"""
ufeat = self.dropout(ufeat)
ifeat = self.dropout(ifeat)
graph.nodes['movie'].data['h'] = ifeat
basis_out = []
for i in range(self._num_basis_functions):
graph.nodes['user'].data['h'] = ufeat @ self.Ps[i]
graph.apply_edges(fn.u_dot_v('h', 'h', 'sr'))
basis_out.append(graph.edata['sr'].unsqueeze(1))
for i in range(self._num_basis):
ufeat_i = ufeat @ self.Ps[i]
out = th.einsum('ab,ab->a', ufeat_i, ifeat)
basis_out.append(out.unsqueeze(1))
out = th.cat(basis_out, dim=1)
out = self.rate_out(out)
out = self.combine_basis(out)
return out
def dot_or_identity(A, B):
def dot_or_identity(A, B, device=None):
# if A is None, treat as identity matrix
if A is None:
return B
elif len(A.shape) == 1:
if device is None:
return B[A]
else:
return B[A].to(device)
else:
return A @ B
"""Training script"""
"""Training GCMC model on the MovieLens data set.
The script loads the full graph to the training device.
"""
import os, time
import argparse
import logging
......@@ -8,7 +11,7 @@ import numpy as np
import torch as th
import torch.nn as nn
from data import MovieLens
from model import GCMCLayer, BiDecoder
from model import BiDecoder, GCMCLayer
from utils import get_activation, get_optimizer, torch_total_param_num, torch_net_info, MetricLogger
class Net(nn.Module):
......@@ -23,10 +26,11 @@ class Net(nn.Module):
args.gcn_dropout,
args.gcn_agg_accum,
agg_act=self._act,
share_user_item_param=args.share_param)
self.decoder = BiDecoder(args.rating_vals,
in_units=args.gcn_out_units,
num_basis_functions=args.gen_r_num_basis_func)
share_user_item_param=args.share_param,
device=args.device)
self.decoder = BiDecoder(in_units=args.gcn_out_units,
num_classes=len(args.rating_vals),
num_basis=args.gen_r_num_basis_func)
def forward(self, enc_graph, dec_graph, ufeat, ifeat):
user_out, movie_out = self.encoder(
......
"""Training GCMC model on the MovieLens data set by mini-batch sampling.
The script loads the full graph in CPU and samples subgraphs for computing
gradients on the training device. The script also supports multi-GPU for
further acceleration.
"""
import os, time
import argparse
import logging
import random
import string
import traceback
import numpy as np
import torch as th
import torch.nn as nn
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
from torch.multiprocessing import Queue
from torch.nn.parallel import DistributedDataParallel
from _thread import start_new_thread
from functools import wraps
from data import MovieLens
from model import GCMCLayer, DenseBiDecoder
from utils import get_activation, get_optimizer, torch_total_param_num, torch_net_info, MetricLogger
import dgl
class GCMCSampler:
"""Neighbor sampler in GCMC mini-batch training."""
def __init__(self, dataset, segment='train'):
self.dataset = dataset
if segment == 'train':
self.truths = dataset.train_truths
self.labels = dataset.train_labels
self.enc_graph = dataset.train_enc_graph
self.dec_graph = dataset.train_dec_graph
elif segment == 'valid':
self.truths = dataset.valid_truths
self.labels = None
self.enc_graph = dataset.valid_enc_graph
self.dec_graph = dataset.valid_dec_graph
elif segment == 'test':
self.truths = dataset.test_truths
self.labels = None
self.enc_graph = dataset.test_enc_graph
self.dec_graph = dataset.test_dec_graph
else:
assert False, "Unknow dataset {}".format(segment)
def sample_blocks(self, seeds):
"""Sample subgraphs from the entire graph.
The input ``seeds`` represents the edges to compute prediction for. The sampling
algorithm works as follows:
1. Get the head and tail nodes of the provided seed edges.
2. For each head and tail node, extract the entire in-coming neighborhood.
3. Copy the node features/embeddings from the full graph to the sampled subgraphs.
"""
dataset = self.dataset
enc_graph = self.enc_graph
dec_graph = self.dec_graph
edge_ids = th.stack(seeds)
# generate frontiers for user and item
possible_rating_values = dataset.possible_rating_values
true_relation_ratings = self.truths[edge_ids]
true_relation_labels = None if self.labels is None else self.labels[edge_ids]
# 1. Get the head and tail nodes from both the decoder and encoder graphs.
head_id, tail_id = dec_graph.find_edges(edge_ids)
utype, _, vtype = enc_graph.canonical_etypes[0]
subg = []
true_rel_ratings = []
true_rel_labels = []
for possible_rating_value in possible_rating_values:
idx_loc = (true_relation_ratings == possible_rating_value)
head = head_id[idx_loc]
tail = tail_id[idx_loc]
true_rel_ratings.append(true_relation_ratings[idx_loc])
if self.labels is not None:
true_rel_labels.append(true_relation_labels[idx_loc])
subg.append(dgl.bipartite((head, tail),
utype=utype,
etype=str(possible_rating_value),
vtype=vtype,
num_nodes=(enc_graph.number_of_nodes(utype),
enc_graph.number_of_nodes(vtype))))
# Convert the encoder subgraph to a more compact one by removing nodes that covered
# by the seed edges.
g = dgl.hetero_from_relations(subg)
g = dgl.compact_graphs(g)
# 2. For each head and tail node, extract the entire in-coming neighborhood.
seed_nodes = {}
for ntype in g.ntypes:
seed_nodes[ntype] = g.nodes[ntype].data[dgl.NID]
frontier = dgl.in_subgraph(enc_graph, seed_nodes)
frontier = dgl.to_block(frontier, seed_nodes)
# 3. Copy the node features/embeddings from the full graph to the sampled subgraphs.
frontier.dstnodes['user'].data['ci'] = \
enc_graph.nodes['user'].data['ci'][frontier.dstnodes['user'].data[dgl.NID]]
frontier.srcnodes['movie'].data['cj'] = \
enc_graph.nodes['movie'].data['cj'][frontier.srcnodes['movie'].data[dgl.NID]]
frontier.srcnodes['user'].data['cj'] = \
enc_graph.nodes['user'].data['cj'][frontier.srcnodes['user'].data[dgl.NID]]
frontier.dstnodes['movie'].data['ci'] = \
enc_graph.nodes['movie'].data['ci'][frontier.dstnodes['movie'].data[dgl.NID]]
# handle features
head_feat = frontier.srcnodes['user'].data[dgl.NID].long() \
if dataset.user_feature is None else \
dataset.user_feature[frontier.srcnodes['user'].data[dgl.NID]]
tail_feat = frontier.srcnodes['movie'].data[dgl.NID].long()\
if dataset.movie_feature is None else \
dataset.movie_feature[frontier.srcnodes['movie'].data[dgl.NID]]
true_rel_labels = None if self.labels is None else th.cat(true_rel_labels, dim=0)
true_rel_ratings = th.cat(true_rel_ratings, dim=0)
return (g, frontier, head_feat, tail_feat, true_rel_labels, true_rel_ratings)
class Net(nn.Module):
def __init__(self, args, dev_id):
super(Net, self).__init__()
self._act = get_activation(args.model_activation)
self.encoder = GCMCLayer(args.rating_vals,
args.src_in_units,
args.dst_in_units,
args.gcn_agg_units,
args.gcn_out_units,
args.gcn_dropout,
args.gcn_agg_accum,
agg_act=self._act,
share_user_item_param=args.share_param,
device=dev_id)
if args.mix_cpu_gpu and args.use_one_hot_fea:
# if use_one_hot_fea, user and movie feature is None
# W can be extremely large, with mix_cpu_gpu W should be stored in CPU
self.encoder.partial_to(dev_id)
else:
self.encoder.to(dev_id)
self.decoder = DenseBiDecoder(in_units=args.gcn_out_units,
num_classes=len(args.rating_vals),
num_basis=args.gen_r_num_basis_func)
self.decoder.to(dev_id)
def forward(self, compact_g, frontier, ufeat, ifeat, possible_rating_values):
user_out, movie_out = self.encoder(frontier, ufeat, ifeat)
head_emb = []
tail_emb = []
for possible_rating_value in possible_rating_values:
head, tail = compact_g.all_edges(etype=str(possible_rating_value))
head_emb.append(user_out[head])
tail_emb.append(movie_out[tail])
head_emb = th.cat(head_emb, dim=0)
tail_emb = th.cat(tail_emb, dim=0)
pred_ratings = self.decoder(head_emb, tail_emb)
return pred_ratings
def evaluate(args, dev_id, net, dataset, dataloader, segment='valid'):
possible_rating_values = dataset.possible_rating_values
nd_possible_rating_values = th.FloatTensor(possible_rating_values).to(dev_id)
real_pred_ratings = []
true_rel_ratings = []
for sample_data in dataloader:
compact_g, frontier, head_feat, tail_feat, \
_, true_relation_ratings = sample_data
head_feat = head_feat.to(dev_id)
tail_feat = tail_feat.to(dev_id)
with th.no_grad():
pred_ratings = net(compact_g, frontier,
head_feat, tail_feat, possible_rating_values)
batch_pred_ratings = (th.softmax(pred_ratings, dim=1) *
nd_possible_rating_values.view(1, -1)).sum(dim=1)
real_pred_ratings.append(batch_pred_ratings)
true_rel_ratings.append(true_relation_ratings)
real_pred_ratings = th.cat(real_pred_ratings, dim=0)
true_rel_ratings = th.cat(true_rel_ratings, dim=0).to(dev_id)
rmse = ((real_pred_ratings - true_rel_ratings) ** 2.).mean().item()
rmse = np.sqrt(rmse)
return rmse
# According to https://github.com/pytorch/pytorch/issues/17199, this decorator
# is necessary to make fork() and openmp work together.
def thread_wrapped_func(func):
"""
Wraps a process entry point to make it work with OpenMP.
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))
start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function
def prepare_mp(g):
"""
Explicitly materialize the CSR, CSC and COO representation of the given graph
so that they could be shared via copy-on-write to sampler workers and GPU
trainers.
This is a workaround before full shared memory support on heterogeneous graphs.
"""
for etype in g.canonical_etypes:
g.in_degree(0, etype=etype)
g.out_degree(0, etype=etype)
g.find_edges([0], etype=etype)
def config():
parser = argparse.ArgumentParser(description='GCMC')
parser.add_argument('--seed', default=123, type=int)
parser.add_argument('--gpu', type=str, default='0')
parser.add_argument('--save_dir', type=str, help='The saving directory')
parser.add_argument('--save_id', type=int, help='The saving log id')
parser.add_argument('--silent', action='store_true')
parser.add_argument('--data_name', default='ml-1m', type=str,
help='The dataset name: ml-100k, ml-1m, ml-10m')
parser.add_argument('--data_test_ratio', type=float, default=0.1) ## for ml-100k the test ration is 0.2
parser.add_argument('--data_valid_ratio', type=float, default=0.1)
parser.add_argument('--use_one_hot_fea', action='store_true', default=False)
parser.add_argument('--model_activation', type=str, default="leaky")
parser.add_argument('--gcn_dropout', type=float, default=0.7)
parser.add_argument('--gcn_agg_norm_symm', type=bool, default=True)
parser.add_argument('--gcn_agg_units', type=int, default=500)
parser.add_argument('--gcn_agg_accum', type=str, default="sum")
parser.add_argument('--gcn_out_units', type=int, default=75)
parser.add_argument('--gen_r_num_basis_func', type=int, default=2)
parser.add_argument('--train_max_epoch', type=int, default=1000)
parser.add_argument('--train_log_interval', type=int, default=1)
parser.add_argument('--train_valid_interval', type=int, default=1)
parser.add_argument('--train_optimizer', type=str, default="adam")
parser.add_argument('--train_grad_clip', type=float, default=1.0)
parser.add_argument('--train_lr', type=float, default=0.01)
parser.add_argument('--train_min_lr', type=float, default=0.0001)
parser.add_argument('--train_lr_decay_factor', type=float, default=0.5)
parser.add_argument('--train_decay_patience', type=int, default=25)
parser.add_argument('--train_early_stopping_patience', type=int, default=50)
parser.add_argument('--share_param', default=False, action='store_true')
parser.add_argument('--mix_cpu_gpu', default=False, action='store_true')
parser.add_argument('--minibatch_size', type=int, default=20000)
parser.add_argument('--num_workers_per_gpu', type=int, default=8)
args = parser.parse_args()
### configure save_fir to save all the info
if args.save_dir is None:
args.save_dir = args.data_name+"_" + ''.join(random.choices(string.ascii_uppercase + string.digits, k=2))
if args.save_id is None:
args.save_id = np.random.randint(20)
args.save_dir = os.path.join("log", args.save_dir)
if not os.path.isdir(args.save_dir):
os.makedirs(args.save_dir)
return args
@thread_wrapped_func
def run(proc_id, n_gpus, args, devices, dataset):
dev_id = devices[proc_id]
train_labels = dataset.train_labels
train_truths = dataset.train_truths
num_edges = train_truths.shape[0]
sampler = GCMCSampler(dataset,
'train')
seeds = th.arange(num_edges)
dataloader = DataLoader(
dataset=seeds,
batch_size=args.minibatch_size,
collate_fn=sampler.sample_blocks,
shuffle=True,
pin_memory=True,
drop_last=False,
num_workers=args.num_workers_per_gpu)
if proc_id == 0:
valid_sampler = GCMCSampler(dataset,
'valid')
valid_seeds = th.arange(dataset.valid_truths.shape[0])
valid_dataloader = DataLoader(dataset=valid_seeds,
batch_size=args.minibatch_size,
collate_fn=valid_sampler.sample_blocks,
shuffle=False,
pin_memory=True,
drop_last=False,
num_workers=args.num_workers_per_gpu)
test_sampler = GCMCSampler(dataset,
'test')
test_seeds = th.arange(dataset.test_truths.shape[0])
test_dataloader = DataLoader(dataset=test_seeds,
batch_size=args.minibatch_size,
collate_fn=test_sampler.sample_blocks,
shuffle=False,
pin_memory=True,
drop_last=False,
num_workers=args.num_workers_per_gpu)
if n_gpus > 1:
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12345')
world_size = n_gpus
th.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=world_size,
rank=dev_id)
if n_gpus > 0:
th.cuda.set_device(dev_id)
nd_possible_rating_values = \
th.FloatTensor(dataset.possible_rating_values)
nd_possible_rating_values = nd_possible_rating_values.to(dev_id)
net = Net(args=args, dev_id=dev_id)
net = net.to(dev_id)
if n_gpus > 1:
net = DistributedDataParallel(net, device_ids=[dev_id], output_device=dev_id)
rating_loss_net = nn.CrossEntropyLoss()
learning_rate = args.train_lr
optimizer = get_optimizer(args.train_optimizer)(net.parameters(), lr=learning_rate)
print("Loading network finished ...\n")
### declare the loss information
best_valid_rmse = np.inf
no_better_valid = 0
best_epoch = -1
count_rmse = 0
count_num = 0
count_loss = 0
print("Start training ...")
dur = []
iter_idx = 1
for epoch in range(1, args.train_max_epoch):
if epoch > 1:
t0 = time.time()
net.train()
for step, sample_data in enumerate(dataloader):
compact_g, frontier, head_feat, tail_feat, \
true_relation_labels, true_relation_ratings = sample_data
head_feat = head_feat.to(dev_id)
tail_feat = tail_feat.to(dev_id)
pred_ratings = net(compact_g, frontier, head_feat, tail_feat, dataset.possible_rating_values)
loss = rating_loss_net(pred_ratings, true_relation_labels.to(dev_id)).mean()
count_loss += loss.item()
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(net.parameters(), args.train_grad_clip)
optimizer.step()
if proc_id == 0 and iter_idx == 1:
print("Total #Param of net: %d" % (torch_total_param_num(net)))
real_pred_ratings = (th.softmax(pred_ratings, dim=1) *
nd_possible_rating_values.view(1, -1)).sum(dim=1)
rmse = ((real_pred_ratings - true_relation_ratings.to(dev_id)) ** 2).sum()
count_rmse += rmse.item()
count_num += pred_ratings.shape[0]
if iter_idx % args.train_log_interval == 0:
logging_str = "Iter={}, loss={:.4f}, rmse={:.4f}".format(
iter_idx, count_loss/iter_idx, count_rmse/count_num)
count_rmse = 0
count_num = 0
if iter_idx % args.train_log_interval == 0:
print("[{}] {}".format(proc_id, logging_str))
iter_idx += 1
if epoch > 1:
epoch_time = time.time() - t0
print("Epoch {} time {}".format(epoch, epoch_time))
if epoch % args.train_valid_interval == 0:
if n_gpus > 1:
th.distributed.barrier()
if proc_id == 0:
valid_rmse = evaluate(args=args,
dev_id=dev_id,
net=net,
dataset=dataset,
dataloader=valid_dataloader,
segment='valid')
logging_str += ',\tVal RMSE={:.4f}'.format(valid_rmse)
if valid_rmse < best_valid_rmse:
best_valid_rmse = valid_rmse
no_better_valid = 0
best_epoch = epoch
test_rmse = evaluate(args=args,
dev_id=dev_id,
net=net,
dataset=dataset,
dataloader=test_dataloader,
segment='test')
best_test_rmse = test_rmse
logging_str += ', Test RMSE={:.4f}'.format(test_rmse)
else:
no_better_valid += 1
if no_better_valid > args.train_early_stopping_patience\
and learning_rate <= args.train_min_lr:
logging.info("Early stopping threshold reached. Stop training.")
break
if no_better_valid > args.train_decay_patience:
new_lr = max(learning_rate * args.train_lr_decay_factor, args.train_min_lr)
if new_lr < learning_rate:
logging.info("\tChange the LR to %g" % new_lr)
learning_rate = new_lr
for p in optimizer.param_groups:
p['lr'] = learning_rate
no_better_valid = 0
print("Change the LR to %g" % new_lr)
# sync on evalution
if n_gpus > 1:
th.distributed.barrier()
print(logging_str)
if proc_id == 0:
print('Best epoch Idx={}, Best Valid RMSE={:.4f}, Best Test RMSE={:.4f}'.format(
best_epoch, best_valid_rmse, best_test_rmse))
if __name__ == '__main__':
args = config()
devices = list(map(int, args.gpu.split(',')))
n_gpus = len(devices)
# For GCMC based on sampling, we require node has its own features.
# Otherwise (node_id is the feature), the model can not scale
dataset = MovieLens(args.data_name,
'cpu',
mix_cpu_gpu=args.mix_cpu_gpu,
use_one_hot_fea=args.use_one_hot_fea,
symm=args.gcn_agg_norm_symm,
test_ratio=args.data_test_ratio,
valid_ratio=args.data_valid_ratio)
print("Loading data finished ...\n")
args.src_in_units = dataset.user_feature_shape[1]
args.dst_in_units = dataset.movie_feature_shape[1]
args.rating_vals = dataset.possible_rating_values
# cpu
if devices[0] == -1:
run(0, 0, args, ['cpu'], dataset)
# gpu
elif n_gpus == 1:
run(0, n_gpus, args, devices, dataset)
# multi gpu
else:
prepare_mp(dataset.train_enc_graph)
prepare_mp(dataset.train_dec_graph)
procs = []
for proc_id in range(n_gpus):
p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices, dataset))
p.start()
procs.append(p)
for p in procs:
p.join()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment