Unverified Commit 704bcaf6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent 6bc82161
import dgl
import torch
import torch.nn
import torch.nn.functional as F
from layer import ConvPoolBlock, SAGPool
import dgl
from dgl.nn import AvgPooling, GraphConv, MaxPooling
from layer import ConvPoolBlock, SAGPool
class SAGNetworkHierarchical(torch.nn.Module):
......
......@@ -20,7 +20,6 @@ def _transform_log_level(str_level):
class LightLogging(object):
def __init__(self, log_path=None, log_name="lightlog", log_level="debug"):
log_level = _transform_log_level(log_level)
if log_path:
......
......@@ -3,6 +3,9 @@ import time
import numpy as np
import torch
import torch.multiprocessing
from dgl import EID, NID
from dgl.dataloading import GraphDataLoader
from logger import LightLogging
from model import DGCNN, GCN
from sampler import SEALData
......@@ -10,9 +13,6 @@ from torch.nn import BCEWithLogitsLoss
from tqdm import tqdm
from utils import evaluate_hits, load_ogb_dataset, parse_arguments
from dgl import EID, NID
from dgl.dataloading import GraphDataLoader
torch.multiprocessing.set_sharing_strategy("file_system")
"""
......
import os.path as osp
from copy import deepcopy
import dgl
import torch
from dgl import add_self_loop, DGLGraph, NID
from dgl.dataloading.negative_sampler import Uniform
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from utils import drnl_node_labeling
import dgl
from dgl import NID, DGLGraph, add_self_loop
from dgl.dataloading.negative_sampler import Uniform
class GraphDataSet(Dataset):
"""
......@@ -48,7 +48,6 @@ class PosNegEdgesGenerator(object):
self.shuffle = shuffle
def __call__(self, split_type):
if split_type == "train":
subsample_ratio = self.subsample_ratio
else:
......@@ -177,7 +176,6 @@ class SEALSampler(object):
return subgraph
def _collate(self, batch):
batch_graphs, batch_labels = map(list, zip(*batch))
batch_graphs = dgl.batch(batch_graphs)
......@@ -272,7 +270,6 @@ class SEALData(object):
)
def __call__(self, split_type):
if split_type == "train":
subsample_ratio = self.subsample_ratio
else:
......
import argparse
import dgl
import numpy as np
import pandas as pd
import torch
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
from scipy.sparse.csgraph import shortest_path
import dgl
def parse_arguments():
"""
......
......@@ -9,13 +9,13 @@ import argparse
import math
import time
import dgl
import dgl.function as fn
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl.data import (
CiteseerGraphDataset,
CoraGraphDataset,
......
......@@ -9,12 +9,12 @@ import argparse
import math
import time
import dgl.function as fn
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import load_data, register_data_args
from dgl.nn.pytorch.conv import SGConv
......
import dgl
import numpy as np
import torch
import dgl
def load_dataset(name):
dataset = name.lower()
......
......@@ -2,14 +2,14 @@ import argparse
import os
import time
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataset import load_dataset
import dgl
import dgl.function as fn
class FeedForwardNet(nn.Module):
def __init__(self, in_feats, hidden, out_feats, n_layers, dropout):
......
......@@ -6,10 +6,10 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tagcn import TAGCN
from dgl import DGLGraph
from dgl.data import load_data, register_data_args
from tagcn import TAGCN
def evaluate(model, features, labels, mask):
......
......@@ -2,32 +2,37 @@ from .attention import *
from .layers import *
from .functions import *
from .embedding import *
import torch as th
import dgl.function as fn
import torch as th
import torch.nn.init as INIT
class UEncoder(nn.Module):
def __init__(self, layer):
super(UEncoder, self).__init__()
self.layer = layer
self.norm = LayerNorm(layer.size)
def pre_func(self, fields='qkv'):
def pre_func(self, fields="qkv"):
layer = self.layer
def func(nodes):
x = nodes.data['x']
x = nodes.data["x"]
norm_x = layer.sublayer[0].norm(x)
return layer.self_attn.get(norm_x, fields=fields)
return func
def post_func(self):
layer = self.layer
def func(nodes):
x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']
x, wv, z = nodes.data["x"], nodes.data["wv"], nodes.data["z"]
o = layer.self_attn.get_o(wv / z)
x = x + layer.sublayer[0].dropout(o)
x = layer.sublayer[1](x, layer.feed_forward)
return {'x': x}
return {"x": x}
return func
......@@ -37,31 +42,36 @@ class UDecoder(nn.Module):
self.layer = layer
self.norm = LayerNorm(layer.size)
def pre_func(self, fields='qkv', l=0):
def pre_func(self, fields="qkv", l=0):
layer = self.layer
def func(nodes):
x = nodes.data['x']
if fields == 'kv':
x = nodes.data["x"]
if fields == "kv":
norm_x = x
else:
norm_x = layer.sublayer[l].norm(x)
return layer.self_attn.get(norm_x, fields)
return func
def post_func(self, l=0):
layer = self.layer
def func(nodes):
x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']
x, wv, z = nodes.data["x"], nodes.data["wv"], nodes.data["z"]
o = layer.self_attn.get_o(wv / z)
x = x + layer.sublayer[l].dropout(o)
if l == 1:
x = layer.sublayer[2](x, layer.feed_forward)
return {'x': x}
return {"x": x}
return func
class HaltingUnit(nn.Module):
halting_bias_init = 1.0
def __init__(self, dim_model):
super(HaltingUnit, self).__init__()
self.linear = nn.Linear(dim_model, 1)
......@@ -71,12 +81,25 @@ class HaltingUnit(nn.Module):
def forward(self, x):
return th.sigmoid(self.linear(self.norm(x)))
class UTransformer(nn.Module):
"Universal Transformer(https://arxiv.org/pdf/1807.03819.pdf) with ACT(https://arxiv.org/pdf/1603.08983.pdf)."
MAX_DEPTH = 8
thres = 0.99
act_loss_weight = 0.01
def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, time_enc, generator, h, d_k):
def __init__(
self,
encoder,
decoder,
src_embed,
tgt_embed,
pos_enc,
time_enc,
generator,
h,
d_k,
):
super(UTransformer, self).__init__()
self.encoder, self.decoder = encoder, decoder
self.src_embed, self.tgt_embed = src_embed, tgt_embed
......@@ -91,34 +114,45 @@ class UTransformer(nn.Module):
self.stat = [0] * (self.MAX_DEPTH + 1)
def step_forward(self, nodes):
x = nodes.data['x']
step = nodes.data['step']
pos = nodes.data['pos']
return {'x': self.pos_enc.dropout(x + self.pos_enc(pos.view(-1)) + self.time_enc(step.view(-1))),
'step': step + 1}
x = nodes.data["x"]
step = nodes.data["step"]
pos = nodes.data["pos"]
return {
"x": self.pos_enc.dropout(
x + self.pos_enc(pos.view(-1)) + self.time_enc(step.view(-1))
),
"step": step + 1,
}
def halt_and_accum(self, name, end=False):
"field: 'enc' or 'dec'"
halt = self.halt_enc if name == 'enc' else self.halt_dec
halt = self.halt_enc if name == "enc" else self.halt_dec
thres = self.thres
def func(nodes):
p = halt(nodes.data['x'])
sum_p = nodes.data['sum_p'] + p
p = halt(nodes.data["x"])
sum_p = nodes.data["sum_p"] + p
active = (sum_p < thres) & (1 - end)
_continue = active.float()
r = nodes.data['r'] * (1 - _continue) + (1 - sum_p) * _continue
s = nodes.data['s'] + ((1 - _continue) * r + _continue * p) * nodes.data['x']
return {'p': p, 'sum_p': sum_p, 'r': r, 's': s, 'active': active}
r = nodes.data["r"] * (1 - _continue) + (1 - sum_p) * _continue
s = (
nodes.data["s"]
+ ((1 - _continue) * r + _continue * p) * nodes.data["x"]
)
return {"p": p, "sum_p": sum_p, "r": r, "s": s, "active": active}
return func
def propagate_attention(self, g, eids):
# Compute attention score
g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids)
g.apply_edges(src_dot_dst("k", "q", "score"), eids)
g.apply_edges(scaled_exp("score", np.sqrt(self.d_k)), eids)
# Send weighted values to target nodes
g.send_and_recv(eids,
[fn.u_mul_e('v', 'score', 'v'), fn.copy_e('score', 'score')],
[fn.sum('v', 'wv'), fn.sum('score', 'z')])
g.send_and_recv(
eids,
[fn.u_mul_e("v", "score", "v"), fn.copy_e("score", "score")],
[fn.sum("v", "wv"), fn.sum("score", "z")],
)
def update_graph(self, g, eids, pre_pairs, post_pairs):
"Update the node states and edge states of the graph."
......@@ -136,79 +170,128 @@ class UTransformer(nn.Module):
nids, eids = graph.nids, graph.eids
# embed & pos
g.nodes[nids['enc']].data['x'] = self.src_embed(graph.src[0])
g.nodes[nids['dec']].data['x'] = self.tgt_embed(graph.tgt[0])
g.nodes[nids['enc']].data['pos'] = graph.src[1]
g.nodes[nids['dec']].data['pos'] = graph.tgt[1]
g.nodes[nids["enc"]].data["x"] = self.src_embed(graph.src[0])
g.nodes[nids["dec"]].data["x"] = self.tgt_embed(graph.tgt[0])
g.nodes[nids["enc"]].data["pos"] = graph.src[1]
g.nodes[nids["dec"]].data["pos"] = graph.tgt[1]
# init step
device = next(self.parameters()).device
g.ndata['s'] = th.zeros(N, self.h * self.d_k, dtype=th.float, device=device) # accumulated state
g.ndata['p'] = th.zeros(N, 1, dtype=th.float, device=device) # halting prob
g.ndata['r'] = th.ones(N, 1, dtype=th.float, device=device) # remainder
g.ndata['sum_p'] = th.zeros(N, 1, dtype=th.float, device=device) # sum of pondering values
g.ndata['step'] = th.zeros(N, 1, dtype=th.long, device=device) # step
g.ndata['active'] = th.ones(N, 1, dtype=th.uint8, device=device) # active
g.ndata["s"] = th.zeros(
N, self.h * self.d_k, dtype=th.float, device=device
) # accumulated state
g.ndata["p"] = th.zeros(
N, 1, dtype=th.float, device=device
) # halting prob
g.ndata["r"] = th.ones(N, 1, dtype=th.float, device=device) # remainder
g.ndata["sum_p"] = th.zeros(
N, 1, dtype=th.float, device=device
) # sum of pondering values
g.ndata["step"] = th.zeros(N, 1, dtype=th.long, device=device) # step
g.ndata["active"] = th.ones(
N, 1, dtype=th.uint8, device=device
) # active
for step in range(self.MAX_DEPTH):
pre_func = self.encoder.pre_func('qkv')
pre_func = self.encoder.pre_func("qkv")
post_func = self.encoder.post_func()
nodes = g.filter_nodes(lambda v: v.data['active'].view(-1), nids['enc'])
if len(nodes) == 0: break
edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['ee'])
nodes = g.filter_nodes(
lambda v: v.data["active"].view(-1), nids["enc"]
)
if len(nodes) == 0:
break
edges = g.filter_edges(
lambda e: e.dst["active"].view(-1), eids["ee"]
)
end = step == self.MAX_DEPTH - 1
self.update_graph(g, edges,
self.update_graph(
g,
edges,
[(self.step_forward, nodes), (pre_func, nodes)],
[(post_func, nodes), (self.halt_and_accum('enc', end), nodes)])
[(post_func, nodes), (self.halt_and_accum("enc", end), nodes)],
)
g.nodes[nids['enc']].data['x'] = self.encoder.norm(g.nodes[nids['enc']].data['s'])
g.nodes[nids["enc"]].data["x"] = self.encoder.norm(
g.nodes[nids["enc"]].data["s"]
)
for step in range(self.MAX_DEPTH):
pre_func = self.decoder.pre_func('qkv')
pre_func = self.decoder.pre_func("qkv")
post_func = self.decoder.post_func()
nodes = g.filter_nodes(lambda v: v.data['active'].view(-1), nids['dec'])
if len(nodes) == 0: break
edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['dd'])
self.update_graph(g, edges,
nodes = g.filter_nodes(
lambda v: v.data["active"].view(-1), nids["dec"]
)
if len(nodes) == 0:
break
edges = g.filter_edges(
lambda e: e.dst["active"].view(-1), eids["dd"]
)
self.update_graph(
g,
edges,
[(self.step_forward, nodes), (pre_func, nodes)],
[(post_func, nodes)])
[(post_func, nodes)],
)
pre_q = self.decoder.pre_func('q', 1)
pre_kv = self.decoder.pre_func('kv', 1)
pre_q = self.decoder.pre_func("q", 1)
pre_kv = self.decoder.pre_func("kv", 1)
post_func = self.decoder.post_func(1)
nodes_e = nids['enc']
edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['ed'])
nodes_e = nids["enc"]
edges = g.filter_edges(
lambda e: e.dst["active"].view(-1), eids["ed"]
)
end = step == self.MAX_DEPTH - 1
self.update_graph(g, edges,
self.update_graph(
g,
edges,
[(pre_q, nodes), (pre_kv, nodes_e)],
[(post_func, nodes), (self.halt_and_accum('dec', end), nodes)])
[(post_func, nodes), (self.halt_and_accum("dec", end), nodes)],
)
g.nodes[nids['dec']].data['x'] = self.decoder.norm(g.nodes[nids['dec']].data['s'])
act_loss = th.mean(g.ndata['r']) # ACT loss
g.nodes[nids["dec"]].data["x"] = self.decoder.norm(
g.nodes[nids["dec"]].data["s"]
)
act_loss = th.mean(g.ndata["r"]) # ACT loss
self.stat[0] += N
for step in range(1, self.MAX_DEPTH + 1):
self.stat[step] += th.sum(g.ndata['step'] >= step).item()
self.stat[step] += th.sum(g.ndata["step"] >= step).item()
return self.generator(g.ndata['x'][nids['dec']]), act_loss * self.act_loss_weight
return (
self.generator(g.ndata["x"][nids["dec"]]),
act_loss * self.act_loss_weight,
)
def infer(self, *args, **kwargs):
raise NotImplementedError
def make_universal_model(src_vocab, tgt_vocab, dim_model=512, dim_ff=2048, h=8, dropout=0.1):
def make_universal_model(
src_vocab, tgt_vocab, dim_model=512, dim_ff=2048, h=8, dropout=0.1
):
c = copy.deepcopy
attn = MultiHeadAttention(h, dim_model)
ff = PositionwiseFeedForward(dim_model, dim_ff)
pos_enc = PositionalEncoding(dim_model, dropout)
time_enc = PositionalEncoding(dim_model, dropout)
encoder = UEncoder(EncoderLayer((dim_model), c(attn), c(ff), dropout))
decoder = UDecoder(DecoderLayer((dim_model), c(attn), c(attn), c(ff), dropout))
decoder = UDecoder(
DecoderLayer((dim_model), c(attn), c(attn), c(ff), dropout)
)
src_embed = Embeddings(src_vocab, dim_model)
tgt_embed = Embeddings(tgt_vocab, dim_model)
generator = Generator(dim_model, tgt_vocab)
model = UTransformer(
encoder, decoder, src_embed, tgt_embed, pos_enc, time_enc, generator, h, dim_model // h)
encoder,
decoder,
src_embed,
tgt_embed,
pos_enc,
time_enc,
generator,
h,
dim_model // h,
)
# xavier init
for p in model.parameters():
if p.dim() > 1:
......
......@@ -6,10 +6,12 @@ from .layers import *
from .functions import *
from .embedding import *
import threading
import torch as th
import dgl.function as fn
import torch as th
import torch.nn.init as INIT
class Encoder(nn.Module):
def __init__(self, layer, N):
super(Encoder, self).__init__()
......@@ -17,24 +19,29 @@ class Encoder(nn.Module):
self.layers = clones(layer, N)
self.norm = LayerNorm(layer.size)
def pre_func(self, i, fields='qkv'):
def pre_func(self, i, fields="qkv"):
layer = self.layers[i]
def func(nodes):
x = nodes.data['x']
x = nodes.data["x"]
norm_x = layer.sublayer[0].norm(x)
return layer.self_attn.get(norm_x, fields=fields)
return func
def post_func(self, i):
layer = self.layers[i]
def func(nodes):
x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']
x, wv, z = nodes.data["x"], nodes.data["wv"], nodes.data["z"]
o = layer.self_attn.get_o(wv / z)
x = x + layer.sublayer[0].dropout(o)
x = layer.sublayer[1](x, layer.feed_forward)
return {'x': x if i < self.N - 1 else self.norm(x)}
return {"x": x if i < self.N - 1 else self.norm(x)}
return func
class Decoder(nn.Module):
def __init__(self, layer, N):
super(Decoder, self).__init__()
......@@ -42,30 +49,37 @@ class Decoder(nn.Module):
self.layers = clones(layer, N)
self.norm = LayerNorm(layer.size)
def pre_func(self, i, fields='qkv', l=0):
def pre_func(self, i, fields="qkv", l=0):
layer = self.layers[i]
def func(nodes):
x = nodes.data['x']
norm_x = layer.sublayer[l].norm(x) if fields.startswith('q') else x
if fields != 'qkv':
x = nodes.data["x"]
norm_x = layer.sublayer[l].norm(x) if fields.startswith("q") else x
if fields != "qkv":
return layer.src_attn.get(norm_x, fields)
else:
return layer.self_attn.get(norm_x, fields)
return func
def post_func(self, i, l=0):
layer = self.layers[i]
def func(nodes):
x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z']
x, wv, z = nodes.data["x"], nodes.data["wv"], nodes.data["z"]
o = layer.self_attn.get_o(wv / z)
x = x + layer.sublayer[l].dropout(o)
if l == 1:
x = layer.sublayer[2](x, layer.feed_forward)
return {'x': x if i < self.N - 1 else self.norm(x)}
return {"x": x if i < self.N - 1 else self.norm(x)}
return func
class Transformer(nn.Module):
def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, generator, h, d_k):
def __init__(
self, encoder, decoder, src_embed, tgt_embed, pos_enc, generator, h, d_k
):
super(Transformer, self).__init__()
self.encoder, self.decoder = encoder, decoder
self.src_embed, self.tgt_embed = src_embed, tgt_embed
......@@ -76,11 +90,11 @@ class Transformer(nn.Module):
def propagate_attention(self, g, eids):
# Compute attention score
g.apply_edges(src_dot_dst('k', 'q', 'score'), eids)
g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids)
g.apply_edges(src_dot_dst("k", "q", "score"), eids)
g.apply_edges(scaled_exp("score", np.sqrt(self.d_k)), eids)
# Send weighted values to target nodes
g.send_and_recv(eids, fn.u_mul_e('v', 'score', 'v'), fn.sum('v', 'wv'))
g.send_and_recv(eids, fn.copy_e('score', 'score'), fn.sum('score', 'z'))
g.send_and_recv(eids, fn.u_mul_e("v", "score", "v"), fn.sum("v", "wv"))
g.send_and_recv(eids, fn.copy_e("score", "score"), fn.sum("score", "z"))
def update_graph(self, g, eids, pre_pairs, post_pairs):
"Update the node states and edge states of the graph."
......@@ -98,27 +112,44 @@ class Transformer(nn.Module):
nids, eids = graph.nids, graph.eids
# embed
src_embed, src_pos = self.src_embed(graph.src[0]), self.pos_enc(graph.src[1])
tgt_embed, tgt_pos = self.tgt_embed(graph.tgt[0]), self.pos_enc(graph.tgt[1])
g.nodes[nids['enc']].data['x'] = self.pos_enc.dropout(src_embed + src_pos)
g.nodes[nids['dec']].data['x'] = self.pos_enc.dropout(tgt_embed + tgt_pos)
src_embed, src_pos = self.src_embed(graph.src[0]), self.pos_enc(
graph.src[1]
)
tgt_embed, tgt_pos = self.tgt_embed(graph.tgt[0]), self.pos_enc(
graph.tgt[1]
)
g.nodes[nids["enc"]].data["x"] = self.pos_enc.dropout(
src_embed + src_pos
)
g.nodes[nids["dec"]].data["x"] = self.pos_enc.dropout(
tgt_embed + tgt_pos
)
for i in range(self.encoder.N):
pre_func = self.encoder.pre_func(i, 'qkv')
pre_func = self.encoder.pre_func(i, "qkv")
post_func = self.encoder.post_func(i)
nodes, edges = nids['enc'], eids['ee']
self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])
nodes, edges = nids["enc"], eids["ee"]
self.update_graph(
g, edges, [(pre_func, nodes)], [(post_func, nodes)]
)
for i in range(self.decoder.N):
pre_func = self.decoder.pre_func(i, 'qkv')
pre_func = self.decoder.pre_func(i, "qkv")
post_func = self.decoder.post_func(i)
nodes, edges = nids['dec'], eids['dd']
self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])
pre_q = self.decoder.pre_func(i, 'q', 1)
pre_kv = self.decoder.pre_func(i, 'kv', 1)
nodes, edges = nids["dec"], eids["dd"]
self.update_graph(
g, edges, [(pre_func, nodes)], [(post_func, nodes)]
)
pre_q = self.decoder.pre_func(i, "q", 1)
pre_kv = self.decoder.pre_func(i, "kv", 1)
post_func = self.decoder.post_func(i, 1)
nodes_e, edges = nids['enc'], eids['ed']
self.update_graph(g, edges, [(pre_q, nodes), (pre_kv, nodes_e)], [(post_func, nodes)])
nodes_e, edges = nids["enc"], eids["ed"]
self.update_graph(
g,
edges,
[(pre_q, nodes), (pre_kv, nodes_e)],
[(post_func, nodes)],
)
# visualize attention
"""
......@@ -126,9 +157,10 @@ class Transformer(nn.Module):
self._register_att_map(g, graph.nid_arr['enc'][VIZ_IDX], graph.nid_arr['dec'][VIZ_IDX])
"""
return self.generator(g.ndata['x'][nids['dec']])
return self.generator(g.ndata["x"][nids["dec"]])
def infer(self, graph, max_len, eos_id, k, alpha=1.0):
'''
"""
This function implements Beam Search in DGL, which is required in inference phase.
Length normalization is given by (5 + len) ^ alpha / 6 ^ alpha. Please refer to https://arxiv.org/pdf/1609.08144.pdf.
args:
......@@ -138,7 +170,7 @@ class Transformer(nn.Module):
k: beam size
return:
ret: a list of index array correspond to the input sequence specified by `graph``.
'''
"""
g = graph.g
N, E = graph.n_nodes, graph.n_edges
nids, eids = graph.nids, graph.eids
......@@ -146,21 +178,25 @@ class Transformer(nn.Module):
# embed & pos
src_embed = self.src_embed(graph.src[0])
src_pos = self.pos_enc(graph.src[1])
g.nodes[nids['enc']].data['pos'] = graph.src[1]
g.nodes[nids['enc']].data['x'] = self.pos_enc.dropout(src_embed + src_pos)
g.nodes[nids["enc"]].data["pos"] = graph.src[1]
g.nodes[nids["enc"]].data["x"] = self.pos_enc.dropout(
src_embed + src_pos
)
tgt_pos = self.pos_enc(graph.tgt[1])
g.nodes[nids['dec']].data['pos'] = graph.tgt[1]
g.nodes[nids["dec"]].data["pos"] = graph.tgt[1]
# init mask
device = next(self.parameters()).device
g.ndata['mask'] = th.zeros(N, dtype=th.uint8, device=device)
g.ndata["mask"] = th.zeros(N, dtype=th.uint8, device=device)
# encode
for i in range(self.encoder.N):
pre_func = self.encoder.pre_func(i, 'qkv')
pre_func = self.encoder.pre_func(i, "qkv")
post_func = self.encoder.post_func(i)
nodes, edges = nids['enc'], eids['ee']
self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])
nodes, edges = nids["enc"], eids["ee"]
self.update_graph(
g, edges, [(pre_func, nodes)], [(post_func, nodes)]
)
# decode
log_prob = None
......@@ -168,36 +204,76 @@ class Transformer(nn.Module):
for step in range(1, max_len):
y = y.view(-1)
tgt_embed = self.tgt_embed(y)
g.ndata['x'][nids['dec']] = self.pos_enc.dropout(tgt_embed + tgt_pos)
edges_ed = g.filter_edges(lambda e: (e.dst['pos'] < step) & ~e.dst['mask'].bool(), eids['ed'])
edges_dd = g.filter_edges(lambda e: (e.dst['pos'] < step) & ~e.dst['mask'].bool(), eids['dd'])
nodes_d = g.filter_nodes(lambda v: (v.data['pos'] < step) & ~v.data['mask'].bool(), nids['dec'])
g.ndata["x"][nids["dec"]] = self.pos_enc.dropout(
tgt_embed + tgt_pos
)
edges_ed = g.filter_edges(
lambda e: (e.dst["pos"] < step) & ~e.dst["mask"].bool(),
eids["ed"],
)
edges_dd = g.filter_edges(
lambda e: (e.dst["pos"] < step) & ~e.dst["mask"].bool(),
eids["dd"],
)
nodes_d = g.filter_nodes(
lambda v: (v.data["pos"] < step) & ~v.data["mask"].bool(),
nids["dec"],
)
for i in range(self.decoder.N):
pre_func, post_func = self.decoder.pre_func(i, 'qkv'), self.decoder.post_func(i)
pre_func, post_func = self.decoder.pre_func(
i, "qkv"
), self.decoder.post_func(i)
nodes, edges = nodes_d, edges_dd
self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)])
pre_q, pre_kv = self.decoder.pre_func(i, 'q', 1), self.decoder.pre_func(i, 'kv', 1)
self.update_graph(
g, edges, [(pre_func, nodes)], [(post_func, nodes)]
)
pre_q, pre_kv = self.decoder.pre_func(
i, "q", 1
), self.decoder.pre_func(i, "kv", 1)
post_func = self.decoder.post_func(i, 1)
nodes_e, nodes_d, edges = nids['enc'], nodes_d, edges_ed
self.update_graph(g, edges, [(pre_q, nodes_d), (pre_kv, nodes_e)], [(post_func, nodes_d)])
nodes_e, nodes_d, edges = nids["enc"], nodes_d, edges_ed
self.update_graph(
g,
edges,
[(pre_q, nodes_d), (pre_kv, nodes_e)],
[(post_func, nodes_d)],
)
frontiers = g.filter_nodes(lambda v: v.data['pos'] == step - 1, nids['dec'])
out = self.generator(g.ndata['x'][frontiers])
frontiers = g.filter_nodes(
lambda v: v.data["pos"] == step - 1, nids["dec"]
)
out = self.generator(g.ndata["x"][frontiers])
batch_size = frontiers.shape[0] // k
vocab_size = out.shape[-1]
# Mask output for complete sequence
one_hot = th.zeros(vocab_size).fill_(-1e9).to(device)
one_hot[eos_id] = 0
mask = g.ndata['mask'][frontiers].unsqueeze(-1).float()
mask = g.ndata["mask"][frontiers].unsqueeze(-1).float()
out = out * (1 - mask) + one_hot.unsqueeze(0) * mask
if log_prob is None:
log_prob, pos = out.view(batch_size, k, -1)[:, 0, :].topk(k, dim=-1)
log_prob, pos = out.view(batch_size, k, -1)[:, 0, :].topk(
k, dim=-1
)
eos = th.zeros(batch_size, k).byte()
else:
norm_old = eos.float().to(device) + (1 - eos.float().to(device)) * np.power((4. + step) / 6, alpha)
norm_new = eos.float().to(device) + (1 - eos.float().to(device)) * np.power((5. + step) / 6, alpha)
log_prob, pos = ((out.view(batch_size, k, -1) + (log_prob * norm_old).unsqueeze(-1)) / norm_new.unsqueeze(-1)).view(batch_size, -1).topk(k, dim=-1)
norm_old = eos.float().to(device) + (
1 - eos.float().to(device)
) * np.power((4.0 + step) / 6, alpha)
norm_new = eos.float().to(device) + (
1 - eos.float().to(device)
) * np.power((5.0 + step) / 6, alpha)
log_prob, pos = (
(
(
out.view(batch_size, k, -1)
+ (log_prob * norm_old).unsqueeze(-1)
)
/ norm_new.unsqueeze(-1)
)
.view(batch_size, -1)
.topk(k, dim=-1)
)
_y = y.view(batch_size * k, -1)
y = th.zeros_like(_y)
......@@ -206,14 +282,16 @@ class Transformer(nn.Module):
for j in range(k):
_j = pos[i, j].item() // vocab_size
token = pos[i, j].item() % vocab_size
y[i*k+j, :] = _y[i*k+_j, :]
y[i*k+j, step] = token
y[i * k + j, :] = _y[i * k + _j, :]
y[i * k + j, step] = token
eos[i, j] = _eos[i, _j] | (token == eos_id)
if eos.all():
break
else:
g.ndata['mask'][nids['dec']] = eos.unsqueeze(-1).repeat(1, 1, max_len).view(-1).to(device)
g.ndata["mask"][nids["dec"]] = (
eos.unsqueeze(-1).repeat(1, 1, max_len).view(-1).to(device)
)
return y.view(batch_size, k, -1)[:, 0, :].tolist()
def _register_att_map(self, g, enc_ids, dec_ids):
......@@ -224,22 +302,42 @@ class Transformer(nn.Module):
]
def make_model(src_vocab, tgt_vocab, N=6,
dim_model=512, dim_ff=2048, h=8, dropout=0.1, universal=False):
def make_model(
src_vocab,
tgt_vocab,
N=6,
dim_model=512,
dim_ff=2048,
h=8,
dropout=0.1,
universal=False,
):
if universal:
return make_universal_model(src_vocab, tgt_vocab, dim_model, dim_ff, h, dropout)
return make_universal_model(
src_vocab, tgt_vocab, dim_model, dim_ff, h, dropout
)
c = copy.deepcopy
attn = MultiHeadAttention(h, dim_model)
ff = PositionwiseFeedForward(dim_model, dim_ff)
pos_enc = PositionalEncoding(dim_model, dropout)
encoder = Encoder(EncoderLayer(dim_model, c(attn), c(ff), dropout), N)
decoder = Decoder(DecoderLayer(dim_model, c(attn), c(attn), c(ff), dropout), N)
decoder = Decoder(
DecoderLayer(dim_model, c(attn), c(attn), c(ff), dropout), N
)
src_embed = Embeddings(src_vocab, dim_model)
tgt_embed = Embeddings(tgt_vocab, dim_model)
generator = Generator(dim_model, tgt_vocab)
model = Transformer(
encoder, decoder, src_embed, tgt_embed, pos_enc, generator, h, dim_model // h)
encoder,
decoder,
src_embed,
tgt_embed,
pos_enc,
generator,
h,
dim_model // h,
)
# xavier init
for p in model.parameters():
if p.dim() > 1:
......
import os
import numpy as np
import torch as th
import networkx as nx
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import torch as th
from networkx.algorithms import bipartite
def get_attention_map(g, src_nodes, dst_nodes, h):
"""
To visualize the attention score between two set of nodes.
......@@ -18,14 +20,15 @@ def get_attention_map(g, src_nodes, dst_nodes, h):
if not g.has_edge_between(src, dst):
continue
eid = g.edge_ids(src, dst)
weight[i][j] = g.edata['score'][eid].squeeze(-1).cpu().detach()
weight[i][j] = g.edata["score"][eid].squeeze(-1).cpu().detach()
weight = weight.transpose(0, 2)
att = th.softmax(weight, -2)
return att.numpy()
def draw_heatmap(array, input_seq, output_seq, dirname, name):
dirname = os.path.join('log', dirname)
dirname = os.path.join("log", dirname)
if not os.path.exists(dirname):
os.makedirs(dirname)
......@@ -38,30 +41,37 @@ def draw_heatmap(array, input_seq, output_seq, dirname, name):
axes[i, j].set_xticks(np.arange(len(output_seq)))
axes[i, j].set_yticklabels(input_seq, fontsize=4)
axes[i, j].set_xticklabels(output_seq, fontsize=4)
axes[i, j].set_title('head_{}'.format(cnt), fontsize=10)
plt.setp(axes[i, j].get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
axes[i, j].set_title("head_{}".format(cnt), fontsize=10)
plt.setp(
axes[i, j].get_xticklabels(),
rotation=45,
ha="right",
rotation_mode="anchor",
)
cnt += 1
fig.suptitle(name, fontsize=12)
plt.tight_layout()
plt.savefig(os.path.join(dirname, '{}.pdf'.format(name)))
plt.savefig(os.path.join(dirname, "{}.pdf".format(name)))
plt.close()
def draw_atts(maps, src, tgt, dirname, prefix):
'''
"""
maps[0]: encoder self-attention
maps[1]: encoder-decoder attention
maps[2]: decoder self-attention
'''
draw_heatmap(maps[0], src, src, dirname, '{}_enc_self_attn'.format(prefix))
draw_heatmap(maps[1], src, tgt, dirname, '{}_enc_dec_attn'.format(prefix))
draw_heatmap(maps[2], tgt, tgt, dirname, '{}_dec_self_attn'.format(prefix))
"""
draw_heatmap(maps[0], src, src, dirname, "{}_enc_self_attn".format(prefix))
draw_heatmap(maps[1], src, tgt, dirname, "{}_enc_dec_attn".format(prefix))
draw_heatmap(maps[2], tgt, tgt, dirname, "{}_dec_self_attn".format(prefix))
mode2id = {'e2e': 0, 'e2d': 1, 'd2d': 2}
mode2id = {"e2e": 0, "e2d": 1, "d2d": 2}
colorbar = None
def att_animation(maps_array, mode, src, tgt, head_id):
weights = [maps[mode2id[mode]][head_id] for maps in maps_array]
fig, axes = plt.subplots(1, 2)
......@@ -71,63 +81,112 @@ def att_animation(maps_array, mode, src, tgt, head_id):
if colorbar:
colorbar.remove()
plt.cla()
axes[0].set_title('heatmap')
axes[0].set_title("heatmap")
axes[0].set_yticks(np.arange(len(src)))
axes[0].set_xticks(np.arange(len(tgt)))
axes[0].set_yticklabels(src)
axes[0].set_xticklabels(tgt)
plt.setp(axes[0].get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
plt.setp(
axes[0].get_xticklabels(),
rotation=45,
ha="right",
rotation_mode="anchor",
)
fig.suptitle('epoch {}'.format(i))
fig.suptitle("epoch {}".format(i))
weight = weights[i].transpose(-1, -2)
heatmap = axes[0].pcolor(weight, vmin=0, vmax=1, cmap=plt.cm.Blues)
colorbar = plt.colorbar(heatmap, ax=axes[0], fraction=0.046, pad=0.04)
axes[0].set_aspect('equal')
axes[0].set_aspect("equal")
axes[1].axis("off")
graph_att_head(src, tgt, weight, axes[1], 'graph')
ani = animation.FuncAnimation(fig, weight_animate, frames=len(weights), interval=500, repeat_delay=2000)
graph_att_head(src, tgt, weight, axes[1], "graph")
ani = animation.FuncAnimation(
fig,
weight_animate,
frames=len(weights),
interval=500,
repeat_delay=2000,
)
return ani
def graph_att_head(M, N, weight, ax, title):
"credit: Jinjing Zhou"
in_nodes=len(M)
out_nodes=len(N)
in_nodes = len(M)
out_nodes = len(N)
g = nx.bipartite.generators.complete_bipartite_graph(in_nodes,out_nodes)
g = nx.bipartite.generators.complete_bipartite_graph(in_nodes, out_nodes)
X, Y = bipartite.sets(g)
height_in = 10
height_out = height_in
height_in_y = np.linspace(0, height_in, in_nodes)
height_out_y = np.linspace((height_in - height_out) / 2, height_out, out_nodes)
height_out_y = np.linspace(
(height_in - height_out) / 2, height_out, out_nodes
)
pos = dict()
pos.update((n, (1, i)) for i, n in zip(height_in_y, X)) # put nodes from X at x=1
pos.update((n, (3, i)) for i, n in zip(height_out_y, Y)) # put nodes from Y at x=2
ax.axis('off')
ax.set_xlim(-1,4)
pos.update(
(n, (1, i)) for i, n in zip(height_in_y, X)
) # put nodes from X at x=1
pos.update(
(n, (3, i)) for i, n in zip(height_out_y, Y)
) # put nodes from Y at x=2
ax.axis("off")
ax.set_xlim(-1, 4)
ax.set_title(title)
nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes), node_color='r', node_size=50, ax=ax)
nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes, in_nodes + out_nodes), node_color='b', node_size=50, ax=ax)
nx.draw_networkx_nodes(
g, pos, nodelist=range(in_nodes), node_color="r", node_size=50, ax=ax
)
nx.draw_networkx_nodes(
g,
pos,
nodelist=range(in_nodes, in_nodes + out_nodes),
node_color="b",
node_size=50,
ax=ax,
)
for edge in g.edges():
nx.draw_networkx_edges(g, pos, edgelist=[edge], width=weight[edge[0], edge[1] - in_nodes] * 1.5, ax=ax)
nx.draw_networkx_labels(g, pos, {i:label + ' ' for i,label in enumerate(M)},horizontalalignment='right', font_size=8, ax=ax)
nx.draw_networkx_labels(g, pos, {i+in_nodes:' ' + label for i,label in enumerate(N)},horizontalalignment='left', font_size=8, ax=ax)
nx.draw_networkx_edges(
g,
pos,
edgelist=[edge],
width=weight[edge[0], edge[1] - in_nodes] * 1.5,
ax=ax,
)
nx.draw_networkx_labels(
g,
pos,
{i: label + " " for i, label in enumerate(M)},
horizontalalignment="right",
font_size=8,
ax=ax,
)
nx.draw_networkx_labels(
g,
pos,
{i + in_nodes: " " + label for i, label in enumerate(N)},
horizontalalignment="left",
font_size=8,
ax=ax,
)
import networkx as nx
from matplotlib.patches import ConnectionStyle, FancyArrowPatch
from networkx.utils import is_string_like
from matplotlib.patches import ConnectionStyle,FancyArrowPatch
"The following function was modified from the source code of networkx"
def draw_networkx_edges(G, pos,
def draw_networkx_edges(
G,
pos,
edgelist=None,
width=1.0,
edge_color='k',
style='solid',
edge_color="k",
style="solid",
alpha=1.0,
arrowstyle='-|>',
arrowstyle="-|>",
arrowsize=10,
edge_cmap=None,
edge_vmin=None,
......@@ -138,8 +197,9 @@ def draw_networkx_edges(G, pos,
node_size=300,
nodelist=None,
node_shape="o",
connectionstyle='arc3',
**kwds):
connectionstyle="arc3",
**kwds
):
"""Draw the edges of the graph G.
This draws only the edges of the graph G.
......@@ -238,12 +298,12 @@ def draw_networkx_edges(G, pos,
"""
try:
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cbook as cb
from matplotlib.colors import colorConverter, Colormap, Normalize
from matplotlib.collections import LineCollection
from matplotlib.patches import FancyArrowPatch, ConnectionStyle
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.collections import LineCollection
from matplotlib.colors import colorConverter, Colormap, Normalize
from matplotlib.patches import ConnectionStyle, FancyArrowPatch
except ImportError:
raise ImportError("Matplotlib required for draw()")
except RuntimeError:
......@@ -270,33 +330,38 @@ def draw_networkx_edges(G, pos,
else:
lw = width
if not is_string_like(edge_color) \
and cb.iterable(edge_color) \
and len(edge_color) == len(edge_pos):
if (
not is_string_like(edge_color)
and cb.iterable(edge_color)
and len(edge_color) == len(edge_pos)
):
if np.alltrue([is_string_like(c) for c in edge_color]):
# (should check ALL elements)
# list of color letters such as ['k','r','k',...]
edge_colors = tuple([colorConverter.to_rgba(c, alpha)
for c in edge_color])
edge_colors = tuple(
[colorConverter.to_rgba(c, alpha) for c in edge_color]
)
elif np.alltrue([not is_string_like(c) for c in edge_color]):
# If color specs are given as (rgb) or (rgba) tuples, we're OK
if np.alltrue([cb.iterable(c) and len(c) in (3, 4)
for c in edge_color]):
if np.alltrue(
[cb.iterable(c) and len(c) in (3, 4) for c in edge_color]
):
edge_colors = tuple(edge_color)
else:
# numbers (which are going to be mapped with a colormap)
edge_colors = None
else:
raise ValueError('edge_color must contain color names or numbers')
raise ValueError("edge_color must contain color names or numbers")
else:
if is_string_like(edge_color) or len(edge_color) == 1:
edge_colors = (colorConverter.to_rgba(edge_color, alpha), )
edge_colors = (colorConverter.to_rgba(edge_color, alpha),)
else:
msg = 'edge_color must be a color or list of one color per edge'
msg = "edge_color must be a color or list of one color per edge"
raise ValueError(msg)
if (not G.is_directed() or not arrows):
edge_collection = LineCollection(edge_pos,
if not G.is_directed() or not arrows:
edge_collection = LineCollection(
edge_pos,
colors=edge_colors,
linewidths=lw,
antialiaseds=(1,),
......@@ -318,7 +383,7 @@ def draw_networkx_edges(G, pos,
if edge_colors is None:
if edge_cmap is not None:
assert(isinstance(edge_cmap, Colormap))
assert isinstance(edge_cmap, Colormap)
edge_collection.set_array(np.asarray(edge_color))
edge_collection.set_cmap(edge_cmap)
if edge_vmin is not None or edge_vmax is not None:
......@@ -346,7 +411,7 @@ def draw_networkx_edges(G, pos,
arrow_colors = edge_colors
if arrow_colors is None:
if edge_cmap is not None:
assert(isinstance(edge_cmap, Colormap))
assert isinstance(edge_cmap, Colormap)
else:
edge_cmap = plt.get_cmap() # default matplotlib colormap
if edge_vmin is None:
......@@ -379,7 +444,9 @@ def draw_networkx_edges(G, pos,
line_width = lw[i]
else:
line_width = lw[0]
arrow = FancyArrowPatch((x1, y1), (x2, y2),
arrow = FancyArrowPatch(
(x1, y1),
(x2, y2),
arrowstyle=arrowstyle,
shrinkA=shrink_source,
shrinkB=shrink_target,
......@@ -387,7 +454,8 @@ def draw_networkx_edges(G, pos,
connectionstyle=connectionstyle,
color=arrow_color,
linewidth=line_width,
zorder=1) # arrows go behind nodes
zorder=1,
) # arrows go behind nodes
# There seems to be a bug in matplotlib to make collections of
# FancyArrowPatch instances. Until fixed, the patches are added
......@@ -412,44 +480,81 @@ def draw_networkx_edges(G, pos,
def draw_g(graph):
g=graph.g.to_networkx()
fig=plt.figure(figsize=(8,4),dpi=150)
ax=fig.subplots()
ax.axis('off')
ax.set_ylim(-1,1.5)
en_indx=graph.nids['enc'].tolist()
de_indx=graph.nids['dec'].tolist()
en_l={i:np.array([i,0]) for i in en_indx}
de_l={i:np.array([i+2,1]) for i in de_indx}
en_de_s=[]
g = graph.g.to_networkx()
fig = plt.figure(figsize=(8, 4), dpi=150)
ax = fig.subplots()
ax.axis("off")
ax.set_ylim(-1, 1.5)
en_indx = graph.nids["enc"].tolist()
de_indx = graph.nids["dec"].tolist()
en_l = {i: np.array([i, 0]) for i in en_indx}
de_l = {i: np.array([i + 2, 1]) for i in de_indx}
en_de_s = []
for i in en_indx:
for j in de_indx:
en_de_s.append((i,j))
g.add_edge(i,j)
en_s=[]
en_de_s.append((i, j))
g.add_edge(i, j)
en_s = []
for i in en_indx:
for j in en_indx:
g.add_edge(i,j)
en_s.append((i,j))
g.add_edge(i, j)
en_s.append((i, j))
de_s=[]
for idx,i in enumerate(de_indx):
de_s = []
for idx, i in enumerate(de_indx):
for j in de_indx[idx:]:
g.add_edge(i,j)
de_s.append((i,j))
g.add_edge(i, j)
de_s.append((i, j))
nx.draw_networkx_nodes(g, en_l, nodelist=en_indx, node_color='r', node_size=60, ax=ax)
nx.draw_networkx_nodes(g, de_l, nodelist=de_indx, node_color='r', node_size=60, ax=ax)
draw_networkx_edges(g,en_l,edgelist=en_s, ax=ax,connectionstyle="arc3,rad=-0.3",width=0.5)
draw_networkx_edges(g,de_l,edgelist=de_s, ax=ax,connectionstyle="arc3,rad=-0.3",width=0.5)
draw_networkx_edges(g,{**en_l,**de_l},edgelist=en_de_s,width=0.3, ax=ax)
nx.draw_networkx_nodes(
g, en_l, nodelist=en_indx, node_color="r", node_size=60, ax=ax
)
nx.draw_networkx_nodes(
g, de_l, nodelist=de_indx, node_color="r", node_size=60, ax=ax
)
draw_networkx_edges(
g,
en_l,
edgelist=en_s,
ax=ax,
connectionstyle="arc3,rad=-0.3",
width=0.5,
)
draw_networkx_edges(
g,
de_l,
edgelist=de_s,
ax=ax,
connectionstyle="arc3,rad=-0.3",
width=0.5,
)
draw_networkx_edges(g, {**en_l, **de_l}, edgelist=en_de_s, width=0.3, ax=ax)
# ax.add_patch()
ax.text(len(en_indx)+0.5,0,"Encoder", verticalalignment='center', horizontalalignment='left')
ax.text(
len(en_indx) + 0.5,
0,
"Encoder",
verticalalignment="center",
horizontalalignment="left",
)
ax.text(len(en_indx)+0.5,1,"Decoder", verticalalignment='center', horizontalalignment='right')
delta=0.03
for value in {**en_l,**de_l}.values():
x,y=value
ax.add_patch(FancyArrowPatch((x-delta,y+delta),(x-delta,y-delta),arrowstyle="->",mutation_scale=8,connectionstyle="arc3,rad=3"))
ax.text(
len(en_indx) + 0.5,
1,
"Decoder",
verticalalignment="center",
horizontalalignment="right",
)
delta = 0.03
for value in {**en_l, **de_l}.values():
x, y = value
ax.add_patch(
FancyArrowPatch(
(x - delta, y + delta),
(x - delta, y - delta),
arrowstyle="->",
mutation_scale=8,
connectionstyle="arc3,rad=3",
)
)
plt.show(fig)
......@@ -2,17 +2,17 @@ import argparse
import collections
import time
import dgl
import numpy as np
import torch as th
import torch.nn.functional as F
import torch.nn.init as INIT
import torch.optim as optim
from dgl.data.tree import SSTDataset
from torch.utils.data import DataLoader
from tree_lstm import TreeLSTM
import dgl
from dgl.data.tree import SSTDataset
SSTBatch = collections.namedtuple(
"SSTBatch", ["graph", "mask", "wordid", "label"]
)
......
......@@ -2,14 +2,16 @@
Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks
https://arxiv.org/abs/1503.00075
"""
import time
import itertools
import time
import dgl
import networkx as nx
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import dgl
class TreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size):
......@@ -20,21 +22,22 @@ class TreeLSTMCell(nn.Module):
self.U_f = nn.Linear(2 * h_size, 2 * h_size)
def message_func(self, edges):
return {'h': edges.src['h'], 'c': edges.src['c']}
return {"h": edges.src["h"], "c": edges.src["c"]}
def reduce_func(self, nodes):
h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1)
f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size())
c = th.sum(f * nodes.mailbox['c'], 1)
return {'iou': self.U_iou(h_cat), 'c': c}
h_cat = nodes.mailbox["h"].view(nodes.mailbox["h"].size(0), -1)
f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox["h"].size())
c = th.sum(f * nodes.mailbox["c"], 1)
return {"iou": self.U_iou(h_cat), "c": c}
def apply_node_func(self, nodes):
iou = nodes.data['iou'] + self.b_iou
iou = nodes.data["iou"] + self.b_iou
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
c = i * u + nodes.data['c']
c = i * u + nodes.data["c"]
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
return {"h": h, "c": c}
class ChildSumTreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size):
......@@ -45,41 +48,44 @@ class ChildSumTreeLSTMCell(nn.Module):
self.U_f = nn.Linear(h_size, h_size)
def message_func(self, edges):
return {'h': edges.src['h'], 'c': edges.src['c']}
return {"h": edges.src["h"], "c": edges.src["c"]}
def reduce_func(self, nodes):
h_tild = th.sum(nodes.mailbox['h'], 1)
f = th.sigmoid(self.U_f(nodes.mailbox['h']))
c = th.sum(f * nodes.mailbox['c'], 1)
return {'iou': self.U_iou(h_tild), 'c': c}
h_tild = th.sum(nodes.mailbox["h"], 1)
f = th.sigmoid(self.U_f(nodes.mailbox["h"]))
c = th.sum(f * nodes.mailbox["c"], 1)
return {"iou": self.U_iou(h_tild), "c": c}
def apply_node_func(self, nodes):
iou = nodes.data['iou'] + self.b_iou
iou = nodes.data["iou"] + self.b_iou
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
c = i * u + nodes.data['c']
c = i * u + nodes.data["c"]
h = o * th.tanh(c)
return {'h': h, 'c': c}
return {"h": h, "c": c}
class TreeLSTM(nn.Module):
def __init__(self,
def __init__(
self,
num_vocabs,
x_size,
h_size,
num_classes,
dropout,
cell_type='nary',
pretrained_emb=None):
cell_type="nary",
pretrained_emb=None,
):
super(TreeLSTM, self).__init__()
self.x_size = x_size
self.embedding = nn.Embedding(num_vocabs, x_size)
if pretrained_emb is not None:
print('Using glove')
print("Using glove")
self.embedding.weight.data.copy_(pretrained_emb)
self.embedding.weight.requires_grad = True
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(h_size, num_classes)
cell = TreeLSTMCell if cell_type == 'nary' else ChildSumTreeLSTMCell
cell = TreeLSTMCell if cell_type == "nary" else ChildSumTreeLSTMCell
self.cell = cell(x_size, h_size)
def forward(self, batch, g, h, c):
......@@ -101,12 +107,19 @@ class TreeLSTM(nn.Module):
"""
# feed embedding
embeds = self.embedding(batch.wordid * batch.mask)
g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1)
g.ndata['h'] = h
g.ndata['c'] = c
g.ndata["iou"] = self.cell.W_iou(
self.dropout(embeds)
) * batch.mask.float().unsqueeze(-1)
g.ndata["h"] = h
g.ndata["c"] = c
# propagate
dgl.prop_nodes_topo(g, self.cell.message_func, self.cell.reduce_func, apply_node_func=self.cell.apply_node_func)
dgl.prop_nodes_topo(
g,
self.cell.message_func,
self.cell.reduce_func,
apply_node_func=self.cell.apply_node_func,
)
# compute logits
h = self.dropout(g.ndata.pop('h'))
h = self.dropout(g.ndata.pop("h"))
logits = self.linear(h)
return logits
import torch
import torch.nn as nn
import torch.nn.functional as F
from train import device
from dgl.nn.pytorch import GraphConv
from train import device
class VGAEModel(nn.Module):
......
......@@ -2,11 +2,14 @@ import argparse
import os
import time
import dgl
import model
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn.functional as F
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
from input_data import load_data
from preprocess import (
mask_test_edges,
......@@ -16,9 +19,6 @@ from preprocess import (
)
from sklearn.metrics import average_precision_score, roc_auc_score
import dgl
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
parser = argparse.ArgumentParser(description="Variant Graph Auto Encoder")
......
import argparse
import time
import dgl
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from torch.utils.data import DataLoader
import dgl
import dgl.function as fn
import dgl.nn.pytorch as dglnn
from dgl.data import RedditDataset
from torch.utils.data import DataLoader
class SAGEConvWithCV(nn.Module):
......
......@@ -3,6 +3,10 @@ import math
import time
import traceback
import dgl
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import numpy as np
import torch as th
import torch.multiprocessing as mp
......@@ -10,14 +14,10 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
from dgl.data import RedditDataset
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
import dgl
import dgl.function as fn
import dgl.nn.pytorch as dglnn
from dgl.data import RedditDataset
class SAGEConvWithCV(nn.Module):
def __init__(self, in_feats, out_feats, activation):
......
......@@ -9,6 +9,7 @@ import torch.nn.functional as F
from dgl.data import CoraGraphDataset
from torch.optim import Adam
###############################################################################
# (HIGHLIGHT) Compute Label Propagation with Sparse Matrix API
###############################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment