Unverified Commit f19f05ce authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4651)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 977b1ba4
import os
import cv2 as cv
import matplotlib import matplotlib
import matplotlib.animation as manimation
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import cv2 as cv
import os
import matplotlib.animation as manimation
matplotlib.use('agg') matplotlib.use("agg")
# Make video can be used to visualize test data # Make video can be used to visualize test data
def make_video(xy, filename): def make_video(xy, filename):
os.system("rm -rf pics/*") os.system("rm -rf pics/*")
FFMpegWriter = manimation.writers['ffmpeg'] FFMpegWriter = manimation.writers["ffmpeg"]
metadata = dict(title='Movie Test', artist='Matplotlib', metadata = dict(
comment='Movie support!') title="Movie Test", artist="Matplotlib", comment="Movie support!"
)
writer = FFMpegWriter(fps=15, metadata=metadata) writer = FFMpegWriter(fps=15, metadata=metadata)
fig = plt.figure() fig = plt.figure()
plt.xlim(-200, 200) plt.xlim(-200, 200)
plt.ylim(-200, 200) plt.ylim(-200, 200)
fig_num = len(xy) fig_num = len(xy)
color = ['ro', 'bo', 'go', 'ko', 'yo', 'mo', 'co'] color = ["ro", "bo", "go", "ko", "yo", "mo", "co"]
with writer.saving(fig, filename, len(xy)): with writer.saving(fig, filename, len(xy)):
for i in range(len(xy)): for i in range(len(xy)):
for j in range(len(xy[0])): for j in range(len(xy[0])):
......
import torch import torch
from modules import MSA, BiLSTM, GraphTrans from modules import MSA, BiLSTM, GraphTrans
from utlis import *
from torch import nn from torch import nn
from utlis import *
import dgl import dgl
class GraphWriter(nn.Module): class GraphWriter(nn.Module):
def __init__(self, args): def __init__(self, args):
super(GraphWriter, self).__init__() super(GraphWriter, self).__init__()
self.args = args self.args = args
if args.title: if args.title:
self.title_emb = nn.Embedding(len(args.title_vocab), args.nhid, padding_idx=0) self.title_emb = nn.Embedding(
self.title_enc = BiLSTM(args, enc_type='title') len(args.title_vocab), args.nhid, padding_idx=0
)
self.title_enc = BiLSTM(args, enc_type="title")
self.title_attn = MSA(args) self.title_attn = MSA(args)
self.ent_emb = nn.Embedding(len(args.ent_text_vocab), args.nhid, padding_idx=0) self.ent_emb = nn.Embedding(
self.tar_emb = nn.Embedding(len(args.text_vocab), args.nhid, padding_idx=0) len(args.ent_text_vocab), args.nhid, padding_idx=0
)
self.tar_emb = nn.Embedding(
len(args.text_vocab), args.nhid, padding_idx=0
)
if args.title: if args.title:
nn.init.xavier_normal_(self.title_emb.weight) nn.init.xavier_normal_(self.title_emb.weight)
nn.init.xavier_normal_(self.ent_emb.weight) nn.init.xavier_normal_(self.ent_emb.weight)
self.rel_emb = nn.Embedding(len(args.rel_vocab), args.nhid, padding_idx=0) self.rel_emb = nn.Embedding(
len(args.rel_vocab), args.nhid, padding_idx=0
)
nn.init.xavier_normal_(self.rel_emb.weight) nn.init.xavier_normal_(self.rel_emb.weight)
self.decode_lstm = nn.LSTMCell(args.dec_ninp, args.nhid) self.decode_lstm = nn.LSTMCell(args.dec_ninp, args.nhid)
self.ent_enc = BiLSTM(args, enc_type='entity') self.ent_enc = BiLSTM(args, enc_type="entity")
self.graph_enc = GraphTrans(args) self.graph_enc = GraphTrans(args)
self.ent_attn = MSA(args) self.ent_attn = MSA(args)
self.copy_attn = MSA(args, mode='copy') self.copy_attn = MSA(args, mode="copy")
self.copy_fc = nn.Linear(args.dec_ninp, 1) self.copy_fc = nn.Linear(args.dec_ninp, 1)
self.pred_v_fc = nn.Linear(args.dec_ninp, len(args.text_vocab)) self.pred_v_fc = nn.Linear(args.dec_ninp, len(args.text_vocab))
def enc_forward(self, batch, ent_mask, ent_text_mask, ent_len, rel_mask, title_mask): def enc_forward(
self, batch, ent_mask, ent_text_mask, ent_len, rel_mask, title_mask
):
title_enc = None title_enc = None
if self.args.title: if self.args.title:
title_enc = self.title_enc(self.title_emb(batch['title']), title_mask) title_enc = self.title_enc(
ent_enc = self.ent_enc(self.ent_emb(batch['ent_text']), ent_text_mask, ent_len = batch['ent_len']) self.title_emb(batch["title"]), title_mask
rel_emb = self.rel_emb(batch['rel']) )
g_ent, g_root = self.graph_enc(ent_enc, ent_mask, ent_len, rel_emb, rel_mask, batch['graph']) ent_enc = self.ent_enc(
return g_ent, g_root, title_enc, ent_enc self.ent_emb(batch["ent_text"]),
ent_text_mask,
ent_len=batch["ent_len"],
)
rel_emb = self.rel_emb(batch["rel"])
g_ent, g_root = self.graph_enc(
ent_enc, ent_mask, ent_len, rel_emb, rel_mask, batch["graph"]
)
return g_ent, g_root, title_enc, ent_enc
def forward(self, batch, beam_size=-1): def forward(self, batch, beam_size=-1):
ent_mask = len2mask(batch['ent_len'], self.args.device) ent_mask = len2mask(batch["ent_len"], self.args.device)
ent_text_mask = batch['ent_text']==0 ent_text_mask = batch["ent_text"] == 0
rel_mask = batch['rel']==0 # 0 means the <PAD> rel_mask = batch["rel"] == 0 # 0 means the <PAD>
title_mask = batch['title']==0 title_mask = batch["title"] == 0
g_ent, g_root, title_enc, ent_enc = self.enc_forward(batch, ent_mask, ent_text_mask, batch['ent_len'], rel_mask, title_mask) g_ent, g_root, title_enc, ent_enc = self.enc_forward(
batch,
ent_mask,
ent_text_mask,
batch["ent_len"],
rel_mask,
title_mask,
)
_h, _c = g_root, g_root.clone().detach() _h, _c = g_root, g_root.clone().detach()
ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask) ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)
if self.args.title: if self.args.title:
attn = _h + self.title_attn(_h, title_enc, mask=title_mask) attn = _h + self.title_attn(_h, title_enc, mask=title_mask)
ctx = torch.cat([ctx, attn], 1) ctx = torch.cat([ctx, attn], 1)
if beam_size<1: if beam_size < 1:
# training # training
outs = [] outs = []
tar_inp = self.tar_emb(batch['text'].transpose(0,1)) tar_inp = self.tar_emb(batch["text"].transpose(0, 1))
for t, xt in enumerate(tar_inp): for t, xt in enumerate(tar_inp):
_xt = torch.cat([ctx, xt], 1) _xt = torch.cat([ctx, xt], 1)
_h, _c = self.decode_lstm(_xt, (_h, _c)) _h, _c = self.decode_lstm(_xt, (_h, _c))
...@@ -60,37 +86,59 @@ class GraphWriter(nn.Module): ...@@ -60,37 +86,59 @@ class GraphWriter(nn.Module):
if self.args.title: if self.args.title:
attn = _h + self.title_attn(_h, title_enc, mask=title_mask) attn = _h + self.title_attn(_h, title_enc, mask=title_mask)
ctx = torch.cat([ctx, attn], 1) ctx = torch.cat([ctx, attn], 1)
outs.append(torch.cat([_h, ctx], 1)) outs.append(torch.cat([_h, ctx], 1))
outs = torch.stack(outs, 1) outs = torch.stack(outs, 1)
copy_gate = torch.sigmoid(self.copy_fc(outs)) copy_gate = torch.sigmoid(self.copy_fc(outs))
EPSI = 1e-6 EPSI = 1e-6
# copy # copy
pred_v = torch.log(copy_gate+EPSI) + torch.log_softmax(self.pred_v_fc(outs), -1) pred_v = torch.log(copy_gate + EPSI) + torch.log_softmax(
pred_c = torch.log((1. - copy_gate)+EPSI) + torch.log_softmax(self.copy_attn(outs, ent_enc, mask=ent_mask), -1) self.pred_v_fc(outs), -1
)
pred_c = torch.log((1.0 - copy_gate) + EPSI) + torch.log_softmax(
self.copy_attn(outs, ent_enc, mask=ent_mask), -1
)
pred = torch.cat([pred_v, pred_c], -1) pred = torch.cat([pred_v, pred_c], -1)
return pred return pred
else: else:
if beam_size==1: if beam_size == 1:
# greedy # greedy
device = g_ent.device device = g_ent.device
B = g_ent.shape[0] B = g_ent.shape[0]
ent_type = batch['ent_type'].view(B, -1) ent_type = batch["ent_type"].view(B, -1)
seq = (torch.ones(B,).long().to(device) * self.args.text_vocab('<BOS>')).unsqueeze(1) seq = (
torch.ones(
B,
)
.long()
.to(device)
* self.args.text_vocab("<BOS>")
).unsqueeze(1)
for t in range(self.args.beam_max_len): for t in range(self.args.beam_max_len):
_inp = replace_ent(seq[:,-1], ent_type, len(self.args.text_vocab)) _inp = replace_ent(
seq[:, -1], ent_type, len(self.args.text_vocab)
)
xt = self.tar_emb(_inp) xt = self.tar_emb(_inp)
_xt = torch.cat([ctx, xt], 1) _xt = torch.cat([ctx, xt], 1)
_h, _c = self.decode_lstm(_xt, (_h, _c)) _h, _c = self.decode_lstm(_xt, (_h, _c))
ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask) ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)
if self.args.title: if self.args.title:
attn = _h + self.title_attn(_h, title_enc, mask=title_mask) attn = _h + self.title_attn(
_h, title_enc, mask=title_mask
)
ctx = torch.cat([ctx, attn], 1) ctx = torch.cat([ctx, attn], 1)
_y = torch.cat([_h, ctx], 1) _y = torch.cat([_h, ctx], 1)
copy_gate = torch.sigmoid(self.copy_fc(_y)) copy_gate = torch.sigmoid(self.copy_fc(_y))
pred_v = torch.log(copy_gate) + torch.log_softmax(self.pred_v_fc(_y), -1) pred_v = torch.log(copy_gate) + torch.log_softmax(
pred_c = torch.log((1. - copy_gate)) + torch.log_softmax(self.copy_attn(_y.unsqueeze(1), ent_enc, mask=ent_mask).squeeze(1), -1) self.pred_v_fc(_y), -1
pred = torch.cat([pred_v, pred_c], -1).view(B,-1) )
for ban_item in ['<BOS>', '<PAD>', '<UNK>']: pred_c = torch.log((1.0 - copy_gate)) + torch.log_softmax(
self.copy_attn(
_y.unsqueeze(1), ent_enc, mask=ent_mask
).squeeze(1),
-1,
)
pred = torch.cat([pred_v, pred_c], -1).view(B, -1)
for ban_item in ["<BOS>", "<PAD>", "<UNK>"]:
pred[:, self.args.text_vocab(ban_item)] = -1e8 pred[:, self.args.text_vocab(ban_item)] = -1e8
_, word = pred.max(-1) _, word = pred.max(-1)
seq = torch.cat([seq, word.unsqueeze(1)], 1) seq = torch.cat([seq, word.unsqueeze(1)], 1)
...@@ -102,47 +150,92 @@ class GraphWriter(nn.Module): ...@@ -102,47 +150,92 @@ class GraphWriter(nn.Module):
BSZ = B * beam_size BSZ = B * beam_size
_h = _h.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1) _h = _h.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
_c = _c.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1) _c = _c.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
ent_mask = ent_mask.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1) ent_mask = (
ent_mask.view(B, 1, -1)
.repeat(1, beam_size, 1)
.view(BSZ, -1)
)
if self.args.title: if self.args.title:
title_mask = title_mask.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1) title_mask = (
title_enc = title_enc.view(B, 1, title_enc.size(1), -1).repeat(1, beam_size, 1, 1).view(BSZ, title_enc.size(1), -1) title_mask.view(B, 1, -1)
.repeat(1, beam_size, 1)
.view(BSZ, -1)
)
title_enc = (
title_enc.view(B, 1, title_enc.size(1), -1)
.repeat(1, beam_size, 1, 1)
.view(BSZ, title_enc.size(1), -1)
)
ctx = ctx.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1) ctx = ctx.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
ent_type = batch['ent_type'].view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1) ent_type = (
g_ent = g_ent.view(B, 1, g_ent.size(1), -1).repeat(1, beam_size, 1, 1).view(BSZ, g_ent.size(1), -1) batch["ent_type"]
ent_enc = ent_enc.view(B, 1, ent_enc.size(1), -1).repeat(1, beam_size, 1, 1).view(BSZ, ent_enc.size(1), -1) .view(B, 1, -1)
.repeat(1, beam_size, 1)
.view(BSZ, -1)
)
g_ent = (
g_ent.view(B, 1, g_ent.size(1), -1)
.repeat(1, beam_size, 1, 1)
.view(BSZ, g_ent.size(1), -1)
)
ent_enc = (
ent_enc.view(B, 1, ent_enc.size(1), -1)
.repeat(1, beam_size, 1, 1)
.view(BSZ, ent_enc.size(1), -1)
)
beam_best = torch.zeros(B).to(device) - 1e9 beam_best = torch.zeros(B).to(device) - 1e9
beam_best_seq = [None] * B beam_best_seq = [None] * B
beam_seq = (torch.ones(B, beam_size).long().to(device) * self.args.text_vocab('<BOS>')).unsqueeze(-1) beam_seq = (
torch.ones(B, beam_size).long().to(device)
* self.args.text_vocab("<BOS>")
).unsqueeze(-1)
beam_score = torch.zeros(B, beam_size).to(device) beam_score = torch.zeros(B, beam_size).to(device)
done_flag = torch.zeros(B, beam_size) done_flag = torch.zeros(B, beam_size)
for t in range(self.args.beam_max_len): for t in range(self.args.beam_max_len):
_inp = replace_ent(beam_seq[:,:,-1].view(-1), ent_type, len(self.args.text_vocab)) _inp = replace_ent(
beam_seq[:, :, -1].view(-1),
ent_type,
len(self.args.text_vocab),
)
xt = self.tar_emb(_inp) xt = self.tar_emb(_inp)
_xt = torch.cat([ctx, xt], 1) _xt = torch.cat([ctx, xt], 1)
_h, _c = self.decode_lstm(_xt, (_h, _c)) _h, _c = self.decode_lstm(_xt, (_h, _c))
ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask) ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)
if self.args.title: if self.args.title:
attn = _h + self.title_attn(_h, title_enc, mask=title_mask) attn = _h + self.title_attn(
_h, title_enc, mask=title_mask
)
ctx = torch.cat([ctx, attn], 1) ctx = torch.cat([ctx, attn], 1)
_y = torch.cat([_h, ctx], 1) _y = torch.cat([_h, ctx], 1)
copy_gate = torch.sigmoid(self.copy_fc(_y)) copy_gate = torch.sigmoid(self.copy_fc(_y))
pred_v = torch.log(copy_gate) + torch.log_softmax(self.pred_v_fc(_y), -1) pred_v = torch.log(copy_gate) + torch.log_softmax(
pred_c = torch.log((1. - copy_gate)) + torch.log_softmax(self.copy_attn(_y.unsqueeze(1), ent_enc, mask=ent_mask).squeeze(1), -1) self.pred_v_fc(_y), -1
pred = torch.cat([pred_v, pred_c], -1).view(B, beam_size, -1) )
for ban_item in ['<BOS>', '<PAD>', '<UNK>']: pred_c = torch.log((1.0 - copy_gate)) + torch.log_softmax(
self.copy_attn(
_y.unsqueeze(1), ent_enc, mask=ent_mask
).squeeze(1),
-1,
)
pred = torch.cat([pred_v, pred_c], -1).view(
B, beam_size, -1
)
for ban_item in ["<BOS>", "<PAD>", "<UNK>"]:
pred[:, :, self.args.text_vocab(ban_item)] = -1e8 pred[:, :, self.args.text_vocab(ban_item)] = -1e8
if t==self.args.beam_max_len-1: # force ending if t == self.args.beam_max_len - 1: # force ending
tt = pred[:, :, self.args.text_vocab('<EOS>')] tt = pred[:, :, self.args.text_vocab("<EOS>")]
pred = pred*0-1e8 pred = pred * 0 - 1e8
pred[:, :, self.args.text_vocab('<EOS>')] = tt pred[:, :, self.args.text_vocab("<EOS>")] = tt
cum_score = beam_score.view(B,beam_size,1) + pred cum_score = beam_score.view(B, beam_size, 1) + pred
score, word = cum_score.topk(dim=-1, k=beam_size) # B, beam_size, beam_size score, word = cum_score.topk(
score, word = score.view(B,-1), word.view(B,-1) dim=-1, k=beam_size
eos_idx = self.args.text_vocab('<EOS>') ) # B, beam_size, beam_size
if beam_seq.size(2)==1: score, word = score.view(B, -1), word.view(B, -1)
eos_idx = self.args.text_vocab("<EOS>")
if beam_seq.size(2) == 1:
new_idx = torch.arange(beam_size).to(word) new_idx = torch.arange(beam_size).to(word)
new_idx = new_idx[None,:].repeat(B,1) new_idx = new_idx[None, :].repeat(B, 1)
else: else:
_, new_idx = score.topk(dim=-1, k=beam_size) _, new_idx = score.topk(dim=-1, k=beam_size)
new_src, new_score, new_word, new_done = [], [], [], [] new_src, new_score, new_word, new_done = [], [], [], []
...@@ -151,7 +244,7 @@ class GraphWriter(nn.Module): ...@@ -151,7 +244,7 @@ class GraphWriter(nn.Module):
for j in range(beam_size): for j in range(beam_size):
tmp_score = score[i][new_idx[i][j]] tmp_score = score[i][new_idx[i][j]]
tmp_word = word[i][new_idx[i][j]] tmp_word = word[i][new_idx[i][j]]
src_idx = new_idx[i][j]//beam_size src_idx = new_idx[i][j] // beam_size
new_src.append(src_idx) new_src.append(src_idx)
if tmp_word == eos_idx: if tmp_word == eos_idx:
new_score.append(-1e8) new_score.append(-1e8)
...@@ -159,24 +252,45 @@ class GraphWriter(nn.Module): ...@@ -159,24 +252,45 @@ class GraphWriter(nn.Module):
new_score.append(tmp_score) new_score.append(tmp_score)
new_word.append(tmp_word) new_word.append(tmp_word)
if tmp_word == eos_idx and done_flag[i][src_idx]==0 and tmp_score/LP>beam_best[i]: if (
beam_best[i] = tmp_score/LP tmp_word == eos_idx
and done_flag[i][src_idx] == 0
and tmp_score / LP > beam_best[i]
):
beam_best[i] = tmp_score / LP
beam_best_seq[i] = beam_seq[i][src_idx] beam_best_seq[i] = beam_seq[i][src_idx]
if tmp_word == eos_idx: if tmp_word == eos_idx:
new_done.append(1) new_done.append(1)
else: else:
new_done.append(done_flag[i][src_idx]) new_done.append(done_flag[i][src_idx])
new_score = torch.Tensor(new_score).view(B,beam_size).to(beam_score) new_score = (
new_word = torch.Tensor(new_word).view(B,beam_size).to(beam_seq) torch.Tensor(new_score)
new_src = torch.LongTensor(new_src).view(B,beam_size).to(device) .view(B, beam_size)
new_done = torch.Tensor(new_done).view(B,beam_size).to(done_flag) .to(beam_score)
)
new_word = (
torch.Tensor(new_word).view(B, beam_size).to(beam_seq)
)
new_src = (
torch.LongTensor(new_src).view(B, beam_size).to(device)
)
new_done = (
torch.Tensor(new_done).view(B, beam_size).to(done_flag)
)
beam_score = new_score beam_score = new_score
done_flag = new_done done_flag = new_done
beam_seq = beam_seq.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src] beam_seq = beam_seq.view(B, beam_size, -1)[
torch.arange(B)[:, None].to(device), new_src
]
beam_seq = torch.cat([beam_seq, new_word.unsqueeze(2)], 2) beam_seq = torch.cat([beam_seq, new_word.unsqueeze(2)], 2)
_h = _h.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src].view(BSZ,-1) _h = _h.view(B, beam_size, -1)[
_c = _c.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src].view(BSZ,-1) torch.arange(B)[:, None].to(device), new_src
ctx = ctx.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src].view(BSZ,-1) ].view(BSZ, -1)
_c = _c.view(B, beam_size, -1)[
torch.arange(B)[:, None].to(device), new_src
].view(BSZ, -1)
ctx = ctx.view(B, beam_size, -1)[
torch.arange(B)[:, None].to(device), new_src
].view(BSZ, -1)
return beam_best_seq return beam_best_seq
import torch
import math import math
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from utlis import *
import dgl.function as fn import dgl.function as fn
from dgl.nn.functional import edge_softmax from dgl.nn.functional import edge_softmax
from utlis import *
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence,pad_packed_sequence
class MSA(nn.Module): class MSA(nn.Module):
...@@ -13,45 +15,57 @@ class MSA(nn.Module): ...@@ -13,45 +15,57 @@ class MSA(nn.Module):
# the first is the copy, determining which entity should be copied. # the first is the copy, determining which entity should be copied.
# the second is the normal attention with two sequence inputs # the second is the normal attention with two sequence inputs
# the third is the attention but with one token and a sequence. (gather, attentive pooling) # the third is the attention but with one token and a sequence. (gather, attentive pooling)
def __init__(self, args, mode='normal'): def __init__(self, args, mode="normal"):
super(MSA, self).__init__() super(MSA, self).__init__()
if mode=='copy': if mode == "copy":
nhead, head_dim = 1, args.nhid nhead, head_dim = 1, args.nhid
qninp, kninp = args.dec_ninp, args.nhid qninp, kninp = args.dec_ninp, args.nhid
if mode=='normal': if mode == "normal":
nhead, head_dim = args.nhead, args.head_dim nhead, head_dim = args.nhead, args.head_dim
qninp, kninp = args.nhid, args.nhid qninp, kninp = args.nhid, args.nhid
self.attn_drop = nn.Dropout(0.1) self.attn_drop = nn.Dropout(0.1)
self.WQ = nn.Linear(qninp, nhead*head_dim, bias=True if mode=='copy' else False) self.WQ = nn.Linear(
if mode!='copy': qninp, nhead * head_dim, bias=True if mode == "copy" else False
self.WK = nn.Linear(kninp, nhead*head_dim, bias=False) )
self.WV = nn.Linear(kninp, nhead*head_dim, bias=False) if mode != "copy":
self.args, self.nhead, self.head_dim, self.mode = args, nhead, head_dim, mode self.WK = nn.Linear(kninp, nhead * head_dim, bias=False)
self.WV = nn.Linear(kninp, nhead * head_dim, bias=False)
self.args, self.nhead, self.head_dim, self.mode = (
args,
nhead,
head_dim,
mode,
)
def forward(self, inp1, inp2, mask=None): def forward(self, inp1, inp2, mask=None):
B, L2, H = inp2.shape B, L2, H = inp2.shape
NH, HD = self.nhead, self.head_dim NH, HD = self.nhead, self.head_dim
if self.mode=='copy': if self.mode == "copy":
q, k, v = self.WQ(inp1), inp2, inp2 q, k, v = self.WQ(inp1), inp2, inp2
else: else:
q, k, v = self.WQ(inp1), self.WK(inp2), self.WV(inp2) q, k, v = self.WQ(inp1), self.WK(inp2), self.WV(inp2)
L1 = 1 if inp1.ndim==2 else inp1.shape[1] L1 = 1 if inp1.ndim == 2 else inp1.shape[1]
if self.mode!='copy': if self.mode != "copy":
q = q / math.sqrt(H) q = q / math.sqrt(H)
q = q.view(B, L1, NH, HD).permute(0, 2, 1, 3) q = q.view(B, L1, NH, HD).permute(0, 2, 1, 3)
k = k.view(B, L2, NH, HD).permute(0, 2, 3, 1) k = k.view(B, L2, NH, HD).permute(0, 2, 3, 1)
v = v.view(B, L2, NH, HD).permute(0, 2, 1, 3) v = v.view(B, L2, NH, HD).permute(0, 2, 1, 3)
pre_attn = torch.matmul(q,k) pre_attn = torch.matmul(q, k)
if mask is not None: if mask is not None:
pre_attn = pre_attn.masked_fill(mask[:,None,None,:], -1e8) pre_attn = pre_attn.masked_fill(mask[:, None, None, :], -1e8)
if self.mode=='copy': if self.mode == "copy":
return pre_attn.squeeze(1) return pre_attn.squeeze(1)
else: else:
alpha = self.attn_drop(torch.softmax(pre_attn, -1)) alpha = self.attn_drop(torch.softmax(pre_attn, -1))
attn = torch.matmul(alpha, v).permute(0, 2, 1, 3).contiguous().view(B,L1,NH*HD) attn = (
torch.matmul(alpha, v)
.permute(0, 2, 1, 3)
.contiguous()
.view(B, L1, NH * HD)
)
ret = attn ret = attn
if inp1.ndim==2: if inp1.ndim == 2:
return ret.squeeze(1) return ret.squeeze(1)
else: else:
return ret return ret
...@@ -59,52 +73,63 @@ class MSA(nn.Module): ...@@ -59,52 +73,63 @@ class MSA(nn.Module):
class BiLSTM(nn.Module): class BiLSTM(nn.Module):
# for entity encoding or the title encoding # for entity encoding or the title encoding
def __init__(self, args, enc_type='title'): def __init__(self, args, enc_type="title"):
super(BiLSTM, self).__init__() super(BiLSTM, self).__init__()
self.enc_type = enc_type self.enc_type = enc_type
self.drop = nn.Dropout(args.emb_drop) self.drop = nn.Dropout(args.emb_drop)
self.bilstm = nn.LSTM(args.nhid, args.nhid//2, bidirectional=True, \ self.bilstm = nn.LSTM(
num_layers=args.enc_lstm_layers, batch_first=True) args.nhid,
args.nhid // 2,
bidirectional=True,
num_layers=args.enc_lstm_layers,
batch_first=True,
)
def forward(self, inp, mask, ent_len=None): def forward(self, inp, mask, ent_len=None):
inp = self.drop(inp) inp = self.drop(inp)
lens = (mask==0).sum(-1).long().tolist() lens = (mask == 0).sum(-1).long().tolist()
pad_seq = pack_padded_sequence(inp, lens, batch_first=True, enforce_sorted=False) pad_seq = pack_padded_sequence(
inp, lens, batch_first=True, enforce_sorted=False
)
y, (_h, _c) = self.bilstm(pad_seq) y, (_h, _c) = self.bilstm(pad_seq)
if self.enc_type=='title': if self.enc_type == "title":
y = pad_packed_sequence(y, batch_first=True)[0] y = pad_packed_sequence(y, batch_first=True)[0]
return y return y
if self.enc_type=='entity': if self.enc_type == "entity":
_h = _h.transpose(0,1).contiguous() _h = _h.transpose(0, 1).contiguous()
_h = _h[:,-2:].view(_h.size(0), -1) # two directions of the top-layer _h = _h[:, -2:].view(
ret = pad(_h.split(ent_len), out_type='tensor') _h.size(0), -1
) # two directions of the top-layer
ret = pad(_h.split(ent_len), out_type="tensor")
return ret return ret
class GAT(nn.Module): class GAT(nn.Module):
# a graph attention network with dot-product attention # a graph attention network with dot-product attention
def __init__(self, def __init__(
in_feats, self,
out_feats, in_feats,
num_heads, out_feats,
ffn_drop=0., num_heads,
attn_drop=0., ffn_drop=0.0,
trans=True): attn_drop=0.0,
trans=True,
):
super(GAT, self).__init__() super(GAT, self).__init__()
self._num_heads = num_heads self._num_heads = num_heads
self._in_feats = in_feats self._in_feats = in_feats
self._out_feats = out_feats self._out_feats = out_feats
self.q_proj = nn.Linear(in_feats, num_heads*out_feats, bias=False) self.q_proj = nn.Linear(in_feats, num_heads * out_feats, bias=False)
self.k_proj = nn.Linear(in_feats, num_heads*out_feats, bias=False) self.k_proj = nn.Linear(in_feats, num_heads * out_feats, bias=False)
self.v_proj = nn.Linear(in_feats, num_heads*out_feats, bias=False) self.v_proj = nn.Linear(in_feats, num_heads * out_feats, bias=False)
self.attn_drop = nn.Dropout(0.1) self.attn_drop = nn.Dropout(0.1)
self.ln1 = nn.LayerNorm(in_feats) self.ln1 = nn.LayerNorm(in_feats)
self.ln2 = nn.LayerNorm(in_feats) self.ln2 = nn.LayerNorm(in_feats)
if trans: if trans:
self.FFN = nn.Sequential( self.FFN = nn.Sequential(
nn.Linear(in_feats, 4*in_feats), nn.Linear(in_feats, 4 * in_feats),
nn.PReLU(4*in_feats), nn.PReLU(4 * in_feats),
nn.Linear(4*in_feats, in_feats), nn.Linear(4 * in_feats, in_feats),
nn.Dropout(0.1), nn.Dropout(0.1),
) )
# a strange FFN, see the author's code # a strange FFN, see the author's code
...@@ -117,40 +142,64 @@ class GAT(nn.Module): ...@@ -117,40 +142,64 @@ class GAT(nn.Module):
q = q.view(-1, self._num_heads, self._out_feats) q = q.view(-1, self._num_heads, self._out_feats)
k = k.view(-1, self._num_heads, self._out_feats) k = k.view(-1, self._num_heads, self._out_feats)
v = v.view(-1, self._num_heads, self._out_feats) v = v.view(-1, self._num_heads, self._out_feats)
graph.ndata.update({'ft': v, 'el': k, 'er': q}) # k,q instead of q,k, the edge_softmax is applied on incoming edges graph.ndata.update(
{"ft": v, "el": k, "er": q}
) # k,q instead of q,k, the edge_softmax is applied on incoming edges
# compute edge attention # compute edge attention
graph.apply_edges(fn.u_dot_v('el', 'er', 'e')) graph.apply_edges(fn.u_dot_v("el", "er", "e"))
e = graph.edata.pop('e') / math.sqrt(self._out_feats * self._num_heads) e = graph.edata.pop("e") / math.sqrt(self._out_feats * self._num_heads)
graph.edata['a'] = edge_softmax(graph, e) graph.edata["a"] = edge_softmax(graph, e)
# message passing # message passing
graph.update_all(fn.u_mul_e('ft', 'a', 'm'), graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft2"))
fn.sum('m', 'ft2')) rst = graph.ndata["ft2"]
rst = graph.ndata['ft2']
# residual # residual
rst = rst.view(feat.shape) + feat rst = rst.view(feat.shape) + feat
if self._trans: if self._trans:
rst = self.ln1(rst) rst = self.ln1(rst)
rst = self.ln1(rst+self.FFN(rst)) rst = self.ln1(rst + self.FFN(rst))
# use the same layer norm, see the author's code # use the same layer norm, see the author's code
return rst return rst
class GraphTrans(nn.Module): class GraphTrans(nn.Module):
def __init__(self,args): def __init__(self, args):
super().__init__() super().__init__()
self.args = args self.args = args
if args.graph_enc == "gat": if args.graph_enc == "gat":
# we only support gtrans, don't use this one # we only support gtrans, don't use this one
self.gat = nn.ModuleList([GAT(args.nhid, args.nhid//4, 4, attn_drop=args.attn_drop, trans=False) for _ in range(args.prop)]) #untested self.gat = nn.ModuleList(
[
GAT(
args.nhid,
args.nhid // 4,
4,
attn_drop=args.attn_drop,
trans=False,
)
for _ in range(args.prop)
]
) # untested
else: else:
self.gat = nn.ModuleList([GAT(args.nhid, args.nhid//4, 4, attn_drop=args.attn_drop, ffn_drop=args.drop, trans=True) for _ in range(args.prop)]) self.gat = nn.ModuleList(
[
GAT(
args.nhid,
args.nhid // 4,
4,
attn_drop=args.attn_drop,
ffn_drop=args.drop,
trans=True,
)
for _ in range(args.prop)
]
)
self.prop = args.prop self.prop = args.prop
def forward(self, ent, ent_mask, ent_len, rel, rel_mask, graphs): def forward(self, ent, ent_mask, ent_len, rel, rel_mask, graphs):
device = ent.device device = ent.device
graphs = graphs.to(device) graphs = graphs.to(device)
ent_mask = (ent_mask==0) # reverse mask ent_mask = ent_mask == 0 # reverse mask
rel_mask = (rel_mask==0) rel_mask = rel_mask == 0
init_h = [] init_h = []
for i in range(graphs.batch_size): for i in range(graphs.batch_size):
init_h.append(ent[i][ent_mask[i]]) init_h.append(ent[i][ent_mask[i]])
...@@ -159,7 +208,19 @@ class GraphTrans(nn.Module): ...@@ -159,7 +208,19 @@ class GraphTrans(nn.Module):
feats = init_h feats = init_h
for i in range(self.prop): for i in range(self.prop):
feats = self.gat[i](graphs, feats) feats = self.gat[i](graphs, feats)
g_root = feats.index_select(0, graphs.filter_nodes(lambda x: x.data['type']==NODE_TYPE['root']).to(device)) g_root = feats.index_select(
g_ent = pad(feats.index_select(0, graphs.filter_nodes(lambda x: x.data['type']==NODE_TYPE['entity']).to(device)).split(ent_len), out_type='tensor') 0,
graphs.filter_nodes(
lambda x: x.data["type"] == NODE_TYPE["root"]
).to(device),
)
g_ent = pad(
feats.index_select(
0,
graphs.filter_nodes(
lambda x: x.data["type"] == NODE_TYPE["entity"]
).to(device),
).split(ent_len),
out_type="tensor",
)
return g_ent, g_root return g_ent, g_root
import torch
import argparse import argparse
import torch
def fill_config(args): def fill_config(args):
# dirty work # dirty work
args.device = torch.device(args.gpu) args.device = torch.device(args.gpu)
args.dec_ninp = args.nhid * 3 if args.title else args.nhid * 2 args.dec_ninp = args.nhid * 3 if args.title else args.nhid * 2
args.fnames = [args.train_file, args.valid_file, args.test_file] args.fnames = [args.train_file, args.valid_file, args.test_file]
return args return args
def vocab_config(args, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab): def vocab_config(
args, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab
):
# dirty work # dirty work
args.ent_vocab = ent_vocab args.ent_vocab = ent_vocab
args.rel_vocab = rel_vocab args.rel_vocab = rel_vocab
...@@ -21,35 +24,84 @@ def vocab_config(args, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_v ...@@ -21,35 +24,84 @@ def vocab_config(args, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_v
def get_args(): def get_args():
args = argparse.ArgumentParser(description='Graph Writer in DGL') args = argparse.ArgumentParser(description="Graph Writer in DGL")
args.add_argument('--nhid', default=500, type=int, help='hidden size') args.add_argument("--nhid", default=500, type=int, help="hidden size")
args.add_argument('--nhead', default=4, type=int, help='number of heads') args.add_argument("--nhead", default=4, type=int, help="number of heads")
args.add_argument('--head_dim', default=125, type=int, help='head dim') args.add_argument("--head_dim", default=125, type=int, help="head dim")
args.add_argument('--weight_decay', default=0.0, type=float, help='weight decay') args.add_argument(
args.add_argument('--prop', default=6, type=int, help='number of layers of gnn') "--weight_decay", default=0.0, type=float, help="weight decay"
args.add_argument('--title', action='store_true', help='use title input') )
args.add_argument('--test', action='store_true', help='inference mode') args.add_argument(
args.add_argument('--batch_size', default=32, type=int, help='batch_size') "--prop", default=6, type=int, help="number of layers of gnn"
args.add_argument('--beam_size', default=4, type=int, help='beam size, 1 for greedy') )
args.add_argument('--epoch', default=20, type=int, help='training epoch') args.add_argument("--title", action="store_true", help="use title input")
args.add_argument('--beam_max_len', default=200, type=int, help='max length of the generated text') args.add_argument("--test", action="store_true", help="inference mode")
args.add_argument('--enc_lstm_layers', default=2, type=int, help='number of layers of lstm') args.add_argument("--batch_size", default=32, type=int, help="batch_size")
args.add_argument('--lr', default=1e-1, type=float, help='learning rate') args.add_argument(
#args.add_argument('--lr_decay', default=1e-8, type=float, help='') "--beam_size", default=4, type=int, help="beam size, 1 for greedy"
args.add_argument('--clip', default=1, type=float, help='gradient clip') )
args.add_argument('--emb_drop', default=0.0, type=float, help='embedding dropout') args.add_argument("--epoch", default=20, type=int, help="training epoch")
args.add_argument('--attn_drop', default=0.1, type=float, help='attention dropout') args.add_argument(
args.add_argument('--drop', default=0.1, type=float, help='dropout') "--beam_max_len",
args.add_argument('--lp', default=1.0, type=float, help='length penalty') default=200,
args.add_argument('--graph_enc', default='gtrans', type=str, help='gnn mode, we only support the graph transformer now') type=int,
args.add_argument('--train_file', default='data/unprocessed.train.json', type=str, help='training file') help="max length of the generated text",
args.add_argument('--valid_file', default='data/unprocessed.val.json', type=str, help='validation file') )
args.add_argument('--test_file', default='data/unprocessed.test.json', type=str, help='test file') args.add_argument(
args.add_argument('--save_dataset', default='data.pickle', type=str, help='save path of dataset') "--enc_lstm_layers",
args.add_argument('--save_model', default='saved_model.pt', type=str, help='save path of model') default=2,
type=int,
args.add_argument('--gpu', default=0, type=int, help='gpu mode') help="number of layers of lstm",
)
args.add_argument("--lr", default=1e-1, type=float, help="learning rate")
# args.add_argument('--lr_decay', default=1e-8, type=float, help='')
args.add_argument("--clip", default=1, type=float, help="gradient clip")
args.add_argument(
"--emb_drop", default=0.0, type=float, help="embedding dropout"
)
args.add_argument(
"--attn_drop", default=0.1, type=float, help="attention dropout"
)
args.add_argument("--drop", default=0.1, type=float, help="dropout")
args.add_argument("--lp", default=1.0, type=float, help="length penalty")
args.add_argument(
"--graph_enc",
default="gtrans",
type=str,
help="gnn mode, we only support the graph transformer now",
)
args.add_argument(
"--train_file",
default="data/unprocessed.train.json",
type=str,
help="training file",
)
args.add_argument(
"--valid_file",
default="data/unprocessed.val.json",
type=str,
help="validation file",
)
args.add_argument(
"--test_file",
default="data/unprocessed.test.json",
type=str,
help="test file",
)
args.add_argument(
"--save_dataset",
default="data.pickle",
type=str,
help="save path of dataset",
)
args.add_argument(
"--save_model",
default="saved_model.pt",
type=str,
help="save path of model",
)
args.add_argument("--gpu", default=0, type=int, help="gpu mode")
args = args.parse_args() args = args.parse_args()
args = fill_config(args) args = fill_config(args)
return args return args
import os
import sys
import time
import numpy as np
import torch import torch
import torch.nn.functional as F
import torch.nn as nn import torch.nn as nn
import numpy as np import torch.nn.functional as F
import time
from tqdm import tqdm
from graphwriter import * from graphwriter import *
from utlis import *
from opts import * from opts import *
import os from tqdm import tqdm
import sys from utlis import *
sys.path.append('./pycocoevalcap') sys.path.append("./pycocoevalcap")
from pycocoevalcap.bleu.bleu import Bleu from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.meteor.meteor import Meteor from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
def train_one_epoch(model, dataloader, optimizer, args, epoch): def train_one_epoch(model, dataloader, optimizer, args, epoch):
model.train() model.train()
tloss = 0. tloss = 0.0
tcnt = 0. tcnt = 0.0
st_time = time.time() st_time = time.time()
with tqdm(dataloader, desc='Train Ep '+str(epoch), mininterval=60) as tq: with tqdm(dataloader, desc="Train Ep " + str(epoch), mininterval=60) as tq:
for batch in tq: for batch in tq:
pred = model(batch) pred = model(batch)
nll_loss = F.nll_loss(pred.view(-1, pred.shape[-1]), batch['tgt_text'].view(-1), ignore_index=0) nll_loss = F.nll_loss(
pred.view(-1, pred.shape[-1]),
batch["tgt_text"].view(-1),
ignore_index=0,
)
loss = nll_loss loss = nll_loss
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), args.clip) nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step() optimizer.step()
loss = loss.item() loss = loss.item()
if loss!=loss: if loss != loss:
raise ValueError('NaN appear') raise ValueError("NaN appear")
tloss += loss * len(batch['tgt_text']) tloss += loss * len(batch["tgt_text"])
tcnt += len(batch['tgt_text']) tcnt += len(batch["tgt_text"])
tq.set_postfix({'loss': tloss/tcnt}, refresh=False) tq.set_postfix({"loss": tloss / tcnt}, refresh=False)
print('Train Ep ', str(epoch), 'AVG Loss ', tloss/tcnt, 'Steps ', tcnt, 'Time ', time.time()-st_time, 'GPU', torch.cuda.max_memory_cached()/1024.0/1024.0/1024.0) print(
torch.save(model, args.save_model+str(epoch%100)) "Train Ep ",
str(epoch),
"AVG Loss ",
tloss / tcnt,
"Steps ",
tcnt,
"Time ",
time.time() - st_time,
"GPU",
torch.cuda.max_memory_cached() / 1024.0 / 1024.0 / 1024.0,
)
torch.save(model, args.save_model + str(epoch % 100))
val_loss = 2**31 val_loss = 2**31
def eval_it(model, dataloader, args, epoch): def eval_it(model, dataloader, args, epoch):
global val_loss global val_loss
model.eval() model.eval()
tloss = 0. tloss = 0.0
tcnt = 0. tcnt = 0.0
st_time = time.time() st_time = time.time()
with tqdm(dataloader, desc='Eval Ep '+str(epoch), mininterval=60) as tq: with tqdm(dataloader, desc="Eval Ep " + str(epoch), mininterval=60) as tq:
for batch in tq: for batch in tq:
with torch.no_grad(): with torch.no_grad():
pred = model(batch) pred = model(batch)
nll_loss = F.nll_loss(pred.view(-1, pred.shape[-1]), batch['tgt_text'].view(-1), ignore_index=0) nll_loss = F.nll_loss(
pred.view(-1, pred.shape[-1]),
batch["tgt_text"].view(-1),
ignore_index=0,
)
loss = nll_loss loss = nll_loss
loss = loss.item() loss = loss.item()
tloss += loss * len(batch['tgt_text']) tloss += loss * len(batch["tgt_text"])
tcnt += len(batch['tgt_text']) tcnt += len(batch["tgt_text"])
tq.set_postfix({'loss': tloss/tcnt}, refresh=False) tq.set_postfix({"loss": tloss / tcnt}, refresh=False)
print('Eval Ep ', str(epoch), 'AVG Loss ', tloss/tcnt, 'Steps ', tcnt, 'Time ', time.time()-st_time) print(
if tloss/tcnt < val_loss: "Eval Ep ",
print('Saving best model ', 'Ep ', epoch, ' loss ', tloss/tcnt) str(epoch),
torch.save(model, args.save_model+'best') "AVG Loss ",
val_loss = tloss/tcnt tloss / tcnt,
"Steps ",
tcnt,
"Time ",
time.time() - st_time,
)
if tloss / tcnt < val_loss:
print("Saving best model ", "Ep ", epoch, " loss ", tloss / tcnt)
torch.save(model, args.save_model + "best")
val_loss = tloss / tcnt
def test(model, dataloader, args): def test(model, dataloader, args):
...@@ -71,39 +102,61 @@ def test(model, dataloader, args): ...@@ -71,39 +102,61 @@ def test(model, dataloader, args):
hyp = [] hyp = []
ref = [] ref = []
model.eval() model.eval()
gold_file = open('tmp_gold.txt', 'w') gold_file = open("tmp_gold.txt", "w")
pred_file = open('tmp_pred.txt', 'w') pred_file = open("tmp_pred.txt", "w")
with tqdm(dataloader, desc='Test ', mininterval=1) as tq: with tqdm(dataloader, desc="Test ", mininterval=1) as tq:
for batch in tq: for batch in tq:
with torch.no_grad(): with torch.no_grad():
seq = model(batch, beam_size=args.beam_size) seq = model(batch, beam_size=args.beam_size)
r = write_txt(batch, batch['tgt_text'], gold_file, args) r = write_txt(batch, batch["tgt_text"], gold_file, args)
h = write_txt(batch, seq, pred_file, args) h = write_txt(batch, seq, pred_file, args)
hyp.extend(h) hyp.extend(h)
ref.extend(r) ref.extend(r)
hyp = dict(zip(range(len(hyp)), hyp)) hyp = dict(zip(range(len(hyp)), hyp))
ref = dict(zip(range(len(ref)), ref)) ref = dict(zip(range(len(ref)), ref))
print(hyp[0], ref[0]) print(hyp[0], ref[0])
print('BLEU INP', len(hyp), len(ref)) print("BLEU INP", len(hyp), len(ref))
print('BLEU', scorer.compute_score(ref, hyp)[0]) print("BLEU", scorer.compute_score(ref, hyp)[0])
print('METEOR', m_scorer.compute_score(ref, hyp)[0]) print("METEOR", m_scorer.compute_score(ref, hyp)[0])
print('ROUGE_L', r_scorer.compute_score(ref, hyp)[0]) print("ROUGE_L", r_scorer.compute_score(ref, hyp)[0])
gold_file.close() gold_file.close()
pred_file.close() pred_file.close()
def main(args): def main(args):
if os.path.exists(args.save_dataset): if os.path.exists(args.save_dataset):
train_dataset, valid_dataset, test_dataset = pickle.load(open(args.save_dataset, 'rb')) train_dataset, valid_dataset, test_dataset = pickle.load(
open(args.save_dataset, "rb")
)
else: else:
train_dataset, valid_dataset, test_dataset = get_datasets(args.fnames, device=args.device, save=args.save_dataset) train_dataset, valid_dataset, test_dataset = get_datasets(
args = vocab_config(args, train_dataset.ent_vocab, train_dataset.rel_vocab, train_dataset.text_vocab, train_dataset.ent_text_vocab, train_dataset.title_vocab) args.fnames, device=args.device, save=args.save_dataset
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_sampler = BucketSampler(train_dataset, batch_size=args.batch_size), \ )
collate_fn=train_dataset.batch_fn) args = vocab_config(
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, \ args,
shuffle=False, collate_fn=train_dataset.batch_fn) train_dataset.ent_vocab,
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, \ train_dataset.rel_vocab,
shuffle=False, collate_fn=train_dataset.batch_fn) train_dataset.text_vocab,
train_dataset.ent_text_vocab,
train_dataset.title_vocab,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=BucketSampler(train_dataset, batch_size=args.batch_size),
collate_fn=train_dataset.batch_fn,
)
valid_dataloader = torch.utils.data.DataLoader(
valid_dataset,
batch_size=args.batch_size,
shuffle=False,
collate_fn=train_dataset.batch_fn,
)
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
collate_fn=train_dataset.batch_fn,
)
model = GraphWriter(args) model = GraphWriter(args)
model.to(args.device) model.to(args.device)
...@@ -113,13 +166,18 @@ def main(args): ...@@ -113,13 +166,18 @@ def main(args):
print(model) print(model)
test(model, test_dataloader, args) test(model, test_dataloader, args)
else: else:
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9) optimizer = torch.optim.SGD(
model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay,
momentum=0.9,
)
print(model) print(model)
for epoch in range(args.epoch): for epoch in range(args.epoch):
train_one_epoch(model, train_dataloader, optimizer, args, epoch) train_one_epoch(model, train_dataloader, optimizer, args, epoch)
eval_it(model, valid_dataloader, args, epoch) eval_it(model, valid_dataloader, args, epoch)
if __name__ == '__main__':
if __name__ == "__main__":
args = get_args() args = get_args()
main(args) main(args)
import torch
import dgl
import numpy as np
import json import json
import pickle import pickle
import random import random
import numpy as np
import torch
import dgl
NODE_TYPE = {'entity': 0, 'root': 1, 'relation':2} NODE_TYPE = {"entity": 0, "root": 1, "relation": 2}
def write_txt(batch, seqs, w_file, args): def write_txt(batch, seqs, w_file, args):
...@@ -16,59 +17,103 @@ def write_txt(batch, seqs, w_file, args): ...@@ -16,59 +17,103 @@ def write_txt(batch, seqs, w_file, args):
txt = [] txt = []
for token in seq: for token in seq:
# copy the entity # copy the entity
if token>=len(args.text_vocab): if token >= len(args.text_vocab):
ent_text = batch['raw_ent_text'][b][token-len(args.text_vocab)] ent_text = batch["raw_ent_text"][b][
ent_text = filter(lambda x:x!='<PAD>', ent_text) token - len(args.text_vocab)
]
ent_text = filter(lambda x: x != "<PAD>", ent_text)
txt.extend(ent_text) txt.extend(ent_text)
else: else:
if int(token) not in [args.text_vocab(x) for x in ['<PAD>', '<BOS>', '<EOS>']]: if int(token) not in [
args.text_vocab(x) for x in ["<PAD>", "<BOS>", "<EOS>"]
]:
txt.append(args.text_vocab(int(token))) txt.append(args.text_vocab(int(token)))
if int(token) == args.text_vocab('<EOS>'): if int(token) == args.text_vocab("<EOS>"):
break break
w_file.write(' '.join([str(x) for x in txt])+'\n') w_file.write(" ".join([str(x) for x in txt]) + "\n")
ret.append([' '.join([str(x) for x in txt])]) ret.append([" ".join([str(x) for x in txt])])
return ret return ret
def replace_ent(x, ent, V): def replace_ent(x, ent, V):
# replace the entity # replace the entity
mask = x>=V mask = x >= V
if mask.sum()==0: if mask.sum() == 0:
return x return x
nz = mask.nonzero() nz = mask.nonzero()
fill_ent = ent[nz, x[mask]-V] fill_ent = ent[nz, x[mask] - V]
x = x.masked_scatter(mask, fill_ent) x = x.masked_scatter(mask, fill_ent)
return x return x
def len2mask(lens, device): def len2mask(lens, device):
max_len = max(lens) max_len = max(lens)
mask = torch.arange(max_len, device=device).unsqueeze(0).expand(len(lens), max_len) mask = (
torch.arange(max_len, device=device)
.unsqueeze(0)
.expand(len(lens), max_len)
)
mask = mask >= torch.LongTensor(lens).to(mask).unsqueeze(1) mask = mask >= torch.LongTensor(lens).to(mask).unsqueeze(1)
return mask return mask
def pad(var_len_list, out_type='list', flatten=False): def pad(var_len_list, out_type="list", flatten=False):
if flatten: if flatten:
lens = [len(x) for x in var_len_list] lens = [len(x) for x in var_len_list]
var_len_list = sum(var_len_list, []) var_len_list = sum(var_len_list, [])
max_len = max([len(x) for x in var_len_list]) max_len = max([len(x) for x in var_len_list])
if out_type=='list': if out_type == "list":
if flatten: if flatten:
return [x+['<PAD>']*(max_len-len(x)) for x in var_len_list], lens return [
x + ["<PAD>"] * (max_len - len(x)) for x in var_len_list
], lens
else: else:
return [x+['<PAD>']*(max_len-len(x)) for x in var_len_list] return [x + ["<PAD>"] * (max_len - len(x)) for x in var_len_list]
if out_type=='tensor': if out_type == "tensor":
if flatten: if flatten:
return torch.stack([torch.cat([x, \ return (
torch.zeros([max_len-len(x)]+list(x.shape[1:])).type_as(x)], 0) for x in var_len_list], 0), lens torch.stack(
[
torch.cat(
[
x,
torch.zeros(
[max_len - len(x)] + list(x.shape[1:])
).type_as(x),
],
0,
)
for x in var_len_list
],
0,
),
lens,
)
else: else:
return torch.stack([torch.cat([x, \ return torch.stack(
torch.zeros([max_len-len(x)]+list(x.shape[1:])).type_as(x)], 0) for x in var_len_list], 0) [
torch.cat(
[
x,
torch.zeros(
[max_len - len(x)] + list(x.shape[1:])
).type_as(x),
],
0,
)
for x in var_len_list
],
0,
)
class Vocab(object): class Vocab(object):
def __init__(self, max_vocab=2**31, min_freq=-1, sp=['<PAD>', '<BOS>', '<EOS>', '<UNK>']): def __init__(
self,
max_vocab=2**31,
min_freq=-1,
sp=["<PAD>", "<BOS>", "<EOS>", "<UNK>"],
):
self.i2s = [] self.i2s = []
self.s2i = {} self.s2i = {}
self.wf = {} self.wf = {}
...@@ -78,7 +123,7 @@ class Vocab(object): ...@@ -78,7 +123,7 @@ class Vocab(object):
return len(self.i2s) return len(self.i2s)
def __str__(self): def __str__(self):
return 'Total ' + str(len(self.i2s)) + str(self.i2s[:10]) return "Total " + str(len(self.i2s)) + str(self.i2s[:10])
def update(self, token): def update(self, token):
if isinstance(token, list): if isinstance(token, list):
...@@ -89,9 +134,13 @@ class Vocab(object): ...@@ -89,9 +134,13 @@ class Vocab(object):
def build(self): def build(self):
self.i2s.extend(self.sp) self.i2s.extend(self.sp)
sort_kv = sorted(self.wf.items(), key=lambda x:x[1], reverse=True) sort_kv = sorted(self.wf.items(), key=lambda x: x[1], reverse=True)
for k,v in sort_kv: for k, v in sort_kv:
if len(self.i2s)<self.max_vocab and v>=self.min_freq and k not in self.sp: if (
len(self.i2s) < self.max_vocab
and v >= self.min_freq
and k not in self.sp
):
self.i2s.append(k) self.i2s.append(k)
self.s2i.update(list(zip(self.i2s, range(len(self.i2s))))) self.s2i.update(list(zip(self.i2s, range(len(self.i2s)))))
...@@ -99,109 +148,167 @@ class Vocab(object): ...@@ -99,109 +148,167 @@ class Vocab(object):
if isinstance(x, int): if isinstance(x, int):
return self.i2s[x] return self.i2s[x]
else: else:
return self.s2i.get(x, self.s2i['<UNK>']) return self.s2i.get(x, self.s2i["<UNK>"])
def save(self, fname): def save(self, fname):
pass pass
def load(self, fname): def load(self, fname):
pass pass
def at_least(x): def at_least(x):
# handling the illegal data # handling the illegal data
if len(x) == 0: if len(x) == 0:
return ['<UNK>'] return ["<UNK>"]
else: else:
return x return x
class Example(object): class Example(object):
def __init__(self, title, ent_text, ent_type, rel, text): def __init__(self, title, ent_text, ent_type, rel, text):
# one object corresponds to a data sample # one object corresponds to a data sample
self.raw_title = title.split() self.raw_title = title.split()
self.raw_ent_text = [at_least(x.split()) for x in ent_text] self.raw_ent_text = [at_least(x.split()) for x in ent_text]
assert min([len(x) for x in self.raw_ent_text])>0, str(self.raw_ent_text) assert min([len(x) for x in self.raw_ent_text]) > 0, str(
self.raw_ent_type = ent_type.split() # <method> .. <> self.raw_ent_text
self.raw_rel = [] )
self.raw_ent_type = ent_type.split() # <method> .. <>
self.raw_rel = []
for r in rel: for r in rel:
rel_list = r.split() rel_list = r.split()
for i in range(len(rel_list)): for i in range(len(rel_list)):
if i>0 and i<len(rel_list)-1 and rel_list[i-1]=='--' and rel_list[i]!=rel_list[i].lower() and rel_list[i+1]=='--': if (
self.raw_rel.append([rel_list[:i-1], rel_list[i-1]+rel_list[i]+rel_list[i+1], rel_list[i+2:]]) i > 0
and i < len(rel_list) - 1
and rel_list[i - 1] == "--"
and rel_list[i] != rel_list[i].lower()
and rel_list[i + 1] == "--"
):
self.raw_rel.append(
[
rel_list[: i - 1],
rel_list[i - 1] + rel_list[i] + rel_list[i + 1],
rel_list[i + 2 :],
]
)
break break
self.raw_text = text.split() self.raw_text = text.split()
self.graph = self.build_graph() self.graph = self.build_graph()
def __str__(self): def __str__(self):
return '\n'.join([str(k)+':\t'+str(v) for k, v in self.__dict__.items()]) return "\n".join(
[str(k) + ":\t" + str(v) for k, v in self.__dict__.items()]
)
def __len__(self): def __len__(self):
return len(self.raw_text) return len(self.raw_text)
@staticmethod @staticmethod
def from_json(json_data): def from_json(json_data):
return Example(json_data['title'], json_data['entities'], json_data['types'], json_data['relations'], json_data['abstract']) return Example(
json_data["title"],
json_data["entities"],
json_data["types"],
json_data["relations"],
json_data["abstract"],
)
def build_graph(self): def build_graph(self):
graph = dgl.DGLGraph() graph = dgl.DGLGraph()
ent_len = len(self.raw_ent_text) ent_len = len(self.raw_ent_text)
rel_len = len(self.raw_rel) # treat the repeated relation as different nodes, refer to the author's code rel_len = len(
self.raw_rel
graph.add_nodes(ent_len, {'type': torch.ones(ent_len) * NODE_TYPE['entity']}) ) # treat the repeated relation as different nodes, refer to the author's code
graph.add_nodes(1, {'type': torch.ones(1) * NODE_TYPE['root']})
graph.add_nodes(rel_len*2, {'type': torch.ones(rel_len*2) * NODE_TYPE['relation']}) graph.add_nodes(
ent_len, {"type": torch.ones(ent_len) * NODE_TYPE["entity"]}
)
graph.add_nodes(1, {"type": torch.ones(1) * NODE_TYPE["root"]})
graph.add_nodes(
rel_len * 2,
{"type": torch.ones(rel_len * 2) * NODE_TYPE["relation"]},
)
graph.add_edges(ent_len, torch.arange(ent_len)) graph.add_edges(ent_len, torch.arange(ent_len))
graph.add_edges(torch.arange(ent_len), ent_len) graph.add_edges(torch.arange(ent_len), ent_len)
graph.add_edges(torch.arange(ent_len+1+rel_len*2), torch.arange(ent_len+1+rel_len*2)) graph.add_edges(
torch.arange(ent_len + 1 + rel_len * 2),
torch.arange(ent_len + 1 + rel_len * 2),
)
adj_edges = [] adj_edges = []
for i, r in enumerate(self.raw_rel): for i, r in enumerate(self.raw_rel):
assert len(r)==3, str(r) assert len(r) == 3, str(r)
st, rt, ed = r st, rt, ed = r
st_ent, ed_ent = self.raw_ent_text.index(st), self.raw_ent_text.index(ed) st_ent, ed_ent = self.raw_ent_text.index(
st
), self.raw_ent_text.index(ed)
# according to the edge_softmax operator, we need to reverse the graph # according to the edge_softmax operator, we need to reverse the graph
adj_edges.append([ent_len+1+2*i, st_ent]) adj_edges.append([ent_len + 1 + 2 * i, st_ent])
adj_edges.append([ed_ent, ent_len+1+2*i]) adj_edges.append([ed_ent, ent_len + 1 + 2 * i])
adj_edges.append([ent_len+1+2*i+1, ed_ent]) adj_edges.append([ent_len + 1 + 2 * i + 1, ed_ent])
adj_edges.append([st_ent, ent_len+1+2*i+1]) adj_edges.append([st_ent, ent_len + 1 + 2 * i + 1])
if len(adj_edges)>0: if len(adj_edges) > 0:
graph.add_edges(*list(map(list, zip(*adj_edges)))) graph.add_edges(*list(map(list, zip(*adj_edges))))
return graph return graph
def get_tensor(self, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab): def get_tensor(
if hasattr(self, '_cached_tensor'): self, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab
):
if hasattr(self, "_cached_tensor"):
return self._cached_tensor return self._cached_tensor
else: else:
title_data = ['<BOS>'] + self.raw_title + ['<EOS>'] title_data = ["<BOS>"] + self.raw_title + ["<EOS>"]
title = [title_vocab(x) for x in title_data] title = [title_vocab(x) for x in title_data]
ent_text = [[ent_text_vocab(y) for y in x] for x in self.raw_ent_text] ent_text = [
ent_type = [text_vocab(x) for x in self.raw_ent_type] # for inference [ent_text_vocab(y) for y in x] for x in self.raw_ent_text
rel_data = ['--root--'] + sum([[x[1],x[1]+'_INV'] for x in self.raw_rel], []) ]
ent_type = [
text_vocab(x) for x in self.raw_ent_type
] # for inference
rel_data = ["--root--"] + sum(
[[x[1], x[1] + "_INV"] for x in self.raw_rel], []
)
rel = [rel_vocab(x) for x in rel_data] rel = [rel_vocab(x) for x in rel_data]
text_data = ['<BOS>'] + self.raw_text + ['<EOS>'] text_data = ["<BOS>"] + self.raw_text + ["<EOS>"]
text = [text_vocab(x) for x in text_data] text = [text_vocab(x) for x in text_data]
tgt_text = [] tgt_text = []
# the input text and decoding target are different since the consideration of the copy mechanism. # the input text and decoding target are different since the consideration of the copy mechanism.
for i, str1 in enumerate(text_data): for i, str1 in enumerate(text_data):
if str1[0]=='<' and str1[-1]=='>' and '_' in str1: if str1[0] == "<" and str1[-1] == ">" and "_" in str1:
a, b = str1[1:-1].split('_') a, b = str1[1:-1].split("_")
text[i] = text_vocab('<'+a+'>') text[i] = text_vocab("<" + a + ">")
tgt_text.append(len(text_vocab)+int(b)) tgt_text.append(len(text_vocab) + int(b))
else: else:
tgt_text.append(text[i]) tgt_text.append(text[i])
self._cached_tensor = {'title': torch.LongTensor(title), 'ent_text': [torch.LongTensor(x) for x in ent_text], \ self._cached_tensor = {
'ent_type': torch.LongTensor(ent_type), 'rel': torch.LongTensor(rel), \ "title": torch.LongTensor(title),
'text': torch.LongTensor(text[:-1]), 'tgt_text': torch.LongTensor(tgt_text[1:]), 'graph': self.graph, 'raw_ent_text': self.raw_ent_text} "ent_text": [torch.LongTensor(x) for x in ent_text],
"ent_type": torch.LongTensor(ent_type),
"rel": torch.LongTensor(rel),
"text": torch.LongTensor(text[:-1]),
"tgt_text": torch.LongTensor(tgt_text[1:]),
"graph": self.graph,
"raw_ent_text": self.raw_ent_text,
}
return self._cached_tensor return self._cached_tensor
def update_vocab(self, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab): def update_vocab(
self, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab
):
ent_vocab.update(self.raw_ent_type) ent_vocab.update(self.raw_ent_type)
ent_text_vocab.update(self.raw_ent_text) ent_text_vocab.update(self.raw_ent_text)
title_vocab.update(self.raw_title) title_vocab.update(self.raw_title)
rel_vocab.update(['--root--']+[x[1] for x in self.raw_rel]+[x[1]+'_INV' for x in self.raw_rel]) rel_vocab.update(
["--root--"]
+ [x[1] for x in self.raw_rel]
+ [x[1] + "_INV" for x in self.raw_rel]
)
text_vocab.update(self.raw_ent_type) text_vocab.update(self.raw_ent_type)
text_vocab.update(self.raw_text) text_vocab.update(self.raw_text)
class BucketSampler(torch.utils.data.Sampler): class BucketSampler(torch.utils.data.Sampler):
def __init__(self, data_source, batch_size=32, bucket=3): def __init__(self, data_source, batch_size=32, bucket=3):
self.data_source = data_source self.data_source = data_source
...@@ -217,37 +324,64 @@ class BucketSampler(torch.utils.data.Sampler): ...@@ -217,37 +324,64 @@ class BucketSampler(torch.utils.data.Sampler):
t2 = [] t2 = []
t3 = [] t3 = []
for i, l in enumerate(lens): for i, l in enumerate(lens):
if (l<100): if l < 100:
t1.append(perm[i]) t1.append(perm[i])
elif (l>100 and l<220): elif l > 100 and l < 220:
t2.append(perm[i]) t2.append(perm[i])
else: else:
t3.append(perm[i]) t3.append(perm[i])
datas = [t1,t2,t3] datas = [t1, t2, t3]
random.shuffle(datas) random.shuffle(datas)
idxs = sum(datas, []) idxs = sum(datas, [])
batch = [] batch = []
lens = torch.Tensor([len(x) for x in self.data_source]) lens = torch.Tensor([len(x) for x in self.data_source])
for idx in idxs: for idx in idxs:
batch.append(idx) batch.append(idx)
mlen = max([0]+[lens[x] for x in batch]) mlen = max([0] + [lens[x] for x in batch])
if (mlen<100 and len(batch) == 32) or (mlen>100 and mlen<220 and len(batch) >= 24) or (mlen>220 and len(batch)>=8) or len(batch)==32: if (
(mlen < 100 and len(batch) == 32)
or (mlen > 100 and mlen < 220 and len(batch) >= 24)
or (mlen > 220 and len(batch) >= 8)
or len(batch) == 32
):
yield batch yield batch
batch = [] batch = []
if len(batch) > 0: if len(batch) > 0:
yield batch yield batch
def __len__(self): def __len__(self):
return (len(self.data_source)+self.batch_size-1)//self.batch_size return (len(self.data_source) + self.batch_size - 1) // self.batch_size
class GWdataset(torch.utils.data.Dataset): class GWdataset(torch.utils.data.Dataset):
def __init__(self, exs, ent_vocab=None, rel_vocab=None, text_vocab=None, ent_text_vocab=None, title_vocab=None, device=None): def __init__(
self,
exs,
ent_vocab=None,
rel_vocab=None,
text_vocab=None,
ent_text_vocab=None,
title_vocab=None,
device=None,
):
super(GWdataset, self).__init__() super(GWdataset, self).__init__()
self.exs = exs self.exs = exs
self.ent_vocab, self.rel_vocab, self.text_vocab, self.ent_text_vocab, self.title_vocab, self.device = \ (
ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab, device self.ent_vocab,
self.rel_vocab,
self.text_vocab,
self.ent_text_vocab,
self.title_vocab,
self.device,
) = (
ent_vocab,
rel_vocab,
text_vocab,
ent_text_vocab,
title_vocab,
device,
)
def __iter__(self): def __iter__(self):
return iter(self.exs) return iter(self.exs)
...@@ -259,41 +393,71 @@ class GWdataset(torch.utils.data.Dataset): ...@@ -259,41 +393,71 @@ class GWdataset(torch.utils.data.Dataset):
return len(self.exs) return len(self.exs)
def batch_fn(self, batch_ex): def batch_fn(self, batch_ex):
batch_title, batch_ent_text, batch_ent_type, batch_rel, batch_text, batch_tgt_text, batch_graph = \ (
[], [], [], [], [], [], [] batch_title,
batch_ent_text,
batch_ent_type,
batch_rel,
batch_text,
batch_tgt_text,
batch_graph,
) = ([], [], [], [], [], [], [])
batch_raw_ent_text = [] batch_raw_ent_text = []
for ex in batch_ex: for ex in batch_ex:
ex_data = ex.get_tensor(self.ent_vocab, self.rel_vocab, self.text_vocab, self.ent_text_vocab, self.title_vocab) ex_data = ex.get_tensor(
batch_title.append(ex_data['title']) self.ent_vocab,
batch_ent_text.append(ex_data['ent_text']) self.rel_vocab,
batch_ent_type.append(ex_data['ent_type']) self.text_vocab,
batch_rel.append(ex_data['rel']) self.ent_text_vocab,
batch_text.append(ex_data['text']) self.title_vocab,
batch_tgt_text.append(ex_data['tgt_text']) )
batch_graph.append(ex_data['graph']) batch_title.append(ex_data["title"])
batch_raw_ent_text.append(ex_data['raw_ent_text']) batch_ent_text.append(ex_data["ent_text"])
batch_title = pad(batch_title, out_type='tensor') batch_ent_type.append(ex_data["ent_type"])
batch_ent_text, ent_len = pad(batch_ent_text, out_type='tensor', flatten=True) batch_rel.append(ex_data["rel"])
batch_ent_type = pad(batch_ent_type, out_type='tensor') batch_text.append(ex_data["text"])
batch_rel = pad(batch_rel, out_type='tensor') batch_tgt_text.append(ex_data["tgt_text"])
batch_text = pad(batch_text, out_type='tensor') batch_graph.append(ex_data["graph"])
batch_tgt_text = pad(batch_tgt_text, out_type='tensor') batch_raw_ent_text.append(ex_data["raw_ent_text"])
batch_title = pad(batch_title, out_type="tensor")
batch_ent_text, ent_len = pad(
batch_ent_text, out_type="tensor", flatten=True
)
batch_ent_type = pad(batch_ent_type, out_type="tensor")
batch_rel = pad(batch_rel, out_type="tensor")
batch_text = pad(batch_text, out_type="tensor")
batch_tgt_text = pad(batch_tgt_text, out_type="tensor")
batch_graph = dgl.batch(batch_graph) batch_graph = dgl.batch(batch_graph)
batch_graph.to(self.device) batch_graph.to(self.device)
return {'title': batch_title.to(self.device), 'ent_text': batch_ent_text.to(self.device), 'ent_len': ent_len, \ return {
'ent_type': batch_ent_type.to(self.device), 'rel': batch_rel.to(self.device), 'text': batch_text.to(self.device), \ "title": batch_title.to(self.device),
'tgt_text': batch_tgt_text.to(self.device), 'graph': batch_graph, 'raw_ent_text': batch_raw_ent_text} "ent_text": batch_ent_text.to(self.device),
"ent_len": ent_len,
"ent_type": batch_ent_type.to(self.device),
def get_datasets(fnames, min_freq=-1, sep=';', joint_vocab=True, device=None, save='tmp.pickle'): "rel": batch_rel.to(self.device),
"text": batch_text.to(self.device),
"tgt_text": batch_tgt_text.to(self.device),
"graph": batch_graph,
"raw_ent_text": batch_raw_ent_text,
}
def get_datasets(
fnames,
min_freq=-1,
sep=";",
joint_vocab=True,
device=None,
save="tmp.pickle",
):
# min_freq : not support now since it's very sensitive to the final results, but you can set it via passing min_freq to the Vocab class. # min_freq : not support now since it's very sensitive to the final results, but you can set it via passing min_freq to the Vocab class.
# sep : not support now # sep : not support now
# joint_vocab : not support now # joint_vocab : not support now
ent_vocab = Vocab(sp=['<PAD>', '<UNK>']) ent_vocab = Vocab(sp=["<PAD>", "<UNK>"])
title_vocab = Vocab(min_freq=5) title_vocab = Vocab(min_freq=5)
rel_vocab = Vocab(sp=['<PAD>', '<UNK>']) rel_vocab = Vocab(sp=["<PAD>", "<UNK>"])
text_vocab = Vocab(min_freq=5) text_vocab = Vocab(min_freq=5)
ent_text_vocab = Vocab(sp=['<PAD>', '<UNK>']) ent_text_vocab = Vocab(sp=["<PAD>", "<UNK>"])
datasets = [] datasets = []
for fname in fnames: for fname in fnames:
exs = [] exs = []
...@@ -301,8 +465,14 @@ def get_datasets(fnames, min_freq=-1, sep=';', joint_vocab=True, device=None, sa ...@@ -301,8 +465,14 @@ def get_datasets(fnames, min_freq=-1, sep=';', joint_vocab=True, device=None, sa
for json_data in json_datas: for json_data in json_datas:
# construct one data example # construct one data example
ex = Example.from_json(json_data) ex = Example.from_json(json_data)
if fname == fnames[0]: # only training set if fname == fnames[0]: # only training set
ex.update_vocab(ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab) ex.update_vocab(
ent_vocab,
rel_vocab,
text_vocab,
ent_text_vocab,
title_vocab,
)
exs.append(ex) exs.append(ex)
datasets.append(exs) datasets.append(exs)
ent_vocab.build() ent_vocab.build()
...@@ -310,14 +480,40 @@ def get_datasets(fnames, min_freq=-1, sep=';', joint_vocab=True, device=None, sa ...@@ -310,14 +480,40 @@ def get_datasets(fnames, min_freq=-1, sep=';', joint_vocab=True, device=None, sa
text_vocab.build() text_vocab.build()
ent_text_vocab.build() ent_text_vocab.build()
title_vocab.build() title_vocab.build()
datasets = [GWdataset(exs, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab, device) for exs in datasets] datasets = [
with open(save, 'wb') as f: GWdataset(
exs,
ent_vocab,
rel_vocab,
text_vocab,
ent_text_vocab,
title_vocab,
device,
)
for exs in datasets
]
with open(save, "wb") as f:
pickle.dump(datasets, f) pickle.dump(datasets, f)
return datasets return datasets
if __name__ == '__main__' : if __name__ == "__main__":
ds = get_datasets(['data/unprocessed.val.json', 'data/unprocessed.val.json', 'data/unprocessed.test.json']) ds = get_datasets(
[
"data/unprocessed.val.json",
"data/unprocessed.val.json",
"data/unprocessed.test.json",
]
)
print(ds[0].exs[0]) print(ds[0].exs[0])
print(ds[0].exs[0].get_tensor(ds[0].ent_vocab, ds[0].rel_vocab, ds[0].text_vocab, ds[0].ent_text_vocab, ds[0].title_vocab)) print(
ds[0]
.exs[0]
.get_tensor(
ds[0].ent_vocab,
ds[0].rel_vocab,
ds[0].text_vocab,
ds[0].ent_text_vocab,
ds[0].title_vocab,
)
)
import json
import logging
import os import os
import sys import sys
import logging
import torch
import numpy as np import numpy as np
import torch
from dgl.data import LegacyTUDataset from dgl.data import LegacyTUDataset
import json
def _load_check_mark(path:str): def _load_check_mark(path: str):
if os.path.exists(path): if os.path.exists(path):
with open(path, 'r') as f: with open(path, "r") as f:
return json.load(f) return json.load(f)
else: else:
return {} return {}
def _save_check_mark(path:str, marks:dict):
with open(path, 'w') as f: def _save_check_mark(path: str, marks: dict):
with open(path, "w") as f:
json.dump(marks, f) json.dump(marks, f)
def node_label_as_feature(dataset:LegacyTUDataset, mode="concat", save=True): def node_label_as_feature(dataset: LegacyTUDataset, mode="concat", save=True):
""" """
Description Description
----------- -----------
...@@ -41,52 +44,63 @@ def node_label_as_feature(dataset:LegacyTUDataset, mode="concat", save=True): ...@@ -41,52 +44,63 @@ def node_label_as_feature(dataset:LegacyTUDataset, mode="concat", save=True):
Default: :obj:`True` Default: :obj:`True`
""" """
# check if node label is not available # check if node label is not available
if not os.path.exists(dataset._file_path("node_labels")) or len(dataset) == 0: if (
not os.path.exists(dataset._file_path("node_labels"))
or len(dataset) == 0
):
logging.warning("No Node Label Data") logging.warning("No Node Label Data")
return dataset return dataset
# check if has cached value # check if has cached value
check_mark_name = "node_label_as_feature" check_mark_name = "node_label_as_feature"
check_mark_path = os.path.join( check_mark_path = os.path.join(
dataset.save_path, "info_{}_{}.json".format(dataset.name, dataset.hash)) dataset.save_path, "info_{}_{}.json".format(dataset.name, dataset.hash)
)
check_mark = _load_check_mark(check_mark_path) check_mark = _load_check_mark(check_mark_path)
if check_mark_name in check_mark \ if (
and check_mark[check_mark_name] \ check_mark_name in check_mark
and not dataset._force_reload: and check_mark[check_mark_name]
and not dataset._force_reload
):
logging.warning("Using cached value in node_label_as_feature") logging.warning("Using cached value in node_label_as_feature")
return dataset return dataset
logging.warning("Adding node labels into node features..., mode={}".format(mode)) logging.warning(
"Adding node labels into node features..., mode={}".format(mode)
)
# check if graph has "feat" # check if graph has "feat"
if "feat" not in dataset[0][0].ndata: if "feat" not in dataset[0][0].ndata:
logging.warning("Dataset has no node feature 'feat'") logging.warning("Dataset has no node feature 'feat'")
if mode.lower() == "concat": if mode.lower() == "concat":
mode = "replace" mode = "replace"
# first read node labels # first read node labels
DS_node_labels = dataset._idx_from_zero( DS_node_labels = dataset._idx_from_zero(
np.loadtxt(dataset._file_path("node_labels"), dtype=int)) np.loadtxt(dataset._file_path("node_labels"), dtype=int)
)
one_hot_node_labels = dataset._to_onehot(DS_node_labels) one_hot_node_labels = dataset._to_onehot(DS_node_labels)
# read graph idx # read graph idx
DS_indicator = dataset._idx_from_zero( DS_indicator = dataset._idx_from_zero(
np.genfromtxt(dataset._file_path("graph_indicator"), dtype=int)) np.genfromtxt(dataset._file_path("graph_indicator"), dtype=int)
)
node_idx_list = [] node_idx_list = []
for idx in range(np.max(DS_indicator) + 1): for idx in range(np.max(DS_indicator) + 1):
node_idx = np.where(DS_indicator == idx) node_idx = np.where(DS_indicator == idx)
node_idx_list.append(node_idx[0]) node_idx_list.append(node_idx[0])
# add to node feature dict # add to node feature dict
for idx, g in zip(node_idx_list, dataset.graph_lists): for idx, g in zip(node_idx_list, dataset.graph_lists):
node_labels_tensor = torch.tensor(one_hot_node_labels[idx, :]) node_labels_tensor = torch.tensor(one_hot_node_labels[idx, :])
if mode.lower() == "concat": if mode.lower() == "concat":
g.ndata["feat"] = torch.cat( g.ndata["feat"] = torch.cat(
(g.ndata["feat"], node_labels_tensor), dim=1) (g.ndata["feat"], node_labels_tensor), dim=1
)
elif mode.lower() == "add": elif mode.lower() == "add":
g.ndata["node_label"] = node_labels_tensor g.ndata["node_label"] = node_labels_tensor
else: # replace else: # replace
g.ndata["feat"] = node_labels_tensor g.ndata["feat"] = node_labels_tensor
if save: if save:
check_mark[check_mark_name] = True check_mark[check_mark_name] = True
_save_check_mark(check_mark_path, check_mark) _save_check_mark(check_mark_path, check_mark)
...@@ -94,7 +108,7 @@ def node_label_as_feature(dataset:LegacyTUDataset, mode="concat", save=True): ...@@ -94,7 +108,7 @@ def node_label_as_feature(dataset:LegacyTUDataset, mode="concat", save=True):
return dataset return dataset
def degree_as_feature(dataset:LegacyTUDataset, save=True): def degree_as_feature(dataset: LegacyTUDataset, save=True):
""" """
Description Description
----------- -----------
...@@ -113,12 +127,15 @@ def degree_as_feature(dataset:LegacyTUDataset, save=True): ...@@ -113,12 +127,15 @@ def degree_as_feature(dataset:LegacyTUDataset, save=True):
check_mark_name = "degree_as_feat" check_mark_name = "degree_as_feat"
feat_name = "feat" feat_name = "feat"
check_mark_path = os.path.join( check_mark_path = os.path.join(
dataset.save_path, "info_{}_{}.json".format(dataset.name, dataset.hash)) dataset.save_path, "info_{}_{}.json".format(dataset.name, dataset.hash)
)
check_mark = _load_check_mark(check_mark_path) check_mark = _load_check_mark(check_mark_path)
if check_mark_name in check_mark \ if (
and check_mark[check_mark_name] \ check_mark_name in check_mark
and not dataset._force_reload: and check_mark[check_mark_name]
and not dataset._force_reload
):
logging.warning("Using cached value in 'degree_as_feature'") logging.warning("Using cached value in 'degree_as_feature'")
return dataset return dataset
...@@ -129,13 +146,13 @@ def degree_as_feature(dataset:LegacyTUDataset, save=True): ...@@ -129,13 +146,13 @@ def degree_as_feature(dataset:LegacyTUDataset, save=True):
degrees = dataset.graph_lists[i].in_degrees() degrees = dataset.graph_lists[i].in_degrees()
min_degree = min(min_degree, degrees.min().item()) min_degree = min(min_degree, degrees.min().item())
max_degree = max(max_degree, degrees.max().item()) max_degree = max(max_degree, degrees.max().item())
vec_len = max_degree - min_degree + 1 vec_len = max_degree - min_degree + 1
for i in range(len(dataset)): for i in range(len(dataset)):
num_nodes = dataset.graph_lists[i].num_nodes() num_nodes = dataset.graph_lists[i].num_nodes()
node_feat = torch.zeros((num_nodes, vec_len)) node_feat = torch.zeros((num_nodes, vec_len))
degrees = dataset.graph_lists[i].in_degrees() degrees = dataset.graph_lists[i].in_degrees()
node_feat[torch.arange(num_nodes), degrees - min_degree] = 1. node_feat[torch.arange(num_nodes), degrees - min_degree] = 1.0
dataset.graph_lists[i].ndata[feat_name] = node_feat dataset.graph_lists[i].ndata[feat_name] = node_feat
if save: if save:
......
from typing import Optional from typing import Optional
import dgl
import torch import torch
import torch.nn import torch.nn
from torch import Tensor
import dgl
from dgl import DGLGraph from dgl import DGLGraph
from dgl.nn import GraphConv from dgl.nn import GraphConv
from torch import Tensor
class GraphConvWithDropout(GraphConv): class GraphConvWithDropout(GraphConv):
""" """
A GraphConv followed by a Dropout. A GraphConv followed by a Dropout.
""" """
def __init__(self, in_feats, out_feats, dropout=0.3, norm='both', weight=True,
bias=True, activation=None, allow_zero_in_degree=False): def __init__(
super(GraphConvWithDropout, self).__init__(in_feats, out_feats, self,
norm, weight, bias, in_feats,
activation, out_feats,
allow_zero_in_degree) dropout=0.3,
norm="both",
weight=True,
bias=True,
activation=None,
allow_zero_in_degree=False,
):
super(GraphConvWithDropout, self).__init__(
in_feats,
out_feats,
norm,
weight,
bias,
activation,
allow_zero_in_degree,
)
self.dropout = torch.nn.Dropout(p=dropout) self.dropout = torch.nn.Dropout(p=dropout)
def call(self, graph, feat, weight=None): def call(self, graph, feat, weight=None):
...@@ -30,7 +46,7 @@ class Discriminator(torch.nn.Module): ...@@ -30,7 +46,7 @@ class Discriminator(torch.nn.Module):
Description Description
----------- -----------
A discriminator used to let the network to discrimate A discriminator used to let the network to discrimate
between positive (neighborhood of center node) and between positive (neighborhood of center node) and
negative (any neighborhood in graph) samplings. negative (any neighborhood in graph) samplings.
Parameters Parameters
...@@ -38,18 +54,24 @@ class Discriminator(torch.nn.Module): ...@@ -38,18 +54,24 @@ class Discriminator(torch.nn.Module):
feat_dim : int feat_dim : int
The number of channels of node features. The number of channels of node features.
""" """
def __init__(self, feat_dim:int):
def __init__(self, feat_dim: int):
super(Discriminator, self).__init__() super(Discriminator, self).__init__()
self.affine = torch.nn.Bilinear(feat_dim, feat_dim, 1) self.affine = torch.nn.Bilinear(feat_dim, feat_dim, 1)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.affine.weight) torch.nn.init.xavier_uniform_(self.affine.weight)
torch.nn.init.zeros_(self.affine.bias) torch.nn.init.zeros_(self.affine.bias)
def forward(self, h_x:Tensor, h_pos:Tensor, def forward(
h_neg:Tensor, bias_pos:Optional[Tensor]=None, self,
bias_neg:Optional[Tensor]=None): h_x: Tensor,
h_pos: Tensor,
h_neg: Tensor,
bias_pos: Optional[Tensor] = None,
bias_neg: Optional[Tensor] = None,
):
""" """
Parameters Parameters
---------- ----------
...@@ -79,9 +101,9 @@ class Discriminator(torch.nn.Module): ...@@ -79,9 +101,9 @@ class Discriminator(torch.nn.Module):
score_pos = score_pos + bias_pos score_pos = score_pos + bias_pos
if bias_neg is not None: if bias_neg is not None:
score_neg = score_neg + bias_neg score_neg = score_neg + bias_neg
logits = torch.cat((score_pos, score_neg), 0) logits = torch.cat((score_pos, score_neg), 0)
return logits, score_pos return logits, score_pos
...@@ -91,13 +113,15 @@ class DenseLayer(torch.nn.Module): ...@@ -91,13 +113,15 @@ class DenseLayer(torch.nn.Module):
----------- -----------
Dense layer with a linear layer and an activation function Dense layer with a linear layer and an activation function
""" """
def __init__(self, in_dim:int, out_dim:int,
act:str="prelu", bias=True): def __init__(
self, in_dim: int, out_dim: int, act: str = "prelu", bias=True
):
super(DenseLayer, self).__init__() super(DenseLayer, self).__init__()
self.lin = torch.nn.Linear(in_dim, out_dim, bias=bias) self.lin = torch.nn.Linear(in_dim, out_dim, bias=bias)
self.act_type = act.lower() self.act_type = act.lower()
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
torch.nn.init.xavier_uniform_(self.lin.weight) torch.nn.init.xavier_uniform_(self.lin.weight)
if self.lin.bias is not None: if self.lin.bias is not None:
...@@ -121,7 +145,7 @@ class IndexSelect(torch.nn.Module): ...@@ -121,7 +145,7 @@ class IndexSelect(torch.nn.Module):
Parameters Parameters
---------- ----------
pool_ratio : float pool_ratio : float
The pooling ratio (for keeping nodes). For example, The pooling ratio (for keeping nodes). For example,
if `pool_ratio=0.8`, 80\% nodes will be preserved. if `pool_ratio=0.8`, 80\% nodes will be preserved.
hidden_dim : int hidden_dim : int
The number of channels in node features. The number of channels in node features.
...@@ -131,8 +155,14 @@ class IndexSelect(torch.nn.Module): ...@@ -131,8 +155,14 @@ class IndexSelect(torch.nn.Module):
dist : int, optional dist : int, optional
DO NOT USE THIS PARAMETER DO NOT USE THIS PARAMETER
""" """
def __init__(self, pool_ratio:float, hidden_dim:int,
act:str="prelu", dist:int=1): def __init__(
self,
pool_ratio: float,
hidden_dim: int,
act: str = "prelu",
dist: int = 1,
):
super(IndexSelect, self).__init__() super(IndexSelect, self).__init__()
self.pool_ratio = pool_ratio self.pool_ratio = pool_ratio
self.dist = dist self.dist = dist
...@@ -140,9 +170,14 @@ class IndexSelect(torch.nn.Module): ...@@ -140,9 +170,14 @@ class IndexSelect(torch.nn.Module):
self.discriminator = Discriminator(hidden_dim) self.discriminator = Discriminator(hidden_dim)
self.gcn = GraphConvWithDropout(hidden_dim, hidden_dim) self.gcn = GraphConvWithDropout(hidden_dim, hidden_dim)
def forward(self, graph:DGLGraph, h_pos:Tensor, def forward(
h_neg:Tensor, bias_pos:Optional[Tensor]=None, self,
bias_neg:Optional[Tensor]=None): graph: DGLGraph,
h_pos: Tensor,
h_neg: Tensor,
bias_pos: Optional[Tensor] = None,
bias_neg: Optional[Tensor] = None,
):
""" """
Description Description
----------- -----------
...@@ -171,11 +206,11 @@ class IndexSelect(torch.nn.Module): ...@@ -171,11 +206,11 @@ class IndexSelect(torch.nn.Module):
embed = self.gcn(graph, h_pos) embed = self.gcn(graph, h_pos)
h_center = torch.sigmoid(embed) h_center = torch.sigmoid(embed)
logit, logit_pos = self.discriminator(h_center, h_pos, logit, logit_pos = self.discriminator(
h_neg, bias_pos, h_center, h_pos, h_neg, bias_pos, bias_neg
bias_neg) )
scores = torch.sigmoid(logit_pos) scores = torch.sigmoid(logit_pos)
# sort scores # sort scores
scores, idx = torch.sort(scores, descending=True) scores, idx = torch.sort(scores, descending=True)
...@@ -203,15 +238,23 @@ class GraphPool(torch.nn.Module): ...@@ -203,15 +238,23 @@ class GraphPool(torch.nn.Module):
Whether use gcn in down sampling process. Whether use gcn in down sampling process.
default: :obj:`False` default: :obj:`False`
""" """
def __init__(self, hidden_dim:int, use_gcn=False):
def __init__(self, hidden_dim: int, use_gcn=False):
super(GraphPool, self).__init__() super(GraphPool, self).__init__()
self.use_gcn = use_gcn self.use_gcn = use_gcn
self.down_sample_gcn = GraphConvWithDropout(hidden_dim, hidden_dim) \ self.down_sample_gcn = (
if use_gcn else None GraphConvWithDropout(hidden_dim, hidden_dim) if use_gcn else None
)
def forward(self, graph:DGLGraph, feat:Tensor,
select_idx:Tensor, non_select_idx:Optional[Tensor]=None, def forward(
scores:Optional[Tensor]=None, pool_graph=False): self,
graph: DGLGraph,
feat: Tensor,
select_idx: Tensor,
non_select_idx: Optional[Tensor] = None,
scores: Optional[Tensor] = None,
pool_graph=False,
):
""" """
Description Description
----------- -----------
...@@ -226,7 +269,7 @@ class GraphPool(torch.nn.Module): ...@@ -226,7 +269,7 @@ class GraphPool(torch.nn.Module):
select_idx : torch.Tensor select_idx : torch.Tensor
The index in fine graph of node from The index in fine graph of node from
coarse graph, this is obtained from coarse graph, this is obtained from
previous graph pooling layers. previous graph pooling layers.
non_select_idx : torch.Tensor, optional non_select_idx : torch.Tensor, optional
The index that not included in output graph. The index that not included in output graph.
default: :obj:`None` default: :obj:`None`
...@@ -239,7 +282,7 @@ class GraphPool(torch.nn.Module): ...@@ -239,7 +282,7 @@ class GraphPool(torch.nn.Module):
""" """
if self.use_gcn: if self.use_gcn:
feat = self.down_sample_gcn(graph, feat) feat = self.down_sample_gcn(graph, feat)
feat = feat[select_idx] feat = feat[select_idx]
if scores is not None: if scores is not None:
feat = feat * scores.unsqueeze(-1) feat = feat * scores.unsqueeze(-1)
...@@ -264,12 +307,12 @@ class GraphUnpool(torch.nn.Module): ...@@ -264,12 +307,12 @@ class GraphUnpool(torch.nn.Module):
hidden_dim : int hidden_dim : int
The number of channels of node features. The number of channels of node features.
""" """
def __init__(self, hidden_dim:int):
def __init__(self, hidden_dim: int):
super(GraphUnpool, self).__init__() super(GraphUnpool, self).__init__()
self.up_sample_gcn = GraphConvWithDropout(hidden_dim, hidden_dim) self.up_sample_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)
def forward(self, graph:DGLGraph, def forward(self, graph: DGLGraph, feat: Tensor, select_idx: Tensor):
feat:Tensor, select_idx:Tensor):
""" """
Description Description
----------- -----------
...@@ -286,8 +329,9 @@ class GraphUnpool(torch.nn.Module): ...@@ -286,8 +329,9 @@ class GraphUnpool(torch.nn.Module):
coarse graph, this is obtained from coarse graph, this is obtained from
previous graph pooling layers. previous graph pooling layers.
""" """
fine_feat = torch.zeros((graph.num_nodes(), feat.size(-1)), fine_feat = torch.zeros(
device=feat.device) (graph.num_nodes(), feat.size(-1)), device=feat.device
)
fine_feat[select_idx] = feat fine_feat[select_idx] = feat
fine_feat = self.up_sample_gcn(graph, fine_feat) fine_feat = self.up_sample_gcn(graph, fine_feat)
return fine_feat return fine_feat
...@@ -3,22 +3,28 @@ import os ...@@ -3,22 +3,28 @@ import os
from datetime import datetime from datetime import datetime
from time import time from time import time
import dgl
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
from torch import Tensor
from torch.utils.data import random_split
from data_preprocess import degree_as_feature, node_label_as_feature from data_preprocess import degree_as_feature, node_label_as_feature
from networks import GraphClassifier from networks import GraphClassifier
from torch import Tensor
from torch.utils.data import random_split
from utils import get_stats, parse_args from utils import get_stats, parse_args
import dgl
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
def compute_loss(cls_logits:Tensor, labels:Tensor, def compute_loss(
logits_s1:Tensor, logits_s2:Tensor, cls_logits: Tensor,
epoch:int, total_epochs:int, device:torch.device): labels: Tensor,
logits_s1: Tensor,
logits_s2: Tensor,
epoch: int,
total_epochs: int,
device: torch.device,
):
# classification loss # classification loss
classify_loss = F.nll_loss(cls_logits, labels.to(device)) classify_loss = F.nll_loss(cls_logits, labels.to(device))
...@@ -32,17 +38,23 @@ def compute_loss(cls_logits:Tensor, labels:Tensor, ...@@ -32,17 +38,23 @@ def compute_loss(cls_logits:Tensor, labels:Tensor,
pool_loss_s1 = F.binary_cross_entropy_with_logits(logits_s1, s1_label) pool_loss_s1 = F.binary_cross_entropy_with_logits(logits_s1, s1_label)
pool_loss_s2 = F.binary_cross_entropy_with_logits(logits_s2, s2_label) pool_loss_s2 = F.binary_cross_entropy_with_logits(logits_s2, s2_label)
pool_loss = (pool_loss_s1 + pool_loss_s2) / 2 pool_loss = (pool_loss_s1 + pool_loss_s2) / 2
loss = classify_loss + (2 - epoch / total_epochs) * pool_loss loss = classify_loss + (2 - epoch / total_epochs) * pool_loss
return loss return loss
def train(model:torch.nn.Module, optimizer, trainloader, def train(
device, curr_epoch, total_epochs): model: torch.nn.Module,
optimizer,
trainloader,
device,
curr_epoch,
total_epochs,
):
model.train() model.train()
total_loss = 0. total_loss = 0.0
num_batches = len(trainloader) num_batches = len(trainloader)
for batch in trainloader: for batch in trainloader:
...@@ -50,23 +62,23 @@ def train(model:torch.nn.Module, optimizer, trainloader, ...@@ -50,23 +62,23 @@ def train(model:torch.nn.Module, optimizer, trainloader,
batch_graphs, batch_labels = batch batch_graphs, batch_labels = batch
batch_graphs = batch_graphs.to(device) batch_graphs = batch_graphs.to(device)
batch_labels = batch_labels.long().to(device) batch_labels = batch_labels.long().to(device)
out, l1, l2 = model(batch_graphs, out, l1, l2 = model(batch_graphs, batch_graphs.ndata["feat"])
batch_graphs.ndata["feat"]) loss = compute_loss(
loss = compute_loss(out, batch_labels, l1, l2, out, batch_labels, l1, l2, curr_epoch, total_epochs, device
curr_epoch, total_epochs, device) )
loss.backward() loss.backward()
optimizer.step() optimizer.step()
total_loss += loss.item() total_loss += loss.item()
return total_loss / num_batches return total_loss / num_batches
@torch.no_grad() @torch.no_grad()
def test(model:torch.nn.Module, loader, device): def test(model: torch.nn.Module, loader, device):
model.eval() model.eval()
correct = 0. correct = 0.0
num_graphs = 0 num_graphs = 0
for batch in loader: for batch in loader:
...@@ -90,7 +102,7 @@ def main(args): ...@@ -90,7 +102,7 @@ def main(args):
for i in range(len(dataset)): for i in range(len(dataset)):
dataset.graph_lists[i] = dgl.remove_self_loop(dataset.graph_lists[i]) dataset.graph_lists[i] = dgl.remove_self_loop(dataset.graph_lists[i])
dataset.graph_lists[i] = dgl.add_self_loop(dataset.graph_lists[i]) dataset.graph_lists[i] = dgl.add_self_loop(dataset.graph_lists[i])
# preprocess: use node degree/label as node feature # preprocess: use node degree/label as node feature
if args.degree_as_feature: if args.degree_as_feature:
dataset = degree_as_feature(dataset) dataset = degree_as_feature(dataset)
...@@ -103,21 +115,30 @@ def main(args): ...@@ -103,21 +115,30 @@ def main(args):
num_test = len(dataset) - num_training num_test = len(dataset) - num_training
train_set, test_set = random_split(dataset, [num_training, num_test]) train_set, test_set = random_split(dataset, [num_training, num_test])
train_loader = GraphDataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=1) train_loader = GraphDataLoader(
test_loader = GraphDataLoader(test_set, batch_size=args.batch_size, num_workers=1) train_set, batch_size=args.batch_size, shuffle=True, num_workers=1
)
test_loader = GraphDataLoader(
test_set, batch_size=args.batch_size, num_workers=1
)
device = torch.device(args.device) device = torch.device(args.device)
# Step 2: Create model =================================================================== # # Step 2: Create model =================================================================== #
num_feature, num_classes, _ = dataset.statistics() num_feature, num_classes, _ = dataset.statistics()
args.in_dim = int(num_feature) args.in_dim = int(num_feature)
args.out_dim = int(num_classes) args.out_dim = int(num_classes)
args.edge_feat_dim = 0 # No edge feature in datasets that we use. args.edge_feat_dim = 0 # No edge feature in datasets that we use.
model = GraphClassifier(args).to(device) model = GraphClassifier(args).to(device)
# Step 3: Create training components ===================================================== # # Step 3: Create training components ===================================================== #
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, amsgrad=True, weight_decay=args.weight_decay) optimizer = torch.optim.Adam(
model.parameters(),
lr=args.lr,
amsgrad=True,
weight_decay=args.weight_decay,
)
# Step 4: training epoches =============================================================== # # Step 4: training epoches =============================================================== #
best_test_acc = 0.0 best_test_acc = 0.0
...@@ -125,8 +146,9 @@ def main(args): ...@@ -125,8 +146,9 @@ def main(args):
train_times = [] train_times = []
for e in range(args.epochs): for e in range(args.epochs):
s_time = time() s_time = time()
train_loss = train(model, optimizer, train_loader, device, train_loss = train(
e, args.epochs) model, optimizer, train_loader, device, e, args.epochs
)
train_times.append(time() - s_time) train_times.append(time() - s_time)
test_acc = test(model, test_loader, device) test_acc = test(model, test_loader, device)
if test_acc > best_test_acc: if test_acc > best_test_acc:
...@@ -134,9 +156,13 @@ def main(args): ...@@ -134,9 +156,13 @@ def main(args):
best_epoch = e + 1 best_epoch = e + 1
if (e + 1) % args.print_every == 0: if (e + 1) % args.print_every == 0:
log_format = "Epoch {}: loss={:.4f}, test_acc={:.4f}, best_test_acc={:.4f}" log_format = (
"Epoch {}: loss={:.4f}, test_acc={:.4f}, best_test_acc={:.4f}"
)
print(log_format.format(e + 1, train_loss, test_acc, best_test_acc)) print(log_format.format(e + 1, train_loss, test_acc, best_test_acc))
print("Best Epoch {}, final test acc {:.4f}".format(best_epoch, best_test_acc)) print(
"Best Epoch {}, final test acc {:.4f}".format(best_epoch, best_test_acc)
)
return best_test_acc, sum(train_times) / len(train_times) return best_test_acc, sum(train_times) / len(train_times)
...@@ -154,11 +180,15 @@ if __name__ == "__main__": ...@@ -154,11 +180,15 @@ if __name__ == "__main__":
mean, err_bd = get_stats(res, conf_interval=False) mean, err_bd = get_stats(res, conf_interval=False)
print("mean acc: {:.4f}, error bound: {:.4f}".format(mean, err_bd)) print("mean acc: {:.4f}, error bound: {:.4f}".format(mean, err_bd))
out_dict = {"hyper-parameters": vars(args), out_dict = {
"result_date": str(datetime.now()), "hyper-parameters": vars(args),
"result": "{:.4f}(+-{:.4f})".format(mean, err_bd), "result_date": str(datetime.now()),
"train_time": "{:.4f}".format(sum(train_times) / len(train_times)), "result": "{:.4f}(+-{:.4f})".format(mean, err_bd),
"details": res} "train_time": "{:.4f}".format(sum(train_times) / len(train_times)),
"details": res,
with open(os.path.join(args.output_path, "{}.log".format(args.dataset)), "w") as f: }
with open(
os.path.join(args.output_path, "{}.log".format(args.dataset)), "w"
) as f:
json.dump(out_dict, f, sort_keys=True, indent=4) json.dump(out_dict, f, sort_keys=True, indent=4)
import json import json
import os import os
from time import time
from datetime import datetime from datetime import datetime
from time import time
import dgl
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
from torch.utils.data import random_split
import torch import torch
from torch import Tensor
import torch.nn.functional as F import torch.nn.functional as F
from data_preprocess import degree_as_feature, node_label_as_feature
from networks import GraphClassifier from networks import GraphClassifier
from torch import Tensor
from torch.utils.data import random_split
from utils import get_stats, parse_args from utils import get_stats, parse_args
from data_preprocess import degree_as_feature, node_label_as_feature
import dgl
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
def compute_loss(cls_logits:Tensor, labels:Tensor, def compute_loss(
logits_s1:Tensor, logits_s2:Tensor, cls_logits: Tensor,
epoch:int, total_epochs:int, device:torch.device): labels: Tensor,
logits_s1: Tensor,
logits_s2: Tensor,
epoch: int,
total_epochs: int,
device: torch.device,
):
# classification loss # classification loss
classify_loss = F.nll_loss(cls_logits, labels.to(device)) classify_loss = F.nll_loss(cls_logits, labels.to(device))
...@@ -32,17 +38,23 @@ def compute_loss(cls_logits:Tensor, labels:Tensor, ...@@ -32,17 +38,23 @@ def compute_loss(cls_logits:Tensor, labels:Tensor,
pool_loss_s1 = F.binary_cross_entropy_with_logits(logits_s1, s1_label) pool_loss_s1 = F.binary_cross_entropy_with_logits(logits_s1, s1_label)
pool_loss_s2 = F.binary_cross_entropy_with_logits(logits_s2, s2_label) pool_loss_s2 = F.binary_cross_entropy_with_logits(logits_s2, s2_label)
pool_loss = (pool_loss_s1 + pool_loss_s2) / 2 pool_loss = (pool_loss_s1 + pool_loss_s2) / 2
loss = classify_loss + (2 - epoch / total_epochs) * pool_loss loss = classify_loss + (2 - epoch / total_epochs) * pool_loss
return loss return loss
def train(model:torch.nn.Module, optimizer, trainloader, def train(
device, curr_epoch, total_epochs): model: torch.nn.Module,
optimizer,
trainloader,
device,
curr_epoch,
total_epochs,
):
model.train() model.train()
total_loss = 0. total_loss = 0.0
num_batches = len(trainloader) num_batches = len(trainloader)
for batch in trainloader: for batch in trainloader:
...@@ -50,23 +62,23 @@ def train(model:torch.nn.Module, optimizer, trainloader, ...@@ -50,23 +62,23 @@ def train(model:torch.nn.Module, optimizer, trainloader,
batch_graphs, batch_labels = batch batch_graphs, batch_labels = batch
batch_graphs = batch_graphs.to(device) batch_graphs = batch_graphs.to(device)
batch_labels = batch_labels.long().to(device) batch_labels = batch_labels.long().to(device)
out, l1, l2 = model(batch_graphs, out, l1, l2 = model(batch_graphs, batch_graphs.ndata["feat"])
batch_graphs.ndata["feat"]) loss = compute_loss(
loss = compute_loss(out, batch_labels, l1, l2, out, batch_labels, l1, l2, curr_epoch, total_epochs, device
curr_epoch, total_epochs, device) )
loss.backward() loss.backward()
optimizer.step() optimizer.step()
total_loss += loss.item() total_loss += loss.item()
return total_loss / num_batches return total_loss / num_batches
@torch.no_grad() @torch.no_grad()
def test(model:torch.nn.Module, loader, device): def test(model: torch.nn.Module, loader, device):
model.eval() model.eval()
correct = 0. correct = 0.0
num_graphs = 0 num_graphs = 0
for batch in loader: for batch in loader:
...@@ -82,12 +94,11 @@ def test(model:torch.nn.Module, loader, device): ...@@ -82,12 +94,11 @@ def test(model:torch.nn.Module, loader, device):
@torch.no_grad() @torch.no_grad()
def validate(model:torch.nn.Module, loader, device, def validate(model: torch.nn.Module, loader, device, curr_epoch, total_epochs):
curr_epoch, total_epochs):
model.eval() model.eval()
tt_loss = 0. tt_loss = 0.0
correct = 0. correct = 0.0
num_graphs = 0 num_graphs = 0
num_batchs = len(loader) num_batchs = len(loader)
...@@ -97,8 +108,9 @@ def validate(model:torch.nn.Module, loader, device, ...@@ -97,8 +108,9 @@ def validate(model:torch.nn.Module, loader, device,
batch_graphs = batch_graphs.to(device) batch_graphs = batch_graphs.to(device)
batch_labels = batch_labels.long().to(device) batch_labels = batch_labels.long().to(device)
out, l1, l2 = model(batch_graphs, batch_graphs.ndata["feat"]) out, l1, l2 = model(batch_graphs, batch_graphs.ndata["feat"])
tt_loss += compute_loss(out, batch_labels, l1, l2, tt_loss += compute_loss(
curr_epoch, total_epochs, device).item() out, batch_labels, l1, l2, curr_epoch, total_epochs, device
).item()
pred = out.argmax(dim=1) pred = out.argmax(dim=1)
correct += pred.eq(batch_labels).sum().item() correct += pred.eq(batch_labels).sum().item()
...@@ -114,7 +126,7 @@ def main(args): ...@@ -114,7 +126,7 @@ def main(args):
for i in range(len(dataset)): for i in range(len(dataset)):
dataset.graph_lists[i] = dgl.remove_self_loop(dataset.graph_lists[i]) dataset.graph_lists[i] = dgl.remove_self_loop(dataset.graph_lists[i])
dataset.graph_lists[i] = dgl.add_self_loop(dataset.graph_lists[i]) dataset.graph_lists[i] = dgl.add_self_loop(dataset.graph_lists[i])
# use degree as node feature # use degree as node feature
if args.degree_as_feature: if args.degree_as_feature:
dataset = degree_as_feature(dataset) dataset = degree_as_feature(dataset)
...@@ -126,24 +138,37 @@ def main(args): ...@@ -126,24 +138,37 @@ def main(args):
num_training = int(len(dataset) * 0.8) num_training = int(len(dataset) * 0.8)
num_val = int(len(dataset) * 0.1) num_val = int(len(dataset) * 0.1)
num_test = len(dataset) - num_training - num_val num_test = len(dataset) - num_training - num_val
train_set, val_set, test_set = random_split(dataset, [num_training, num_val, num_test]) train_set, val_set, test_set = random_split(
dataset, [num_training, num_val, num_test]
train_loader = GraphDataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=1) )
val_loader = GraphDataLoader(val_set, batch_size=args.batch_size, num_workers=1)
test_loader = GraphDataLoader(test_set, batch_size=args.batch_size, num_workers=1) train_loader = GraphDataLoader(
train_set, batch_size=args.batch_size, shuffle=True, num_workers=1
)
val_loader = GraphDataLoader(
val_set, batch_size=args.batch_size, num_workers=1
)
test_loader = GraphDataLoader(
test_set, batch_size=args.batch_size, num_workers=1
)
device = torch.device(args.device) device = torch.device(args.device)
# Step 2: Create model =================================================================== # # Step 2: Create model =================================================================== #
num_feature, num_classes, _ = dataset.statistics() num_feature, num_classes, _ = dataset.statistics()
args.in_dim = int(num_feature) args.in_dim = int(num_feature)
args.out_dim = int(num_classes) args.out_dim = int(num_classes)
args.edge_feat_dim = 0 # No edge feature in datasets that we use. args.edge_feat_dim = 0 # No edge feature in datasets that we use.
model = GraphClassifier(args).to(device) model = GraphClassifier(args).to(device)
# Step 3: Create training components ===================================================== # # Step 3: Create training components ===================================================== #
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, amsgrad=True, weight_decay=args.weight_decay) optimizer = torch.optim.Adam(
model.parameters(),
lr=args.lr,
amsgrad=True,
weight_decay=args.weight_decay,
)
# Step 4: training epoches =============================================================== # # Step 4: training epoches =============================================================== #
best_test_acc = 0.0 best_test_acc = 0.0
...@@ -154,8 +179,9 @@ def main(args): ...@@ -154,8 +179,9 @@ def main(args):
best_val_loss = float("inf") best_val_loss = float("inf")
for e in range(args.epochs): for e in range(args.epochs):
s_time = time() s_time = time()
train_loss = train(model, optimizer, train_loader, device, train_loss = train(
e, args.epochs) model, optimizer, train_loader, device, e, args.epochs
)
train_times.append(time() - s_time) train_times.append(time() - s_time)
_, val_loss = validate(model, val_loader, device, e, args.epochs) _, val_loss = validate(model, val_loader, device, e, args.epochs)
test_acc = test(model, test_loader, device) test_acc = test(model, test_loader, device)
...@@ -167,14 +193,18 @@ def main(args): ...@@ -167,14 +193,18 @@ def main(args):
best_test_acc = test_acc best_test_acc = test_acc
else: else:
bad_count += 1 bad_count += 1
if bad_count > args.patience: if bad_count > args.patience:
break break
if (e + 1) % args.print_every == 0: if (e + 1) % args.print_every == 0:
log_format = "Epoch {}: loss={:.4f}, test_acc={:.4f}, best_test_acc={:.4f}" log_format = (
"Epoch {}: loss={:.4f}, test_acc={:.4f}, best_test_acc={:.4f}"
)
print(log_format.format(e + 1, train_loss, test_acc, best_test_acc)) print(log_format.format(e + 1, train_loss, test_acc, best_test_acc))
print("Best Epoch {}, final test acc {:.4f}".format(best_epoch, best_test_acc)) print(
"Best Epoch {}, final test acc {:.4f}".format(best_epoch, best_test_acc)
)
return best_test_acc, sum(train_times) / len(train_times) return best_test_acc, sum(train_times) / len(train_times)
...@@ -192,11 +222,15 @@ if __name__ == "__main__": ...@@ -192,11 +222,15 @@ if __name__ == "__main__":
mean, err_bd = get_stats(res, conf_interval=False) mean, err_bd = get_stats(res, conf_interval=False)
print("mean acc: {:.4f}, error bound: {:.4f}".format(mean, err_bd)) print("mean acc: {:.4f}, error bound: {:.4f}".format(mean, err_bd))
out_dict = {"hyper-parameters": vars(args), out_dict = {
"result_date": str(datetime.now()), "hyper-parameters": vars(args),
"result": "{:.4f}(+-{:.4f})".format(mean, err_bd), "result_date": str(datetime.now()),
"train_time": "{:.4f}".format(sum(train_times) / len(train_times)), "result": "{:.4f}(+-{:.4f})".format(mean, err_bd),
"details": res} "train_time": "{:.4f}".format(sum(train_times) / len(train_times)),
"details": res,
with open(os.path.join(args.output_path, "{}.log".format(args.dataset)), "w") as f: }
with open(
os.path.join(args.output_path, "{}.log".format(args.dataset)), "w"
) as f:
json.dump(out_dict, f, sort_keys=True, indent=4) json.dump(out_dict, f, sort_keys=True, indent=4)
...@@ -10,17 +10,19 @@ import torch.cuda ...@@ -10,17 +10,19 @@ import torch.cuda
from scipy.stats import t from scipy.stats import t
def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False): def get_stats(
array, conf_interval=False, name=None, stdout=False, logout=False
):
"""Compute mean and standard deviation from an numerical array """Compute mean and standard deviation from an numerical array
Args: Args:
array (array like obj): The numerical array, this array can be array (array like obj): The numerical array, this array can be
convert to :obj:`torch.Tensor`. convert to :obj:`torch.Tensor`.
conf_interval (bool, optional): If True, compute the confidence interval bound (95%) conf_interval (bool, optional): If True, compute the confidence interval bound (95%)
instead of the std value. (default: :obj:`False`) instead of the std value. (default: :obj:`False`)
name (str, optional): The name of this numerical array, for log usage. name (str, optional): The name of this numerical array, for log usage.
(default: :obj:`None`) (default: :obj:`None`)
stdout (bool, optional): Whether to output result to the terminal. stdout (bool, optional): Whether to output result to the terminal.
(default: :obj:`False`) (default: :obj:`False`)
logout (bool, optional): Whether to output result via logging module. logout (bool, optional): Whether to output result via logging module.
(default: :obj:`False`) (default: :obj:`False`)
...@@ -35,7 +37,7 @@ def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False) ...@@ -35,7 +37,7 @@ def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False)
if conf_interval: if conf_interval:
n = array.size(0) n = array.size(0)
se = std / (math.sqrt(n) + eps) se = std / (math.sqrt(n) + eps)
t_value = t.ppf(0.975, df=n-1) t_value = t.ppf(0.975, df=n - 1)
err_bound = t_value * se err_bound = t_value * se
else: else:
err_bound = std err_bound = std
...@@ -54,71 +56,138 @@ def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False) ...@@ -54,71 +56,138 @@ def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser("Graph Cross Network") parser = argparse.ArgumentParser("Graph Cross Network")
parser.add_argument("--pool_ratios", nargs="+", type=float, parser.add_argument(
help="The pooling ratios used in graph cross layers") "--pool_ratios",
parser.add_argument("--hidden_dim", type=int, default=96, nargs="+",
help="The number of hidden channels in GXN") type=float,
parser.add_argument("--cross_weight", type=float, default=1., help="The pooling ratios used in graph cross layers",
help="Weight parameter used in graph cross layer") )
parser.add_argument("--fuse_weight", type=float, default=1., parser.add_argument(
help="Weight parameter for feature fusion") "--hidden_dim",
parser.add_argument("--num_cross_layers", type=int, default=2, type=int,
help="The number of graph corss layers") default=96,
parser.add_argument("--readout_nodes", type=int, default=30, help="The number of hidden channels in GXN",
help="Number of nodes for each graph after final graph pooling") )
parser.add_argument("--conv1d_dims", nargs="+", type=int, parser.add_argument(
help="Number of channels in conv operations in the end of graph cross net") "--cross_weight",
parser.add_argument("--conv1d_kws", nargs="+", type=int, type=float,
help="Kernel sizes of conv1d operations") default=1.0,
parser.add_argument("--dropout", type=float, default=0., help="Weight parameter used in graph cross layer",
help="Dropout rate") )
parser.add_argument("--embed_dim", type=int, default=1024, parser.add_argument(
help="Number of channels of graph embedding") "--fuse_weight",
parser.add_argument("--final_dense_hidden_dim", type=int, default=128, type=float,
help="The number of hidden channels in final dense layers") default=1.0,
help="Weight parameter for feature fusion",
parser.add_argument("--batch_size", type=int, default=64, )
help="Batch size") parser.add_argument(
parser.add_argument("--lr", type=float, default=1e-4, "--num_cross_layers",
help="Learning rate") type=int,
parser.add_argument("--weight_decay", type=float, default=0., default=2,
help="Weight decay rate") help="The number of graph corss layers",
parser.add_argument("--epochs", type=int, default=1000, )
help="Number of training epochs") parser.add_argument(
parser.add_argument("--patience", type=int, default=20, "--readout_nodes",
help="Patience for early stopping") type=int,
parser.add_argument("--num_trials", type=int, default=1, default=30,
help="Number of trials") help="Number of nodes for each graph after final graph pooling",
)
parser.add_argument("--device", type=int, default=0, parser.add_argument(
help="Computation device id, -1 for cpu") "--conv1d_dims",
parser.add_argument("--dataset", type=str, default="DD", nargs="+",
help="Dataset used for training") type=int,
parser.add_argument("--seed", type=int, default=-1, help="Number of channels in conv operations in the end of graph cross net",
help="Random seed, -1 for unset") )
parser.add_argument("--print_every", type=int, default=10, parser.add_argument(
help="Print train log every ? epochs, -1 for silence training") "--conv1d_kws",
parser.add_argument("--dataset_path", type=str, default="./datasets", nargs="+",
help="Path holding your dataset") type=int,
parser.add_argument("--output_path", type=str, default="./output", help="Kernel sizes of conv1d operations",
help="Path holding your result files") )
parser.add_argument(
"--dropout", type=float, default=0.0, help="Dropout rate"
)
parser.add_argument(
"--embed_dim",
type=int,
default=1024,
help="Number of channels of graph embedding",
)
parser.add_argument(
"--final_dense_hidden_dim",
type=int,
default=128,
help="The number of hidden channels in final dense layers",
)
parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
parser.add_argument(
"--weight_decay", type=float, default=0.0, help="Weight decay rate"
)
parser.add_argument(
"--epochs", type=int, default=1000, help="Number of training epochs"
)
parser.add_argument(
"--patience", type=int, default=20, help="Patience for early stopping"
)
parser.add_argument(
"--num_trials", type=int, default=1, help="Number of trials"
)
parser.add_argument(
"--device",
type=int,
default=0,
help="Computation device id, -1 for cpu",
)
parser.add_argument(
"--dataset", type=str, default="DD", help="Dataset used for training"
)
parser.add_argument(
"--seed", type=int, default=-1, help="Random seed, -1 for unset"
)
parser.add_argument(
"--print_every",
type=int,
default=10,
help="Print train log every ? epochs, -1 for silence training",
)
parser.add_argument(
"--dataset_path",
type=str,
default="./datasets",
help="Path holding your dataset",
)
parser.add_argument(
"--output_path",
type=str,
default="./output",
help="Path holding your result files",
)
args = parser.parse_args() args = parser.parse_args()
# default value for list hyper-parameters # default value for list hyper-parameters
if not args.pool_ratios or len(args.pool_ratios) < 2: if not args.pool_ratios or len(args.pool_ratios) < 2:
args.pool_ratios = [0.8, 0.7] args.pool_ratios = [0.8, 0.7]
logging.warning("No valid pool_ratios is given, " logging.warning(
"using default value '{}'".format(args.pool_ratios)) "No valid pool_ratios is given, "
"using default value '{}'".format(args.pool_ratios)
)
if not args.conv1d_dims or len(args.conv1d_dims) < 2: if not args.conv1d_dims or len(args.conv1d_dims) < 2:
args.conv1d_dims = [16, 32] args.conv1d_dims = [16, 32]
logging.warning("No valid conv1d_dims is give, " logging.warning(
"using default value {}".format(args.conv1d_dims)) "No valid conv1d_dims is give, "
"using default value {}".format(args.conv1d_dims)
)
if not args.conv1d_kws or len(args.conv1d_kws) < 1: if not args.conv1d_kws or len(args.conv1d_kws) < 1:
args.conv1d_kws = [5] args.conv1d_kws = [5]
logging.warning("No valid conv1d_kws is given, " logging.warning(
"using default value '{}'".format(args.conv1d_kws)) "No valid conv1d_kws is given, "
"using default value '{}'".format(args.conv1d_kws)
)
# device # device
args.device = "cpu" if args.device < 0 else "cuda:{}".format(args.device) args.device = "cpu" if args.device < 0 else "cuda:{}".format(args.device)
if not torch.cuda.is_available(): if not torch.cuda.is_available():
...@@ -136,11 +205,11 @@ def parse_args(): ...@@ -136,11 +205,11 @@ def parse_args():
torch.cuda.manual_seed(args.seed) torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
# print every # print every
if args.print_every < 0: if args.print_every < 0:
args.print_every = args.epochs + 1 args.print_every = args.epochs + 1
# path # path
paths = [args.output_path, args.dataset_path] paths = [args.output_path, args.dataset_path]
for p in paths: for p in paths:
...@@ -148,7 +217,7 @@ def parse_args(): ...@@ -148,7 +217,7 @@ def parse_args():
os.makedirs(p) os.makedirs(p)
# datasets ad-hoc # datasets ad-hoc
if args.dataset in ['COLLAB', 'IMDB-BINARY', 'IMDB-MULTI', 'ENZYMES']: if args.dataset in ["COLLAB", "IMDB-BINARY", "IMDB-MULTI", "ENZYMES"]:
args.degree_as_feature = True args.degree_as_feature = True
else: else:
args.degree_as_feature = False args.degree_as_feature = False
......
import torch import torch
from sklearn.metrics import f1_score from sklearn.metrics import f1_score
from utils import EarlyStopping, load_data
from utils import load_data, EarlyStopping
def score(logits, labels): def score(logits, labels):
_, indices = torch.max(logits, dim=1) _, indices = torch.max(logits, dim=1)
...@@ -9,11 +9,12 @@ def score(logits, labels): ...@@ -9,11 +9,12 @@ def score(logits, labels):
labels = labels.cpu().numpy() labels = labels.cpu().numpy()
accuracy = (prediction == labels).sum() / len(prediction) accuracy = (prediction == labels).sum() / len(prediction)
micro_f1 = f1_score(labels, prediction, average='micro') micro_f1 = f1_score(labels, prediction, average="micro")
macro_f1 = f1_score(labels, prediction, average='macro') macro_f1 = f1_score(labels, prediction, average="macro")
return accuracy, micro_f1, macro_f1 return accuracy, micro_f1, macro_f1
def evaluate(model, g, features, labels, mask, loss_func): def evaluate(model, g, features, labels, mask, loss_func):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
...@@ -23,48 +24,66 @@ def evaluate(model, g, features, labels, mask, loss_func): ...@@ -23,48 +24,66 @@ def evaluate(model, g, features, labels, mask, loss_func):
return loss, accuracy, micro_f1, macro_f1 return loss, accuracy, micro_f1, macro_f1
def main(args): def main(args):
# If args['hetero'] is True, g would be a heterogeneous graph. # If args['hetero'] is True, g would be a heterogeneous graph.
# Otherwise, it will be a list of homogeneous graphs. # Otherwise, it will be a list of homogeneous graphs.
g, features, labels, num_classes, train_idx, val_idx, test_idx, train_mask, \ (
val_mask, test_mask = load_data(args['dataset']) g,
features,
if hasattr(torch, 'BoolTensor'): labels,
num_classes,
train_idx,
val_idx,
test_idx,
train_mask,
val_mask,
test_mask,
) = load_data(args["dataset"])
if hasattr(torch, "BoolTensor"):
train_mask = train_mask.bool() train_mask = train_mask.bool()
val_mask = val_mask.bool() val_mask = val_mask.bool()
test_mask = test_mask.bool() test_mask = test_mask.bool()
features = features.to(args['device']) features = features.to(args["device"])
labels = labels.to(args['device']) labels = labels.to(args["device"])
train_mask = train_mask.to(args['device']) train_mask = train_mask.to(args["device"])
val_mask = val_mask.to(args['device']) val_mask = val_mask.to(args["device"])
test_mask = test_mask.to(args['device']) test_mask = test_mask.to(args["device"])
if args['hetero']: if args["hetero"]:
from model_hetero import HAN from model_hetero import HAN
model = HAN(meta_paths=[['pa', 'ap'], ['pf', 'fp']],
in_size=features.shape[1], model = HAN(
hidden_size=args['hidden_units'], meta_paths=[["pa", "ap"], ["pf", "fp"]],
out_size=num_classes, in_size=features.shape[1],
num_heads=args['num_heads'], hidden_size=args["hidden_units"],
dropout=args['dropout']).to(args['device']) out_size=num_classes,
g = g.to(args['device']) num_heads=args["num_heads"],
dropout=args["dropout"],
).to(args["device"])
g = g.to(args["device"])
else: else:
from model import HAN from model import HAN
model = HAN(num_meta_paths=len(g),
in_size=features.shape[1], model = HAN(
hidden_size=args['hidden_units'], num_meta_paths=len(g),
out_size=num_classes, in_size=features.shape[1],
num_heads=args['num_heads'], hidden_size=args["hidden_units"],
dropout=args['dropout']).to(args['device']) out_size=num_classes,
g = [graph.to(args['device']) for graph in g] num_heads=args["num_heads"],
dropout=args["dropout"],
stopper = EarlyStopping(patience=args['patience']) ).to(args["device"])
g = [graph.to(args["device"]) for graph in g]
stopper = EarlyStopping(patience=args["patience"])
loss_fcn = torch.nn.CrossEntropyLoss() loss_fcn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], optimizer = torch.optim.Adam(
weight_decay=args['weight_decay']) model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"]
)
for epoch in range(args['num_epochs']): for epoch in range(args["num_epochs"]):
model.train() model.train()
logits = model(g, features) logits = model(g, features)
loss = loss_fcn(logits[train_mask], labels[train_mask]) loss = loss_fcn(logits[train_mask], labels[train_mask])
...@@ -73,34 +92,60 @@ def main(args): ...@@ -73,34 +92,60 @@ def main(args):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
train_acc, train_micro_f1, train_macro_f1 = score(logits[train_mask], labels[train_mask]) train_acc, train_micro_f1, train_macro_f1 = score(
val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(model, g, features, labels, val_mask, loss_fcn) logits[train_mask], labels[train_mask]
)
val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(
model, g, features, labels, val_mask, loss_fcn
)
early_stop = stopper.step(val_loss.data.item(), val_acc, model) early_stop = stopper.step(val_loss.data.item(), val_acc, model)
print('Epoch {:d} | Train Loss {:.4f} | Train Micro f1 {:.4f} | Train Macro f1 {:.4f} | ' print(
'Val Loss {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}'.format( "Epoch {:d} | Train Loss {:.4f} | Train Micro f1 {:.4f} | Train Macro f1 {:.4f} | "
epoch + 1, loss.item(), train_micro_f1, train_macro_f1, val_loss.item(), val_micro_f1, val_macro_f1)) "Val Loss {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}".format(
epoch + 1,
loss.item(),
train_micro_f1,
train_macro_f1,
val_loss.item(),
val_micro_f1,
val_macro_f1,
)
)
if early_stop: if early_stop:
break break
stopper.load_checkpoint(model) stopper.load_checkpoint(model)
test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(model, g, features, labels, test_mask, loss_fcn) test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(
print('Test loss {:.4f} | Test Micro f1 {:.4f} | Test Macro f1 {:.4f}'.format( model, g, features, labels, test_mask, loss_fcn
test_loss.item(), test_micro_f1, test_macro_f1)) )
print(
"Test loss {:.4f} | Test Micro f1 {:.4f} | Test Macro f1 {:.4f}".format(
test_loss.item(), test_micro_f1, test_macro_f1
)
)
if __name__ == '__main__': if __name__ == "__main__":
import argparse import argparse
from utils import setup from utils import setup
parser = argparse.ArgumentParser('HAN') parser = argparse.ArgumentParser("HAN")
parser.add_argument('-s', '--seed', type=int, default=1, parser.add_argument("-s", "--seed", type=int, default=1, help="Random seed")
help='Random seed') parser.add_argument(
parser.add_argument('-ld', '--log-dir', type=str, default='results', "-ld",
help='Dir for saving training results') "--log-dir",
parser.add_argument('--hetero', action='store_true', type=str,
help='Use metapath coalescing with DGL\'s own dataset') default="results",
help="Dir for saving training results",
)
parser.add_argument(
"--hetero",
action="store_true",
help="Use metapath coalescing with DGL's own dataset",
)
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__
args = setup(args) args = setup(args)
......
...@@ -4,6 +4,7 @@ import torch.nn.functional as F ...@@ -4,6 +4,7 @@ import torch.nn.functional as F
from dgl.nn.pytorch import GATConv from dgl.nn.pytorch import GATConv
class SemanticAttention(nn.Module): class SemanticAttention(nn.Module):
def __init__(self, in_size, hidden_size=128): def __init__(self, in_size, hidden_size=128):
super(SemanticAttention, self).__init__() super(SemanticAttention, self).__init__()
...@@ -11,15 +12,16 @@ class SemanticAttention(nn.Module): ...@@ -11,15 +12,16 @@ class SemanticAttention(nn.Module):
self.project = nn.Sequential( self.project = nn.Sequential(
nn.Linear(in_size, hidden_size), nn.Linear(in_size, hidden_size),
nn.Tanh(), nn.Tanh(),
nn.Linear(hidden_size, 1, bias=False) nn.Linear(hidden_size, 1, bias=False),
) )
def forward(self, z): def forward(self, z):
w = self.project(z).mean(0) # (M, 1) w = self.project(z).mean(0) # (M, 1)
beta = torch.softmax(w, dim=0) # (M, 1) beta = torch.softmax(w, dim=0) # (M, 1)
beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1) beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1)
return (beta * z).sum(1) # (N, D * K)
return (beta * z).sum(1) # (N, D * K)
class HANLayer(nn.Module): class HANLayer(nn.Module):
""" """
...@@ -45,15 +47,28 @@ class HANLayer(nn.Module): ...@@ -45,15 +47,28 @@ class HANLayer(nn.Module):
tensor tensor
The output feature The output feature
""" """
def __init__(self, num_meta_paths, in_size, out_size, layer_num_heads, dropout):
def __init__(
self, num_meta_paths, in_size, out_size, layer_num_heads, dropout
):
super(HANLayer, self).__init__() super(HANLayer, self).__init__()
# One GAT layer for each meta path based adjacency matrix # One GAT layer for each meta path based adjacency matrix
self.gat_layers = nn.ModuleList() self.gat_layers = nn.ModuleList()
for i in range(num_meta_paths): for i in range(num_meta_paths):
self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads, self.gat_layers.append(
dropout, dropout, activation=F.elu)) GATConv(
self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads) in_size,
out_size,
layer_num_heads,
dropout,
dropout,
activation=F.elu,
)
)
self.semantic_attention = SemanticAttention(
in_size=out_size * layer_num_heads
)
self.num_meta_paths = num_meta_paths self.num_meta_paths = num_meta_paths
def forward(self, gs, h): def forward(self, gs, h):
...@@ -61,19 +76,35 @@ class HANLayer(nn.Module): ...@@ -61,19 +76,35 @@ class HANLayer(nn.Module):
for i, g in enumerate(gs): for i, g in enumerate(gs):
semantic_embeddings.append(self.gat_layers[i](g, h).flatten(1)) semantic_embeddings.append(self.gat_layers[i](g, h).flatten(1))
semantic_embeddings = torch.stack(semantic_embeddings, dim=1) # (N, M, D * K) semantic_embeddings = torch.stack(
semantic_embeddings, dim=1
) # (N, M, D * K)
return self.semantic_attention(semantic_embeddings) # (N, D * K)
return self.semantic_attention(semantic_embeddings) # (N, D * K)
class HAN(nn.Module): class HAN(nn.Module):
def __init__(self, num_meta_paths, in_size, hidden_size, out_size, num_heads, dropout): def __init__(
self, num_meta_paths, in_size, hidden_size, out_size, num_heads, dropout
):
super(HAN, self).__init__() super(HAN, self).__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layers.append(HANLayer(num_meta_paths, in_size, hidden_size, num_heads[0], dropout)) self.layers.append(
HANLayer(
num_meta_paths, in_size, hidden_size, num_heads[0], dropout
)
)
for l in range(1, len(num_heads)): for l in range(1, len(num_heads)):
self.layers.append(HANLayer(num_meta_paths, hidden_size * num_heads[l-1], self.layers.append(
hidden_size, num_heads[l], dropout)) HANLayer(
num_meta_paths,
hidden_size * num_heads[l - 1],
hidden_size,
num_heads[l],
dropout,
)
)
self.predict = nn.Linear(hidden_size * num_heads[-1], out_size) self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)
def forward(self, g, h): def forward(self, g, h):
......
...@@ -14,6 +14,7 @@ import torch.nn.functional as F ...@@ -14,6 +14,7 @@ import torch.nn.functional as F
import dgl import dgl
from dgl.nn.pytorch import GATConv from dgl.nn.pytorch import GATConv
class SemanticAttention(nn.Module): class SemanticAttention(nn.Module):
def __init__(self, in_size, hidden_size=128): def __init__(self, in_size, hidden_size=128):
super(SemanticAttention, self).__init__() super(SemanticAttention, self).__init__()
...@@ -21,15 +22,16 @@ class SemanticAttention(nn.Module): ...@@ -21,15 +22,16 @@ class SemanticAttention(nn.Module):
self.project = nn.Sequential( self.project = nn.Sequential(
nn.Linear(in_size, hidden_size), nn.Linear(in_size, hidden_size),
nn.Tanh(), nn.Tanh(),
nn.Linear(hidden_size, 1, bias=False) nn.Linear(hidden_size, 1, bias=False),
) )
def forward(self, z): def forward(self, z):
w = self.project(z).mean(0) # (M, 1) w = self.project(z).mean(0) # (M, 1)
beta = torch.softmax(w, dim=0) # (M, 1) beta = torch.softmax(w, dim=0) # (M, 1)
beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1) beta = beta.expand((z.shape[0],) + beta.shape) # (N, M, 1)
return (beta * z).sum(1) # (N, D * K)
return (beta * z).sum(1) # (N, D * K)
class HANLayer(nn.Module): class HANLayer(nn.Module):
""" """
...@@ -55,16 +57,27 @@ class HANLayer(nn.Module): ...@@ -55,16 +57,27 @@ class HANLayer(nn.Module):
tensor tensor
The output feature The output feature
""" """
def __init__(self, meta_paths, in_size, out_size, layer_num_heads, dropout): def __init__(self, meta_paths, in_size, out_size, layer_num_heads, dropout):
super(HANLayer, self).__init__() super(HANLayer, self).__init__()
# One GAT layer for each meta path based adjacency matrix # One GAT layer for each meta path based adjacency matrix
self.gat_layers = nn.ModuleList() self.gat_layers = nn.ModuleList()
for i in range(len(meta_paths)): for i in range(len(meta_paths)):
self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads, self.gat_layers.append(
dropout, dropout, activation=F.elu, GATConv(
allow_zero_in_degree=True)) in_size,
self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads) out_size,
layer_num_heads,
dropout,
dropout,
activation=F.elu,
allow_zero_in_degree=True,
)
)
self.semantic_attention = SemanticAttention(
in_size=out_size * layer_num_heads
)
self.meta_paths = list(tuple(meta_path) for meta_path in meta_paths) self.meta_paths = list(tuple(meta_path) for meta_path in meta_paths)
self._cached_graph = None self._cached_graph = None
...@@ -77,25 +90,40 @@ class HANLayer(nn.Module): ...@@ -77,25 +90,40 @@ class HANLayer(nn.Module):
self._cached_graph = g self._cached_graph = g
self._cached_coalesced_graph.clear() self._cached_coalesced_graph.clear()
for meta_path in self.meta_paths: for meta_path in self.meta_paths:
self._cached_coalesced_graph[meta_path] = dgl.metapath_reachable_graph( self._cached_coalesced_graph[
g, meta_path) meta_path
] = dgl.metapath_reachable_graph(g, meta_path)
for i, meta_path in enumerate(self.meta_paths): for i, meta_path in enumerate(self.meta_paths):
new_g = self._cached_coalesced_graph[meta_path] new_g = self._cached_coalesced_graph[meta_path]
semantic_embeddings.append(self.gat_layers[i](new_g, h).flatten(1)) semantic_embeddings.append(self.gat_layers[i](new_g, h).flatten(1))
semantic_embeddings = torch.stack(semantic_embeddings, dim=1) # (N, M, D * K) semantic_embeddings = torch.stack(
semantic_embeddings, dim=1
) # (N, M, D * K)
return self.semantic_attention(semantic_embeddings) # (N, D * K)
return self.semantic_attention(semantic_embeddings) # (N, D * K)
class HAN(nn.Module): class HAN(nn.Module):
def __init__(self, meta_paths, in_size, hidden_size, out_size, num_heads, dropout): def __init__(
self, meta_paths, in_size, hidden_size, out_size, num_heads, dropout
):
super(HAN, self).__init__() super(HAN, self).__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layers.append(HANLayer(meta_paths, in_size, hidden_size, num_heads[0], dropout)) self.layers.append(
HANLayer(meta_paths, in_size, hidden_size, num_heads[0], dropout)
)
for l in range(1, len(num_heads)): for l in range(1, len(num_heads)):
self.layers.append(HANLayer(meta_paths, hidden_size * num_heads[l-1], self.layers.append(
hidden_size, num_heads[l], dropout)) HANLayer(
meta_paths,
hidden_size * num_heads[l - 1],
hidden_size,
num_heads[l],
dropout,
)
)
self.predict = nn.Linear(hidden_size * num_heads[-1], out_size) self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)
def forward(self, g, h): def forward(self, g, h):
......
...@@ -4,21 +4,21 @@ HAN mini-batch training by RandomWalkSampler. ...@@ -4,21 +4,21 @@ HAN mini-batch training by RandomWalkSampler.
note: This demo use RandomWalkSampler to sample neighbors, it's hard to get all neighbors when valid or test, note: This demo use RandomWalkSampler to sample neighbors, it's hard to get all neighbors when valid or test,
so we sampled twice as many neighbors during val/test than training. so we sampled twice as many neighbors during val/test than training.
""" """
import dgl
import numpy
import argparse import argparse
import numpy
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.nn.pytorch import GATConv from model_hetero import SemanticAttention
from dgl.sampling import RandomWalkNeighborSampler
from sklearn.metrics import f1_score from sklearn.metrics import f1_score
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from model_hetero import SemanticAttention
from utils import EarlyStopping, set_random_seed from utils import EarlyStopping, set_random_seed
import dgl
from dgl.nn.pytorch import GATConv
from dgl.sampling import RandomWalkNeighborSampler
class HANLayer(torch.nn.Module): class HANLayer(torch.nn.Module):
""" """
...@@ -45,37 +45,64 @@ class HANLayer(torch.nn.Module): ...@@ -45,37 +45,64 @@ class HANLayer(torch.nn.Module):
The output feature The output feature
""" """
def __init__(self, num_metapath, in_size, out_size, layer_num_heads, dropout): def __init__(
self, num_metapath, in_size, out_size, layer_num_heads, dropout
):
super(HANLayer, self).__init__() super(HANLayer, self).__init__()
# One GAT layer for each meta path based adjacency matrix # One GAT layer for each meta path based adjacency matrix
self.gat_layers = nn.ModuleList() self.gat_layers = nn.ModuleList()
for i in range(num_metapath): for i in range(num_metapath):
self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads, self.gat_layers.append(
dropout, dropout, activation=F.elu, GATConv(
allow_zero_in_degree=True)) in_size,
self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads) out_size,
layer_num_heads,
dropout,
dropout,
activation=F.elu,
allow_zero_in_degree=True,
)
)
self.semantic_attention = SemanticAttention(
in_size=out_size * layer_num_heads
)
self.num_metapath = num_metapath self.num_metapath = num_metapath
def forward(self, block_list, h_list): def forward(self, block_list, h_list):
semantic_embeddings = [] semantic_embeddings = []
for i, block in enumerate(block_list): for i, block in enumerate(block_list):
semantic_embeddings.append(self.gat_layers[i](block, h_list[i]).flatten(1)) semantic_embeddings.append(
semantic_embeddings = torch.stack(semantic_embeddings, dim=1) # (N, M, D * K) self.gat_layers[i](block, h_list[i]).flatten(1)
)
semantic_embeddings = torch.stack(
semantic_embeddings, dim=1
) # (N, M, D * K)
return self.semantic_attention(semantic_embeddings) # (N, D * K) return self.semantic_attention(semantic_embeddings) # (N, D * K)
class HAN(nn.Module): class HAN(nn.Module):
def __init__(self, num_metapath, in_size, hidden_size, out_size, num_heads, dropout): def __init__(
self, num_metapath, in_size, hidden_size, out_size, num_heads, dropout
):
super(HAN, self).__init__() super(HAN, self).__init__()
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.layers.append(HANLayer(num_metapath, in_size, hidden_size, num_heads[0], dropout)) self.layers.append(
HANLayer(num_metapath, in_size, hidden_size, num_heads[0], dropout)
)
for l in range(1, len(num_heads)): for l in range(1, len(num_heads)):
self.layers.append(HANLayer(num_metapath, hidden_size * num_heads[l - 1], self.layers.append(
hidden_size, num_heads[l], dropout)) HANLayer(
num_metapath,
hidden_size * num_heads[l - 1],
hidden_size,
num_heads[l],
dropout,
)
)
self.predict = nn.Linear(hidden_size * num_heads[-1], out_size) self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)
def forward(self, g, h): def forward(self, g, h):
...@@ -91,12 +118,16 @@ class HANSampler(object): ...@@ -91,12 +118,16 @@ class HANSampler(object):
for metapath in metapath_list: for metapath in metapath_list:
# note: random walk may get same route(same edge), which will be removed in the sampled graph. # note: random walk may get same route(same edge), which will be removed in the sampled graph.
# So the sampled graph's edges may be less than num_random_walks(num_neighbors). # So the sampled graph's edges may be less than num_random_walks(num_neighbors).
self.sampler_list.append(RandomWalkNeighborSampler(G=g, self.sampler_list.append(
num_traversals=1, RandomWalkNeighborSampler(
termination_prob=0, G=g,
num_random_walks=num_neighbors, num_traversals=1,
num_neighbors=num_neighbors, termination_prob=0,
metapath=metapath)) num_random_walks=num_neighbors,
num_neighbors=num_neighbors,
metapath=metapath,
)
)
def sample_blocks(self, seeds): def sample_blocks(self, seeds):
block_list = [] block_list = []
...@@ -117,34 +148,49 @@ def score(logits, labels): ...@@ -117,34 +148,49 @@ def score(logits, labels):
labels = labels.cpu().numpy() labels = labels.cpu().numpy()
accuracy = (prediction == labels).sum() / len(prediction) accuracy = (prediction == labels).sum() / len(prediction)
micro_f1 = f1_score(labels, prediction, average='micro') micro_f1 = f1_score(labels, prediction, average="micro")
macro_f1 = f1_score(labels, prediction, average='macro') macro_f1 = f1_score(labels, prediction, average="macro")
return accuracy, micro_f1, macro_f1 return accuracy, micro_f1, macro_f1
def evaluate(model, g, metapath_list, num_neighbors, features, labels, val_nid, loss_fcn, batch_size): def evaluate(
model,
g,
metapath_list,
num_neighbors,
features,
labels,
val_nid,
loss_fcn,
batch_size,
):
model.eval() model.eval()
han_valid_sampler = HANSampler(g, metapath_list, num_neighbors=num_neighbors * 2) han_valid_sampler = HANSampler(
g, metapath_list, num_neighbors=num_neighbors * 2
)
dataloader = DataLoader( dataloader = DataLoader(
dataset=val_nid, dataset=val_nid,
batch_size=batch_size, batch_size=batch_size,
collate_fn=han_valid_sampler.sample_blocks, collate_fn=han_valid_sampler.sample_blocks,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
num_workers=4) num_workers=4,
)
correct = total = 0 correct = total = 0
prediction_list = [] prediction_list = []
labels_list = [] labels_list = []
with torch.no_grad(): with torch.no_grad():
for step, (seeds, blocks) in enumerate(dataloader): for step, (seeds, blocks) in enumerate(dataloader):
h_list = load_subtensors(blocks, features) h_list = load_subtensors(blocks, features)
blocks = [block.to(args['device']) for block in blocks] blocks = [block.to(args["device"]) for block in blocks]
hs = [h.to(args['device']) for h in h_list] hs = [h.to(args["device"]) for h in h_list]
logits = model(blocks, hs) logits = model(blocks, hs)
loss = loss_fcn(logits, labels[numpy.asarray(seeds)].to(args['device'])) loss = loss_fcn(
logits, labels[numpy.asarray(seeds)].to(args["device"])
)
# get each predict label # get each predict label
_, indices = torch.max(logits, dim=1) _, indices = torch.max(logits, dim=1)
prediction = indices.long().cpu().numpy() prediction = indices.long().cpu().numpy()
...@@ -158,8 +204,8 @@ def evaluate(model, g, metapath_list, num_neighbors, features, labels, val_nid, ...@@ -158,8 +204,8 @@ def evaluate(model, g, metapath_list, num_neighbors, features, labels, val_nid,
total_prediction = numpy.concatenate(prediction_list) total_prediction = numpy.concatenate(prediction_list)
total_labels = numpy.concatenate(labels_list) total_labels = numpy.concatenate(labels_list)
micro_f1 = f1_score(total_labels, total_prediction, average='micro') micro_f1 = f1_score(total_labels, total_prediction, average="micro")
macro_f1 = f1_score(total_labels, total_prediction, average='macro') macro_f1 = f1_score(total_labels, total_prediction, average="macro")
accuracy = correct / total accuracy = correct / total
return loss, accuracy, micro_f1, macro_f1 return loss, accuracy, micro_f1, macro_f1
...@@ -175,95 +221,142 @@ def load_subtensors(blocks, features): ...@@ -175,95 +221,142 @@ def load_subtensors(blocks, features):
def main(args): def main(args):
# acm data # acm data
if args['dataset'] == 'ACMRaw': if args["dataset"] == "ACMRaw":
from utils import load_data from utils import load_data
g, features, labels, n_classes, train_nid, val_nid, test_nid, train_mask, \
val_mask, test_mask = load_data('ACMRaw') (
metapath_list = [['pa', 'ap'], ['pf', 'fp']] g,
features,
labels,
n_classes,
train_nid,
val_nid,
test_nid,
train_mask,
val_mask,
test_mask,
) = load_data("ACMRaw")
metapath_list = [["pa", "ap"], ["pf", "fp"]]
else: else:
raise NotImplementedError('Unsupported dataset {}'.format(args['dataset'])) raise NotImplementedError(
"Unsupported dataset {}".format(args["dataset"])
)
# Is it need to set different neighbors numbers for different meta-path based graph? # Is it need to set different neighbors numbers for different meta-path based graph?
num_neighbors = args['num_neighbors'] num_neighbors = args["num_neighbors"]
han_sampler = HANSampler(g, metapath_list, num_neighbors) han_sampler = HANSampler(g, metapath_list, num_neighbors)
# Create PyTorch DataLoader for constructing blocks # Create PyTorch DataLoader for constructing blocks
dataloader = DataLoader( dataloader = DataLoader(
dataset=train_nid, dataset=train_nid,
batch_size=args['batch_size'], batch_size=args["batch_size"],
collate_fn=han_sampler.sample_blocks, collate_fn=han_sampler.sample_blocks,
shuffle=True, shuffle=True,
drop_last=False, drop_last=False,
num_workers=4) num_workers=4,
)
model = HAN(num_metapath=len(metapath_list),
in_size=features.shape[1], model = HAN(
hidden_size=args['hidden_units'], num_metapath=len(metapath_list),
out_size=n_classes, in_size=features.shape[1],
num_heads=args['num_heads'], hidden_size=args["hidden_units"],
dropout=args['dropout']).to(args['device']) out_size=n_classes,
num_heads=args["num_heads"],
dropout=args["dropout"],
).to(args["device"])
total_params = sum(p.numel() for p in model.parameters()) total_params = sum(p.numel() for p in model.parameters())
print("total_params: {:d}".format(total_params)) print("total_params: {:d}".format(total_params))
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) total_trainable_params = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
print("total trainable params: {:d}".format(total_trainable_params)) print("total trainable params: {:d}".format(total_trainable_params))
stopper = EarlyStopping(patience=args['patience']) stopper = EarlyStopping(patience=args["patience"])
loss_fn = torch.nn.CrossEntropyLoss() loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], optimizer = torch.optim.Adam(
weight_decay=args['weight_decay']) model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"]
)
for epoch in range(args['num_epochs']): for epoch in range(args["num_epochs"]):
model.train() model.train()
for step, (seeds, blocks) in enumerate(dataloader): for step, (seeds, blocks) in enumerate(dataloader):
h_list = load_subtensors(blocks, features) h_list = load_subtensors(blocks, features)
blocks = [block.to(args['device']) for block in blocks] blocks = [block.to(args["device"]) for block in blocks]
hs = [h.to(args['device']) for h in h_list] hs = [h.to(args["device"]) for h in h_list]
logits = model(blocks, hs) logits = model(blocks, hs)
loss = loss_fn(logits, labels[numpy.asarray(seeds)].to(args['device'])) loss = loss_fn(
logits, labels[numpy.asarray(seeds)].to(args["device"])
)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
# print info in each batch # print info in each batch
train_acc, train_micro_f1, train_macro_f1 = score(logits, labels[numpy.asarray(seeds)]) train_acc, train_micro_f1, train_macro_f1 = score(
logits, labels[numpy.asarray(seeds)]
)
print( print(
"Epoch {:d} | loss: {:.4f} | train_acc: {:.4f} | train_micro_f1: {:.4f} | train_macro_f1: {:.4f}".format( "Epoch {:d} | loss: {:.4f} | train_acc: {:.4f} | train_micro_f1: {:.4f} | train_macro_f1: {:.4f}".format(
epoch + 1, loss, train_acc, train_micro_f1, train_macro_f1 epoch + 1, loss, train_acc, train_micro_f1, train_macro_f1
)) )
val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(model, g, metapath_list, num_neighbors, features, )
labels, val_nid, loss_fn, args['batch_size']) val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(
model,
g,
metapath_list,
num_neighbors,
features,
labels,
val_nid,
loss_fn,
args["batch_size"],
)
early_stop = stopper.step(val_loss.data.item(), val_acc, model) early_stop = stopper.step(val_loss.data.item(), val_acc, model)
print('Epoch {:d} | Val loss {:.4f} | Val Accuracy {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}'.format( print(
epoch + 1, val_loss.item(), val_acc, val_micro_f1, val_macro_f1)) "Epoch {:d} | Val loss {:.4f} | Val Accuracy {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}".format(
epoch + 1, val_loss.item(), val_acc, val_micro_f1, val_macro_f1
)
)
if early_stop: if early_stop:
break break
stopper.load_checkpoint(model) stopper.load_checkpoint(model)
test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(model, g, metapath_list, num_neighbors, features, test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(
labels, test_nid, loss_fn, args['batch_size']) model,
print('Test loss {:.4f} | Test Accuracy {:.4f} | Test Micro f1 {:.4f} | Test Macro f1 {:.4f}'.format( g,
test_loss.item(), test_acc, test_micro_f1, test_macro_f1)) metapath_list,
num_neighbors,
features,
if __name__ == '__main__': labels,
parser = argparse.ArgumentParser('mini-batch HAN') test_nid,
parser.add_argument('-s', '--seed', type=int, default=1, loss_fn,
help='Random seed') args["batch_size"],
parser.add_argument('--batch_size', type=int, default=32) )
parser.add_argument('--num_neighbors', type=int, default=20) print(
parser.add_argument('--lr', type=float, default=0.001) "Test loss {:.4f} | Test Accuracy {:.4f} | Test Micro f1 {:.4f} | Test Macro f1 {:.4f}".format(
parser.add_argument('--num_heads', type=list, default=[8]) test_loss.item(), test_acc, test_micro_f1, test_macro_f1
parser.add_argument('--hidden_units', type=int, default=8) )
parser.add_argument('--dropout', type=float, default=0.6) )
parser.add_argument('--weight_decay', type=float, default=0.001)
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--patience', type=int, default=10) if __name__ == "__main__":
parser.add_argument('--dataset', type=str, default='ACMRaw') parser = argparse.ArgumentParser("mini-batch HAN")
parser.add_argument('--device', type=str, default='cuda:0') parser.add_argument("-s", "--seed", type=int, default=1, help="Random seed")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--num_neighbors", type=int, default=20)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--num_heads", type=list, default=[8])
parser.add_argument("--hidden_units", type=int, default=8)
parser.add_argument("--dropout", type=float, default=0.6)
parser.add_argument("--weight_decay", type=float, default=0.001)
parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--patience", type=int, default=10)
parser.add_argument("--dataset", type=str, default="ACMRaw")
parser.add_argument("--device", type=str, default="cuda:0")
args = parser.parse_args().__dict__ args = parser.parse_args().__dict__
# set_random_seed(args['seed']) # set_random_seed(args['seed'])
......
import datetime import datetime
import dgl
import errno import errno
import numpy as np
import os import os
import pickle import pickle
import random import random
import torch
from dgl.data.utils import download, get_download_dir, _get_dgl_url
from pprint import pprint from pprint import pprint
from scipy import sparse
import numpy as np
import torch
from scipy import io as sio from scipy import io as sio
from scipy import sparse
import dgl
from dgl.data.utils import _get_dgl_url, download, get_download_dir
def set_random_seed(seed=0): def set_random_seed(seed=0):
"""Set random seed. """Set random seed.
...@@ -25,6 +27,7 @@ def set_random_seed(seed=0): ...@@ -25,6 +27,7 @@ def set_random_seed(seed=0):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
def mkdir_p(path, log=True): def mkdir_p(path, log=True):
"""Create a directory for the specified path. """Create a directory for the specified path.
Parameters Parameters
...@@ -37,13 +40,14 @@ def mkdir_p(path, log=True): ...@@ -37,13 +40,14 @@ def mkdir_p(path, log=True):
try: try:
os.makedirs(path) os.makedirs(path)
if log: if log:
print('Created directory {}'.format(path)) print("Created directory {}".format(path))
except OSError as exc: except OSError as exc:
if exc.errno == errno.EEXIST and os.path.isdir(path) and log: if exc.errno == errno.EEXIST and os.path.isdir(path) and log:
print('Directory {} already exists.'.format(path)) print("Directory {} already exists.".format(path))
else: else:
raise raise
def get_date_postfix(): def get_date_postfix():
"""Get a date based postfix for directory name. """Get a date based postfix for directory name.
Returns Returns
...@@ -51,11 +55,13 @@ def get_date_postfix(): ...@@ -51,11 +55,13 @@ def get_date_postfix():
post_fix : str post_fix : str
""" """
dt = datetime.datetime.now() dt = datetime.datetime.now()
post_fix = '{}_{:02d}-{:02d}-{:02d}'.format( post_fix = "{}_{:02d}-{:02d}-{:02d}".format(
dt.date(), dt.hour, dt.minute, dt.second) dt.date(), dt.hour, dt.minute, dt.second
)
return post_fix return post_fix
def setup_log_dir(args, sampling=False): def setup_log_dir(args, sampling=False):
"""Name and create directory for logging. """Name and create directory for logging.
Parameters Parameters
...@@ -71,106 +77,124 @@ def setup_log_dir(args, sampling=False): ...@@ -71,106 +77,124 @@ def setup_log_dir(args, sampling=False):
""" """
date_postfix = get_date_postfix() date_postfix = get_date_postfix()
log_dir = os.path.join( log_dir = os.path.join(
args['log_dir'], args["log_dir"], "{}_{}".format(args["dataset"], date_postfix)
'{}_{}'.format(args['dataset'], date_postfix)) )
if sampling: if sampling:
log_dir = log_dir + '_sampling' log_dir = log_dir + "_sampling"
mkdir_p(log_dir) mkdir_p(log_dir)
return log_dir return log_dir
# The configuration below is from the paper. # The configuration below is from the paper.
default_configure = { default_configure = {
'lr': 0.005, # Learning rate "lr": 0.005, # Learning rate
'num_heads': [8], # Number of attention heads for node-level attention "num_heads": [8], # Number of attention heads for node-level attention
'hidden_units': 8, "hidden_units": 8,
'dropout': 0.6, "dropout": 0.6,
'weight_decay': 0.001, "weight_decay": 0.001,
'num_epochs': 200, "num_epochs": 200,
'patience': 100 "patience": 100,
} }
sampling_configure = { sampling_configure = {"batch_size": 20}
'batch_size': 20
}
def setup(args): def setup(args):
args.update(default_configure) args.update(default_configure)
set_random_seed(args['seed']) set_random_seed(args["seed"])
args['dataset'] = 'ACMRaw' if args['hetero'] else 'ACM' args["dataset"] = "ACMRaw" if args["hetero"] else "ACM"
args['device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu' args["device"] = "cuda:0" if torch.cuda.is_available() else "cpu"
args['log_dir'] = setup_log_dir(args) args["log_dir"] = setup_log_dir(args)
return args return args
def setup_for_sampling(args): def setup_for_sampling(args):
args.update(default_configure) args.update(default_configure)
args.update(sampling_configure) args.update(sampling_configure)
set_random_seed() set_random_seed()
args['device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu' args["device"] = "cuda:0" if torch.cuda.is_available() else "cpu"
args['log_dir'] = setup_log_dir(args, sampling=True) args["log_dir"] = setup_log_dir(args, sampling=True)
return args return args
def get_binary_mask(total_size, indices): def get_binary_mask(total_size, indices):
mask = torch.zeros(total_size) mask = torch.zeros(total_size)
mask[indices] = 1 mask[indices] = 1
return mask.byte() return mask.byte()
def load_acm(remove_self_loop): def load_acm(remove_self_loop):
url = 'dataset/ACM3025.pkl' url = "dataset/ACM3025.pkl"
data_path = get_download_dir() + '/ACM3025.pkl' data_path = get_download_dir() + "/ACM3025.pkl"
download(_get_dgl_url(url), path=data_path) download(_get_dgl_url(url), path=data_path)
with open(data_path, 'rb') as f: with open(data_path, "rb") as f:
data = pickle.load(f) data = pickle.load(f)
labels, features = torch.from_numpy(data['label'].todense()).long(), \ labels, features = (
torch.from_numpy(data['feature'].todense()).float() torch.from_numpy(data["label"].todense()).long(),
torch.from_numpy(data["feature"].todense()).float(),
)
num_classes = labels.shape[1] num_classes = labels.shape[1]
labels = labels.nonzero()[:, 1] labels = labels.nonzero()[:, 1]
if remove_self_loop: if remove_self_loop:
num_nodes = data['label'].shape[0] num_nodes = data["label"].shape[0]
data['PAP'] = sparse.csr_matrix(data['PAP'] - np.eye(num_nodes)) data["PAP"] = sparse.csr_matrix(data["PAP"] - np.eye(num_nodes))
data['PLP'] = sparse.csr_matrix(data['PLP'] - np.eye(num_nodes)) data["PLP"] = sparse.csr_matrix(data["PLP"] - np.eye(num_nodes))
# Adjacency matrices for meta path based neighbors # Adjacency matrices for meta path based neighbors
# (Mufei): I verified both of them are binary adjacency matrices with self loops # (Mufei): I verified both of them are binary adjacency matrices with self loops
author_g = dgl.from_scipy(data['PAP']) author_g = dgl.from_scipy(data["PAP"])
subject_g = dgl.from_scipy(data['PLP']) subject_g = dgl.from_scipy(data["PLP"])
gs = [author_g, subject_g] gs = [author_g, subject_g]
train_idx = torch.from_numpy(data['train_idx']).long().squeeze(0) train_idx = torch.from_numpy(data["train_idx"]).long().squeeze(0)
val_idx = torch.from_numpy(data['val_idx']).long().squeeze(0) val_idx = torch.from_numpy(data["val_idx"]).long().squeeze(0)
test_idx = torch.from_numpy(data['test_idx']).long().squeeze(0) test_idx = torch.from_numpy(data["test_idx"]).long().squeeze(0)
num_nodes = author_g.number_of_nodes() num_nodes = author_g.number_of_nodes()
train_mask = get_binary_mask(num_nodes, train_idx) train_mask = get_binary_mask(num_nodes, train_idx)
val_mask = get_binary_mask(num_nodes, val_idx) val_mask = get_binary_mask(num_nodes, val_idx)
test_mask = get_binary_mask(num_nodes, test_idx) test_mask = get_binary_mask(num_nodes, test_idx)
print('dataset loaded') print("dataset loaded")
pprint({ pprint(
'dataset': 'ACM', {
'train': train_mask.sum().item() / num_nodes, "dataset": "ACM",
'val': val_mask.sum().item() / num_nodes, "train": train_mask.sum().item() / num_nodes,
'test': test_mask.sum().item() / num_nodes "val": val_mask.sum().item() / num_nodes,
}) "test": test_mask.sum().item() / num_nodes,
}
)
return (
gs,
features,
labels,
num_classes,
train_idx,
val_idx,
test_idx,
train_mask,
val_mask,
test_mask,
)
return gs, features, labels, num_classes, train_idx, val_idx, test_idx, \
train_mask, val_mask, test_mask
def load_acm_raw(remove_self_loop): def load_acm_raw(remove_self_loop):
assert not remove_self_loop assert not remove_self_loop
url = 'dataset/ACM.mat' url = "dataset/ACM.mat"
data_path = get_download_dir() + '/ACM.mat' data_path = get_download_dir() + "/ACM.mat"
download(_get_dgl_url(url), path=data_path) download(_get_dgl_url(url), path=data_path)
data = sio.loadmat(data_path) data = sio.loadmat(data_path)
p_vs_l = data['PvsL'] # paper-field? p_vs_l = data["PvsL"] # paper-field?
p_vs_a = data['PvsA'] # paper-author p_vs_a = data["PvsA"] # paper-author
p_vs_t = data['PvsT'] # paper-term, bag of words p_vs_t = data["PvsT"] # paper-term, bag of words
p_vs_c = data['PvsC'] # paper-conference, labels come from that p_vs_c = data["PvsC"] # paper-conference, labels come from that
# We assign # We assign
# (1) KDD papers as class 0 (data mining), # (1) KDD papers as class 0 (data mining),
...@@ -186,12 +210,14 @@ def load_acm_raw(remove_self_loop): ...@@ -186,12 +210,14 @@ def load_acm_raw(remove_self_loop):
p_vs_t = p_vs_t[p_selected] p_vs_t = p_vs_t[p_selected]
p_vs_c = p_vs_c[p_selected] p_vs_c = p_vs_c[p_selected]
hg = dgl.heterograph({ hg = dgl.heterograph(
('paper', 'pa', 'author'): p_vs_a.nonzero(), {
('author', 'ap', 'paper'): p_vs_a.transpose().nonzero(), ("paper", "pa", "author"): p_vs_a.nonzero(),
('paper', 'pf', 'field'): p_vs_l.nonzero(), ("author", "ap", "paper"): p_vs_a.transpose().nonzero(),
('field', 'fp', 'paper'): p_vs_l.transpose().nonzero() ("paper", "pf", "field"): p_vs_l.nonzero(),
}) ("field", "fp", "paper"): p_vs_l.transpose().nonzero(),
}
)
features = torch.FloatTensor(p_vs_t.toarray()) features = torch.FloatTensor(p_vs_t.toarray())
...@@ -205,33 +231,48 @@ def load_acm_raw(remove_self_loop): ...@@ -205,33 +231,48 @@ def load_acm_raw(remove_self_loop):
float_mask = np.zeros(len(pc_p)) float_mask = np.zeros(len(pc_p))
for conf_id in conf_ids: for conf_id in conf_ids:
pc_c_mask = (pc_c == conf_id) pc_c_mask = pc_c == conf_id
float_mask[pc_c_mask] = np.random.permutation(np.linspace(0, 1, pc_c_mask.sum())) float_mask[pc_c_mask] = np.random.permutation(
np.linspace(0, 1, pc_c_mask.sum())
)
train_idx = np.where(float_mask <= 0.2)[0] train_idx = np.where(float_mask <= 0.2)[0]
val_idx = np.where((float_mask > 0.2) & (float_mask <= 0.3))[0] val_idx = np.where((float_mask > 0.2) & (float_mask <= 0.3))[0]
test_idx = np.where(float_mask > 0.3)[0] test_idx = np.where(float_mask > 0.3)[0]
num_nodes = hg.number_of_nodes('paper') num_nodes = hg.number_of_nodes("paper")
train_mask = get_binary_mask(num_nodes, train_idx) train_mask = get_binary_mask(num_nodes, train_idx)
val_mask = get_binary_mask(num_nodes, val_idx) val_mask = get_binary_mask(num_nodes, val_idx)
test_mask = get_binary_mask(num_nodes, test_idx) test_mask = get_binary_mask(num_nodes, test_idx)
return hg, features, labels, num_classes, train_idx, val_idx, test_idx, \ return (
train_mask, val_mask, test_mask hg,
features,
labels,
num_classes,
train_idx,
val_idx,
test_idx,
train_mask,
val_mask,
test_mask,
)
def load_data(dataset, remove_self_loop=False): def load_data(dataset, remove_self_loop=False):
if dataset == 'ACM': if dataset == "ACM":
return load_acm(remove_self_loop) return load_acm(remove_self_loop)
elif dataset == 'ACMRaw': elif dataset == "ACMRaw":
return load_acm_raw(remove_self_loop) return load_acm_raw(remove_self_loop)
else: else:
return NotImplementedError('Unsupported dataset {}'.format(dataset)) return NotImplementedError("Unsupported dataset {}".format(dataset))
class EarlyStopping(object): class EarlyStopping(object):
def __init__(self, patience=10): def __init__(self, patience=10):
dt = datetime.datetime.now() dt = datetime.datetime.now()
self.filename = 'early_stop_{}_{:02d}-{:02d}-{:02d}.pth'.format( self.filename = "early_stop_{}_{:02d}-{:02d}-{:02d}.pth".format(
dt.date(), dt.hour, dt.minute, dt.second) dt.date(), dt.hour, dt.minute, dt.second
)
self.patience = patience self.patience = patience
self.counter = 0 self.counter = 0
self.best_acc = None self.best_acc = None
...@@ -245,7 +286,9 @@ class EarlyStopping(object): ...@@ -245,7 +286,9 @@ class EarlyStopping(object):
self.save_checkpoint(model) self.save_checkpoint(model)
elif (loss > self.best_loss) and (acc < self.best_acc): elif (loss > self.best_loss) and (acc < self.best_acc):
self.counter += 1 self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}') print(
f"EarlyStopping counter: {self.counter} out of {self.patience}"
)
if self.counter >= self.patience: if self.counter >= self.patience:
self.early_stop = True self.early_stop = True
else: else:
......
...@@ -5,28 +5,33 @@ References ...@@ -5,28 +5,33 @@ References
Paper: https://arxiv.org/abs/1907.04652 Paper: https://arxiv.org/abs/1907.04652
""" """
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn import dgl.function as fn
from dgl.base import DGLError
from dgl.nn.pytorch import edge_softmax from dgl.nn.pytorch import edge_softmax
from dgl.sampling import select_topk
from functools import partial
from dgl.nn.pytorch.utils import Identity from dgl.nn.pytorch.utils import Identity
import torch.nn.functional as F from dgl.sampling import select_topk
from dgl.base import DGLError
import dgl
class HardGAO(nn.Module): class HardGAO(nn.Module):
def __init__(self, def __init__(
in_feats, self,
out_feats, in_feats,
num_heads=8, out_feats,
feat_drop=0., num_heads=8,
attn_drop=0., feat_drop=0.0,
negative_slope=0.2, attn_drop=0.0,
residual=True, negative_slope=0.2,
activation=F.elu, residual=True,
k=8,): activation=F.elu,
k=8,
):
super(HardGAO, self).__init__() super(HardGAO, self).__init__()
self.num_heads = num_heads self.num_heads = num_heads
self.in_feats = in_feats self.in_feats = in_feats
...@@ -35,11 +40,16 @@ class HardGAO(nn.Module): ...@@ -35,11 +40,16 @@ class HardGAO(nn.Module):
self.residual = residual self.residual = residual
# Initialize Parameters for Additive Attention # Initialize Parameters for Additive Attention
self.fc = nn.Linear( self.fc = nn.Linear(
self.in_feats, self.out_feats * self.num_heads, bias=False) self.in_feats, self.out_feats * self.num_heads, bias=False
self.attn_l = nn.Parameter(torch.FloatTensor(size=(1, self.num_heads, self.out_feats))) )
self.attn_r = nn.Parameter(torch.FloatTensor(size=(1, self.num_heads, self.out_feats))) self.attn_l = nn.Parameter(
torch.FloatTensor(size=(1, self.num_heads, self.out_feats))
)
self.attn_r = nn.Parameter(
torch.FloatTensor(size=(1, self.num_heads, self.out_feats))
)
# Initialize Parameters for Hard Projection # Initialize Parameters for Hard Projection
self.p = nn.Parameter(torch.FloatTensor(size=(1,in_feats))) self.p = nn.Parameter(torch.FloatTensor(size=(1, in_feats)))
# Initialize Dropouts # Initialize Dropouts
self.feat_drop = nn.Dropout(feat_drop) self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
...@@ -48,105 +58,142 @@ class HardGAO(nn.Module): ...@@ -48,105 +58,142 @@ class HardGAO(nn.Module):
if self.in_feats == self.out_feats: if self.in_feats == self.out_feats:
self.residual_module = Identity() self.residual_module = Identity()
else: else:
self.residual_module = nn.Linear(self.in_feats,self.out_feats*num_heads,bias=False) self.residual_module = nn.Linear(
self.in_feats, self.out_feats * num_heads, bias=False
)
self.reset_parameters() self.reset_parameters()
self.activation = activation self.activation = activation
def reset_parameters(self): def reset_parameters(self):
gain = nn.init.calculate_gain('relu') gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.fc.weight, gain=gain) nn.init.xavier_normal_(self.fc.weight, gain=gain)
nn.init.xavier_normal_(self.p,gain=gain) nn.init.xavier_normal_(self.p, gain=gain)
nn.init.xavier_normal_(self.attn_l, gain=gain) nn.init.xavier_normal_(self.attn_l, gain=gain)
nn.init.xavier_normal_(self.attn_r, gain=gain) nn.init.xavier_normal_(self.attn_r, gain=gain)
if self.residual: if self.residual:
nn.init.xavier_normal_(self.residual_module.weight,gain=gain) nn.init.xavier_normal_(self.residual_module.weight, gain=gain)
def forward(self, graph, feat, get_attention=False): def forward(self, graph, feat, get_attention=False):
# Check in degree and generate error # Check in degree and generate error
if (graph.in_degrees()==0).any(): if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, ' raise DGLError(
'output for those nodes will be invalid. ' "There are 0-in-degree nodes in the graph, "
'This is harmful for some applications, ' "output for those nodes will be invalid. "
'causing silent performance regression. ' "This is harmful for some applications, "
'Adding self-loop on the input graph by ' "causing silent performance regression. "
'calling `g = dgl.add_self_loop(g)` will resolve ' "Adding self-loop on the input graph by "
'the issue. Setting ``allow_zero_in_degree`` ' "calling `g = dgl.add_self_loop(g)` will resolve "
'to be `True` when constructing this module will ' "the issue. Setting ``allow_zero_in_degree`` "
'suppress the check and let the code run.') "to be `True` when constructing this module will "
# projection process to get importance vector y "suppress the check and let the code run."
graph.ndata['y'] = torch.abs(torch.matmul(self.p,feat.T).view(-1))/torch.norm(self.p,p=2) )
# Use edge message passing function to get the weight from src node # projection process to get importance vector y
graph.apply_edges(fn.copy_u('y','y')) graph.ndata["y"] = torch.abs(
# Select Top k neighbors torch.matmul(self.p, feat.T).view(-1)
subgraph = select_topk(graph.cpu(),self.k,'y').to(graph.device) ) / torch.norm(self.p, p=2)
# Sigmoid as information threshold # Use edge message passing function to get the weight from src node
subgraph.ndata['y'] = torch.sigmoid(subgraph.ndata['y']) graph.apply_edges(fn.copy_u("y", "y"))
# Using vector matrix elementwise mul for acceleration # Select Top k neighbors
feat = subgraph.ndata['y'].view(-1,1)*feat subgraph = select_topk(graph.cpu(), self.k, "y").to(graph.device)
feat = self.feat_drop(feat) # Sigmoid as information threshold
h = self.fc(feat).view(-1, self.num_heads, self.out_feats) subgraph.ndata["y"] = torch.sigmoid(subgraph.ndata["y"])
el = (h * self.attn_l).sum(dim=-1).unsqueeze(-1) # Using vector matrix elementwise mul for acceleration
er = (h * self.attn_r).sum(dim=-1).unsqueeze(-1) feat = subgraph.ndata["y"].view(-1, 1) * feat
# Assign the value on the subgraph feat = self.feat_drop(feat)
subgraph.srcdata.update({'ft': h, 'el': el}) h = self.fc(feat).view(-1, self.num_heads, self.out_feats)
subgraph.dstdata.update({'er': er}) el = (h * self.attn_l).sum(dim=-1).unsqueeze(-1)
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively. er = (h * self.attn_r).sum(dim=-1).unsqueeze(-1)
subgraph.apply_edges(fn.u_add_v('el', 'er', 'e')) # Assign the value on the subgraph
e = self.leaky_relu(subgraph.edata.pop('e')) subgraph.srcdata.update({"ft": h, "el": el})
# compute softmax subgraph.dstdata.update({"er": er})
subgraph.edata['a'] = self.attn_drop(edge_softmax(subgraph, e)) # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
# message passing subgraph.apply_edges(fn.u_add_v("el", "er", "e"))
subgraph.update_all(fn.u_mul_e('ft', 'a', 'm'), e = self.leaky_relu(subgraph.edata.pop("e"))
fn.sum('m', 'ft')) # compute softmax
rst = subgraph.dstdata['ft'] subgraph.edata["a"] = self.attn_drop(edge_softmax(subgraph, e))
# activation # message passing
if self.activation: subgraph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft"))
rst = self.activation(rst) rst = subgraph.dstdata["ft"]
# Residual # activation
if self.residual: if self.activation:
rst = rst + self.residual_module(feat).view(feat.shape[0],-1,self.out_feats) rst = self.activation(rst)
# Residual
if self.residual:
rst = rst + self.residual_module(feat).view(
feat.shape[0], -1, self.out_feats
)
if get_attention:
return rst, subgraph.edata["a"]
else:
return rst
if get_attention:
return rst, subgraph.edata['a']
else:
return rst
class HardGAT(nn.Module): class HardGAT(nn.Module):
def __init__(self, def __init__(
g, self,
num_layers, g,
in_dim, num_layers,
num_hidden, in_dim,
num_classes, num_hidden,
heads, num_classes,
activation, heads,
feat_drop, activation,
attn_drop, feat_drop,
negative_slope, attn_drop,
residual, negative_slope,
k): residual,
k,
):
super(HardGAT, self).__init__() super(HardGAT, self).__init__()
self.g = g self.g = g
self.num_layers = num_layers self.num_layers = num_layers
self.gat_layers = nn.ModuleList() self.gat_layers = nn.ModuleList()
self.activation = activation self.activation = activation
gat_layer = partial(HardGAO,k=k) gat_layer = partial(HardGAO, k=k)
muls = heads muls = heads
# input projection (no residual) # input projection (no residual)
self.gat_layers.append(gat_layer( self.gat_layers.append(
in_dim, num_hidden, heads[0], gat_layer(
feat_drop, attn_drop, negative_slope, False, self.activation)) in_dim,
num_hidden,
heads[0],
feat_drop,
attn_drop,
negative_slope,
False,
self.activation,
)
)
# hidden layers # hidden layers
for l in range(1, num_layers): for l in range(1, num_layers):
# due to multi-head, the in_dim = num_hidden * num_heads # due to multi-head, the in_dim = num_hidden * num_heads
self.gat_layers.append(gat_layer( self.gat_layers.append(
num_hidden*muls[l-1] , num_hidden, heads[l], gat_layer(
feat_drop, attn_drop, negative_slope, residual, self.activation)) num_hidden * muls[l - 1],
num_hidden,
heads[l],
feat_drop,
attn_drop,
negative_slope,
residual,
self.activation,
)
)
# output projection # output projection
self.gat_layers.append(gat_layer( self.gat_layers.append(
num_hidden*muls[-2] , num_classes, heads[-1], gat_layer(
feat_drop, attn_drop, negative_slope, False, None)) num_hidden * muls[-2],
num_classes,
heads[-1],
feat_drop,
attn_drop,
negative_slope,
False,
None,
)
)
def forward(self, inputs): def forward(self, inputs):
h = inputs h = inputs
......
...@@ -6,17 +6,22 @@ Paper: https://arxiv.org/abs/1907.04652 ...@@ -6,17 +6,22 @@ Paper: https://arxiv.org/abs/1907.04652
""" """
import argparse import argparse
import numpy as np
import time import time
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import dgl
from dgl.data import register_data_args
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from hgao import HardGAT from hgao import HardGAT
from utils import EarlyStopping from utils import EarlyStopping
import dgl
from dgl.data import (
CiteseerGraphDataset,
CoraGraphDataset,
PubmedGraphDataset,
register_data_args,
)
def accuracy(logits, labels): def accuracy(logits, labels):
_, indices = torch.max(logits, dim=1) _, indices = torch.max(logits, dim=1)
...@@ -35,16 +40,16 @@ def evaluate(model, features, labels, mask): ...@@ -35,16 +40,16 @@ def evaluate(model, features, labels, mask):
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
if args.dataset == 'cora': if args.dataset == "cora":
data = CoraGraphDataset() data = CoraGraphDataset()
elif args.dataset == 'citeseer': elif args.dataset == "citeseer":
data = CiteseerGraphDataset() data = CiteseerGraphDataset()
elif args.dataset == 'pubmed': elif args.dataset == "pubmed":
data = PubmedGraphDataset() data = PubmedGraphDataset()
else: else:
raise ValueError('Unknown dataset: {}'.format(args.dataset)) raise ValueError("Unknown dataset: {}".format(args.dataset))
if args.num_layers <=0: if args.num_layers <= 0:
raise ValueError("num layer must be positive int") raise ValueError("num layer must be positive int")
g = data[0] g = data[0]
if args.gpu < 0: if args.gpu < 0:
...@@ -53,24 +58,29 @@ def main(args): ...@@ -53,24 +58,29 @@ def main(args):
cuda = True cuda = True
g = g.to(args.gpu) g = g.to(args.gpu)
features = g.ndata['feat'] features = g.ndata["feat"]
labels = g.ndata['label'] labels = g.ndata["label"]
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
num_feats = features.shape[1] num_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = g.number_of_edges() n_edges = g.number_of_edges()
print("""----Data statistics------' print(
"""----Data statistics------'
#Edges %d #Edges %d
#Classes %d #Classes %d
#Train samples %d #Train samples %d
#Val samples %d #Val samples %d
#Test samples %d""" % #Test samples %d"""
(n_edges, n_classes, % (
train_mask.int().sum().item(), n_edges,
val_mask.int().sum().item(), n_classes,
test_mask.int().sum().item())) train_mask.int().sum().item(),
val_mask.int().sum().item(),
test_mask.int().sum().item(),
)
)
# add self loop # add self loop
g = dgl.remove_self_loop(g) g = dgl.remove_self_loop(g)
...@@ -78,18 +88,20 @@ def main(args): ...@@ -78,18 +88,20 @@ def main(args):
n_edges = g.number_of_edges() n_edges = g.number_of_edges()
# create model # create model
heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads] heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]
model = HardGAT(g, model = HardGAT(
args.num_layers, g,
num_feats, args.num_layers,
args.num_hidden, num_feats,
n_classes, args.num_hidden,
heads, n_classes,
F.elu, heads,
args.in_drop, F.elu,
args.attn_drop, args.in_drop,
args.negative_slope, args.attn_drop,
args.residual, args.negative_slope,
args.k) args.residual,
args.k,
)
print(model) print(model)
if args.early_stop: if args.early_stop:
stopper = EarlyStopping(patience=100) stopper = EarlyStopping(patience=100)
...@@ -99,7 +111,8 @@ def main(args): ...@@ -99,7 +111,8 @@ def main(args):
# use optimizer # use optimizer
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay) model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
# initialize graph # initialize graph
dur = [] dur = []
...@@ -128,52 +141,96 @@ def main(args): ...@@ -128,52 +141,96 @@ def main(args):
if stopper.step(val_acc, model): if stopper.step(val_acc, model):
break break
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |" print(
" ValAcc {:.4f} | ETputs(KTEPS) {:.2f}". "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |"
format(epoch, np.mean(dur), loss.item(), train_acc, " ValAcc {:.4f} | ETputs(KTEPS) {:.2f}".format(
val_acc, n_edges / np.mean(dur) / 1000)) epoch,
np.mean(dur),
loss.item(),
train_acc,
val_acc,
n_edges / np.mean(dur) / 1000,
)
)
print() print()
if args.early_stop: if args.early_stop:
model.load_state_dict(torch.load('es_checkpoint.pt')) model.load_state_dict(torch.load("es_checkpoint.pt"))
acc = evaluate(model, features, labels, test_mask) acc = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc)) print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description='GAT') parser = argparse.ArgumentParser(description="GAT")
register_data_args(parser) register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1, parser.add_argument(
help="which GPU to use. Set -1 to use CPU.") "--gpu",
parser.add_argument("--epochs", type=int, default=200, type=int,
help="number of training epochs") default=-1,
parser.add_argument("--num-heads", type=int, default=8, help="which GPU to use. Set -1 to use CPU.",
help="number of hidden attention heads") )
parser.add_argument("--num-out-heads", type=int, default=1, parser.add_argument(
help="number of output attention heads") "--epochs", type=int, default=200, help="number of training epochs"
parser.add_argument("--num-layers", type=int, default=1, )
help="number of hidden layers") parser.add_argument(
parser.add_argument("--num-hidden", type=int, default=8, "--num-heads",
help="number of hidden units") type=int,
parser.add_argument("--residual", action="store_true", default=False, default=8,
help="use residual connection") help="number of hidden attention heads",
parser.add_argument("--in-drop", type=float, default=.6, )
help="input feature dropout") parser.add_argument(
parser.add_argument("--attn-drop", type=float, default=.6, "--num-out-heads",
help="attention dropout") type=int,
parser.add_argument("--lr", type=float, default=0.01, default=1,
help="learning rate") help="number of output attention heads",
parser.add_argument('--weight-decay', type=float, default=5e-4, )
help="weight decay") parser.add_argument(
parser.add_argument('--negative-slope', type=float, default=0.2, "--num-layers", type=int, default=1, help="number of hidden layers"
help="the negative slope of leaky relu") )
parser.add_argument('--early-stop', action='store_true', default=False, parser.add_argument(
help="indicates whether to use early stop or not") "--num-hidden", type=int, default=8, help="number of hidden units"
parser.add_argument('--fastmode', action="store_true", default=False, )
help="skip re-evaluate the validation set") parser.add_argument(
parser.add_argument('--k',type=int,default=8, "--residual",
help='top k neighor for attention calculation') action="store_true",
default=False,
help="use residual connection",
)
parser.add_argument(
"--in-drop", type=float, default=0.6, help="input feature dropout"
)
parser.add_argument(
"--attn-drop", type=float, default=0.6, help="attention dropout"
)
parser.add_argument("--lr", type=float, default=0.01, help="learning rate")
parser.add_argument(
"--weight-decay", type=float, default=5e-4, help="weight decay"
)
parser.add_argument(
"--negative-slope",
type=float,
default=0.2,
help="the negative slope of leaky relu",
)
parser.add_argument(
"--early-stop",
action="store_true",
default=False,
help="indicates whether to use early stop or not",
)
parser.add_argument(
"--fastmode",
action="store_true",
default=False,
help="skip re-evaluate the validation set",
)
parser.add_argument(
"--k",
type=int,
default=8,
help="top k neighor for attention calculation",
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -9,6 +9,7 @@ import numpy as np ...@@ -9,6 +9,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
class EarlyStopping: class EarlyStopping:
def __init__(self, patience=10): def __init__(self, patience=10):
self.patience = patience self.patience = patience
...@@ -23,7 +24,9 @@ class EarlyStopping: ...@@ -23,7 +24,9 @@ class EarlyStopping:
self.save_checkpoint(model) self.save_checkpoint(model)
elif score < self.best_score: elif score < self.best_score:
self.counter += 1 self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}') print(
f"EarlyStopping counter: {self.counter} out of {self.patience}"
)
if self.counter >= self.patience: if self.counter >= self.patience:
self.early_stop = True self.early_stop = True
else: else:
...@@ -33,5 +36,5 @@ class EarlyStopping: ...@@ -33,5 +36,5 @@ class EarlyStopping:
return self.early_stop return self.early_stop
def save_checkpoint(self, model): def save_checkpoint(self, model):
'''Saves model when validation loss decrease.''' """Saves model when validation loss decrease."""
torch.save(model.state_dict(), 'es_checkpoint.pt') torch.save(model.state_dict(), "es_checkpoint.pt")
\ No newline at end of file
...@@ -7,69 +7,102 @@ for detailed description. ...@@ -7,69 +7,102 @@ for detailed description.
Here we implement a graph-edge version of sparsemax where we perform sparsemax for all edges Here we implement a graph-edge version of sparsemax where we perform sparsemax for all edges
with the same node as end-node in graphs. with the same node as end-node in graphs.
""" """
import dgl
import torch import torch
from torch import Tensor
from torch.autograd import Function
import dgl
from dgl.backend import astype from dgl.backend import astype
from dgl.base import ALL, is_all from dgl.base import ALL, is_all
from dgl.heterograph_index import HeteroGraphIndex from dgl.heterograph_index import HeteroGraphIndex
from dgl.sparse import _gsddmm, _gspmm from dgl.sparse import _gsddmm, _gspmm
from torch import Tensor
from torch.autograd import Function
def _neighbor_sort(scores:Tensor, end_n_ids:Tensor, in_degrees:Tensor, cum_in_degrees:Tensor): def _neighbor_sort(
scores: Tensor,
end_n_ids: Tensor,
in_degrees: Tensor,
cum_in_degrees: Tensor,
):
"""Sort edge scores for each node""" """Sort edge scores for each node"""
num_nodes, max_in_degree = in_degrees.size(0), int(in_degrees.max().item()) num_nodes, max_in_degree = in_degrees.size(0), int(in_degrees.max().item())
# Compute the index for dense score matrix with size (N x D_{max}) # Compute the index for dense score matrix with size (N x D_{max})
# Note that the end_n_ids here is the end_node tensor in dgl graph, # Note that the end_n_ids here is the end_node tensor in dgl graph,
# which is not grouped by its node id (i.e. in this form: 0,0,1,1,1,...,N,N). # which is not grouped by its node id (i.e. in this form: 0,0,1,1,1,...,N,N).
# Thus here we first sort the end_node tensor to make it easier to compute # Thus here we first sort the end_node tensor to make it easier to compute
# indexs in dense edge score matrix. Since we will need the original order # indexs in dense edge score matrix. Since we will need the original order
# for following gspmm and gsddmm operations, we also keep the reverse mapping # for following gspmm and gsddmm operations, we also keep the reverse mapping
# (the reverse_perm) here. # (the reverse_perm) here.
end_n_ids, perm = torch.sort(end_n_ids) end_n_ids, perm = torch.sort(end_n_ids)
scores = scores[perm] scores = scores[perm]
_, reverse_perm = torch.sort(perm) _, reverse_perm = torch.sort(perm)
index = torch.arange(end_n_ids.size(0), dtype=torch.long, device=scores.device) index = torch.arange(
end_n_ids.size(0), dtype=torch.long, device=scores.device
)
index = (index - cum_in_degrees[end_n_ids]) + (end_n_ids * max_in_degree) index = (index - cum_in_degrees[end_n_ids]) + (end_n_ids * max_in_degree)
index = index.long() index = index.long()
dense_scores = scores.new_full((num_nodes * max_in_degree, ), torch.finfo(scores.dtype).min) dense_scores = scores.new_full(
(num_nodes * max_in_degree,), torch.finfo(scores.dtype).min
)
dense_scores[index] = scores dense_scores[index] = scores
dense_scores = dense_scores.view(num_nodes, max_in_degree) dense_scores = dense_scores.view(num_nodes, max_in_degree)
sorted_dense_scores, dense_reverse_perm = dense_scores.sort(dim=-1, descending=True) sorted_dense_scores, dense_reverse_perm = dense_scores.sort(
dim=-1, descending=True
)
_, dense_reverse_perm = torch.sort(dense_reverse_perm, dim=-1) _, dense_reverse_perm = torch.sort(dense_reverse_perm, dim=-1)
dense_reverse_perm = dense_reverse_perm + cum_in_degrees.view(-1, 1) dense_reverse_perm = dense_reverse_perm + cum_in_degrees.view(-1, 1)
dense_reverse_perm = dense_reverse_perm.view(-1) dense_reverse_perm = dense_reverse_perm.view(-1)
cumsum_sorted_dense_scores = sorted_dense_scores.cumsum(dim=-1).view(-1) cumsum_sorted_dense_scores = sorted_dense_scores.cumsum(dim=-1).view(-1)
sorted_dense_scores = sorted_dense_scores.view(-1) sorted_dense_scores = sorted_dense_scores.view(-1)
arange_vec = torch.arange(1, max_in_degree + 1, dtype=torch.long, device=end_n_ids.device) arange_vec = torch.arange(
arange_vec = torch.repeat_interleave(arange_vec.view(1, -1), num_nodes, dim=0).view(-1) 1, max_in_degree + 1, dtype=torch.long, device=end_n_ids.device
)
valid_mask = (sorted_dense_scores != torch.finfo(scores.dtype).min) arange_vec = torch.repeat_interleave(
arange_vec.view(1, -1), num_nodes, dim=0
).view(-1)
valid_mask = sorted_dense_scores != torch.finfo(scores.dtype).min
sorted_scores = sorted_dense_scores[valid_mask] sorted_scores = sorted_dense_scores[valid_mask]
cumsum_sorted_scores = cumsum_sorted_dense_scores[valid_mask] cumsum_sorted_scores = cumsum_sorted_dense_scores[valid_mask]
arange_vec = arange_vec[valid_mask] arange_vec = arange_vec[valid_mask]
dense_reverse_perm = dense_reverse_perm[valid_mask].long() dense_reverse_perm = dense_reverse_perm[valid_mask].long()
return sorted_scores, cumsum_sorted_scores, arange_vec, reverse_perm, dense_reverse_perm return (
sorted_scores,
cumsum_sorted_scores,
arange_vec,
reverse_perm,
dense_reverse_perm,
)
def _threshold_and_support_graph(gidx:HeteroGraphIndex, scores:Tensor, end_n_ids:Tensor): def _threshold_and_support_graph(
gidx: HeteroGraphIndex, scores: Tensor, end_n_ids: Tensor
):
"""Find the threshold for each node and its edges""" """Find the threshold for each node and its edges"""
in_degrees = _gspmm(gidx, "copy_rhs", "sum", None, torch.ones_like(scores))[0] in_degrees = _gspmm(gidx, "copy_rhs", "sum", None, torch.ones_like(scores))[
cum_in_degrees = torch.cat([in_degrees.new_zeros(1), in_degrees.cumsum(dim=0)[:-1]], dim=0) 0
]
cum_in_degrees = torch.cat(
[in_degrees.new_zeros(1), in_degrees.cumsum(dim=0)[:-1]], dim=0
)
# perform sort on edges for each node # perform sort on edges for each node
sorted_scores, cumsum_scores, rhos, reverse_perm, dense_reverse_perm = _neighbor_sort(scores, end_n_ids, (
in_degrees, cum_in_degrees) sorted_scores,
cumsum_scores = cumsum_scores - 1. cumsum_scores,
rhos,
reverse_perm,
dense_reverse_perm,
) = _neighbor_sort(scores, end_n_ids, in_degrees, cum_in_degrees)
cumsum_scores = cumsum_scores - 1.0
support = rhos * sorted_scores > cumsum_scores support = rhos * sorted_scores > cumsum_scores
support = support[dense_reverse_perm] # from sorted order to unsorted order support = support[dense_reverse_perm] # from sorted order to unsorted order
support = support[reverse_perm] # from src-dst order to eid order support = support[reverse_perm] # from src-dst order to eid order
support_size = _gspmm(gidx, "copy_rhs", "sum", None, support.float())[0] support_size = _gspmm(gidx, "copy_rhs", "sum", None, support.float())[0]
support_size = support_size.long() support_size = support_size.long()
...@@ -88,15 +121,22 @@ class EdgeSparsemaxFunction(Function): ...@@ -88,15 +121,22 @@ class EdgeSparsemaxFunction(Function):
r""" r"""
Description Description
----------- -----------
Pytorch Auto-Grad Function for edge sparsemax. Pytorch Auto-Grad Function for edge sparsemax.
We define this auto-grad function here since We define this auto-grad function here since
sparsemax involves sort and select, which are sparsemax involves sort and select, which are
not derivative. not derivative.
""" """
@staticmethod @staticmethod
def forward(ctx, gidx:HeteroGraphIndex, scores:Tensor, def forward(
eids:Tensor, end_n_ids:Tensor, norm_by:str): ctx,
gidx: HeteroGraphIndex,
scores: Tensor,
eids: Tensor,
end_n_ids: Tensor,
norm_by: str,
):
if not is_all(eids): if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == "src": if norm_by == "src":
...@@ -125,24 +165,26 @@ class EdgeSparsemaxFunction(Function): ...@@ -125,24 +165,26 @@ class EdgeSparsemaxFunction(Function):
grad_in[out == 0] = 0 grad_in[out == 0] = 0
# dL/dv_i = dL/do_i - 1/k \sum_{j=1}^k dL/do_j # dL/dv_i = dL/do_i - 1/k \sum_{j=1}^k dL/do_j
v_hat = _gspmm(gidx, "copy_rhs", "sum", None, grad_in)[0] / supp_size.to(out.dtype) v_hat = _gspmm(gidx, "copy_rhs", "sum", None, grad_in)[
0
] / supp_size.to(out.dtype)
grad_in_modify = _gsddmm(gidx, "sub", grad_in, v_hat, "e", "v") grad_in_modify = _gsddmm(gidx, "sub", grad_in, v_hat, "e", "v")
grad_in = torch.where(out != 0, grad_in_modify, grad_in) grad_in = torch.where(out != 0, grad_in_modify, grad_in)
del gidx del gidx
torch.cuda.empty_cache() torch.cuda.empty_cache()
return None, grad_in, None, None, None return None, grad_in, None, None, None
def edge_sparsemax(graph:dgl.DGLGraph, logits, eids=ALL, norm_by="dst"): def edge_sparsemax(graph: dgl.DGLGraph, logits, eids=ALL, norm_by="dst"):
r""" r"""
Description Description
----------- -----------
Compute edge sparsemax. For a node :math:`i`, edge sparsemax is an operation that computes Compute edge sparsemax. For a node :math:`i`, edge sparsemax is an operation that computes
.. math:: .. math::
a_{ij} = \text{ReLU}(z_{ij} - \tau(\z_{i,:})) a_{ij} = \text{ReLU}(z_{ij} - \tau(\z_{i,:}))
where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also
called logits in the context of sparsemax. :math:`\tau` is a function called logits in the context of sparsemax. :math:`\tau` is a function
that can be found at the `From Softmax to Sparsemax <https://arxiv.org/pdf/1602.02068.pdf>` that can be found at the `From Softmax to Sparsemax <https://arxiv.org/pdf/1602.02068.pdf>`
...@@ -176,19 +218,20 @@ def edge_sparsemax(graph:dgl.DGLGraph, logits, eids=ALL, norm_by="dst"): ...@@ -176,19 +218,20 @@ def edge_sparsemax(graph:dgl.DGLGraph, logits, eids=ALL, norm_by="dst"):
if not is_all(eids): if not is_all(eids):
eids = astype(eids, graph.idtype) eids = astype(eids, graph.idtype)
end_n_ids = end_n_ids[eids] end_n_ids = end_n_ids[eids]
return EdgeSparsemaxFunction.apply(graph._graph, logits, return EdgeSparsemaxFunction.apply(
eids, end_n_ids, norm_by) graph._graph, logits, eids, end_n_ids, norm_by
)
class EdgeSparsemax(torch.nn.Module): class EdgeSparsemax(torch.nn.Module):
r""" r"""
Description Description
----------- -----------
Compute edge sparsemax. For a node :math:`i`, edge sparsemax is an operation that computes Compute edge sparsemax. For a node :math:`i`, edge sparsemax is an operation that computes
.. math:: .. math::
a_{ij} = \text{ReLU}(z_{ij} - \tau(\z_{i,:})) a_{ij} = \text{ReLU}(z_{ij} - \tau(\z_{i,:}))
where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also where :math:`z_{ij}` is a signal of edge :math:`j\rightarrow i`, also
called logits in the context of sparsemax. :math:`\tau` is a function called logits in the context of sparsemax. :math:`\tau` is a function
that can be found at the `From Softmax to Sparsemax <https://arxiv.org/pdf/1602.02068.pdf>` that can be found at the `From Softmax to Sparsemax <https://arxiv.org/pdf/1602.02068.pdf>`
...@@ -213,8 +256,9 @@ class EdgeSparsemax(torch.nn.Module): ...@@ -213,8 +256,9 @@ class EdgeSparsemax(torch.nn.Module):
Tensor Tensor
Sparsemax value. Sparsemax value.
""" """
def __init__(self): def __init__(self):
super(EdgeSparsemax, self).__init__() super(EdgeSparsemax, self).__init__()
def forward(self, graph, logits, eids=ALL, norm_by="dst"): def forward(self, graph, logits, eids=ALL, norm_by="dst"):
return edge_sparsemax(graph, logits, eids, norm_by) return edge_sparsemax(graph, logits, eids, norm_by)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment