Unverified Commit f19f05ce authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4651)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 977b1ba4
import os
import cv2 as cv
import matplotlib
import matplotlib.animation as manimation
import matplotlib.pyplot as plt
import numpy as np
import cv2 as cv
import os
import matplotlib.animation as manimation
matplotlib.use('agg')
matplotlib.use("agg")
# Make video can be used to visualize test data
def make_video(xy, filename):
os.system("rm -rf pics/*")
FFMpegWriter = manimation.writers['ffmpeg']
metadata = dict(title='Movie Test', artist='Matplotlib',
comment='Movie support!')
FFMpegWriter = manimation.writers["ffmpeg"]
metadata = dict(
title="Movie Test", artist="Matplotlib", comment="Movie support!"
)
writer = FFMpegWriter(fps=15, metadata=metadata)
fig = plt.figure()
plt.xlim(-200, 200)
plt.ylim(-200, 200)
fig_num = len(xy)
color = ['ro', 'bo', 'go', 'ko', 'yo', 'mo', 'co']
color = ["ro", "bo", "go", "ko", "yo", "mo", "co"]
with writer.saving(fig, filename, len(xy)):
for i in range(len(xy)):
for j in range(len(xy[0])):
......
import torch
from modules import MSA, BiLSTM, GraphTrans
from utlis import *
from torch import nn
from utlis import *
import dgl
......@@ -10,49 +11,74 @@ class GraphWriter(nn.Module):
super(GraphWriter, self).__init__()
self.args = args
if args.title:
self.title_emb = nn.Embedding(len(args.title_vocab), args.nhid, padding_idx=0)
self.title_enc = BiLSTM(args, enc_type='title')
self.title_emb = nn.Embedding(
len(args.title_vocab), args.nhid, padding_idx=0
)
self.title_enc = BiLSTM(args, enc_type="title")
self.title_attn = MSA(args)
self.ent_emb = nn.Embedding(len(args.ent_text_vocab), args.nhid, padding_idx=0)
self.tar_emb = nn.Embedding(len(args.text_vocab), args.nhid, padding_idx=0)
self.ent_emb = nn.Embedding(
len(args.ent_text_vocab), args.nhid, padding_idx=0
)
self.tar_emb = nn.Embedding(
len(args.text_vocab), args.nhid, padding_idx=0
)
if args.title:
nn.init.xavier_normal_(self.title_emb.weight)
nn.init.xavier_normal_(self.ent_emb.weight)
self.rel_emb = nn.Embedding(len(args.rel_vocab), args.nhid, padding_idx=0)
self.rel_emb = nn.Embedding(
len(args.rel_vocab), args.nhid, padding_idx=0
)
nn.init.xavier_normal_(self.rel_emb.weight)
self.decode_lstm = nn.LSTMCell(args.dec_ninp, args.nhid)
self.ent_enc = BiLSTM(args, enc_type='entity')
self.ent_enc = BiLSTM(args, enc_type="entity")
self.graph_enc = GraphTrans(args)
self.ent_attn = MSA(args)
self.copy_attn = MSA(args, mode='copy')
self.copy_attn = MSA(args, mode="copy")
self.copy_fc = nn.Linear(args.dec_ninp, 1)
self.pred_v_fc = nn.Linear(args.dec_ninp, len(args.text_vocab))
def enc_forward(self, batch, ent_mask, ent_text_mask, ent_len, rel_mask, title_mask):
def enc_forward(
self, batch, ent_mask, ent_text_mask, ent_len, rel_mask, title_mask
):
title_enc = None
if self.args.title:
title_enc = self.title_enc(self.title_emb(batch['title']), title_mask)
ent_enc = self.ent_enc(self.ent_emb(batch['ent_text']), ent_text_mask, ent_len = batch['ent_len'])
rel_emb = self.rel_emb(batch['rel'])
g_ent, g_root = self.graph_enc(ent_enc, ent_mask, ent_len, rel_emb, rel_mask, batch['graph'])
title_enc = self.title_enc(
self.title_emb(batch["title"]), title_mask
)
ent_enc = self.ent_enc(
self.ent_emb(batch["ent_text"]),
ent_text_mask,
ent_len=batch["ent_len"],
)
rel_emb = self.rel_emb(batch["rel"])
g_ent, g_root = self.graph_enc(
ent_enc, ent_mask, ent_len, rel_emb, rel_mask, batch["graph"]
)
return g_ent, g_root, title_enc, ent_enc
def forward(self, batch, beam_size=-1):
ent_mask = len2mask(batch['ent_len'], self.args.device)
ent_text_mask = batch['ent_text']==0
rel_mask = batch['rel']==0 # 0 means the <PAD>
title_mask = batch['title']==0
g_ent, g_root, title_enc, ent_enc = self.enc_forward(batch, ent_mask, ent_text_mask, batch['ent_len'], rel_mask, title_mask)
ent_mask = len2mask(batch["ent_len"], self.args.device)
ent_text_mask = batch["ent_text"] == 0
rel_mask = batch["rel"] == 0 # 0 means the <PAD>
title_mask = batch["title"] == 0
g_ent, g_root, title_enc, ent_enc = self.enc_forward(
batch,
ent_mask,
ent_text_mask,
batch["ent_len"],
rel_mask,
title_mask,
)
_h, _c = g_root, g_root.clone().detach()
ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)
if self.args.title:
attn = _h + self.title_attn(_h, title_enc, mask=title_mask)
ctx = torch.cat([ctx, attn], 1)
if beam_size<1:
if beam_size < 1:
# training
outs = []
tar_inp = self.tar_emb(batch['text'].transpose(0,1))
tar_inp = self.tar_emb(batch["text"].transpose(0, 1))
for t, xt in enumerate(tar_inp):
_xt = torch.cat([ctx, xt], 1)
_h, _c = self.decode_lstm(_xt, (_h, _c))
......@@ -65,32 +91,54 @@ class GraphWriter(nn.Module):
copy_gate = torch.sigmoid(self.copy_fc(outs))
EPSI = 1e-6
# copy
pred_v = torch.log(copy_gate+EPSI) + torch.log_softmax(self.pred_v_fc(outs), -1)
pred_c = torch.log((1. - copy_gate)+EPSI) + torch.log_softmax(self.copy_attn(outs, ent_enc, mask=ent_mask), -1)
pred_v = torch.log(copy_gate + EPSI) + torch.log_softmax(
self.pred_v_fc(outs), -1
)
pred_c = torch.log((1.0 - copy_gate) + EPSI) + torch.log_softmax(
self.copy_attn(outs, ent_enc, mask=ent_mask), -1
)
pred = torch.cat([pred_v, pred_c], -1)
return pred
else:
if beam_size==1:
if beam_size == 1:
# greedy
device = g_ent.device
B = g_ent.shape[0]
ent_type = batch['ent_type'].view(B, -1)
seq = (torch.ones(B,).long().to(device) * self.args.text_vocab('<BOS>')).unsqueeze(1)
ent_type = batch["ent_type"].view(B, -1)
seq = (
torch.ones(
B,
)
.long()
.to(device)
* self.args.text_vocab("<BOS>")
).unsqueeze(1)
for t in range(self.args.beam_max_len):
_inp = replace_ent(seq[:,-1], ent_type, len(self.args.text_vocab))
_inp = replace_ent(
seq[:, -1], ent_type, len(self.args.text_vocab)
)
xt = self.tar_emb(_inp)
_xt = torch.cat([ctx, xt], 1)
_h, _c = self.decode_lstm(_xt, (_h, _c))
ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)
if self.args.title:
attn = _h + self.title_attn(_h, title_enc, mask=title_mask)
attn = _h + self.title_attn(
_h, title_enc, mask=title_mask
)
ctx = torch.cat([ctx, attn], 1)
_y = torch.cat([_h, ctx], 1)
copy_gate = torch.sigmoid(self.copy_fc(_y))
pred_v = torch.log(copy_gate) + torch.log_softmax(self.pred_v_fc(_y), -1)
pred_c = torch.log((1. - copy_gate)) + torch.log_softmax(self.copy_attn(_y.unsqueeze(1), ent_enc, mask=ent_mask).squeeze(1), -1)
pred = torch.cat([pred_v, pred_c], -1).view(B,-1)
for ban_item in ['<BOS>', '<PAD>', '<UNK>']:
pred_v = torch.log(copy_gate) + torch.log_softmax(
self.pred_v_fc(_y), -1
)
pred_c = torch.log((1.0 - copy_gate)) + torch.log_softmax(
self.copy_attn(
_y.unsqueeze(1), ent_enc, mask=ent_mask
).squeeze(1),
-1,
)
pred = torch.cat([pred_v, pred_c], -1).view(B, -1)
for ban_item in ["<BOS>", "<PAD>", "<UNK>"]:
pred[:, self.args.text_vocab(ban_item)] = -1e8
_, word = pred.max(-1)
seq = torch.cat([seq, word.unsqueeze(1)], 1)
......@@ -102,47 +150,92 @@ class GraphWriter(nn.Module):
BSZ = B * beam_size
_h = _h.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
_c = _c.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
ent_mask = ent_mask.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
ent_mask = (
ent_mask.view(B, 1, -1)
.repeat(1, beam_size, 1)
.view(BSZ, -1)
)
if self.args.title:
title_mask = title_mask.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
title_enc = title_enc.view(B, 1, title_enc.size(1), -1).repeat(1, beam_size, 1, 1).view(BSZ, title_enc.size(1), -1)
title_mask = (
title_mask.view(B, 1, -1)
.repeat(1, beam_size, 1)
.view(BSZ, -1)
)
title_enc = (
title_enc.view(B, 1, title_enc.size(1), -1)
.repeat(1, beam_size, 1, 1)
.view(BSZ, title_enc.size(1), -1)
)
ctx = ctx.view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
ent_type = batch['ent_type'].view(B, 1, -1).repeat(1, beam_size, 1).view(BSZ, -1)
g_ent = g_ent.view(B, 1, g_ent.size(1), -1).repeat(1, beam_size, 1, 1).view(BSZ, g_ent.size(1), -1)
ent_enc = ent_enc.view(B, 1, ent_enc.size(1), -1).repeat(1, beam_size, 1, 1).view(BSZ, ent_enc.size(1), -1)
ent_type = (
batch["ent_type"]
.view(B, 1, -1)
.repeat(1, beam_size, 1)
.view(BSZ, -1)
)
g_ent = (
g_ent.view(B, 1, g_ent.size(1), -1)
.repeat(1, beam_size, 1, 1)
.view(BSZ, g_ent.size(1), -1)
)
ent_enc = (
ent_enc.view(B, 1, ent_enc.size(1), -1)
.repeat(1, beam_size, 1, 1)
.view(BSZ, ent_enc.size(1), -1)
)
beam_best = torch.zeros(B).to(device) - 1e9
beam_best_seq = [None] * B
beam_seq = (torch.ones(B, beam_size).long().to(device) * self.args.text_vocab('<BOS>')).unsqueeze(-1)
beam_seq = (
torch.ones(B, beam_size).long().to(device)
* self.args.text_vocab("<BOS>")
).unsqueeze(-1)
beam_score = torch.zeros(B, beam_size).to(device)
done_flag = torch.zeros(B, beam_size)
for t in range(self.args.beam_max_len):
_inp = replace_ent(beam_seq[:,:,-1].view(-1), ent_type, len(self.args.text_vocab))
_inp = replace_ent(
beam_seq[:, :, -1].view(-1),
ent_type,
len(self.args.text_vocab),
)
xt = self.tar_emb(_inp)
_xt = torch.cat([ctx, xt], 1)
_h, _c = self.decode_lstm(_xt, (_h, _c))
ctx = _h + self.ent_attn(_h, g_ent, mask=ent_mask)
if self.args.title:
attn = _h + self.title_attn(_h, title_enc, mask=title_mask)
attn = _h + self.title_attn(
_h, title_enc, mask=title_mask
)
ctx = torch.cat([ctx, attn], 1)
_y = torch.cat([_h, ctx], 1)
copy_gate = torch.sigmoid(self.copy_fc(_y))
pred_v = torch.log(copy_gate) + torch.log_softmax(self.pred_v_fc(_y), -1)
pred_c = torch.log((1. - copy_gate)) + torch.log_softmax(self.copy_attn(_y.unsqueeze(1), ent_enc, mask=ent_mask).squeeze(1), -1)
pred = torch.cat([pred_v, pred_c], -1).view(B, beam_size, -1)
for ban_item in ['<BOS>', '<PAD>', '<UNK>']:
pred_v = torch.log(copy_gate) + torch.log_softmax(
self.pred_v_fc(_y), -1
)
pred_c = torch.log((1.0 - copy_gate)) + torch.log_softmax(
self.copy_attn(
_y.unsqueeze(1), ent_enc, mask=ent_mask
).squeeze(1),
-1,
)
pred = torch.cat([pred_v, pred_c], -1).view(
B, beam_size, -1
)
for ban_item in ["<BOS>", "<PAD>", "<UNK>"]:
pred[:, :, self.args.text_vocab(ban_item)] = -1e8
if t==self.args.beam_max_len-1: # force ending
tt = pred[:, :, self.args.text_vocab('<EOS>')]
pred = pred*0-1e8
pred[:, :, self.args.text_vocab('<EOS>')] = tt
cum_score = beam_score.view(B,beam_size,1) + pred
score, word = cum_score.topk(dim=-1, k=beam_size) # B, beam_size, beam_size
score, word = score.view(B,-1), word.view(B,-1)
eos_idx = self.args.text_vocab('<EOS>')
if beam_seq.size(2)==1:
if t == self.args.beam_max_len - 1: # force ending
tt = pred[:, :, self.args.text_vocab("<EOS>")]
pred = pred * 0 - 1e8
pred[:, :, self.args.text_vocab("<EOS>")] = tt
cum_score = beam_score.view(B, beam_size, 1) + pred
score, word = cum_score.topk(
dim=-1, k=beam_size
) # B, beam_size, beam_size
score, word = score.view(B, -1), word.view(B, -1)
eos_idx = self.args.text_vocab("<EOS>")
if beam_seq.size(2) == 1:
new_idx = torch.arange(beam_size).to(word)
new_idx = new_idx[None,:].repeat(B,1)
new_idx = new_idx[None, :].repeat(B, 1)
else:
_, new_idx = score.topk(dim=-1, k=beam_size)
new_src, new_score, new_word, new_done = [], [], [], []
......@@ -151,7 +244,7 @@ class GraphWriter(nn.Module):
for j in range(beam_size):
tmp_score = score[i][new_idx[i][j]]
tmp_word = word[i][new_idx[i][j]]
src_idx = new_idx[i][j]//beam_size
src_idx = new_idx[i][j] // beam_size
new_src.append(src_idx)
if tmp_word == eos_idx:
new_score.append(-1e8)
......@@ -159,24 +252,45 @@ class GraphWriter(nn.Module):
new_score.append(tmp_score)
new_word.append(tmp_word)
if tmp_word == eos_idx and done_flag[i][src_idx]==0 and tmp_score/LP>beam_best[i]:
beam_best[i] = tmp_score/LP
if (
tmp_word == eos_idx
and done_flag[i][src_idx] == 0
and tmp_score / LP > beam_best[i]
):
beam_best[i] = tmp_score / LP
beam_best_seq[i] = beam_seq[i][src_idx]
if tmp_word == eos_idx:
new_done.append(1)
else:
new_done.append(done_flag[i][src_idx])
new_score = torch.Tensor(new_score).view(B,beam_size).to(beam_score)
new_word = torch.Tensor(new_word).view(B,beam_size).to(beam_seq)
new_src = torch.LongTensor(new_src).view(B,beam_size).to(device)
new_done = torch.Tensor(new_done).view(B,beam_size).to(done_flag)
new_score = (
torch.Tensor(new_score)
.view(B, beam_size)
.to(beam_score)
)
new_word = (
torch.Tensor(new_word).view(B, beam_size).to(beam_seq)
)
new_src = (
torch.LongTensor(new_src).view(B, beam_size).to(device)
)
new_done = (
torch.Tensor(new_done).view(B, beam_size).to(done_flag)
)
beam_score = new_score
done_flag = new_done
beam_seq = beam_seq.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src]
beam_seq = beam_seq.view(B, beam_size, -1)[
torch.arange(B)[:, None].to(device), new_src
]
beam_seq = torch.cat([beam_seq, new_word.unsqueeze(2)], 2)
_h = _h.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src].view(BSZ,-1)
_c = _c.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src].view(BSZ,-1)
ctx = ctx.view(B,beam_size,-1)[torch.arange(B)[:,None].to(device), new_src].view(BSZ,-1)
_h = _h.view(B, beam_size, -1)[
torch.arange(B)[:, None].to(device), new_src
].view(BSZ, -1)
_c = _c.view(B, beam_size, -1)[
torch.arange(B)[:, None].to(device), new_src
].view(BSZ, -1)
ctx = ctx.view(B, beam_size, -1)[
torch.arange(B)[:, None].to(device), new_src
].view(BSZ, -1)
return beam_best_seq
import torch
import math
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from utlis import *
import dgl.function as fn
from dgl.nn.functional import edge_softmax
from utlis import *
from torch import nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence,pad_packed_sequence
class MSA(nn.Module):
......@@ -14,44 +16,56 @@ class MSA(nn.Module):
# the second is the normal attention with two sequence inputs
# the third is the attention but with one token and a sequence. (gather, attentive pooling)
def __init__(self, args, mode='normal'):
def __init__(self, args, mode="normal"):
super(MSA, self).__init__()
if mode=='copy':
if mode == "copy":
nhead, head_dim = 1, args.nhid
qninp, kninp = args.dec_ninp, args.nhid
if mode=='normal':
if mode == "normal":
nhead, head_dim = args.nhead, args.head_dim
qninp, kninp = args.nhid, args.nhid
self.attn_drop = nn.Dropout(0.1)
self.WQ = nn.Linear(qninp, nhead*head_dim, bias=True if mode=='copy' else False)
if mode!='copy':
self.WK = nn.Linear(kninp, nhead*head_dim, bias=False)
self.WV = nn.Linear(kninp, nhead*head_dim, bias=False)
self.args, self.nhead, self.head_dim, self.mode = args, nhead, head_dim, mode
self.WQ = nn.Linear(
qninp, nhead * head_dim, bias=True if mode == "copy" else False
)
if mode != "copy":
self.WK = nn.Linear(kninp, nhead * head_dim, bias=False)
self.WV = nn.Linear(kninp, nhead * head_dim, bias=False)
self.args, self.nhead, self.head_dim, self.mode = (
args,
nhead,
head_dim,
mode,
)
def forward(self, inp1, inp2, mask=None):
B, L2, H = inp2.shape
NH, HD = self.nhead, self.head_dim
if self.mode=='copy':
if self.mode == "copy":
q, k, v = self.WQ(inp1), inp2, inp2
else:
q, k, v = self.WQ(inp1), self.WK(inp2), self.WV(inp2)
L1 = 1 if inp1.ndim==2 else inp1.shape[1]
if self.mode!='copy':
L1 = 1 if inp1.ndim == 2 else inp1.shape[1]
if self.mode != "copy":
q = q / math.sqrt(H)
q = q.view(B, L1, NH, HD).permute(0, 2, 1, 3)
k = k.view(B, L2, NH, HD).permute(0, 2, 3, 1)
v = v.view(B, L2, NH, HD).permute(0, 2, 1, 3)
pre_attn = torch.matmul(q,k)
pre_attn = torch.matmul(q, k)
if mask is not None:
pre_attn = pre_attn.masked_fill(mask[:,None,None,:], -1e8)
if self.mode=='copy':
pre_attn = pre_attn.masked_fill(mask[:, None, None, :], -1e8)
if self.mode == "copy":
return pre_attn.squeeze(1)
else:
alpha = self.attn_drop(torch.softmax(pre_attn, -1))
attn = torch.matmul(alpha, v).permute(0, 2, 1, 3).contiguous().view(B,L1,NH*HD)
attn = (
torch.matmul(alpha, v)
.permute(0, 2, 1, 3)
.contiguous()
.view(B, L1, NH * HD)
)
ret = attn
if inp1.ndim==2:
if inp1.ndim == 2:
return ret.squeeze(1)
else:
return ret
......@@ -59,52 +73,63 @@ class MSA(nn.Module):
class BiLSTM(nn.Module):
# for entity encoding or the title encoding
def __init__(self, args, enc_type='title'):
def __init__(self, args, enc_type="title"):
super(BiLSTM, self).__init__()
self.enc_type = enc_type
self.drop = nn.Dropout(args.emb_drop)
self.bilstm = nn.LSTM(args.nhid, args.nhid//2, bidirectional=True, \
num_layers=args.enc_lstm_layers, batch_first=True)
self.bilstm = nn.LSTM(
args.nhid,
args.nhid // 2,
bidirectional=True,
num_layers=args.enc_lstm_layers,
batch_first=True,
)
def forward(self, inp, mask, ent_len=None):
inp = self.drop(inp)
lens = (mask==0).sum(-1).long().tolist()
pad_seq = pack_padded_sequence(inp, lens, batch_first=True, enforce_sorted=False)
lens = (mask == 0).sum(-1).long().tolist()
pad_seq = pack_padded_sequence(
inp, lens, batch_first=True, enforce_sorted=False
)
y, (_h, _c) = self.bilstm(pad_seq)
if self.enc_type=='title':
if self.enc_type == "title":
y = pad_packed_sequence(y, batch_first=True)[0]
return y
if self.enc_type=='entity':
_h = _h.transpose(0,1).contiguous()
_h = _h[:,-2:].view(_h.size(0), -1) # two directions of the top-layer
ret = pad(_h.split(ent_len), out_type='tensor')
if self.enc_type == "entity":
_h = _h.transpose(0, 1).contiguous()
_h = _h[:, -2:].view(
_h.size(0), -1
) # two directions of the top-layer
ret = pad(_h.split(ent_len), out_type="tensor")
return ret
class GAT(nn.Module):
# a graph attention network with dot-product attention
def __init__(self,
def __init__(
self,
in_feats,
out_feats,
num_heads,
ffn_drop=0.,
attn_drop=0.,
trans=True):
ffn_drop=0.0,
attn_drop=0.0,
trans=True,
):
super(GAT, self).__init__()
self._num_heads = num_heads
self._in_feats = in_feats
self._out_feats = out_feats
self.q_proj = nn.Linear(in_feats, num_heads*out_feats, bias=False)
self.k_proj = nn.Linear(in_feats, num_heads*out_feats, bias=False)
self.v_proj = nn.Linear(in_feats, num_heads*out_feats, bias=False)
self.q_proj = nn.Linear(in_feats, num_heads * out_feats, bias=False)
self.k_proj = nn.Linear(in_feats, num_heads * out_feats, bias=False)
self.v_proj = nn.Linear(in_feats, num_heads * out_feats, bias=False)
self.attn_drop = nn.Dropout(0.1)
self.ln1 = nn.LayerNorm(in_feats)
self.ln2 = nn.LayerNorm(in_feats)
if trans:
self.FFN = nn.Sequential(
nn.Linear(in_feats, 4*in_feats),
nn.PReLU(4*in_feats),
nn.Linear(4*in_feats, in_feats),
nn.Linear(in_feats, 4 * in_feats),
nn.PReLU(4 * in_feats),
nn.Linear(4 * in_feats, in_feats),
nn.Dropout(0.1),
)
# a strange FFN, see the author's code
......@@ -117,40 +142,64 @@ class GAT(nn.Module):
q = q.view(-1, self._num_heads, self._out_feats)
k = k.view(-1, self._num_heads, self._out_feats)
v = v.view(-1, self._num_heads, self._out_feats)
graph.ndata.update({'ft': v, 'el': k, 'er': q}) # k,q instead of q,k, the edge_softmax is applied on incoming edges
graph.ndata.update(
{"ft": v, "el": k, "er": q}
) # k,q instead of q,k, the edge_softmax is applied on incoming edges
# compute edge attention
graph.apply_edges(fn.u_dot_v('el', 'er', 'e'))
e = graph.edata.pop('e') / math.sqrt(self._out_feats * self._num_heads)
graph.edata['a'] = edge_softmax(graph, e)
graph.apply_edges(fn.u_dot_v("el", "er", "e"))
e = graph.edata.pop("e") / math.sqrt(self._out_feats * self._num_heads)
graph.edata["a"] = edge_softmax(graph, e)
# message passing
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft2'))
rst = graph.ndata['ft2']
graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft2"))
rst = graph.ndata["ft2"]
# residual
rst = rst.view(feat.shape) + feat
if self._trans:
rst = self.ln1(rst)
rst = self.ln1(rst+self.FFN(rst))
rst = self.ln1(rst + self.FFN(rst))
# use the same layer norm, see the author's code
return rst
class GraphTrans(nn.Module):
def __init__(self,args):
def __init__(self, args):
super().__init__()
self.args = args
if args.graph_enc == "gat":
# we only support gtrans, don't use this one
self.gat = nn.ModuleList([GAT(args.nhid, args.nhid//4, 4, attn_drop=args.attn_drop, trans=False) for _ in range(args.prop)]) #untested
self.gat = nn.ModuleList(
[
GAT(
args.nhid,
args.nhid // 4,
4,
attn_drop=args.attn_drop,
trans=False,
)
for _ in range(args.prop)
]
) # untested
else:
self.gat = nn.ModuleList([GAT(args.nhid, args.nhid//4, 4, attn_drop=args.attn_drop, ffn_drop=args.drop, trans=True) for _ in range(args.prop)])
self.gat = nn.ModuleList(
[
GAT(
args.nhid,
args.nhid // 4,
4,
attn_drop=args.attn_drop,
ffn_drop=args.drop,
trans=True,
)
for _ in range(args.prop)
]
)
self.prop = args.prop
def forward(self, ent, ent_mask, ent_len, rel, rel_mask, graphs):
device = ent.device
graphs = graphs.to(device)
ent_mask = (ent_mask==0) # reverse mask
rel_mask = (rel_mask==0)
ent_mask = ent_mask == 0 # reverse mask
rel_mask = rel_mask == 0
init_h = []
for i in range(graphs.batch_size):
init_h.append(ent[i][ent_mask[i]])
......@@ -159,7 +208,19 @@ class GraphTrans(nn.Module):
feats = init_h
for i in range(self.prop):
feats = self.gat[i](graphs, feats)
g_root = feats.index_select(0, graphs.filter_nodes(lambda x: x.data['type']==NODE_TYPE['root']).to(device))
g_ent = pad(feats.index_select(0, graphs.filter_nodes(lambda x: x.data['type']==NODE_TYPE['entity']).to(device)).split(ent_len), out_type='tensor')
g_root = feats.index_select(
0,
graphs.filter_nodes(
lambda x: x.data["type"] == NODE_TYPE["root"]
).to(device),
)
g_ent = pad(
feats.index_select(
0,
graphs.filter_nodes(
lambda x: x.data["type"] == NODE_TYPE["entity"]
).to(device),
).split(ent_len),
out_type="tensor",
)
return g_ent, g_root
import torch
import argparse
import torch
def fill_config(args):
# dirty work
......@@ -10,7 +11,9 @@ def fill_config(args):
return args
def vocab_config(args, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab):
def vocab_config(
args, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab
):
# dirty work
args.ent_vocab = ent_vocab
args.rel_vocab = rel_vocab
......@@ -21,35 +24,84 @@ def vocab_config(args, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_v
def get_args():
args = argparse.ArgumentParser(description='Graph Writer in DGL')
args.add_argument('--nhid', default=500, type=int, help='hidden size')
args.add_argument('--nhead', default=4, type=int, help='number of heads')
args.add_argument('--head_dim', default=125, type=int, help='head dim')
args.add_argument('--weight_decay', default=0.0, type=float, help='weight decay')
args.add_argument('--prop', default=6, type=int, help='number of layers of gnn')
args.add_argument('--title', action='store_true', help='use title input')
args.add_argument('--test', action='store_true', help='inference mode')
args.add_argument('--batch_size', default=32, type=int, help='batch_size')
args.add_argument('--beam_size', default=4, type=int, help='beam size, 1 for greedy')
args.add_argument('--epoch', default=20, type=int, help='training epoch')
args.add_argument('--beam_max_len', default=200, type=int, help='max length of the generated text')
args.add_argument('--enc_lstm_layers', default=2, type=int, help='number of layers of lstm')
args.add_argument('--lr', default=1e-1, type=float, help='learning rate')
#args.add_argument('--lr_decay', default=1e-8, type=float, help='')
args.add_argument('--clip', default=1, type=float, help='gradient clip')
args.add_argument('--emb_drop', default=0.0, type=float, help='embedding dropout')
args.add_argument('--attn_drop', default=0.1, type=float, help='attention dropout')
args.add_argument('--drop', default=0.1, type=float, help='dropout')
args.add_argument('--lp', default=1.0, type=float, help='length penalty')
args.add_argument('--graph_enc', default='gtrans', type=str, help='gnn mode, we only support the graph transformer now')
args.add_argument('--train_file', default='data/unprocessed.train.json', type=str, help='training file')
args.add_argument('--valid_file', default='data/unprocessed.val.json', type=str, help='validation file')
args.add_argument('--test_file', default='data/unprocessed.test.json', type=str, help='test file')
args.add_argument('--save_dataset', default='data.pickle', type=str, help='save path of dataset')
args.add_argument('--save_model', default='saved_model.pt', type=str, help='save path of model')
args.add_argument('--gpu', default=0, type=int, help='gpu mode')
args = argparse.ArgumentParser(description="Graph Writer in DGL")
args.add_argument("--nhid", default=500, type=int, help="hidden size")
args.add_argument("--nhead", default=4, type=int, help="number of heads")
args.add_argument("--head_dim", default=125, type=int, help="head dim")
args.add_argument(
"--weight_decay", default=0.0, type=float, help="weight decay"
)
args.add_argument(
"--prop", default=6, type=int, help="number of layers of gnn"
)
args.add_argument("--title", action="store_true", help="use title input")
args.add_argument("--test", action="store_true", help="inference mode")
args.add_argument("--batch_size", default=32, type=int, help="batch_size")
args.add_argument(
"--beam_size", default=4, type=int, help="beam size, 1 for greedy"
)
args.add_argument("--epoch", default=20, type=int, help="training epoch")
args.add_argument(
"--beam_max_len",
default=200,
type=int,
help="max length of the generated text",
)
args.add_argument(
"--enc_lstm_layers",
default=2,
type=int,
help="number of layers of lstm",
)
args.add_argument("--lr", default=1e-1, type=float, help="learning rate")
# args.add_argument('--lr_decay', default=1e-8, type=float, help='')
args.add_argument("--clip", default=1, type=float, help="gradient clip")
args.add_argument(
"--emb_drop", default=0.0, type=float, help="embedding dropout"
)
args.add_argument(
"--attn_drop", default=0.1, type=float, help="attention dropout"
)
args.add_argument("--drop", default=0.1, type=float, help="dropout")
args.add_argument("--lp", default=1.0, type=float, help="length penalty")
args.add_argument(
"--graph_enc",
default="gtrans",
type=str,
help="gnn mode, we only support the graph transformer now",
)
args.add_argument(
"--train_file",
default="data/unprocessed.train.json",
type=str,
help="training file",
)
args.add_argument(
"--valid_file",
default="data/unprocessed.val.json",
type=str,
help="validation file",
)
args.add_argument(
"--test_file",
default="data/unprocessed.test.json",
type=str,
help="test file",
)
args.add_argument(
"--save_dataset",
default="data.pickle",
type=str,
help="save path of dataset",
)
args.add_argument(
"--save_model",
default="saved_model.pt",
type=str,
help="save path of model",
)
args.add_argument("--gpu", default=0, type=int, help="gpu mode")
args = args.parse_args()
args = fill_config(args)
return args
import os
import sys
import time
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import time
from tqdm import tqdm
import torch.nn.functional as F
from graphwriter import *
from utlis import *
from opts import *
import os
import sys
from tqdm import tqdm
from utlis import *
sys.path.append('./pycocoevalcap')
sys.path.append("./pycocoevalcap")
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.rouge.rouge import Rouge
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
def train_one_epoch(model, dataloader, optimizer, args, epoch):
model.train()
tloss = 0.
tcnt = 0.
tloss = 0.0
tcnt = 0.0
st_time = time.time()
with tqdm(dataloader, desc='Train Ep '+str(epoch), mininterval=60) as tq:
with tqdm(dataloader, desc="Train Ep " + str(epoch), mininterval=60) as tq:
for batch in tq:
pred = model(batch)
nll_loss = F.nll_loss(pred.view(-1, pred.shape[-1]), batch['tgt_text'].view(-1), ignore_index=0)
nll_loss = F.nll_loss(
pred.view(-1, pred.shape[-1]),
batch["tgt_text"].view(-1),
ignore_index=0,
)
loss = nll_loss
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()
loss = loss.item()
if loss!=loss:
raise ValueError('NaN appear')
tloss += loss * len(batch['tgt_text'])
tcnt += len(batch['tgt_text'])
tq.set_postfix({'loss': tloss/tcnt}, refresh=False)
print('Train Ep ', str(epoch), 'AVG Loss ', tloss/tcnt, 'Steps ', tcnt, 'Time ', time.time()-st_time, 'GPU', torch.cuda.max_memory_cached()/1024.0/1024.0/1024.0)
torch.save(model, args.save_model+str(epoch%100))
if loss != loss:
raise ValueError("NaN appear")
tloss += loss * len(batch["tgt_text"])
tcnt += len(batch["tgt_text"])
tq.set_postfix({"loss": tloss / tcnt}, refresh=False)
print(
"Train Ep ",
str(epoch),
"AVG Loss ",
tloss / tcnt,
"Steps ",
tcnt,
"Time ",
time.time() - st_time,
"GPU",
torch.cuda.max_memory_cached() / 1024.0 / 1024.0 / 1024.0,
)
torch.save(model, args.save_model + str(epoch % 100))
val_loss = 2**31
def eval_it(model, dataloader, args, epoch):
global val_loss
model.eval()
tloss = 0.
tcnt = 0.
tloss = 0.0
tcnt = 0.0
st_time = time.time()
with tqdm(dataloader, desc='Eval Ep '+str(epoch), mininterval=60) as tq:
with tqdm(dataloader, desc="Eval Ep " + str(epoch), mininterval=60) as tq:
for batch in tq:
with torch.no_grad():
pred = model(batch)
nll_loss = F.nll_loss(pred.view(-1, pred.shape[-1]), batch['tgt_text'].view(-1), ignore_index=0)
nll_loss = F.nll_loss(
pred.view(-1, pred.shape[-1]),
batch["tgt_text"].view(-1),
ignore_index=0,
)
loss = nll_loss
loss = loss.item()
tloss += loss * len(batch['tgt_text'])
tcnt += len(batch['tgt_text'])
tq.set_postfix({'loss': tloss/tcnt}, refresh=False)
print('Eval Ep ', str(epoch), 'AVG Loss ', tloss/tcnt, 'Steps ', tcnt, 'Time ', time.time()-st_time)
if tloss/tcnt < val_loss:
print('Saving best model ', 'Ep ', epoch, ' loss ', tloss/tcnt)
torch.save(model, args.save_model+'best')
val_loss = tloss/tcnt
tloss += loss * len(batch["tgt_text"])
tcnt += len(batch["tgt_text"])
tq.set_postfix({"loss": tloss / tcnt}, refresh=False)
print(
"Eval Ep ",
str(epoch),
"AVG Loss ",
tloss / tcnt,
"Steps ",
tcnt,
"Time ",
time.time() - st_time,
)
if tloss / tcnt < val_loss:
print("Saving best model ", "Ep ", epoch, " loss ", tloss / tcnt)
torch.save(model, args.save_model + "best")
val_loss = tloss / tcnt
def test(model, dataloader, args):
......@@ -71,39 +102,61 @@ def test(model, dataloader, args):
hyp = []
ref = []
model.eval()
gold_file = open('tmp_gold.txt', 'w')
pred_file = open('tmp_pred.txt', 'w')
with tqdm(dataloader, desc='Test ', mininterval=1) as tq:
gold_file = open("tmp_gold.txt", "w")
pred_file = open("tmp_pred.txt", "w")
with tqdm(dataloader, desc="Test ", mininterval=1) as tq:
for batch in tq:
with torch.no_grad():
seq = model(batch, beam_size=args.beam_size)
r = write_txt(batch, batch['tgt_text'], gold_file, args)
r = write_txt(batch, batch["tgt_text"], gold_file, args)
h = write_txt(batch, seq, pred_file, args)
hyp.extend(h)
ref.extend(r)
hyp = dict(zip(range(len(hyp)), hyp))
ref = dict(zip(range(len(ref)), ref))
print(hyp[0], ref[0])
print('BLEU INP', len(hyp), len(ref))
print('BLEU', scorer.compute_score(ref, hyp)[0])
print('METEOR', m_scorer.compute_score(ref, hyp)[0])
print('ROUGE_L', r_scorer.compute_score(ref, hyp)[0])
print("BLEU INP", len(hyp), len(ref))
print("BLEU", scorer.compute_score(ref, hyp)[0])
print("METEOR", m_scorer.compute_score(ref, hyp)[0])
print("ROUGE_L", r_scorer.compute_score(ref, hyp)[0])
gold_file.close()
pred_file.close()
def main(args):
if os.path.exists(args.save_dataset):
train_dataset, valid_dataset, test_dataset = pickle.load(open(args.save_dataset, 'rb'))
train_dataset, valid_dataset, test_dataset = pickle.load(
open(args.save_dataset, "rb")
)
else:
train_dataset, valid_dataset, test_dataset = get_datasets(args.fnames, device=args.device, save=args.save_dataset)
args = vocab_config(args, train_dataset.ent_vocab, train_dataset.rel_vocab, train_dataset.text_vocab, train_dataset.ent_text_vocab, train_dataset.title_vocab)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_sampler = BucketSampler(train_dataset, batch_size=args.batch_size), \
collate_fn=train_dataset.batch_fn)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, \
shuffle=False, collate_fn=train_dataset.batch_fn)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, \
shuffle=False, collate_fn=train_dataset.batch_fn)
train_dataset, valid_dataset, test_dataset = get_datasets(
args.fnames, device=args.device, save=args.save_dataset
)
args = vocab_config(
args,
train_dataset.ent_vocab,
train_dataset.rel_vocab,
train_dataset.text_vocab,
train_dataset.ent_text_vocab,
train_dataset.title_vocab,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=BucketSampler(train_dataset, batch_size=args.batch_size),
collate_fn=train_dataset.batch_fn,
)
valid_dataloader = torch.utils.data.DataLoader(
valid_dataset,
batch_size=args.batch_size,
shuffle=False,
collate_fn=train_dataset.batch_fn,
)
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
collate_fn=train_dataset.batch_fn,
)
model = GraphWriter(args)
model.to(args.device)
......@@ -113,13 +166,18 @@ def main(args):
print(model)
test(model, test_dataloader, args)
else:
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9)
optimizer = torch.optim.SGD(
model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay,
momentum=0.9,
)
print(model)
for epoch in range(args.epoch):
train_one_epoch(model, train_dataloader, optimizer, args, epoch)
eval_it(model, valid_dataloader, args, epoch)
if __name__ == '__main__':
if __name__ == "__main__":
args = get_args()
main(args)
import torch
import dgl
import numpy as np
import json
import pickle
import random
import numpy as np
import torch
import dgl
NODE_TYPE = {'entity': 0, 'root': 1, 'relation':2}
NODE_TYPE = {"entity": 0, "root": 1, "relation": 2}
def write_txt(batch, seqs, w_file, args):
......@@ -16,59 +17,103 @@ def write_txt(batch, seqs, w_file, args):
txt = []
for token in seq:
# copy the entity
if token>=len(args.text_vocab):
ent_text = batch['raw_ent_text'][b][token-len(args.text_vocab)]
ent_text = filter(lambda x:x!='<PAD>', ent_text)
if token >= len(args.text_vocab):
ent_text = batch["raw_ent_text"][b][
token - len(args.text_vocab)
]
ent_text = filter(lambda x: x != "<PAD>", ent_text)
txt.extend(ent_text)
else:
if int(token) not in [args.text_vocab(x) for x in ['<PAD>', '<BOS>', '<EOS>']]:
if int(token) not in [
args.text_vocab(x) for x in ["<PAD>", "<BOS>", "<EOS>"]
]:
txt.append(args.text_vocab(int(token)))
if int(token) == args.text_vocab('<EOS>'):
if int(token) == args.text_vocab("<EOS>"):
break
w_file.write(' '.join([str(x) for x in txt])+'\n')
ret.append([' '.join([str(x) for x in txt])])
w_file.write(" ".join([str(x) for x in txt]) + "\n")
ret.append([" ".join([str(x) for x in txt])])
return ret
def replace_ent(x, ent, V):
# replace the entity
mask = x>=V
if mask.sum()==0:
mask = x >= V
if mask.sum() == 0:
return x
nz = mask.nonzero()
fill_ent = ent[nz, x[mask]-V]
fill_ent = ent[nz, x[mask] - V]
x = x.masked_scatter(mask, fill_ent)
return x
def len2mask(lens, device):
max_len = max(lens)
mask = torch.arange(max_len, device=device).unsqueeze(0).expand(len(lens), max_len)
mask = (
torch.arange(max_len, device=device)
.unsqueeze(0)
.expand(len(lens), max_len)
)
mask = mask >= torch.LongTensor(lens).to(mask).unsqueeze(1)
return mask
def pad(var_len_list, out_type='list', flatten=False):
def pad(var_len_list, out_type="list", flatten=False):
if flatten:
lens = [len(x) for x in var_len_list]
var_len_list = sum(var_len_list, [])
max_len = max([len(x) for x in var_len_list])
if out_type=='list':
if out_type == "list":
if flatten:
return [x+['<PAD>']*(max_len-len(x)) for x in var_len_list], lens
return [
x + ["<PAD>"] * (max_len - len(x)) for x in var_len_list
], lens
else:
return [x+['<PAD>']*(max_len-len(x)) for x in var_len_list]
if out_type=='tensor':
return [x + ["<PAD>"] * (max_len - len(x)) for x in var_len_list]
if out_type == "tensor":
if flatten:
return torch.stack([torch.cat([x, \
torch.zeros([max_len-len(x)]+list(x.shape[1:])).type_as(x)], 0) for x in var_len_list], 0), lens
return (
torch.stack(
[
torch.cat(
[
x,
torch.zeros(
[max_len - len(x)] + list(x.shape[1:])
).type_as(x),
],
0,
)
for x in var_len_list
],
0,
),
lens,
)
else:
return torch.stack([torch.cat([x, \
torch.zeros([max_len-len(x)]+list(x.shape[1:])).type_as(x)], 0) for x in var_len_list], 0)
return torch.stack(
[
torch.cat(
[
x,
torch.zeros(
[max_len - len(x)] + list(x.shape[1:])
).type_as(x),
],
0,
)
for x in var_len_list
],
0,
)
class Vocab(object):
def __init__(self, max_vocab=2**31, min_freq=-1, sp=['<PAD>', '<BOS>', '<EOS>', '<UNK>']):
def __init__(
self,
max_vocab=2**31,
min_freq=-1,
sp=["<PAD>", "<BOS>", "<EOS>", "<UNK>"],
):
self.i2s = []
self.s2i = {}
self.wf = {}
......@@ -78,7 +123,7 @@ class Vocab(object):
return len(self.i2s)
def __str__(self):
return 'Total ' + str(len(self.i2s)) + str(self.i2s[:10])
return "Total " + str(len(self.i2s)) + str(self.i2s[:10])
def update(self, token):
if isinstance(token, list):
......@@ -89,9 +134,13 @@ class Vocab(object):
def build(self):
self.i2s.extend(self.sp)
sort_kv = sorted(self.wf.items(), key=lambda x:x[1], reverse=True)
for k,v in sort_kv:
if len(self.i2s)<self.max_vocab and v>=self.min_freq and k not in self.sp:
sort_kv = sorted(self.wf.items(), key=lambda x: x[1], reverse=True)
for k, v in sort_kv:
if (
len(self.i2s) < self.max_vocab
and v >= self.min_freq
and k not in self.sp
):
self.i2s.append(k)
self.s2i.update(list(zip(self.i2s, range(len(self.i2s)))))
......@@ -99,7 +148,7 @@ class Vocab(object):
if isinstance(x, int):
return self.i2s[x]
else:
return self.s2i.get(x, self.s2i['<UNK>'])
return self.s2i.get(x, self.s2i["<UNK>"])
def save(self, fname):
pass
......@@ -107,101 +156,159 @@ class Vocab(object):
def load(self, fname):
pass
def at_least(x):
# handling the illegal data
if len(x) == 0:
return ['<UNK>']
return ["<UNK>"]
else:
return x
class Example(object):
def __init__(self, title, ent_text, ent_type, rel, text):
# one object corresponds to a data sample
self.raw_title = title.split()
self.raw_ent_text = [at_least(x.split()) for x in ent_text]
assert min([len(x) for x in self.raw_ent_text])>0, str(self.raw_ent_text)
assert min([len(x) for x in self.raw_ent_text]) > 0, str(
self.raw_ent_text
)
self.raw_ent_type = ent_type.split() # <method> .. <>
self.raw_rel = []
for r in rel:
rel_list = r.split()
for i in range(len(rel_list)):
if i>0 and i<len(rel_list)-1 and rel_list[i-1]=='--' and rel_list[i]!=rel_list[i].lower() and rel_list[i+1]=='--':
self.raw_rel.append([rel_list[:i-1], rel_list[i-1]+rel_list[i]+rel_list[i+1], rel_list[i+2:]])
if (
i > 0
and i < len(rel_list) - 1
and rel_list[i - 1] == "--"
and rel_list[i] != rel_list[i].lower()
and rel_list[i + 1] == "--"
):
self.raw_rel.append(
[
rel_list[: i - 1],
rel_list[i - 1] + rel_list[i] + rel_list[i + 1],
rel_list[i + 2 :],
]
)
break
self.raw_text = text.split()
self.graph = self.build_graph()
def __str__(self):
return '\n'.join([str(k)+':\t'+str(v) for k, v in self.__dict__.items()])
return "\n".join(
[str(k) + ":\t" + str(v) for k, v in self.__dict__.items()]
)
def __len__(self):
return len(self.raw_text)
@staticmethod
def from_json(json_data):
return Example(json_data['title'], json_data['entities'], json_data['types'], json_data['relations'], json_data['abstract'])
return Example(
json_data["title"],
json_data["entities"],
json_data["types"],
json_data["relations"],
json_data["abstract"],
)
def build_graph(self):
graph = dgl.DGLGraph()
ent_len = len(self.raw_ent_text)
rel_len = len(self.raw_rel) # treat the repeated relation as different nodes, refer to the author's code
graph.add_nodes(ent_len, {'type': torch.ones(ent_len) * NODE_TYPE['entity']})
graph.add_nodes(1, {'type': torch.ones(1) * NODE_TYPE['root']})
graph.add_nodes(rel_len*2, {'type': torch.ones(rel_len*2) * NODE_TYPE['relation']})
rel_len = len(
self.raw_rel
) # treat the repeated relation as different nodes, refer to the author's code
graph.add_nodes(
ent_len, {"type": torch.ones(ent_len) * NODE_TYPE["entity"]}
)
graph.add_nodes(1, {"type": torch.ones(1) * NODE_TYPE["root"]})
graph.add_nodes(
rel_len * 2,
{"type": torch.ones(rel_len * 2) * NODE_TYPE["relation"]},
)
graph.add_edges(ent_len, torch.arange(ent_len))
graph.add_edges(torch.arange(ent_len), ent_len)
graph.add_edges(torch.arange(ent_len+1+rel_len*2), torch.arange(ent_len+1+rel_len*2))
graph.add_edges(
torch.arange(ent_len + 1 + rel_len * 2),
torch.arange(ent_len + 1 + rel_len * 2),
)
adj_edges = []
for i, r in enumerate(self.raw_rel):
assert len(r)==3, str(r)
assert len(r) == 3, str(r)
st, rt, ed = r
st_ent, ed_ent = self.raw_ent_text.index(st), self.raw_ent_text.index(ed)
st_ent, ed_ent = self.raw_ent_text.index(
st
), self.raw_ent_text.index(ed)
# according to the edge_softmax operator, we need to reverse the graph
adj_edges.append([ent_len+1+2*i, st_ent])
adj_edges.append([ed_ent, ent_len+1+2*i])
adj_edges.append([ent_len+1+2*i+1, ed_ent])
adj_edges.append([st_ent, ent_len+1+2*i+1])
adj_edges.append([ent_len + 1 + 2 * i, st_ent])
adj_edges.append([ed_ent, ent_len + 1 + 2 * i])
adj_edges.append([ent_len + 1 + 2 * i + 1, ed_ent])
adj_edges.append([st_ent, ent_len + 1 + 2 * i + 1])
if len(adj_edges)>0:
if len(adj_edges) > 0:
graph.add_edges(*list(map(list, zip(*adj_edges))))
return graph
def get_tensor(self, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab):
if hasattr(self, '_cached_tensor'):
def get_tensor(
self, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab
):
if hasattr(self, "_cached_tensor"):
return self._cached_tensor
else:
title_data = ['<BOS>'] + self.raw_title + ['<EOS>']
title_data = ["<BOS>"] + self.raw_title + ["<EOS>"]
title = [title_vocab(x) for x in title_data]
ent_text = [[ent_text_vocab(y) for y in x] for x in self.raw_ent_text]
ent_type = [text_vocab(x) for x in self.raw_ent_type] # for inference
rel_data = ['--root--'] + sum([[x[1],x[1]+'_INV'] for x in self.raw_rel], [])
ent_text = [
[ent_text_vocab(y) for y in x] for x in self.raw_ent_text
]
ent_type = [
text_vocab(x) for x in self.raw_ent_type
] # for inference
rel_data = ["--root--"] + sum(
[[x[1], x[1] + "_INV"] for x in self.raw_rel], []
)
rel = [rel_vocab(x) for x in rel_data]
text_data = ['<BOS>'] + self.raw_text + ['<EOS>']
text_data = ["<BOS>"] + self.raw_text + ["<EOS>"]
text = [text_vocab(x) for x in text_data]
tgt_text = []
# the input text and decoding target are different since the consideration of the copy mechanism.
for i, str1 in enumerate(text_data):
if str1[0]=='<' and str1[-1]=='>' and '_' in str1:
a, b = str1[1:-1].split('_')
text[i] = text_vocab('<'+a+'>')
tgt_text.append(len(text_vocab)+int(b))
if str1[0] == "<" and str1[-1] == ">" and "_" in str1:
a, b = str1[1:-1].split("_")
text[i] = text_vocab("<" + a + ">")
tgt_text.append(len(text_vocab) + int(b))
else:
tgt_text.append(text[i])
self._cached_tensor = {'title': torch.LongTensor(title), 'ent_text': [torch.LongTensor(x) for x in ent_text], \
'ent_type': torch.LongTensor(ent_type), 'rel': torch.LongTensor(rel), \
'text': torch.LongTensor(text[:-1]), 'tgt_text': torch.LongTensor(tgt_text[1:]), 'graph': self.graph, 'raw_ent_text': self.raw_ent_text}
self._cached_tensor = {
"title": torch.LongTensor(title),
"ent_text": [torch.LongTensor(x) for x in ent_text],
"ent_type": torch.LongTensor(ent_type),
"rel": torch.LongTensor(rel),
"text": torch.LongTensor(text[:-1]),
"tgt_text": torch.LongTensor(tgt_text[1:]),
"graph": self.graph,
"raw_ent_text": self.raw_ent_text,
}
return self._cached_tensor
def update_vocab(self, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab):
def update_vocab(
self, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab
):
ent_vocab.update(self.raw_ent_type)
ent_text_vocab.update(self.raw_ent_text)
title_vocab.update(self.raw_title)
rel_vocab.update(['--root--']+[x[1] for x in self.raw_rel]+[x[1]+'_INV' for x in self.raw_rel])
rel_vocab.update(
["--root--"]
+ [x[1] for x in self.raw_rel]
+ [x[1] + "_INV" for x in self.raw_rel]
)
text_vocab.update(self.raw_ent_type)
text_vocab.update(self.raw_text)
class BucketSampler(torch.utils.data.Sampler):
def __init__(self, data_source, batch_size=32, bucket=3):
self.data_source = data_source
......@@ -217,13 +324,13 @@ class BucketSampler(torch.utils.data.Sampler):
t2 = []
t3 = []
for i, l in enumerate(lens):
if (l<100):
if l < 100:
t1.append(perm[i])
elif (l>100 and l<220):
elif l > 100 and l < 220:
t2.append(perm[i])
else:
t3.append(perm[i])
datas = [t1,t2,t3]
datas = [t1, t2, t3]
random.shuffle(datas)
idxs = sum(datas, [])
batch = []
......@@ -231,23 +338,50 @@ class BucketSampler(torch.utils.data.Sampler):
lens = torch.Tensor([len(x) for x in self.data_source])
for idx in idxs:
batch.append(idx)
mlen = max([0]+[lens[x] for x in batch])
if (mlen<100 and len(batch) == 32) or (mlen>100 and mlen<220 and len(batch) >= 24) or (mlen>220 and len(batch)>=8) or len(batch)==32:
mlen = max([0] + [lens[x] for x in batch])
if (
(mlen < 100 and len(batch) == 32)
or (mlen > 100 and mlen < 220 and len(batch) >= 24)
or (mlen > 220 and len(batch) >= 8)
or len(batch) == 32
):
yield batch
batch = []
if len(batch) > 0:
yield batch
def __len__(self):
return (len(self.data_source)+self.batch_size-1)//self.batch_size
return (len(self.data_source) + self.batch_size - 1) // self.batch_size
class GWdataset(torch.utils.data.Dataset):
def __init__(self, exs, ent_vocab=None, rel_vocab=None, text_vocab=None, ent_text_vocab=None, title_vocab=None, device=None):
def __init__(
self,
exs,
ent_vocab=None,
rel_vocab=None,
text_vocab=None,
ent_text_vocab=None,
title_vocab=None,
device=None,
):
super(GWdataset, self).__init__()
self.exs = exs
self.ent_vocab, self.rel_vocab, self.text_vocab, self.ent_text_vocab, self.title_vocab, self.device = \
ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab, device
(
self.ent_vocab,
self.rel_vocab,
self.text_vocab,
self.ent_text_vocab,
self.title_vocab,
self.device,
) = (
ent_vocab,
rel_vocab,
text_vocab,
ent_text_vocab,
title_vocab,
device,
)
def __iter__(self):
return iter(self.exs)
......@@ -259,41 +393,71 @@ class GWdataset(torch.utils.data.Dataset):
return len(self.exs)
def batch_fn(self, batch_ex):
batch_title, batch_ent_text, batch_ent_type, batch_rel, batch_text, batch_tgt_text, batch_graph = \
[], [], [], [], [], [], []
(
batch_title,
batch_ent_text,
batch_ent_type,
batch_rel,
batch_text,
batch_tgt_text,
batch_graph,
) = ([], [], [], [], [], [], [])
batch_raw_ent_text = []
for ex in batch_ex:
ex_data = ex.get_tensor(self.ent_vocab, self.rel_vocab, self.text_vocab, self.ent_text_vocab, self.title_vocab)
batch_title.append(ex_data['title'])
batch_ent_text.append(ex_data['ent_text'])
batch_ent_type.append(ex_data['ent_type'])
batch_rel.append(ex_data['rel'])
batch_text.append(ex_data['text'])
batch_tgt_text.append(ex_data['tgt_text'])
batch_graph.append(ex_data['graph'])
batch_raw_ent_text.append(ex_data['raw_ent_text'])
batch_title = pad(batch_title, out_type='tensor')
batch_ent_text, ent_len = pad(batch_ent_text, out_type='tensor', flatten=True)
batch_ent_type = pad(batch_ent_type, out_type='tensor')
batch_rel = pad(batch_rel, out_type='tensor')
batch_text = pad(batch_text, out_type='tensor')
batch_tgt_text = pad(batch_tgt_text, out_type='tensor')
ex_data = ex.get_tensor(
self.ent_vocab,
self.rel_vocab,
self.text_vocab,
self.ent_text_vocab,
self.title_vocab,
)
batch_title.append(ex_data["title"])
batch_ent_text.append(ex_data["ent_text"])
batch_ent_type.append(ex_data["ent_type"])
batch_rel.append(ex_data["rel"])
batch_text.append(ex_data["text"])
batch_tgt_text.append(ex_data["tgt_text"])
batch_graph.append(ex_data["graph"])
batch_raw_ent_text.append(ex_data["raw_ent_text"])
batch_title = pad(batch_title, out_type="tensor")
batch_ent_text, ent_len = pad(
batch_ent_text, out_type="tensor", flatten=True
)
batch_ent_type = pad(batch_ent_type, out_type="tensor")
batch_rel = pad(batch_rel, out_type="tensor")
batch_text = pad(batch_text, out_type="tensor")
batch_tgt_text = pad(batch_tgt_text, out_type="tensor")
batch_graph = dgl.batch(batch_graph)
batch_graph.to(self.device)
return {'title': batch_title.to(self.device), 'ent_text': batch_ent_text.to(self.device), 'ent_len': ent_len, \
'ent_type': batch_ent_type.to(self.device), 'rel': batch_rel.to(self.device), 'text': batch_text.to(self.device), \
'tgt_text': batch_tgt_text.to(self.device), 'graph': batch_graph, 'raw_ent_text': batch_raw_ent_text}
def get_datasets(fnames, min_freq=-1, sep=';', joint_vocab=True, device=None, save='tmp.pickle'):
return {
"title": batch_title.to(self.device),
"ent_text": batch_ent_text.to(self.device),
"ent_len": ent_len,
"ent_type": batch_ent_type.to(self.device),
"rel": batch_rel.to(self.device),
"text": batch_text.to(self.device),
"tgt_text": batch_tgt_text.to(self.device),
"graph": batch_graph,
"raw_ent_text": batch_raw_ent_text,
}
def get_datasets(
fnames,
min_freq=-1,
sep=";",
joint_vocab=True,
device=None,
save="tmp.pickle",
):
# min_freq : not support now since it's very sensitive to the final results, but you can set it via passing min_freq to the Vocab class.
# sep : not support now
# joint_vocab : not support now
ent_vocab = Vocab(sp=['<PAD>', '<UNK>'])
ent_vocab = Vocab(sp=["<PAD>", "<UNK>"])
title_vocab = Vocab(min_freq=5)
rel_vocab = Vocab(sp=['<PAD>', '<UNK>'])
rel_vocab = Vocab(sp=["<PAD>", "<UNK>"])
text_vocab = Vocab(min_freq=5)
ent_text_vocab = Vocab(sp=['<PAD>', '<UNK>'])
ent_text_vocab = Vocab(sp=["<PAD>", "<UNK>"])
datasets = []
for fname in fnames:
exs = []
......@@ -302,7 +466,13 @@ def get_datasets(fnames, min_freq=-1, sep=';', joint_vocab=True, device=None, sa
# construct one data example
ex = Example.from_json(json_data)
if fname == fnames[0]: # only training set
ex.update_vocab(ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab)
ex.update_vocab(
ent_vocab,
rel_vocab,
text_vocab,
ent_text_vocab,
title_vocab,
)
exs.append(ex)
datasets.append(exs)
ent_vocab.build()
......@@ -310,14 +480,40 @@ def get_datasets(fnames, min_freq=-1, sep=';', joint_vocab=True, device=None, sa
text_vocab.build()
ent_text_vocab.build()
title_vocab.build()
datasets = [GWdataset(exs, ent_vocab, rel_vocab, text_vocab, ent_text_vocab, title_vocab, device) for exs in datasets]
with open(save, 'wb') as f:
datasets = [
GWdataset(
exs,
ent_vocab,
rel_vocab,
text_vocab,
ent_text_vocab,
title_vocab,
device,
)
for exs in datasets
]
with open(save, "wb") as f:
pickle.dump(datasets, f)
return datasets
if __name__ == '__main__' :
ds = get_datasets(['data/unprocessed.val.json', 'data/unprocessed.val.json', 'data/unprocessed.test.json'])
if __name__ == "__main__":
ds = get_datasets(
[
"data/unprocessed.val.json",
"data/unprocessed.val.json",
"data/unprocessed.test.json",
]
)
print(ds[0].exs[0])
print(ds[0].exs[0].get_tensor(ds[0].ent_vocab, ds[0].rel_vocab, ds[0].text_vocab, ds[0].ent_text_vocab, ds[0].title_vocab))
print(
ds[0]
.exs[0]
.get_tensor(
ds[0].ent_vocab,
ds[0].rel_vocab,
ds[0].text_vocab,
ds[0].ent_text_vocab,
ds[0].title_vocab,
)
)
import json
import logging
import os
import sys
import logging
import torch
import numpy as np
import torch
from dgl.data import LegacyTUDataset
import json
def _load_check_mark(path:str):
def _load_check_mark(path: str):
if os.path.exists(path):
with open(path, 'r') as f:
with open(path, "r") as f:
return json.load(f)
else:
return {}
def _save_check_mark(path:str, marks:dict):
with open(path, 'w') as f:
def _save_check_mark(path: str, marks: dict):
with open(path, "w") as f:
json.dump(marks, f)
def node_label_as_feature(dataset:LegacyTUDataset, mode="concat", save=True):
def node_label_as_feature(dataset: LegacyTUDataset, mode="concat", save=True):
"""
Description
-----------
......@@ -41,21 +44,29 @@ def node_label_as_feature(dataset:LegacyTUDataset, mode="concat", save=True):
Default: :obj:`True`
"""
# check if node label is not available
if not os.path.exists(dataset._file_path("node_labels")) or len(dataset) == 0:
if (
not os.path.exists(dataset._file_path("node_labels"))
or len(dataset) == 0
):
logging.warning("No Node Label Data")
return dataset
# check if has cached value
check_mark_name = "node_label_as_feature"
check_mark_path = os.path.join(
dataset.save_path, "info_{}_{}.json".format(dataset.name, dataset.hash))
dataset.save_path, "info_{}_{}.json".format(dataset.name, dataset.hash)
)
check_mark = _load_check_mark(check_mark_path)
if check_mark_name in check_mark \
and check_mark[check_mark_name] \
and not dataset._force_reload:
if (
check_mark_name in check_mark
and check_mark[check_mark_name]
and not dataset._force_reload
):
logging.warning("Using cached value in node_label_as_feature")
return dataset
logging.warning("Adding node labels into node features..., mode={}".format(mode))
logging.warning(
"Adding node labels into node features..., mode={}".format(mode)
)
# check if graph has "feat"
if "feat" not in dataset[0][0].ndata:
......@@ -65,12 +76,14 @@ def node_label_as_feature(dataset:LegacyTUDataset, mode="concat", save=True):
# first read node labels
DS_node_labels = dataset._idx_from_zero(
np.loadtxt(dataset._file_path("node_labels"), dtype=int))
np.loadtxt(dataset._file_path("node_labels"), dtype=int)
)
one_hot_node_labels = dataset._to_onehot(DS_node_labels)
# read graph idx
DS_indicator = dataset._idx_from_zero(
np.genfromtxt(dataset._file_path("graph_indicator"), dtype=int))
np.genfromtxt(dataset._file_path("graph_indicator"), dtype=int)
)
node_idx_list = []
for idx in range(np.max(DS_indicator) + 1):
node_idx = np.where(DS_indicator == idx)
......@@ -81,7 +94,8 @@ def node_label_as_feature(dataset:LegacyTUDataset, mode="concat", save=True):
node_labels_tensor = torch.tensor(one_hot_node_labels[idx, :])
if mode.lower() == "concat":
g.ndata["feat"] = torch.cat(
(g.ndata["feat"], node_labels_tensor), dim=1)
(g.ndata["feat"], node_labels_tensor), dim=1
)
elif mode.lower() == "add":
g.ndata["node_label"] = node_labels_tensor
else: # replace
......@@ -94,7 +108,7 @@ def node_label_as_feature(dataset:LegacyTUDataset, mode="concat", save=True):
return dataset
def degree_as_feature(dataset:LegacyTUDataset, save=True):
def degree_as_feature(dataset: LegacyTUDataset, save=True):
"""
Description
-----------
......@@ -113,12 +127,15 @@ def degree_as_feature(dataset:LegacyTUDataset, save=True):
check_mark_name = "degree_as_feat"
feat_name = "feat"
check_mark_path = os.path.join(
dataset.save_path, "info_{}_{}.json".format(dataset.name, dataset.hash))
dataset.save_path, "info_{}_{}.json".format(dataset.name, dataset.hash)
)
check_mark = _load_check_mark(check_mark_path)
if check_mark_name in check_mark \
and check_mark[check_mark_name] \
and not dataset._force_reload:
if (
check_mark_name in check_mark
and check_mark[check_mark_name]
and not dataset._force_reload
):
logging.warning("Using cached value in 'degree_as_feature'")
return dataset
......@@ -135,7 +152,7 @@ def degree_as_feature(dataset:LegacyTUDataset, save=True):
num_nodes = dataset.graph_lists[i].num_nodes()
node_feat = torch.zeros((num_nodes, vec_len))
degrees = dataset.graph_lists[i].in_degrees()
node_feat[torch.arange(num_nodes), degrees - min_degree] = 1.
node_feat[torch.arange(num_nodes), degrees - min_degree] = 1.0
dataset.graph_lists[i].ndata[feat_name] = node_feat
if save:
......
from typing import Optional
import dgl
import torch
import torch.nn
from torch import Tensor
import dgl
from dgl import DGLGraph
from dgl.nn import GraphConv
from torch import Tensor
class GraphConvWithDropout(GraphConv):
"""
A GraphConv followed by a Dropout.
"""
def __init__(self, in_feats, out_feats, dropout=0.3, norm='both', weight=True,
bias=True, activation=None, allow_zero_in_degree=False):
super(GraphConvWithDropout, self).__init__(in_feats, out_feats,
norm, weight, bias,
def __init__(
self,
in_feats,
out_feats,
dropout=0.3,
norm="both",
weight=True,
bias=True,
activation=None,
allow_zero_in_degree=False,
):
super(GraphConvWithDropout, self).__init__(
in_feats,
out_feats,
norm,
weight,
bias,
activation,
allow_zero_in_degree)
allow_zero_in_degree,
)
self.dropout = torch.nn.Dropout(p=dropout)
def call(self, graph, feat, weight=None):
......@@ -38,7 +54,8 @@ class Discriminator(torch.nn.Module):
feat_dim : int
The number of channels of node features.
"""
def __init__(self, feat_dim:int):
def __init__(self, feat_dim: int):
super(Discriminator, self).__init__()
self.affine = torch.nn.Bilinear(feat_dim, feat_dim, 1)
self.reset_parameters()
......@@ -47,9 +64,14 @@ class Discriminator(torch.nn.Module):
torch.nn.init.xavier_uniform_(self.affine.weight)
torch.nn.init.zeros_(self.affine.bias)
def forward(self, h_x:Tensor, h_pos:Tensor,
h_neg:Tensor, bias_pos:Optional[Tensor]=None,
bias_neg:Optional[Tensor]=None):
def forward(
self,
h_x: Tensor,
h_pos: Tensor,
h_neg: Tensor,
bias_pos: Optional[Tensor] = None,
bias_neg: Optional[Tensor] = None,
):
"""
Parameters
----------
......@@ -91,8 +113,10 @@ class DenseLayer(torch.nn.Module):
-----------
Dense layer with a linear layer and an activation function
"""
def __init__(self, in_dim:int, out_dim:int,
act:str="prelu", bias=True):
def __init__(
self, in_dim: int, out_dim: int, act: str = "prelu", bias=True
):
super(DenseLayer, self).__init__()
self.lin = torch.nn.Linear(in_dim, out_dim, bias=bias)
self.act_type = act.lower()
......@@ -131,8 +155,14 @@ class IndexSelect(torch.nn.Module):
dist : int, optional
DO NOT USE THIS PARAMETER
"""
def __init__(self, pool_ratio:float, hidden_dim:int,
act:str="prelu", dist:int=1):
def __init__(
self,
pool_ratio: float,
hidden_dim: int,
act: str = "prelu",
dist: int = 1,
):
super(IndexSelect, self).__init__()
self.pool_ratio = pool_ratio
self.dist = dist
......@@ -140,9 +170,14 @@ class IndexSelect(torch.nn.Module):
self.discriminator = Discriminator(hidden_dim)
self.gcn = GraphConvWithDropout(hidden_dim, hidden_dim)
def forward(self, graph:DGLGraph, h_pos:Tensor,
h_neg:Tensor, bias_pos:Optional[Tensor]=None,
bias_neg:Optional[Tensor]=None):
def forward(
self,
graph: DGLGraph,
h_pos: Tensor,
h_neg: Tensor,
bias_pos: Optional[Tensor] = None,
bias_neg: Optional[Tensor] = None,
):
"""
Description
-----------
......@@ -171,9 +206,9 @@ class IndexSelect(torch.nn.Module):
embed = self.gcn(graph, h_pos)
h_center = torch.sigmoid(embed)
logit, logit_pos = self.discriminator(h_center, h_pos,
h_neg, bias_pos,
bias_neg)
logit, logit_pos = self.discriminator(
h_center, h_pos, h_neg, bias_pos, bias_neg
)
scores = torch.sigmoid(logit_pos)
# sort scores
......@@ -203,15 +238,23 @@ class GraphPool(torch.nn.Module):
Whether use gcn in down sampling process.
default: :obj:`False`
"""
def __init__(self, hidden_dim:int, use_gcn=False):
def __init__(self, hidden_dim: int, use_gcn=False):
super(GraphPool, self).__init__()
self.use_gcn = use_gcn
self.down_sample_gcn = GraphConvWithDropout(hidden_dim, hidden_dim) \
if use_gcn else None
def forward(self, graph:DGLGraph, feat:Tensor,
select_idx:Tensor, non_select_idx:Optional[Tensor]=None,
scores:Optional[Tensor]=None, pool_graph=False):
self.down_sample_gcn = (
GraphConvWithDropout(hidden_dim, hidden_dim) if use_gcn else None
)
def forward(
self,
graph: DGLGraph,
feat: Tensor,
select_idx: Tensor,
non_select_idx: Optional[Tensor] = None,
scores: Optional[Tensor] = None,
pool_graph=False,
):
"""
Description
-----------
......@@ -264,12 +307,12 @@ class GraphUnpool(torch.nn.Module):
hidden_dim : int
The number of channels of node features.
"""
def __init__(self, hidden_dim:int):
def __init__(self, hidden_dim: int):
super(GraphUnpool, self).__init__()
self.up_sample_gcn = GraphConvWithDropout(hidden_dim, hidden_dim)
def forward(self, graph:DGLGraph,
feat:Tensor, select_idx:Tensor):
def forward(self, graph: DGLGraph, feat: Tensor, select_idx: Tensor):
"""
Description
-----------
......@@ -286,8 +329,9 @@ class GraphUnpool(torch.nn.Module):
coarse graph, this is obtained from
previous graph pooling layers.
"""
fine_feat = torch.zeros((graph.num_nodes(), feat.size(-1)),
device=feat.device)
fine_feat = torch.zeros(
(graph.num_nodes(), feat.size(-1)), device=feat.device
)
fine_feat[select_idx] = feat
fine_feat = self.up_sample_gcn(graph, fine_feat)
return fine_feat
......@@ -3,22 +3,28 @@ import os
from datetime import datetime
from time import time
import dgl
import torch
import torch.nn.functional as F
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
from torch import Tensor
from torch.utils.data import random_split
from data_preprocess import degree_as_feature, node_label_as_feature
from networks import GraphClassifier
from torch import Tensor
from torch.utils.data import random_split
from utils import get_stats, parse_args
import dgl
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
def compute_loss(cls_logits:Tensor, labels:Tensor,
logits_s1:Tensor, logits_s2:Tensor,
epoch:int, total_epochs:int, device:torch.device):
def compute_loss(
cls_logits: Tensor,
labels: Tensor,
logits_s1: Tensor,
logits_s2: Tensor,
epoch: int,
total_epochs: int,
device: torch.device,
):
# classification loss
classify_loss = F.nll_loss(cls_logits, labels.to(device))
......@@ -38,11 +44,17 @@ def compute_loss(cls_logits:Tensor, labels:Tensor,
return loss
def train(model:torch.nn.Module, optimizer, trainloader,
device, curr_epoch, total_epochs):
def train(
model: torch.nn.Module,
optimizer,
trainloader,
device,
curr_epoch,
total_epochs,
):
model.train()
total_loss = 0.
total_loss = 0.0
num_batches = len(trainloader)
for batch in trainloader:
......@@ -50,10 +62,10 @@ def train(model:torch.nn.Module, optimizer, trainloader,
batch_graphs, batch_labels = batch
batch_graphs = batch_graphs.to(device)
batch_labels = batch_labels.long().to(device)
out, l1, l2 = model(batch_graphs,
batch_graphs.ndata["feat"])
loss = compute_loss(out, batch_labels, l1, l2,
curr_epoch, total_epochs, device)
out, l1, l2 = model(batch_graphs, batch_graphs.ndata["feat"])
loss = compute_loss(
out, batch_labels, l1, l2, curr_epoch, total_epochs, device
)
loss.backward()
optimizer.step()
......@@ -63,10 +75,10 @@ def train(model:torch.nn.Module, optimizer, trainloader,
@torch.no_grad()
def test(model:torch.nn.Module, loader, device):
def test(model: torch.nn.Module, loader, device):
model.eval()
correct = 0.
correct = 0.0
num_graphs = 0
for batch in loader:
......@@ -103,8 +115,12 @@ def main(args):
num_test = len(dataset) - num_training
train_set, test_set = random_split(dataset, [num_training, num_test])
train_loader = GraphDataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=1)
test_loader = GraphDataLoader(test_set, batch_size=args.batch_size, num_workers=1)
train_loader = GraphDataLoader(
train_set, batch_size=args.batch_size, shuffle=True, num_workers=1
)
test_loader = GraphDataLoader(
test_set, batch_size=args.batch_size, num_workers=1
)
device = torch.device(args.device)
......@@ -117,7 +133,12 @@ def main(args):
model = GraphClassifier(args).to(device)
# Step 3: Create training components ===================================================== #
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, amsgrad=True, weight_decay=args.weight_decay)
optimizer = torch.optim.Adam(
model.parameters(),
lr=args.lr,
amsgrad=True,
weight_decay=args.weight_decay,
)
# Step 4: training epoches =============================================================== #
best_test_acc = 0.0
......@@ -125,8 +146,9 @@ def main(args):
train_times = []
for e in range(args.epochs):
s_time = time()
train_loss = train(model, optimizer, train_loader, device,
e, args.epochs)
train_loss = train(
model, optimizer, train_loader, device, e, args.epochs
)
train_times.append(time() - s_time)
test_acc = test(model, test_loader, device)
if test_acc > best_test_acc:
......@@ -134,9 +156,13 @@ def main(args):
best_epoch = e + 1
if (e + 1) % args.print_every == 0:
log_format = "Epoch {}: loss={:.4f}, test_acc={:.4f}, best_test_acc={:.4f}"
log_format = (
"Epoch {}: loss={:.4f}, test_acc={:.4f}, best_test_acc={:.4f}"
)
print(log_format.format(e + 1, train_loss, test_acc, best_test_acc))
print("Best Epoch {}, final test acc {:.4f}".format(best_epoch, best_test_acc))
print(
"Best Epoch {}, final test acc {:.4f}".format(best_epoch, best_test_acc)
)
return best_test_acc, sum(train_times) / len(train_times)
......@@ -154,11 +180,15 @@ if __name__ == "__main__":
mean, err_bd = get_stats(res, conf_interval=False)
print("mean acc: {:.4f}, error bound: {:.4f}".format(mean, err_bd))
out_dict = {"hyper-parameters": vars(args),
out_dict = {
"hyper-parameters": vars(args),
"result_date": str(datetime.now()),
"result": "{:.4f}(+-{:.4f})".format(mean, err_bd),
"train_time": "{:.4f}".format(sum(train_times) / len(train_times)),
"details": res}
"details": res,
}
with open(os.path.join(args.output_path, "{}.log".format(args.dataset)), "w") as f:
with open(
os.path.join(args.output_path, "{}.log".format(args.dataset)), "w"
) as f:
json.dump(out_dict, f, sort_keys=True, indent=4)
import json
import os
from time import time
from datetime import datetime
from time import time
import dgl
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
from torch.utils.data import random_split
import torch
from torch import Tensor
import torch.nn.functional as F
from data_preprocess import degree_as_feature, node_label_as_feature
from networks import GraphClassifier
from torch import Tensor
from torch.utils.data import random_split
from utils import get_stats, parse_args
from data_preprocess import degree_as_feature, node_label_as_feature
import dgl
from dgl.data import LegacyTUDataset
from dgl.dataloading import GraphDataLoader
def compute_loss(cls_logits:Tensor, labels:Tensor,
logits_s1:Tensor, logits_s2:Tensor,
epoch:int, total_epochs:int, device:torch.device):
def compute_loss(
cls_logits: Tensor,
labels: Tensor,
logits_s1: Tensor,
logits_s2: Tensor,
epoch: int,
total_epochs: int,
device: torch.device,
):
# classification loss
classify_loss = F.nll_loss(cls_logits, labels.to(device))
......@@ -38,11 +44,17 @@ def compute_loss(cls_logits:Tensor, labels:Tensor,
return loss
def train(model:torch.nn.Module, optimizer, trainloader,
device, curr_epoch, total_epochs):
def train(
model: torch.nn.Module,
optimizer,
trainloader,
device,
curr_epoch,
total_epochs,
):
model.train()
total_loss = 0.
total_loss = 0.0
num_batches = len(trainloader)
for batch in trainloader:
......@@ -50,10 +62,10 @@ def train(model:torch.nn.Module, optimizer, trainloader,
batch_graphs, batch_labels = batch
batch_graphs = batch_graphs.to(device)
batch_labels = batch_labels.long().to(device)
out, l1, l2 = model(batch_graphs,
batch_graphs.ndata["feat"])
loss = compute_loss(out, batch_labels, l1, l2,
curr_epoch, total_epochs, device)
out, l1, l2 = model(batch_graphs, batch_graphs.ndata["feat"])
loss = compute_loss(
out, batch_labels, l1, l2, curr_epoch, total_epochs, device
)
loss.backward()
optimizer.step()
......@@ -63,10 +75,10 @@ def train(model:torch.nn.Module, optimizer, trainloader,
@torch.no_grad()
def test(model:torch.nn.Module, loader, device):
def test(model: torch.nn.Module, loader, device):
model.eval()
correct = 0.
correct = 0.0
num_graphs = 0
for batch in loader:
......@@ -82,12 +94,11 @@ def test(model:torch.nn.Module, loader, device):
@torch.no_grad()
def validate(model:torch.nn.Module, loader, device,
curr_epoch, total_epochs):
def validate(model: torch.nn.Module, loader, device, curr_epoch, total_epochs):
model.eval()
tt_loss = 0.
correct = 0.
tt_loss = 0.0
correct = 0.0
num_graphs = 0
num_batchs = len(loader)
......@@ -97,8 +108,9 @@ def validate(model:torch.nn.Module, loader, device,
batch_graphs = batch_graphs.to(device)
batch_labels = batch_labels.long().to(device)
out, l1, l2 = model(batch_graphs, batch_graphs.ndata["feat"])
tt_loss += compute_loss(out, batch_labels, l1, l2,
curr_epoch, total_epochs, device).item()
tt_loss += compute_loss(
out, batch_labels, l1, l2, curr_epoch, total_epochs, device
).item()
pred = out.argmax(dim=1)
correct += pred.eq(batch_labels).sum().item()
......@@ -126,11 +138,19 @@ def main(args):
num_training = int(len(dataset) * 0.8)
num_val = int(len(dataset) * 0.1)
num_test = len(dataset) - num_training - num_val
train_set, val_set, test_set = random_split(dataset, [num_training, num_val, num_test])
train_loader = GraphDataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=1)
val_loader = GraphDataLoader(val_set, batch_size=args.batch_size, num_workers=1)
test_loader = GraphDataLoader(test_set, batch_size=args.batch_size, num_workers=1)
train_set, val_set, test_set = random_split(
dataset, [num_training, num_val, num_test]
)
train_loader = GraphDataLoader(
train_set, batch_size=args.batch_size, shuffle=True, num_workers=1
)
val_loader = GraphDataLoader(
val_set, batch_size=args.batch_size, num_workers=1
)
test_loader = GraphDataLoader(
test_set, batch_size=args.batch_size, num_workers=1
)
device = torch.device(args.device)
......@@ -143,7 +163,12 @@ def main(args):
model = GraphClassifier(args).to(device)
# Step 3: Create training components ===================================================== #
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, amsgrad=True, weight_decay=args.weight_decay)
optimizer = torch.optim.Adam(
model.parameters(),
lr=args.lr,
amsgrad=True,
weight_decay=args.weight_decay,
)
# Step 4: training epoches =============================================================== #
best_test_acc = 0.0
......@@ -154,8 +179,9 @@ def main(args):
best_val_loss = float("inf")
for e in range(args.epochs):
s_time = time()
train_loss = train(model, optimizer, train_loader, device,
e, args.epochs)
train_loss = train(
model, optimizer, train_loader, device, e, args.epochs
)
train_times.append(time() - s_time)
_, val_loss = validate(model, val_loader, device, e, args.epochs)
test_acc = test(model, test_loader, device)
......@@ -172,9 +198,13 @@ def main(args):
break
if (e + 1) % args.print_every == 0:
log_format = "Epoch {}: loss={:.4f}, test_acc={:.4f}, best_test_acc={:.4f}"
log_format = (
"Epoch {}: loss={:.4f}, test_acc={:.4f}, best_test_acc={:.4f}"
)
print(log_format.format(e + 1, train_loss, test_acc, best_test_acc))
print("Best Epoch {}, final test acc {:.4f}".format(best_epoch, best_test_acc))
print(
"Best Epoch {}, final test acc {:.4f}".format(best_epoch, best_test_acc)
)
return best_test_acc, sum(train_times) / len(train_times)
......@@ -192,11 +222,15 @@ if __name__ == "__main__":
mean, err_bd = get_stats(res, conf_interval=False)
print("mean acc: {:.4f}, error bound: {:.4f}".format(mean, err_bd))
out_dict = {"hyper-parameters": vars(args),
out_dict = {
"hyper-parameters": vars(args),
"result_date": str(datetime.now()),
"result": "{:.4f}(+-{:.4f})".format(mean, err_bd),
"train_time": "{:.4f}".format(sum(train_times) / len(train_times)),
"details": res}
"details": res,
}
with open(os.path.join(args.output_path, "{}.log".format(args.dataset)), "w") as f:
with open(
os.path.join(args.output_path, "{}.log".format(args.dataset)), "w"
) as f:
json.dump(out_dict, f, sort_keys=True, indent=4)
......@@ -10,7 +10,9 @@ import torch.cuda
from scipy.stats import t
def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False):
def get_stats(
array, conf_interval=False, name=None, stdout=False, logout=False
):
"""Compute mean and standard deviation from an numerical array
Args:
......@@ -35,7 +37,7 @@ def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False)
if conf_interval:
n = array.size(0)
se = std / (math.sqrt(n) + eps)
t_value = t.ppf(0.975, df=n-1)
t_value = t.ppf(0.975, df=n - 1)
err_bound = t_value * se
else:
err_bound = std
......@@ -54,70 +56,137 @@ def get_stats(array, conf_interval=False, name=None, stdout=False, logout=False)
def parse_args():
parser = argparse.ArgumentParser("Graph Cross Network")
parser.add_argument("--pool_ratios", nargs="+", type=float,
help="The pooling ratios used in graph cross layers")
parser.add_argument("--hidden_dim", type=int, default=96,
help="The number of hidden channels in GXN")
parser.add_argument("--cross_weight", type=float, default=1.,
help="Weight parameter used in graph cross layer")
parser.add_argument("--fuse_weight", type=float, default=1.,
help="Weight parameter for feature fusion")
parser.add_argument("--num_cross_layers", type=int, default=2,
help="The number of graph corss layers")
parser.add_argument("--readout_nodes", type=int, default=30,
help="Number of nodes for each graph after final graph pooling")
parser.add_argument("--conv1d_dims", nargs="+", type=int,
help="Number of channels in conv operations in the end of graph cross net")
parser.add_argument("--conv1d_kws", nargs="+", type=int,
help="Kernel sizes of conv1d operations")
parser.add_argument("--dropout", type=float, default=0.,
help="Dropout rate")
parser.add_argument("--embed_dim", type=int, default=1024,
help="Number of channels of graph embedding")
parser.add_argument("--final_dense_hidden_dim", type=int, default=128,
help="The number of hidden channels in final dense layers")
parser.add_argument("--batch_size", type=int, default=64,
help="Batch size")
parser.add_argument("--lr", type=float, default=1e-4,
help="Learning rate")
parser.add_argument("--weight_decay", type=float, default=0.,
help="Weight decay rate")
parser.add_argument("--epochs", type=int, default=1000,
help="Number of training epochs")
parser.add_argument("--patience", type=int, default=20,
help="Patience for early stopping")
parser.add_argument("--num_trials", type=int, default=1,
help="Number of trials")
parser.add_argument("--device", type=int, default=0,
help="Computation device id, -1 for cpu")
parser.add_argument("--dataset", type=str, default="DD",
help="Dataset used for training")
parser.add_argument("--seed", type=int, default=-1,
help="Random seed, -1 for unset")
parser.add_argument("--print_every", type=int, default=10,
help="Print train log every ? epochs, -1 for silence training")
parser.add_argument("--dataset_path", type=str, default="./datasets",
help="Path holding your dataset")
parser.add_argument("--output_path", type=str, default="./output",
help="Path holding your result files")
parser.add_argument(
"--pool_ratios",
nargs="+",
type=float,
help="The pooling ratios used in graph cross layers",
)
parser.add_argument(
"--hidden_dim",
type=int,
default=96,
help="The number of hidden channels in GXN",
)
parser.add_argument(
"--cross_weight",
type=float,
default=1.0,
help="Weight parameter used in graph cross layer",
)
parser.add_argument(
"--fuse_weight",
type=float,
default=1.0,
help="Weight parameter for feature fusion",
)
parser.add_argument(
"--num_cross_layers",
type=int,
default=2,
help="The number of graph corss layers",
)
parser.add_argument(
"--readout_nodes",
type=int,
default=30,
help="Number of nodes for each graph after final graph pooling",
)
parser.add_argument(
"--conv1d_dims",
nargs="+",
type=int,
help="Number of channels in conv operations in the end of graph cross net",
)
parser.add_argument(
"--conv1d_kws",
nargs="+",
type=int,
help="Kernel sizes of conv1d operations",
)
parser.add_argument(
"--dropout", type=float, default=0.0, help="Dropout rate"
)
parser.add_argument(
"--embed_dim",
type=int,
default=1024,
help="Number of channels of graph embedding",
)
parser.add_argument(
"--final_dense_hidden_dim",
type=int,
default=128,
help="The number of hidden channels in final dense layers",
)
parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
parser.add_argument(
"--weight_decay", type=float, default=0.0, help="Weight decay rate"
)
parser.add_argument(
"--epochs", type=int, default=1000, help="Number of training epochs"
)
parser.add_argument(
"--patience", type=int, default=20, help="Patience for early stopping"
)
parser.add_argument(
"--num_trials", type=int, default=1, help="Number of trials"
)
parser.add_argument(
"--device",
type=int,
default=0,
help="Computation device id, -1 for cpu",
)
parser.add_argument(
"--dataset", type=str, default="DD", help="Dataset used for training"
)
parser.add_argument(
"--seed", type=int, default=-1, help="Random seed, -1 for unset"
)
parser.add_argument(
"--print_every",
type=int,
default=10,
help="Print train log every ? epochs, -1 for silence training",
)
parser.add_argument(
"--dataset_path",
type=str,
default="./datasets",
help="Path holding your dataset",
)
parser.add_argument(
"--output_path",
type=str,
default="./output",
help="Path holding your result files",
)
args = parser.parse_args()
# default value for list hyper-parameters
if not args.pool_ratios or len(args.pool_ratios) < 2:
args.pool_ratios = [0.8, 0.7]
logging.warning("No valid pool_ratios is given, "
"using default value '{}'".format(args.pool_ratios))
logging.warning(
"No valid pool_ratios is given, "
"using default value '{}'".format(args.pool_ratios)
)
if not args.conv1d_dims or len(args.conv1d_dims) < 2:
args.conv1d_dims = [16, 32]
logging.warning("No valid conv1d_dims is give, "
"using default value {}".format(args.conv1d_dims))
logging.warning(
"No valid conv1d_dims is give, "
"using default value {}".format(args.conv1d_dims)
)
if not args.conv1d_kws or len(args.conv1d_kws) < 1:
args.conv1d_kws = [5]
logging.warning("No valid conv1d_kws is given, "
"using default value '{}'".format(args.conv1d_kws))
logging.warning(
"No valid conv1d_kws is given, "
"using default value '{}'".format(args.conv1d_kws)
)
# device
args.device = "cpu" if args.device < 0 else "cuda:{}".format(args.device)
......@@ -148,7 +217,7 @@ def parse_args():
os.makedirs(p)
# datasets ad-hoc
if args.dataset in ['COLLAB', 'IMDB-BINARY', 'IMDB-MULTI', 'ENZYMES']:
if args.dataset in ["COLLAB", "IMDB-BINARY", "IMDB-MULTI", "ENZYMES"]:
args.degree_as_feature = True
else:
args.degree_as_feature = False
......
import torch
from sklearn.metrics import f1_score
from utils import EarlyStopping, load_data
from utils import load_data, EarlyStopping
def score(logits, labels):
_, indices = torch.max(logits, dim=1)
......@@ -9,11 +9,12 @@ def score(logits, labels):
labels = labels.cpu().numpy()
accuracy = (prediction == labels).sum() / len(prediction)
micro_f1 = f1_score(labels, prediction, average='micro')
macro_f1 = f1_score(labels, prediction, average='macro')
micro_f1 = f1_score(labels, prediction, average="micro")
macro_f1 = f1_score(labels, prediction, average="macro")
return accuracy, micro_f1, macro_f1
def evaluate(model, g, features, labels, mask, loss_func):
model.eval()
with torch.no_grad():
......@@ -23,48 +24,66 @@ def evaluate(model, g, features, labels, mask, loss_func):
return loss, accuracy, micro_f1, macro_f1
def main(args):
# If args['hetero'] is True, g would be a heterogeneous graph.
# Otherwise, it will be a list of homogeneous graphs.
g, features, labels, num_classes, train_idx, val_idx, test_idx, train_mask, \
val_mask, test_mask = load_data(args['dataset'])
if hasattr(torch, 'BoolTensor'):
(
g,
features,
labels,
num_classes,
train_idx,
val_idx,
test_idx,
train_mask,
val_mask,
test_mask,
) = load_data(args["dataset"])
if hasattr(torch, "BoolTensor"):
train_mask = train_mask.bool()
val_mask = val_mask.bool()
test_mask = test_mask.bool()
features = features.to(args['device'])
labels = labels.to(args['device'])
train_mask = train_mask.to(args['device'])
val_mask = val_mask.to(args['device'])
test_mask = test_mask.to(args['device'])
features = features.to(args["device"])
labels = labels.to(args["device"])
train_mask = train_mask.to(args["device"])
val_mask = val_mask.to(args["device"])
test_mask = test_mask.to(args["device"])
if args['hetero']:
if args["hetero"]:
from model_hetero import HAN
model = HAN(meta_paths=[['pa', 'ap'], ['pf', 'fp']],
model = HAN(
meta_paths=[["pa", "ap"], ["pf", "fp"]],
in_size=features.shape[1],
hidden_size=args['hidden_units'],
hidden_size=args["hidden_units"],
out_size=num_classes,
num_heads=args['num_heads'],
dropout=args['dropout']).to(args['device'])
g = g.to(args['device'])
num_heads=args["num_heads"],
dropout=args["dropout"],
).to(args["device"])
g = g.to(args["device"])
else:
from model import HAN
model = HAN(num_meta_paths=len(g),
model = HAN(
num_meta_paths=len(g),
in_size=features.shape[1],
hidden_size=args['hidden_units'],
hidden_size=args["hidden_units"],
out_size=num_classes,
num_heads=args['num_heads'],
dropout=args['dropout']).to(args['device'])
g = [graph.to(args['device']) for graph in g]
num_heads=args["num_heads"],
dropout=args["dropout"],
).to(args["device"])
g = [graph.to(args["device"]) for graph in g]
stopper = EarlyStopping(patience=args['patience'])
stopper = EarlyStopping(patience=args["patience"])
loss_fcn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'],
weight_decay=args['weight_decay'])
optimizer = torch.optim.Adam(
model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"]
)
for epoch in range(args['num_epochs']):
for epoch in range(args["num_epochs"]):
model.train()
logits = model(g, features)
loss = loss_fcn(logits[train_mask], labels[train_mask])
......@@ -73,34 +92,60 @@ def main(args):
loss.backward()
optimizer.step()
train_acc, train_micro_f1, train_macro_f1 = score(logits[train_mask], labels[train_mask])
val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(model, g, features, labels, val_mask, loss_fcn)
train_acc, train_micro_f1, train_macro_f1 = score(
logits[train_mask], labels[train_mask]
)
val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(
model, g, features, labels, val_mask, loss_fcn
)
early_stop = stopper.step(val_loss.data.item(), val_acc, model)
print('Epoch {:d} | Train Loss {:.4f} | Train Micro f1 {:.4f} | Train Macro f1 {:.4f} | '
'Val Loss {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}'.format(
epoch + 1, loss.item(), train_micro_f1, train_macro_f1, val_loss.item(), val_micro_f1, val_macro_f1))
print(
"Epoch {:d} | Train Loss {:.4f} | Train Micro f1 {:.4f} | Train Macro f1 {:.4f} | "
"Val Loss {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}".format(
epoch + 1,
loss.item(),
train_micro_f1,
train_macro_f1,
val_loss.item(),
val_micro_f1,
val_macro_f1,
)
)
if early_stop:
break
stopper.load_checkpoint(model)
test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(model, g, features, labels, test_mask, loss_fcn)
print('Test loss {:.4f} | Test Micro f1 {:.4f} | Test Macro f1 {:.4f}'.format(
test_loss.item(), test_micro_f1, test_macro_f1))
test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(
model, g, features, labels, test_mask, loss_fcn
)
print(
"Test loss {:.4f} | Test Micro f1 {:.4f} | Test Macro f1 {:.4f}".format(
test_loss.item(), test_micro_f1, test_macro_f1
)
)
if __name__ == '__main__':
if __name__ == "__main__":
import argparse
from utils import setup
parser = argparse.ArgumentParser('HAN')
parser.add_argument('-s', '--seed', type=int, default=1,
help='Random seed')
parser.add_argument('-ld', '--log-dir', type=str, default='results',
help='Dir for saving training results')
parser.add_argument('--hetero', action='store_true',
help='Use metapath coalescing with DGL\'s own dataset')
parser = argparse.ArgumentParser("HAN")
parser.add_argument("-s", "--seed", type=int, default=1, help="Random seed")
parser.add_argument(
"-ld",
"--log-dir",
type=str,
default="results",
help="Dir for saving training results",
)
parser.add_argument(
"--hetero",
action="store_true",
help="Use metapath coalescing with DGL's own dataset",
)
args = parser.parse_args().__dict__
args = setup(args)
......
......@@ -4,6 +4,7 @@ import torch.nn.functional as F
from dgl.nn.pytorch import GATConv
class SemanticAttention(nn.Module):
def __init__(self, in_size, hidden_size=128):
super(SemanticAttention, self).__init__()
......@@ -11,7 +12,7 @@ class SemanticAttention(nn.Module):
self.project = nn.Sequential(
nn.Linear(in_size, hidden_size),
nn.Tanh(),
nn.Linear(hidden_size, 1, bias=False)
nn.Linear(hidden_size, 1, bias=False),
)
def forward(self, z):
......@@ -21,6 +22,7 @@ class SemanticAttention(nn.Module):
return (beta * z).sum(1) # (N, D * K)
class HANLayer(nn.Module):
"""
HAN layer.
......@@ -45,15 +47,28 @@ class HANLayer(nn.Module):
tensor
The output feature
"""
def __init__(self, num_meta_paths, in_size, out_size, layer_num_heads, dropout):
def __init__(
self, num_meta_paths, in_size, out_size, layer_num_heads, dropout
):
super(HANLayer, self).__init__()
# One GAT layer for each meta path based adjacency matrix
self.gat_layers = nn.ModuleList()
for i in range(num_meta_paths):
self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads,
dropout, dropout, activation=F.elu))
self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads)
self.gat_layers.append(
GATConv(
in_size,
out_size,
layer_num_heads,
dropout,
dropout,
activation=F.elu,
)
)
self.semantic_attention = SemanticAttention(
in_size=out_size * layer_num_heads
)
self.num_meta_paths = num_meta_paths
def forward(self, gs, h):
......@@ -61,19 +76,35 @@ class HANLayer(nn.Module):
for i, g in enumerate(gs):
semantic_embeddings.append(self.gat_layers[i](g, h).flatten(1))
semantic_embeddings = torch.stack(semantic_embeddings, dim=1) # (N, M, D * K)
semantic_embeddings = torch.stack(
semantic_embeddings, dim=1
) # (N, M, D * K)
return self.semantic_attention(semantic_embeddings) # (N, D * K)
class HAN(nn.Module):
def __init__(self, num_meta_paths, in_size, hidden_size, out_size, num_heads, dropout):
def __init__(
self, num_meta_paths, in_size, hidden_size, out_size, num_heads, dropout
):
super(HAN, self).__init__()
self.layers = nn.ModuleList()
self.layers.append(HANLayer(num_meta_paths, in_size, hidden_size, num_heads[0], dropout))
self.layers.append(
HANLayer(
num_meta_paths, in_size, hidden_size, num_heads[0], dropout
)
)
for l in range(1, len(num_heads)):
self.layers.append(HANLayer(num_meta_paths, hidden_size * num_heads[l-1],
hidden_size, num_heads[l], dropout))
self.layers.append(
HANLayer(
num_meta_paths,
hidden_size * num_heads[l - 1],
hidden_size,
num_heads[l],
dropout,
)
)
self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)
def forward(self, g, h):
......
......@@ -14,6 +14,7 @@ import torch.nn.functional as F
import dgl
from dgl.nn.pytorch import GATConv
class SemanticAttention(nn.Module):
def __init__(self, in_size, hidden_size=128):
super(SemanticAttention, self).__init__()
......@@ -21,7 +22,7 @@ class SemanticAttention(nn.Module):
self.project = nn.Sequential(
nn.Linear(in_size, hidden_size),
nn.Tanh(),
nn.Linear(hidden_size, 1, bias=False)
nn.Linear(hidden_size, 1, bias=False),
)
def forward(self, z):
......@@ -31,6 +32,7 @@ class SemanticAttention(nn.Module):
return (beta * z).sum(1) # (N, D * K)
class HANLayer(nn.Module):
"""
HAN layer.
......@@ -55,16 +57,27 @@ class HANLayer(nn.Module):
tensor
The output feature
"""
def __init__(self, meta_paths, in_size, out_size, layer_num_heads, dropout):
super(HANLayer, self).__init__()
# One GAT layer for each meta path based adjacency matrix
self.gat_layers = nn.ModuleList()
for i in range(len(meta_paths)):
self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads,
dropout, dropout, activation=F.elu,
allow_zero_in_degree=True))
self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads)
self.gat_layers.append(
GATConv(
in_size,
out_size,
layer_num_heads,
dropout,
dropout,
activation=F.elu,
allow_zero_in_degree=True,
)
)
self.semantic_attention = SemanticAttention(
in_size=out_size * layer_num_heads
)
self.meta_paths = list(tuple(meta_path) for meta_path in meta_paths)
self._cached_graph = None
......@@ -77,25 +90,40 @@ class HANLayer(nn.Module):
self._cached_graph = g
self._cached_coalesced_graph.clear()
for meta_path in self.meta_paths:
self._cached_coalesced_graph[meta_path] = dgl.metapath_reachable_graph(
g, meta_path)
self._cached_coalesced_graph[
meta_path
] = dgl.metapath_reachable_graph(g, meta_path)
for i, meta_path in enumerate(self.meta_paths):
new_g = self._cached_coalesced_graph[meta_path]
semantic_embeddings.append(self.gat_layers[i](new_g, h).flatten(1))
semantic_embeddings = torch.stack(semantic_embeddings, dim=1) # (N, M, D * K)
semantic_embeddings = torch.stack(
semantic_embeddings, dim=1
) # (N, M, D * K)
return self.semantic_attention(semantic_embeddings) # (N, D * K)
class HAN(nn.Module):
def __init__(self, meta_paths, in_size, hidden_size, out_size, num_heads, dropout):
def __init__(
self, meta_paths, in_size, hidden_size, out_size, num_heads, dropout
):
super(HAN, self).__init__()
self.layers = nn.ModuleList()
self.layers.append(HANLayer(meta_paths, in_size, hidden_size, num_heads[0], dropout))
self.layers.append(
HANLayer(meta_paths, in_size, hidden_size, num_heads[0], dropout)
)
for l in range(1, len(num_heads)):
self.layers.append(HANLayer(meta_paths, hidden_size * num_heads[l-1],
hidden_size, num_heads[l], dropout))
self.layers.append(
HANLayer(
meta_paths,
hidden_size * num_heads[l - 1],
hidden_size,
num_heads[l],
dropout,
)
)
self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)
def forward(self, g, h):
......
......@@ -4,21 +4,21 @@ HAN mini-batch training by RandomWalkSampler.
note: This demo use RandomWalkSampler to sample neighbors, it's hard to get all neighbors when valid or test,
so we sampled twice as many neighbors during val/test than training.
"""
import dgl
import numpy
import argparse
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GATConv
from dgl.sampling import RandomWalkNeighborSampler
from model_hetero import SemanticAttention
from sklearn.metrics import f1_score
from torch.utils.data import DataLoader
from model_hetero import SemanticAttention
from utils import EarlyStopping, set_random_seed
import dgl
from dgl.nn.pytorch import GATConv
from dgl.sampling import RandomWalkNeighborSampler
class HANLayer(torch.nn.Module):
"""
......@@ -45,37 +45,64 @@ class HANLayer(torch.nn.Module):
The output feature
"""
def __init__(self, num_metapath, in_size, out_size, layer_num_heads, dropout):
def __init__(
self, num_metapath, in_size, out_size, layer_num_heads, dropout
):
super(HANLayer, self).__init__()
# One GAT layer for each meta path based adjacency matrix
self.gat_layers = nn.ModuleList()
for i in range(num_metapath):
self.gat_layers.append(GATConv(in_size, out_size, layer_num_heads,
dropout, dropout, activation=F.elu,
allow_zero_in_degree=True))
self.semantic_attention = SemanticAttention(in_size=out_size * layer_num_heads)
self.gat_layers.append(
GATConv(
in_size,
out_size,
layer_num_heads,
dropout,
dropout,
activation=F.elu,
allow_zero_in_degree=True,
)
)
self.semantic_attention = SemanticAttention(
in_size=out_size * layer_num_heads
)
self.num_metapath = num_metapath
def forward(self, block_list, h_list):
semantic_embeddings = []
for i, block in enumerate(block_list):
semantic_embeddings.append(self.gat_layers[i](block, h_list[i]).flatten(1))
semantic_embeddings = torch.stack(semantic_embeddings, dim=1) # (N, M, D * K)
semantic_embeddings.append(
self.gat_layers[i](block, h_list[i]).flatten(1)
)
semantic_embeddings = torch.stack(
semantic_embeddings, dim=1
) # (N, M, D * K)
return self.semantic_attention(semantic_embeddings) # (N, D * K)
class HAN(nn.Module):
def __init__(self, num_metapath, in_size, hidden_size, out_size, num_heads, dropout):
def __init__(
self, num_metapath, in_size, hidden_size, out_size, num_heads, dropout
):
super(HAN, self).__init__()
self.layers = nn.ModuleList()
self.layers.append(HANLayer(num_metapath, in_size, hidden_size, num_heads[0], dropout))
self.layers.append(
HANLayer(num_metapath, in_size, hidden_size, num_heads[0], dropout)
)
for l in range(1, len(num_heads)):
self.layers.append(HANLayer(num_metapath, hidden_size * num_heads[l - 1],
hidden_size, num_heads[l], dropout))
self.layers.append(
HANLayer(
num_metapath,
hidden_size * num_heads[l - 1],
hidden_size,
num_heads[l],
dropout,
)
)
self.predict = nn.Linear(hidden_size * num_heads[-1], out_size)
def forward(self, g, h):
......@@ -91,12 +118,16 @@ class HANSampler(object):
for metapath in metapath_list:
# note: random walk may get same route(same edge), which will be removed in the sampled graph.
# So the sampled graph's edges may be less than num_random_walks(num_neighbors).
self.sampler_list.append(RandomWalkNeighborSampler(G=g,
self.sampler_list.append(
RandomWalkNeighborSampler(
G=g,
num_traversals=1,
termination_prob=0,
num_random_walks=num_neighbors,
num_neighbors=num_neighbors,
metapath=metapath))
metapath=metapath,
)
)
def sample_blocks(self, seeds):
block_list = []
......@@ -117,34 +148,49 @@ def score(logits, labels):
labels = labels.cpu().numpy()
accuracy = (prediction == labels).sum() / len(prediction)
micro_f1 = f1_score(labels, prediction, average='micro')
macro_f1 = f1_score(labels, prediction, average='macro')
micro_f1 = f1_score(labels, prediction, average="micro")
macro_f1 = f1_score(labels, prediction, average="macro")
return accuracy, micro_f1, macro_f1
def evaluate(model, g, metapath_list, num_neighbors, features, labels, val_nid, loss_fcn, batch_size):
def evaluate(
model,
g,
metapath_list,
num_neighbors,
features,
labels,
val_nid,
loss_fcn,
batch_size,
):
model.eval()
han_valid_sampler = HANSampler(g, metapath_list, num_neighbors=num_neighbors * 2)
han_valid_sampler = HANSampler(
g, metapath_list, num_neighbors=num_neighbors * 2
)
dataloader = DataLoader(
dataset=val_nid,
batch_size=batch_size,
collate_fn=han_valid_sampler.sample_blocks,
shuffle=False,
drop_last=False,
num_workers=4)
num_workers=4,
)
correct = total = 0
prediction_list = []
labels_list = []
with torch.no_grad():
for step, (seeds, blocks) in enumerate(dataloader):
h_list = load_subtensors(blocks, features)
blocks = [block.to(args['device']) for block in blocks]
hs = [h.to(args['device']) for h in h_list]
blocks = [block.to(args["device"]) for block in blocks]
hs = [h.to(args["device"]) for h in h_list]
logits = model(blocks, hs)
loss = loss_fcn(logits, labels[numpy.asarray(seeds)].to(args['device']))
loss = loss_fcn(
logits, labels[numpy.asarray(seeds)].to(args["device"])
)
# get each predict label
_, indices = torch.max(logits, dim=1)
prediction = indices.long().cpu().numpy()
......@@ -158,8 +204,8 @@ def evaluate(model, g, metapath_list, num_neighbors, features, labels, val_nid,
total_prediction = numpy.concatenate(prediction_list)
total_labels = numpy.concatenate(labels_list)
micro_f1 = f1_score(total_labels, total_prediction, average='micro')
macro_f1 = f1_score(total_labels, total_prediction, average='macro')
micro_f1 = f1_score(total_labels, total_prediction, average="micro")
macro_f1 = f1_score(total_labels, total_prediction, average="macro")
accuracy = correct / total
return loss, accuracy, micro_f1, macro_f1
......@@ -175,95 +221,142 @@ def load_subtensors(blocks, features):
def main(args):
# acm data
if args['dataset'] == 'ACMRaw':
if args["dataset"] == "ACMRaw":
from utils import load_data
g, features, labels, n_classes, train_nid, val_nid, test_nid, train_mask, \
val_mask, test_mask = load_data('ACMRaw')
metapath_list = [['pa', 'ap'], ['pf', 'fp']]
(
g,
features,
labels,
n_classes,
train_nid,
val_nid,
test_nid,
train_mask,
val_mask,
test_mask,
) = load_data("ACMRaw")
metapath_list = [["pa", "ap"], ["pf", "fp"]]
else:
raise NotImplementedError('Unsupported dataset {}'.format(args['dataset']))
raise NotImplementedError(
"Unsupported dataset {}".format(args["dataset"])
)
# Is it need to set different neighbors numbers for different meta-path based graph?
num_neighbors = args['num_neighbors']
num_neighbors = args["num_neighbors"]
han_sampler = HANSampler(g, metapath_list, num_neighbors)
# Create PyTorch DataLoader for constructing blocks
dataloader = DataLoader(
dataset=train_nid,
batch_size=args['batch_size'],
batch_size=args["batch_size"],
collate_fn=han_sampler.sample_blocks,
shuffle=True,
drop_last=False,
num_workers=4)
num_workers=4,
)
model = HAN(num_metapath=len(metapath_list),
model = HAN(
num_metapath=len(metapath_list),
in_size=features.shape[1],
hidden_size=args['hidden_units'],
hidden_size=args["hidden_units"],
out_size=n_classes,
num_heads=args['num_heads'],
dropout=args['dropout']).to(args['device'])
num_heads=args["num_heads"],
dropout=args["dropout"],
).to(args["device"])
total_params = sum(p.numel() for p in model.parameters())
print("total_params: {:d}".format(total_params))
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_trainable_params = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
print("total trainable params: {:d}".format(total_trainable_params))
stopper = EarlyStopping(patience=args['patience'])
stopper = EarlyStopping(patience=args["patience"])
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'],
weight_decay=args['weight_decay'])
optimizer = torch.optim.Adam(
model.parameters(), lr=args["lr"], weight_decay=args["weight_decay"]
)
for epoch in range(args['num_epochs']):
for epoch in range(args["num_epochs"]):
model.train()
for step, (seeds, blocks) in enumerate(dataloader):
h_list = load_subtensors(blocks, features)
blocks = [block.to(args['device']) for block in blocks]
hs = [h.to(args['device']) for h in h_list]
blocks = [block.to(args["device"]) for block in blocks]
hs = [h.to(args["device"]) for h in h_list]
logits = model(blocks, hs)
loss = loss_fn(logits, labels[numpy.asarray(seeds)].to(args['device']))
loss = loss_fn(
logits, labels[numpy.asarray(seeds)].to(args["device"])
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# print info in each batch
train_acc, train_micro_f1, train_macro_f1 = score(logits, labels[numpy.asarray(seeds)])
train_acc, train_micro_f1, train_macro_f1 = score(
logits, labels[numpy.asarray(seeds)]
)
print(
"Epoch {:d} | loss: {:.4f} | train_acc: {:.4f} | train_micro_f1: {:.4f} | train_macro_f1: {:.4f}".format(
epoch + 1, loss, train_acc, train_micro_f1, train_macro_f1
))
val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(model, g, metapath_list, num_neighbors, features,
labels, val_nid, loss_fn, args['batch_size'])
)
)
val_loss, val_acc, val_micro_f1, val_macro_f1 = evaluate(
model,
g,
metapath_list,
num_neighbors,
features,
labels,
val_nid,
loss_fn,
args["batch_size"],
)
early_stop = stopper.step(val_loss.data.item(), val_acc, model)
print('Epoch {:d} | Val loss {:.4f} | Val Accuracy {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}'.format(
epoch + 1, val_loss.item(), val_acc, val_micro_f1, val_macro_f1))
print(
"Epoch {:d} | Val loss {:.4f} | Val Accuracy {:.4f} | Val Micro f1 {:.4f} | Val Macro f1 {:.4f}".format(
epoch + 1, val_loss.item(), val_acc, val_micro_f1, val_macro_f1
)
)
if early_stop:
break
stopper.load_checkpoint(model)
test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(model, g, metapath_list, num_neighbors, features,
labels, test_nid, loss_fn, args['batch_size'])
print('Test loss {:.4f} | Test Accuracy {:.4f} | Test Micro f1 {:.4f} | Test Macro f1 {:.4f}'.format(
test_loss.item(), test_acc, test_micro_f1, test_macro_f1))
if __name__ == '__main__':
parser = argparse.ArgumentParser('mini-batch HAN')
parser.add_argument('-s', '--seed', type=int, default=1,
help='Random seed')
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--num_neighbors', type=int, default=20)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--num_heads', type=list, default=[8])
parser.add_argument('--hidden_units', type=int, default=8)
parser.add_argument('--dropout', type=float, default=0.6)
parser.add_argument('--weight_decay', type=float, default=0.001)
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--patience', type=int, default=10)
parser.add_argument('--dataset', type=str, default='ACMRaw')
parser.add_argument('--device', type=str, default='cuda:0')
test_loss, test_acc, test_micro_f1, test_macro_f1 = evaluate(
model,
g,
metapath_list,
num_neighbors,
features,
labels,
test_nid,
loss_fn,
args["batch_size"],
)
print(
"Test loss {:.4f} | Test Accuracy {:.4f} | Test Micro f1 {:.4f} | Test Macro f1 {:.4f}".format(
test_loss.item(), test_acc, test_micro_f1, test_macro_f1
)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser("mini-batch HAN")
parser.add_argument("-s", "--seed", type=int, default=1, help="Random seed")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--num_neighbors", type=int, default=20)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--num_heads", type=list, default=[8])
parser.add_argument("--hidden_units", type=int, default=8)
parser.add_argument("--dropout", type=float, default=0.6)
parser.add_argument("--weight_decay", type=float, default=0.001)
parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--patience", type=int, default=10)
parser.add_argument("--dataset", type=str, default="ACMRaw")
parser.add_argument("--device", type=str, default="cuda:0")
args = parser.parse_args().__dict__
# set_random_seed(args['seed'])
......
import datetime
import dgl
import errno
import numpy as np
import os
import pickle
import random
import torch
from dgl.data.utils import download, get_download_dir, _get_dgl_url
from pprint import pprint
from scipy import sparse
import numpy as np
import torch
from scipy import io as sio
from scipy import sparse
import dgl
from dgl.data.utils import _get_dgl_url, download, get_download_dir
def set_random_seed(seed=0):
"""Set random seed.
......@@ -25,6 +27,7 @@ def set_random_seed(seed=0):
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def mkdir_p(path, log=True):
"""Create a directory for the specified path.
Parameters
......@@ -37,13 +40,14 @@ def mkdir_p(path, log=True):
try:
os.makedirs(path)
if log:
print('Created directory {}'.format(path))
print("Created directory {}".format(path))
except OSError as exc:
if exc.errno == errno.EEXIST and os.path.isdir(path) and log:
print('Directory {} already exists.'.format(path))
print("Directory {} already exists.".format(path))
else:
raise
def get_date_postfix():
"""Get a date based postfix for directory name.
Returns
......@@ -51,11 +55,13 @@ def get_date_postfix():
post_fix : str
"""
dt = datetime.datetime.now()
post_fix = '{}_{:02d}-{:02d}-{:02d}'.format(
dt.date(), dt.hour, dt.minute, dt.second)
post_fix = "{}_{:02d}-{:02d}-{:02d}".format(
dt.date(), dt.hour, dt.minute, dt.second
)
return post_fix
def setup_log_dir(args, sampling=False):
"""Name and create directory for logging.
Parameters
......@@ -71,106 +77,124 @@ def setup_log_dir(args, sampling=False):
"""
date_postfix = get_date_postfix()
log_dir = os.path.join(
args['log_dir'],
'{}_{}'.format(args['dataset'], date_postfix))
args["log_dir"], "{}_{}".format(args["dataset"], date_postfix)
)
if sampling:
log_dir = log_dir + '_sampling'
log_dir = log_dir + "_sampling"
mkdir_p(log_dir)
return log_dir
# The configuration below is from the paper.
default_configure = {
'lr': 0.005, # Learning rate
'num_heads': [8], # Number of attention heads for node-level attention
'hidden_units': 8,
'dropout': 0.6,
'weight_decay': 0.001,
'num_epochs': 200,
'patience': 100
"lr": 0.005, # Learning rate
"num_heads": [8], # Number of attention heads for node-level attention
"hidden_units": 8,
"dropout": 0.6,
"weight_decay": 0.001,
"num_epochs": 200,
"patience": 100,
}
sampling_configure = {
'batch_size': 20
}
sampling_configure = {"batch_size": 20}
def setup(args):
args.update(default_configure)
set_random_seed(args['seed'])
args['dataset'] = 'ACMRaw' if args['hetero'] else 'ACM'
args['device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu'
args['log_dir'] = setup_log_dir(args)
set_random_seed(args["seed"])
args["dataset"] = "ACMRaw" if args["hetero"] else "ACM"
args["device"] = "cuda:0" if torch.cuda.is_available() else "cpu"
args["log_dir"] = setup_log_dir(args)
return args
def setup_for_sampling(args):
args.update(default_configure)
args.update(sampling_configure)
set_random_seed()
args['device'] = 'cuda:0' if torch.cuda.is_available() else 'cpu'
args['log_dir'] = setup_log_dir(args, sampling=True)
args["device"] = "cuda:0" if torch.cuda.is_available() else "cpu"
args["log_dir"] = setup_log_dir(args, sampling=True)
return args
def get_binary_mask(total_size, indices):
mask = torch.zeros(total_size)
mask[indices] = 1
return mask.byte()
def load_acm(remove_self_loop):
url = 'dataset/ACM3025.pkl'
data_path = get_download_dir() + '/ACM3025.pkl'
url = "dataset/ACM3025.pkl"
data_path = get_download_dir() + "/ACM3025.pkl"
download(_get_dgl_url(url), path=data_path)
with open(data_path, 'rb') as f:
with open(data_path, "rb") as f:
data = pickle.load(f)
labels, features = torch.from_numpy(data['label'].todense()).long(), \
torch.from_numpy(data['feature'].todense()).float()
labels, features = (
torch.from_numpy(data["label"].todense()).long(),
torch.from_numpy(data["feature"].todense()).float(),
)
num_classes = labels.shape[1]
labels = labels.nonzero()[:, 1]
if remove_self_loop:
num_nodes = data['label'].shape[0]
data['PAP'] = sparse.csr_matrix(data['PAP'] - np.eye(num_nodes))
data['PLP'] = sparse.csr_matrix(data['PLP'] - np.eye(num_nodes))
num_nodes = data["label"].shape[0]
data["PAP"] = sparse.csr_matrix(data["PAP"] - np.eye(num_nodes))
data["PLP"] = sparse.csr_matrix(data["PLP"] - np.eye(num_nodes))
# Adjacency matrices for meta path based neighbors
# (Mufei): I verified both of them are binary adjacency matrices with self loops
author_g = dgl.from_scipy(data['PAP'])
subject_g = dgl.from_scipy(data['PLP'])
author_g = dgl.from_scipy(data["PAP"])
subject_g = dgl.from_scipy(data["PLP"])
gs = [author_g, subject_g]
train_idx = torch.from_numpy(data['train_idx']).long().squeeze(0)
val_idx = torch.from_numpy(data['val_idx']).long().squeeze(0)
test_idx = torch.from_numpy(data['test_idx']).long().squeeze(0)
train_idx = torch.from_numpy(data["train_idx"]).long().squeeze(0)
val_idx = torch.from_numpy(data["val_idx"]).long().squeeze(0)
test_idx = torch.from_numpy(data["test_idx"]).long().squeeze(0)
num_nodes = author_g.number_of_nodes()
train_mask = get_binary_mask(num_nodes, train_idx)
val_mask = get_binary_mask(num_nodes, val_idx)
test_mask = get_binary_mask(num_nodes, test_idx)
print('dataset loaded')
pprint({
'dataset': 'ACM',
'train': train_mask.sum().item() / num_nodes,
'val': val_mask.sum().item() / num_nodes,
'test': test_mask.sum().item() / num_nodes
})
print("dataset loaded")
pprint(
{
"dataset": "ACM",
"train": train_mask.sum().item() / num_nodes,
"val": val_mask.sum().item() / num_nodes,
"test": test_mask.sum().item() / num_nodes,
}
)
return (
gs,
features,
labels,
num_classes,
train_idx,
val_idx,
test_idx,
train_mask,
val_mask,
test_mask,
)
return gs, features, labels, num_classes, train_idx, val_idx, test_idx, \
train_mask, val_mask, test_mask
def load_acm_raw(remove_self_loop):
assert not remove_self_loop
url = 'dataset/ACM.mat'
data_path = get_download_dir() + '/ACM.mat'
url = "dataset/ACM.mat"
data_path = get_download_dir() + "/ACM.mat"
download(_get_dgl_url(url), path=data_path)
data = sio.loadmat(data_path)
p_vs_l = data['PvsL'] # paper-field?
p_vs_a = data['PvsA'] # paper-author
p_vs_t = data['PvsT'] # paper-term, bag of words
p_vs_c = data['PvsC'] # paper-conference, labels come from that
p_vs_l = data["PvsL"] # paper-field?
p_vs_a = data["PvsA"] # paper-author
p_vs_t = data["PvsT"] # paper-term, bag of words
p_vs_c = data["PvsC"] # paper-conference, labels come from that
# We assign
# (1) KDD papers as class 0 (data mining),
......@@ -186,12 +210,14 @@ def load_acm_raw(remove_self_loop):
p_vs_t = p_vs_t[p_selected]
p_vs_c = p_vs_c[p_selected]
hg = dgl.heterograph({
('paper', 'pa', 'author'): p_vs_a.nonzero(),
('author', 'ap', 'paper'): p_vs_a.transpose().nonzero(),
('paper', 'pf', 'field'): p_vs_l.nonzero(),
('field', 'fp', 'paper'): p_vs_l.transpose().nonzero()
})
hg = dgl.heterograph(
{
("paper", "pa", "author"): p_vs_a.nonzero(),
("author", "ap", "paper"): p_vs_a.transpose().nonzero(),
("paper", "pf", "field"): p_vs_l.nonzero(),
("field", "fp", "paper"): p_vs_l.transpose().nonzero(),
}
)
features = torch.FloatTensor(p_vs_t.toarray())
......@@ -205,33 +231,48 @@ def load_acm_raw(remove_self_loop):
float_mask = np.zeros(len(pc_p))
for conf_id in conf_ids:
pc_c_mask = (pc_c == conf_id)
float_mask[pc_c_mask] = np.random.permutation(np.linspace(0, 1, pc_c_mask.sum()))
pc_c_mask = pc_c == conf_id
float_mask[pc_c_mask] = np.random.permutation(
np.linspace(0, 1, pc_c_mask.sum())
)
train_idx = np.where(float_mask <= 0.2)[0]
val_idx = np.where((float_mask > 0.2) & (float_mask <= 0.3))[0]
test_idx = np.where(float_mask > 0.3)[0]
num_nodes = hg.number_of_nodes('paper')
num_nodes = hg.number_of_nodes("paper")
train_mask = get_binary_mask(num_nodes, train_idx)
val_mask = get_binary_mask(num_nodes, val_idx)
test_mask = get_binary_mask(num_nodes, test_idx)
return hg, features, labels, num_classes, train_idx, val_idx, test_idx, \
train_mask, val_mask, test_mask
return (
hg,
features,
labels,
num_classes,
train_idx,
val_idx,
test_idx,
train_mask,
val_mask,
test_mask,
)
def load_data(dataset, remove_self_loop=False):
if dataset == 'ACM':
if dataset == "ACM":
return load_acm(remove_self_loop)
elif dataset == 'ACMRaw':
elif dataset == "ACMRaw":
return load_acm_raw(remove_self_loop)
else:
return NotImplementedError('Unsupported dataset {}'.format(dataset))
return NotImplementedError("Unsupported dataset {}".format(dataset))
class EarlyStopping(object):
def __init__(self, patience=10):
dt = datetime.datetime.now()
self.filename = 'early_stop_{}_{:02d}-{:02d}-{:02d}.pth'.format(
dt.date(), dt.hour, dt.minute, dt.second)
self.filename = "early_stop_{}_{:02d}-{:02d}-{:02d}.pth".format(
dt.date(), dt.hour, dt.minute, dt.second
)
self.patience = patience
self.counter = 0
self.best_acc = None
......@@ -245,7 +286,9 @@ class EarlyStopping(object):
self.save_checkpoint(model)
elif (loss > self.best_loss) and (acc < self.best_acc):
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
print(
f"EarlyStopping counter: {self.counter} out of {self.patience}"
)
if self.counter >= self.patience:
self.early_stop = True
else:
......
......@@ -5,28 +5,33 @@ References
Paper: https://arxiv.org/abs/1907.04652
"""
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl.base import DGLError
from dgl.nn.pytorch import edge_softmax
from dgl.sampling import select_topk
from functools import partial
from dgl.nn.pytorch.utils import Identity
import torch.nn.functional as F
from dgl.base import DGLError
import dgl
from dgl.sampling import select_topk
class HardGAO(nn.Module):
def __init__(self,
def __init__(
self,
in_feats,
out_feats,
num_heads=8,
feat_drop=0.,
attn_drop=0.,
feat_drop=0.0,
attn_drop=0.0,
negative_slope=0.2,
residual=True,
activation=F.elu,
k=8,):
k=8,
):
super(HardGAO, self).__init__()
self.num_heads = num_heads
self.in_feats = in_feats
......@@ -35,11 +40,16 @@ class HardGAO(nn.Module):
self.residual = residual
# Initialize Parameters for Additive Attention
self.fc = nn.Linear(
self.in_feats, self.out_feats * self.num_heads, bias=False)
self.attn_l = nn.Parameter(torch.FloatTensor(size=(1, self.num_heads, self.out_feats)))
self.attn_r = nn.Parameter(torch.FloatTensor(size=(1, self.num_heads, self.out_feats)))
self.in_feats, self.out_feats * self.num_heads, bias=False
)
self.attn_l = nn.Parameter(
torch.FloatTensor(size=(1, self.num_heads, self.out_feats))
)
self.attn_r = nn.Parameter(
torch.FloatTensor(size=(1, self.num_heads, self.out_feats))
)
# Initialize Parameters for Hard Projection
self.p = nn.Parameter(torch.FloatTensor(size=(1,in_feats)))
self.p = nn.Parameter(torch.FloatTensor(size=(1, in_feats)))
# Initialize Dropouts
self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop)
......@@ -48,72 +58,81 @@ class HardGAO(nn.Module):
if self.in_feats == self.out_feats:
self.residual_module = Identity()
else:
self.residual_module = nn.Linear(self.in_feats,self.out_feats*num_heads,bias=False)
self.residual_module = nn.Linear(
self.in_feats, self.out_feats * num_heads, bias=False
)
self.reset_parameters()
self.activation = activation
def reset_parameters(self):
gain = nn.init.calculate_gain('relu')
gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.fc.weight, gain=gain)
nn.init.xavier_normal_(self.p,gain=gain)
nn.init.xavier_normal_(self.p, gain=gain)
nn.init.xavier_normal_(self.attn_l, gain=gain)
nn.init.xavier_normal_(self.attn_r, gain=gain)
if self.residual:
nn.init.xavier_normal_(self.residual_module.weight,gain=gain)
nn.init.xavier_normal_(self.residual_module.weight, gain=gain)
def forward(self, graph, feat, get_attention=False):
# Check in degree and generate error
if (graph.in_degrees()==0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
if (graph.in_degrees() == 0).any():
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
# projection process to get importance vector y
graph.ndata['y'] = torch.abs(torch.matmul(self.p,feat.T).view(-1))/torch.norm(self.p,p=2)
graph.ndata["y"] = torch.abs(
torch.matmul(self.p, feat.T).view(-1)
) / torch.norm(self.p, p=2)
# Use edge message passing function to get the weight from src node
graph.apply_edges(fn.copy_u('y','y'))
graph.apply_edges(fn.copy_u("y", "y"))
# Select Top k neighbors
subgraph = select_topk(graph.cpu(),self.k,'y').to(graph.device)
subgraph = select_topk(graph.cpu(), self.k, "y").to(graph.device)
# Sigmoid as information threshold
subgraph.ndata['y'] = torch.sigmoid(subgraph.ndata['y'])
subgraph.ndata["y"] = torch.sigmoid(subgraph.ndata["y"])
# Using vector matrix elementwise mul for acceleration
feat = subgraph.ndata['y'].view(-1,1)*feat
feat = subgraph.ndata["y"].view(-1, 1) * feat
feat = self.feat_drop(feat)
h = self.fc(feat).view(-1, self.num_heads, self.out_feats)
el = (h * self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (h * self.attn_r).sum(dim=-1).unsqueeze(-1)
# Assign the value on the subgraph
subgraph.srcdata.update({'ft': h, 'el': el})
subgraph.dstdata.update({'er': er})
subgraph.srcdata.update({"ft": h, "el": el})
subgraph.dstdata.update({"er": er})
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
subgraph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(subgraph.edata.pop('e'))
subgraph.apply_edges(fn.u_add_v("el", "er", "e"))
e = self.leaky_relu(subgraph.edata.pop("e"))
# compute softmax
subgraph.edata['a'] = self.attn_drop(edge_softmax(subgraph, e))
subgraph.edata["a"] = self.attn_drop(edge_softmax(subgraph, e))
# message passing
subgraph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
rst = subgraph.dstdata['ft']
subgraph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft"))
rst = subgraph.dstdata["ft"]
# activation
if self.activation:
rst = self.activation(rst)
# Residual
if self.residual:
rst = rst + self.residual_module(feat).view(feat.shape[0],-1,self.out_feats)
rst = rst + self.residual_module(feat).view(
feat.shape[0], -1, self.out_feats
)
if get_attention:
return rst, subgraph.edata['a']
return rst, subgraph.edata["a"]
else:
return rst
class HardGAT(nn.Module):
def __init__(self,
def __init__(
self,
g,
num_layers,
in_dim,
......@@ -125,28 +144,56 @@ class HardGAT(nn.Module):
attn_drop,
negative_slope,
residual,
k):
k,
):
super(HardGAT, self).__init__()
self.g = g
self.num_layers = num_layers
self.gat_layers = nn.ModuleList()
self.activation = activation
gat_layer = partial(HardGAO,k=k)
gat_layer = partial(HardGAO, k=k)
muls = heads
# input projection (no residual)
self.gat_layers.append(gat_layer(
in_dim, num_hidden, heads[0],
feat_drop, attn_drop, negative_slope, False, self.activation))
self.gat_layers.append(
gat_layer(
in_dim,
num_hidden,
heads[0],
feat_drop,
attn_drop,
negative_slope,
False,
self.activation,
)
)
# hidden layers
for l in range(1, num_layers):
# due to multi-head, the in_dim = num_hidden * num_heads
self.gat_layers.append(gat_layer(
num_hidden*muls[l-1] , num_hidden, heads[l],
feat_drop, attn_drop, negative_slope, residual, self.activation))
self.gat_layers.append(
gat_layer(
num_hidden * muls[l - 1],
num_hidden,
heads[l],
feat_drop,
attn_drop,
negative_slope,
residual,
self.activation,
)
)
# output projection
self.gat_layers.append(gat_layer(
num_hidden*muls[-2] , num_classes, heads[-1],
feat_drop, attn_drop, negative_slope, False, None))
self.gat_layers.append(
gat_layer(
num_hidden * muls[-2],
num_classes,
heads[-1],
feat_drop,
attn_drop,
negative_slope,
False,
None,
)
)
def forward(self, inputs):
h = inputs
......
......@@ -6,17 +6,22 @@ Paper: https://arxiv.org/abs/1907.04652
"""
import argparse
import numpy as np
import time
import numpy as np
import torch
import torch.nn.functional as F
import dgl
from dgl.data import register_data_args
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
from hgao import HardGAT
from utils import EarlyStopping
import dgl
from dgl.data import (
CiteseerGraphDataset,
CoraGraphDataset,
PubmedGraphDataset,
register_data_args,
)
def accuracy(logits, labels):
_, indices = torch.max(logits, dim=1)
......@@ -35,16 +40,16 @@ def evaluate(model, features, labels, mask):
def main(args):
# load and preprocess dataset
if args.dataset == 'cora':
if args.dataset == "cora":
data = CoraGraphDataset()
elif args.dataset == 'citeseer':
elif args.dataset == "citeseer":
data = CiteseerGraphDataset()
elif args.dataset == 'pubmed':
elif args.dataset == "pubmed":
data = PubmedGraphDataset()
else:
raise ValueError('Unknown dataset: {}'.format(args.dataset))
raise ValueError("Unknown dataset: {}".format(args.dataset))
if args.num_layers <=0:
if args.num_layers <= 0:
raise ValueError("num layer must be positive int")
g = data[0]
if args.gpu < 0:
......@@ -53,24 +58,29 @@ def main(args):
cuda = True
g = g.to(args.gpu)
features = g.ndata['feat']
labels = g.ndata['label']
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
features = g.ndata["feat"]
labels = g.ndata["label"]
train_mask = g.ndata["train_mask"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
num_feats = features.shape[1]
n_classes = data.num_labels
n_edges = g.number_of_edges()
print("""----Data statistics------'
print(
"""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
#Test samples %d"""
% (
n_edges,
n_classes,
train_mask.int().sum().item(),
val_mask.int().sum().item(),
test_mask.int().sum().item()))
test_mask.int().sum().item(),
)
)
# add self loop
g = dgl.remove_self_loop(g)
......@@ -78,7 +88,8 @@ def main(args):
n_edges = g.number_of_edges()
# create model
heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]
model = HardGAT(g,
model = HardGAT(
g,
args.num_layers,
num_feats,
args.num_hidden,
......@@ -89,7 +100,8 @@ def main(args):
args.attn_drop,
args.negative_slope,
args.residual,
args.k)
args.k,
)
print(model)
if args.early_stop:
stopper = EarlyStopping(patience=100)
......@@ -99,7 +111,8 @@ def main(args):
# use optimizer
optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
# initialize graph
dur = []
......@@ -128,52 +141,96 @@ def main(args):
if stopper.step(val_acc, model):
break
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |"
" ValAcc {:.4f} | ETputs(KTEPS) {:.2f}".
format(epoch, np.mean(dur), loss.item(), train_acc,
val_acc, n_edges / np.mean(dur) / 1000))
print(
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |"
" ValAcc {:.4f} | ETputs(KTEPS) {:.2f}".format(
epoch,
np.mean(dur),
loss.item(),
train_acc,
val_acc,
n_edges / np.mean(dur) / 1000,
)
)
print()
if args.early_stop:
model.load_state_dict(torch.load('es_checkpoint.pt'))
model.load_state_dict(torch.load("es_checkpoint.pt"))
acc = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='GAT')
parser = argparse.ArgumentParser(description="GAT")
register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1,
help="which GPU to use. Set -1 to use CPU.")
parser.add_argument("--epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--num-heads", type=int, default=8,
help="number of hidden attention heads")
parser.add_argument("--num-out-heads", type=int, default=1,
help="number of output attention heads")
parser.add_argument("--num-layers", type=int, default=1,
help="number of hidden layers")
parser.add_argument("--num-hidden", type=int, default=8,
help="number of hidden units")
parser.add_argument("--residual", action="store_true", default=False,
help="use residual connection")
parser.add_argument("--in-drop", type=float, default=.6,
help="input feature dropout")
parser.add_argument("--attn-drop", type=float, default=.6,
help="attention dropout")
parser.add_argument("--lr", type=float, default=0.01,
help="learning rate")
parser.add_argument('--weight-decay', type=float, default=5e-4,
help="weight decay")
parser.add_argument('--negative-slope', type=float, default=0.2,
help="the negative slope of leaky relu")
parser.add_argument('--early-stop', action='store_true', default=False,
help="indicates whether to use early stop or not")
parser.add_argument('--fastmode', action="store_true", default=False,
help="skip re-evaluate the validation set")
parser.add_argument('--k',type=int,default=8,
help='top k neighor for attention calculation')
parser.add_argument(
"--gpu",
type=int,
default=-1,
help="which GPU to use. Set -1 to use CPU.",
)
parser.add_argument(
"--epochs", type=int, default=200, help="number of training epochs"
)
parser.add_argument(
"--num-heads",
type=int,
default=8,
help="number of hidden attention heads",
)
parser.add_argument(
"--num-out-heads",
type=int,
default=1,
help="number of output attention heads",
)
parser.add_argument(
"--num-layers", type=int, default=1, help="number of hidden layers"
)
parser.add_argument(
"--num-hidden", type=int, default=8, help="number of hidden units"
)
parser.add_argument(
"--residual",
action="store_true",
default=False,
help="use residual connection",
)
parser.add_argument(
"--in-drop", type=float, default=0.6, help="input feature dropout"
)
parser.add_argument(
"--attn-drop", type=float, default=0.6, help="attention dropout"
)
parser.add_argument("--lr", type=float, default=0.01, help="learning rate")
parser.add_argument(
"--weight-decay", type=float, default=5e-4, help="weight decay"
)
parser.add_argument(
"--negative-slope",
type=float,
default=0.2,
help="the negative slope of leaky relu",
)
parser.add_argument(
"--early-stop",
action="store_true",
default=False,
help="indicates whether to use early stop or not",
)
parser.add_argument(
"--fastmode",
action="store_true",
default=False,
help="skip re-evaluate the validation set",
)
parser.add_argument(
"--k",
type=int,
default=8,
help="top k neighor for attention calculation",
)
args = parser.parse_args()
print(args)
......
......@@ -9,6 +9,7 @@ import numpy as np
import torch
import torch.nn as nn
class EarlyStopping:
def __init__(self, patience=10):
self.patience = patience
......@@ -23,7 +24,9 @@ class EarlyStopping:
self.save_checkpoint(model)
elif score < self.best_score:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
print(
f"EarlyStopping counter: {self.counter} out of {self.patience}"
)
if self.counter >= self.patience:
self.early_stop = True
else:
......@@ -33,5 +36,5 @@ class EarlyStopping:
return self.early_stop
def save_checkpoint(self, model):
'''Saves model when validation loss decrease.'''
torch.save(model.state_dict(), 'es_checkpoint.pt')
\ No newline at end of file
"""Saves model when validation loss decrease."""
torch.save(model.state_dict(), "es_checkpoint.pt")
......@@ -7,17 +7,23 @@ for detailed description.
Here we implement a graph-edge version of sparsemax where we perform sparsemax for all edges
with the same node as end-node in graphs.
"""
import dgl
import torch
from torch import Tensor
from torch.autograd import Function
import dgl
from dgl.backend import astype
from dgl.base import ALL, is_all
from dgl.heterograph_index import HeteroGraphIndex
from dgl.sparse import _gsddmm, _gspmm
from torch import Tensor
from torch.autograd import Function
def _neighbor_sort(scores:Tensor, end_n_ids:Tensor, in_degrees:Tensor, cum_in_degrees:Tensor):
def _neighbor_sort(
scores: Tensor,
end_n_ids: Tensor,
in_degrees: Tensor,
cum_in_degrees: Tensor,
):
"""Sort edge scores for each node"""
num_nodes, max_in_degree = in_degrees.size(0), int(in_degrees.max().item())
......@@ -32,41 +38,68 @@ def _neighbor_sort(scores:Tensor, end_n_ids:Tensor, in_degrees:Tensor, cum_in_de
scores = scores[perm]
_, reverse_perm = torch.sort(perm)
index = torch.arange(end_n_ids.size(0), dtype=torch.long, device=scores.device)
index = torch.arange(
end_n_ids.size(0), dtype=torch.long, device=scores.device
)
index = (index - cum_in_degrees[end_n_ids]) + (end_n_ids * max_in_degree)
index = index.long()
dense_scores = scores.new_full((num_nodes * max_in_degree, ), torch.finfo(scores.dtype).min)
dense_scores = scores.new_full(
(num_nodes * max_in_degree,), torch.finfo(scores.dtype).min
)
dense_scores[index] = scores
dense_scores = dense_scores.view(num_nodes, max_in_degree)
sorted_dense_scores, dense_reverse_perm = dense_scores.sort(dim=-1, descending=True)
sorted_dense_scores, dense_reverse_perm = dense_scores.sort(
dim=-1, descending=True
)
_, dense_reverse_perm = torch.sort(dense_reverse_perm, dim=-1)
dense_reverse_perm = dense_reverse_perm + cum_in_degrees.view(-1, 1)
dense_reverse_perm = dense_reverse_perm.view(-1)
cumsum_sorted_dense_scores = sorted_dense_scores.cumsum(dim=-1).view(-1)
sorted_dense_scores = sorted_dense_scores.view(-1)
arange_vec = torch.arange(1, max_in_degree + 1, dtype=torch.long, device=end_n_ids.device)
arange_vec = torch.repeat_interleave(arange_vec.view(1, -1), num_nodes, dim=0).view(-1)
valid_mask = (sorted_dense_scores != torch.finfo(scores.dtype).min)
arange_vec = torch.arange(
1, max_in_degree + 1, dtype=torch.long, device=end_n_ids.device
)
arange_vec = torch.repeat_interleave(
arange_vec.view(1, -1), num_nodes, dim=0
).view(-1)
valid_mask = sorted_dense_scores != torch.finfo(scores.dtype).min
sorted_scores = sorted_dense_scores[valid_mask]
cumsum_sorted_scores = cumsum_sorted_dense_scores[valid_mask]
arange_vec = arange_vec[valid_mask]
dense_reverse_perm = dense_reverse_perm[valid_mask].long()
return sorted_scores, cumsum_sorted_scores, arange_vec, reverse_perm, dense_reverse_perm
return (
sorted_scores,
cumsum_sorted_scores,
arange_vec,
reverse_perm,
dense_reverse_perm,
)
def _threshold_and_support_graph(gidx:HeteroGraphIndex, scores:Tensor, end_n_ids:Tensor):
def _threshold_and_support_graph(
gidx: HeteroGraphIndex, scores: Tensor, end_n_ids: Tensor
):
"""Find the threshold for each node and its edges"""
in_degrees = _gspmm(gidx, "copy_rhs", "sum", None, torch.ones_like(scores))[0]
cum_in_degrees = torch.cat([in_degrees.new_zeros(1), in_degrees.cumsum(dim=0)[:-1]], dim=0)
in_degrees = _gspmm(gidx, "copy_rhs", "sum", None, torch.ones_like(scores))[
0
]
cum_in_degrees = torch.cat(
[in_degrees.new_zeros(1), in_degrees.cumsum(dim=0)[:-1]], dim=0
)
# perform sort on edges for each node
sorted_scores, cumsum_scores, rhos, reverse_perm, dense_reverse_perm = _neighbor_sort(scores, end_n_ids,
in_degrees, cum_in_degrees)
cumsum_scores = cumsum_scores - 1.
(
sorted_scores,
cumsum_scores,
rhos,
reverse_perm,
dense_reverse_perm,
) = _neighbor_sort(scores, end_n_ids, in_degrees, cum_in_degrees)
cumsum_scores = cumsum_scores - 1.0
support = rhos * sorted_scores > cumsum_scores
support = support[dense_reverse_perm] # from sorted order to unsorted order
support = support[reverse_perm] # from src-dst order to eid order
......@@ -94,9 +127,16 @@ class EdgeSparsemaxFunction(Function):
sparsemax involves sort and select, which are
not derivative.
"""
@staticmethod
def forward(ctx, gidx:HeteroGraphIndex, scores:Tensor,
eids:Tensor, end_n_ids:Tensor, norm_by:str):
def forward(
ctx,
gidx: HeteroGraphIndex,
scores: Tensor,
eids: Tensor,
end_n_ids: Tensor,
norm_by: str,
):
if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == "src":
......@@ -125,7 +165,9 @@ class EdgeSparsemaxFunction(Function):
grad_in[out == 0] = 0
# dL/dv_i = dL/do_i - 1/k \sum_{j=1}^k dL/do_j
v_hat = _gspmm(gidx, "copy_rhs", "sum", None, grad_in)[0] / supp_size.to(out.dtype)
v_hat = _gspmm(gidx, "copy_rhs", "sum", None, grad_in)[
0
] / supp_size.to(out.dtype)
grad_in_modify = _gsddmm(gidx, "sub", grad_in, v_hat, "e", "v")
grad_in = torch.where(out != 0, grad_in_modify, grad_in)
del gidx
......@@ -134,7 +176,7 @@ class EdgeSparsemaxFunction(Function):
return None, grad_in, None, None, None
def edge_sparsemax(graph:dgl.DGLGraph, logits, eids=ALL, norm_by="dst"):
def edge_sparsemax(graph: dgl.DGLGraph, logits, eids=ALL, norm_by="dst"):
r"""
Description
-----------
......@@ -176,8 +218,9 @@ def edge_sparsemax(graph:dgl.DGLGraph, logits, eids=ALL, norm_by="dst"):
if not is_all(eids):
eids = astype(eids, graph.idtype)
end_n_ids = end_n_ids[eids]
return EdgeSparsemaxFunction.apply(graph._graph, logits,
eids, end_n_ids, norm_by)
return EdgeSparsemaxFunction.apply(
graph._graph, logits, eids, end_n_ids, norm_by
)
class EdgeSparsemax(torch.nn.Module):
......@@ -213,6 +256,7 @@ class EdgeSparsemax(torch.nn.Module):
Tensor
Sparsemax value.
"""
def __init__(self):
super(EdgeSparsemax, self).__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment