Unverified Commit 91cb3477 authored by Tianqi Zhang (张天启)'s avatar Tianqi Zhang (张天启) Committed by GitHub
Browse files

[Example][Bug Fix] Improve DiffPool (#2730)



* change DiffPoolBatchedGraphLayer

* fix bug and add benchmark

* upt

* upt

* upt

* upt
Co-authored-by: default avatarTong He <hetong007@gmail.com>
parent 33175226
...@@ -16,8 +16,8 @@ How to run ...@@ -16,8 +16,8 @@ How to run
---------- ----------
```bash ```bash
python train.py --dataset ENZYMES --pool_ratio 0.10 --num_pool 1 python train.py --dataset ENZYMES --pool_ratio 0.10 --num_pool 1 --epochs 1000
python train.py --dataset DD --pool_ratio 0.15 --num_pool 1 python train.py --dataset DD --pool_ratio 0.15 --num_pool 1 --batch-size 10
``` ```
Performance Performance
----------- -----------
...@@ -25,5 +25,35 @@ ENZYMES 63.33% (with early stopping) ...@@ -25,5 +25,35 @@ ENZYMES 63.33% (with early stopping)
DD 79.31% (with early stopping) DD 79.31% (with early stopping)
## Dependencies ## Update (2021-03-09)
**Changes:**
* Fix bug in Diffpool: the wrong `assign_dim` parameter
* Improve efficiency of DiffPool, make the model independent of batch size. Remove redundant computation.
**Efficiency:**
On V100-SXM2 16GB
| | Train time/epoch (original) (s) | Train time/epoch (improved) (s) |
| ------------------ | ------------------------------: | ------------------------------: |
| DD (batch_size=10) | 21.302 | **17.282** |
| DD (batch_size=20) | OOM | **44.682** |
| ENZYMES | 1.749 | **1.685** |
| | Memory usage (original) (MB) | Memory usage (improved) (MB) |
| ------------------ | ---------------------------: | ---------------------------: |
| DD (batch_size=10) | 5274.620 | **2928.568** |
| DD (batch_size=20) | OOM | **10088.889** |
| ENZYMES | 25.685 | **21.909** |
**Accuracy**
Each experiment with improved model is only conducted once, thus the result may has noise.
| | Original | Improved |
| ------- | ---------: | ---------: |
| DD | **79.31%** | 78.33% |
| ENZYMES | 63.33% | **68.33%** |
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import numpy as np import numpy as np
from scipy.linalg import block_diag from scipy.linalg import block_diag
...@@ -101,27 +102,13 @@ class DiffPoolBatchedGraphLayer(nn.Module): ...@@ -101,27 +102,13 @@ class DiffPoolBatchedGraphLayer(nn.Module):
self.reg_loss.append(EntropyLoss()) self.reg_loss.append(EntropyLoss())
def forward(self, g, h): def forward(self, g, h):
feat = self.feat_gc(g, h) feat = self.feat_gc(g, h) # size = (sum_N, F_out), sum_N is num of nodes in this batch
assign_tensor = self.pool_gc(g, h)
device = feat.device device = feat.device
assign_tensor_masks = [] assign_tensor = self.pool_gc(g, h) # size = (sum_N, N_a), N_a is num of nodes in pooled graph.
batch_size = len(g.batch_num_nodes()) assign_tensor = F.softmax(assign_tensor, dim=1)
for g_n_nodes in g.batch_num_nodes(): assign_tensor = torch.split(assign_tensor, g.batch_num_nodes().tolist())
mask = torch.ones((g_n_nodes, assign_tensor = torch.block_diag(*assign_tensor) # size = (sum_N, batch_size * N_a)
int(assign_tensor.size()[1] / batch_size)))
assign_tensor_masks.append(mask)
"""
The first pooling layer is computed on batched graph.
We first take the adjacency matrix of the batched graph, which is block-wise diagonal.
We then compute the assignment matrix for the whole batch graph, which will also be block diagonal
"""
mask = torch.FloatTensor(
block_diag(
*
assign_tensor_masks)).to(
device=device)
assign_tensor = masked_softmax(assign_tensor, mask,
memory_efficient=False)
h = torch.matmul(torch.t(assign_tensor), feat) h = torch.matmul(torch.t(assign_tensor), feat)
adj = g.adjacency_matrix(transpose=False, ctx=device) adj = g.adjacency_matrix(transpose=False, ctx=device)
adj_new = torch.sparse.mm(adj, assign_tensor) adj_new = torch.sparse.mm(adj, assign_tensor)
......
...@@ -180,7 +180,7 @@ class DiffPool(nn.Module): ...@@ -180,7 +180,7 @@ class DiffPool(nn.Module):
out_all.append(readout) out_all.append(readout)
adj, h = self.first_diffpool_layer(g, g_embedding) adj, h = self.first_diffpool_layer(g, g_embedding)
node_per_pool_graph = int(adj.size()[0] / self.batch_size) node_per_pool_graph = int(adj.size()[0] / len(g.batch_num_nodes()))
h, adj = batch2tensor(adj, h, node_per_pool_graph) h, adj = batch2tensor(adj, h, node_per_pool_graph)
h = self.gcn_forward_tensorized( h = self.gcn_forward_tensorized(
......
...@@ -17,14 +17,15 @@ class BatchedGraphSAGE(nn.Module): ...@@ -17,14 +17,15 @@ class BatchedGraphSAGE(nn.Module):
gain=nn.init.calculate_gain('relu')) gain=nn.init.calculate_gain('relu'))
def forward(self, x, adj): def forward(self, x, adj):
num_node_per_graph = adj.size(1)
if self.use_bn and not hasattr(self, 'bn'): if self.use_bn and not hasattr(self, 'bn'):
self.bn = nn.BatchNorm1d(adj.size(1)).to(adj.device) self.bn = nn.BatchNorm1d(num_node_per_graph).to(adj.device)
if self.add_self: if self.add_self:
adj = adj + torch.eye(adj.size(0)).to(adj.device) adj = adj + torch.eye(num_node_per_graph).to(adj.device)
if self.mean: if self.mean:
adj = adj / adj.sum(1, keepdim=True) adj = adj / adj.sum(-1, keepdim=True)
h_k_N = torch.matmul(adj, x) h_k_N = torch.matmul(adj, x)
h_k = self.W(h_k_N) h_k = self.W(h_k_N)
......
...@@ -17,6 +17,7 @@ from dgl.data import tu ...@@ -17,6 +17,7 @@ from dgl.data import tu
from model.encoder import DiffPool from model.encoder import DiffPool
from data_utils import pre_process from data_utils import pre_process
global_train_time_per_epoch = []
def arg_parse(): def arg_parse():
''' '''
...@@ -68,7 +69,7 @@ def arg_parse(): ...@@ -68,7 +69,7 @@ def arg_parse():
'--save_dir', '--save_dir',
dest='save_dir', dest='save_dir',
help='model saving directory: SAVE_DICT/DATASET') help='model saving directory: SAVE_DICT/DATASET')
parser.add_argument('--load_epoch', dest='load_epoch', help='load trained model params from\ parser.add_argument('--load_epoch', dest='load_epoch', type=int, help='load trained model params from\
SAVE_DICT/DATASET/model-LOAD_EPOCH') SAVE_DICT/DATASET/model-LOAD_EPOCH')
parser.add_argument('--data_mode', dest='data_mode', help='data\ parser.add_argument('--data_mode', dest='data_mode', help='data\
preprocessing mode: default, id, degree, or one-hot\ preprocessing mode: default, id, degree, or one-hot\
...@@ -113,7 +114,6 @@ def prepare_data(dataset, prog_args, train=False, pre_process=None): ...@@ -113,7 +114,6 @@ def prepare_data(dataset, prog_args, train=False, pre_process=None):
return dgl.dataloading.GraphDataLoader(dataset, return dgl.dataloading.GraphDataLoader(dataset,
batch_size=prog_args.batch_size, batch_size=prog_args.batch_size,
shuffle=shuffle, shuffle=shuffle,
drop_last=True,
num_workers=prog_args.n_worker) num_workers=prog_args.n_worker)
...@@ -148,8 +148,7 @@ def graph_classify_task(prog_args): ...@@ -148,8 +148,7 @@ def graph_classify_task(prog_args):
# calculate assignment dimension: pool_ratio * largest graph's maximum # calculate assignment dimension: pool_ratio * largest graph's maximum
# number of nodes in the dataset # number of nodes in the dataset
assign_dim = int(max_num_node * prog_args.pool_ratio) * \ assign_dim = int(max_num_node * prog_args.pool_ratio)
prog_args.batch_size
print("++++++++++MODEL STATISTICS++++++++") print("++++++++++MODEL STATISTICS++++++++")
print("model hidden dim is", hidden_dim) print("model hidden dim is", hidden_dim)
print("model embedding dim for graph instance embedding", embedding_dim) print("model embedding dim for graph instance embedding", embedding_dim)
...@@ -187,7 +186,7 @@ def graph_classify_task(prog_args): ...@@ -187,7 +186,7 @@ def graph_classify_task(prog_args):
prog_args, prog_args,
val_dataset=val_dataloader) val_dataset=val_dataloader)
result = evaluate(test_dataloader, model, prog_args, logger) result = evaluate(test_dataloader, model, prog_args, logger)
print("test accuracy {}%".format(result * 100)) print("test accuracy {:.2f}%".format(result * 100))
def train(dataset, model, prog_args, same_feat=True, val_dataset=None): def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
...@@ -209,7 +208,7 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None): ...@@ -209,7 +208,7 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
model.train() model.train()
accum_correct = 0 accum_correct = 0
total = 0 total = 0
print("EPOCH ###### {} ######".format(epoch)) print("\nEPOCH ###### {} ######".format(epoch))
computation_time = 0.0 computation_time = 0.0
for (batch_idx, (batch_graph, graph_labels)) in enumerate(dataloader): for (batch_idx, (batch_graph, graph_labels)) in enumerate(dataloader):
for (key, value) in batch_graph.ndata.items(): for (key, value) in batch_graph.ndata.items():
...@@ -234,21 +233,22 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None): ...@@ -234,21 +233,22 @@ def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
optimizer.step() optimizer.step()
train_accu = accum_correct / total train_accu = accum_correct / total
print("train accuracy for this epoch {} is {}%".format(epoch, print("train accuracy for this epoch {} is {:.2f}%".format(epoch,
train_accu * 100)) train_accu * 100))
elapsed_time = time.time() - begin_time elapsed_time = time.time() - begin_time
print("loss {} with epoch time {} s & computation time {} s ".format( print("loss {:.4f} with epoch time {:.4f} s & computation time {:.4f} s ".format(
loss.item(), elapsed_time, computation_time)) loss.item(), elapsed_time, computation_time))
global_train_time_per_epoch.append(elapsed_time)
if val_dataset is not None: if val_dataset is not None:
result = evaluate(val_dataset, model, prog_args) result = evaluate(val_dataset, model, prog_args)
print("validation accuracy {}%".format(result * 100)) print("validation accuracy {:.2f}%".format(result * 100))
if result >= early_stopping_logger['val_acc'] and result <=\ if result >= early_stopping_logger['val_acc'] and result <=\
train_accu: train_accu:
early_stopping_logger.update(best_epoch=epoch, val_acc=result) early_stopping_logger.update(best_epoch=epoch, val_acc=result)
if prog_args.save_dir is not None: if prog_args.save_dir is not None:
torch.save(model.state_dict(), prog_args.save_dir + "/" + prog_args.dataset torch.save(model.state_dict(), prog_args.save_dir + "/" + prog_args.dataset
+ "/model.iter-" + str(early_stopping_logger['best_epoch'])) + "/model.iter-" + str(early_stopping_logger['best_epoch']))
print("best epoch is EPOCH {}, val_acc is {}%".format(early_stopping_logger['best_epoch'], print("best epoch is EPOCH {}, val_acc is {:.2f}%".format(early_stopping_logger['best_epoch'],
early_stopping_logger['val_acc'] * 100)) early_stopping_logger['val_acc'] * 100))
torch.cuda.empty_cache() torch.cuda.empty_cache()
return early_stopping_logger return early_stopping_logger
...@@ -287,6 +287,9 @@ def main(): ...@@ -287,6 +287,9 @@ def main():
print(prog_args) print(prog_args)
graph_classify_task(prog_args) graph_classify_task(prog_args)
print("Train time per epoch: {:.4f}".format( sum(global_train_time_per_epoch) / len(global_train_time_per_epoch) ))
print("Max memory usage: {:.4f}".format(torch.cuda.max_memory_allocated(0) / (1024 * 1024)))
if __name__ == "__main__": if __name__ == "__main__":
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment