Unverified Commit 704bcaf6 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files
parent 6bc82161
import dgl
import torch import torch
import torch.nn import torch.nn
import torch.nn.functional as F import torch.nn.functional as F
from layer import ConvPoolBlock, SAGPool
import dgl
from dgl.nn import AvgPooling, GraphConv, MaxPooling from dgl.nn import AvgPooling, GraphConv, MaxPooling
from layer import ConvPoolBlock, SAGPool
class SAGNetworkHierarchical(torch.nn.Module): class SAGNetworkHierarchical(torch.nn.Module):
......
...@@ -20,7 +20,6 @@ def _transform_log_level(str_level): ...@@ -20,7 +20,6 @@ def _transform_log_level(str_level):
class LightLogging(object): class LightLogging(object):
def __init__(self, log_path=None, log_name="lightlog", log_level="debug"): def __init__(self, log_path=None, log_name="lightlog", log_level="debug"):
log_level = _transform_log_level(log_level) log_level = _transform_log_level(log_level)
if log_path: if log_path:
......
...@@ -3,6 +3,9 @@ import time ...@@ -3,6 +3,9 @@ import time
import numpy as np import numpy as np
import torch import torch
import torch.multiprocessing import torch.multiprocessing
from dgl import EID, NID
from dgl.dataloading import GraphDataLoader
from logger import LightLogging from logger import LightLogging
from model import DGCNN, GCN from model import DGCNN, GCN
from sampler import SEALData from sampler import SEALData
...@@ -10,9 +13,6 @@ from torch.nn import BCEWithLogitsLoss ...@@ -10,9 +13,6 @@ from torch.nn import BCEWithLogitsLoss
from tqdm import tqdm from tqdm import tqdm
from utils import evaluate_hits, load_ogb_dataset, parse_arguments from utils import evaluate_hits, load_ogb_dataset, parse_arguments
from dgl import EID, NID
from dgl.dataloading import GraphDataLoader
torch.multiprocessing.set_sharing_strategy("file_system") torch.multiprocessing.set_sharing_strategy("file_system")
""" """
......
import os.path as osp import os.path as osp
from copy import deepcopy from copy import deepcopy
import dgl
import torch import torch
from dgl import add_self_loop, DGLGraph, NID
from dgl.dataloading.negative_sampler import Uniform
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm from tqdm import tqdm
from utils import drnl_node_labeling from utils import drnl_node_labeling
import dgl
from dgl import NID, DGLGraph, add_self_loop
from dgl.dataloading.negative_sampler import Uniform
class GraphDataSet(Dataset): class GraphDataSet(Dataset):
""" """
...@@ -48,7 +48,6 @@ class PosNegEdgesGenerator(object): ...@@ -48,7 +48,6 @@ class PosNegEdgesGenerator(object):
self.shuffle = shuffle self.shuffle = shuffle
def __call__(self, split_type): def __call__(self, split_type):
if split_type == "train": if split_type == "train":
subsample_ratio = self.subsample_ratio subsample_ratio = self.subsample_ratio
else: else:
...@@ -177,7 +176,6 @@ class SEALSampler(object): ...@@ -177,7 +176,6 @@ class SEALSampler(object):
return subgraph return subgraph
def _collate(self, batch): def _collate(self, batch):
batch_graphs, batch_labels = map(list, zip(*batch)) batch_graphs, batch_labels = map(list, zip(*batch))
batch_graphs = dgl.batch(batch_graphs) batch_graphs = dgl.batch(batch_graphs)
...@@ -272,7 +270,6 @@ class SEALData(object): ...@@ -272,7 +270,6 @@ class SEALData(object):
) )
def __call__(self, split_type): def __call__(self, split_type):
if split_type == "train": if split_type == "train":
subsample_ratio = self.subsample_ratio subsample_ratio = self.subsample_ratio
else: else:
......
import argparse import argparse
import dgl
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import torch import torch
from ogb.linkproppred import DglLinkPropPredDataset, Evaluator from ogb.linkproppred import DglLinkPropPredDataset, Evaluator
from scipy.sparse.csgraph import shortest_path from scipy.sparse.csgraph import shortest_path
import dgl
def parse_arguments(): def parse_arguments():
""" """
......
...@@ -9,13 +9,13 @@ import argparse ...@@ -9,13 +9,13 @@ import argparse
import math import math
import time import time
import dgl
import dgl.function as fn
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl.data import ( from dgl.data import (
CiteseerGraphDataset, CiteseerGraphDataset,
CoraGraphDataset, CoraGraphDataset,
......
...@@ -9,12 +9,12 @@ import argparse ...@@ -9,12 +9,12 @@ import argparse
import math import math
import time import time
import dgl.function as fn
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl.function as fn
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import load_data, register_data_args from dgl.data import load_data, register_data_args
from dgl.nn.pytorch.conv import SGConv from dgl.nn.pytorch.conv import SGConv
......
import dgl
import numpy as np import numpy as np
import torch import torch
import dgl
def load_dataset(name): def load_dataset(name):
dataset = name.lower() dataset = name.lower()
......
...@@ -2,14 +2,14 @@ import argparse ...@@ -2,14 +2,14 @@ import argparse
import os import os
import time import time
import dgl
import dgl.function as fn
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dataset import load_dataset from dataset import load_dataset
import dgl
import dgl.function as fn
class FeedForwardNet(nn.Module): class FeedForwardNet(nn.Module):
def __init__(self, in_feats, hidden, out_feats, n_layers, dropout): def __init__(self, in_feats, hidden, out_feats, n_layers, dropout):
......
...@@ -6,10 +6,10 @@ import numpy as np ...@@ -6,10 +6,10 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from tagcn import TAGCN
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import load_data, register_data_args from dgl.data import load_data, register_data_args
from tagcn import TAGCN
def evaluate(model, features, labels, mask): def evaluate(model, features, labels, mask):
......
...@@ -2,32 +2,37 @@ from .attention import * ...@@ -2,32 +2,37 @@ from .attention import *
from .layers import * from .layers import *
from .functions import * from .functions import *
from .embedding import * from .embedding import *
import torch as th
import dgl.function as fn import dgl.function as fn
import torch as th
import torch.nn.init as INIT import torch.nn.init as INIT
class UEncoder(nn.Module): class UEncoder(nn.Module):
def __init__(self, layer): def __init__(self, layer):
super(UEncoder, self).__init__() super(UEncoder, self).__init__()
self.layer = layer self.layer = layer
self.norm = LayerNorm(layer.size) self.norm = LayerNorm(layer.size)
def pre_func(self, fields='qkv'): def pre_func(self, fields="qkv"):
layer = self.layer layer = self.layer
def func(nodes): def func(nodes):
x = nodes.data['x'] x = nodes.data["x"]
norm_x = layer.sublayer[0].norm(x) norm_x = layer.sublayer[0].norm(x)
return layer.self_attn.get(norm_x, fields=fields) return layer.self_attn.get(norm_x, fields=fields)
return func return func
def post_func(self): def post_func(self):
layer = self.layer layer = self.layer
def func(nodes): def func(nodes):
x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z'] x, wv, z = nodes.data["x"], nodes.data["wv"], nodes.data["z"]
o = layer.self_attn.get_o(wv / z) o = layer.self_attn.get_o(wv / z)
x = x + layer.sublayer[0].dropout(o) x = x + layer.sublayer[0].dropout(o)
x = layer.sublayer[1](x, layer.feed_forward) x = layer.sublayer[1](x, layer.feed_forward)
return {'x': x} return {"x": x}
return func return func
...@@ -37,31 +42,36 @@ class UDecoder(nn.Module): ...@@ -37,31 +42,36 @@ class UDecoder(nn.Module):
self.layer = layer self.layer = layer
self.norm = LayerNorm(layer.size) self.norm = LayerNorm(layer.size)
def pre_func(self, fields='qkv', l=0): def pre_func(self, fields="qkv", l=0):
layer = self.layer layer = self.layer
def func(nodes): def func(nodes):
x = nodes.data['x'] x = nodes.data["x"]
if fields == 'kv': if fields == "kv":
norm_x = x norm_x = x
else: else:
norm_x = layer.sublayer[l].norm(x) norm_x = layer.sublayer[l].norm(x)
return layer.self_attn.get(norm_x, fields) return layer.self_attn.get(norm_x, fields)
return func return func
def post_func(self, l=0): def post_func(self, l=0):
layer = self.layer layer = self.layer
def func(nodes): def func(nodes):
x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z'] x, wv, z = nodes.data["x"], nodes.data["wv"], nodes.data["z"]
o = layer.self_attn.get_o(wv / z) o = layer.self_attn.get_o(wv / z)
x = x + layer.sublayer[l].dropout(o) x = x + layer.sublayer[l].dropout(o)
if l == 1: if l == 1:
x = layer.sublayer[2](x, layer.feed_forward) x = layer.sublayer[2](x, layer.feed_forward)
return {'x': x} return {"x": x}
return func return func
class HaltingUnit(nn.Module): class HaltingUnit(nn.Module):
halting_bias_init = 1.0 halting_bias_init = 1.0
def __init__(self, dim_model): def __init__(self, dim_model):
super(HaltingUnit, self).__init__() super(HaltingUnit, self).__init__()
self.linear = nn.Linear(dim_model, 1) self.linear = nn.Linear(dim_model, 1)
...@@ -71,14 +81,27 @@ class HaltingUnit(nn.Module): ...@@ -71,14 +81,27 @@ class HaltingUnit(nn.Module):
def forward(self, x): def forward(self, x):
return th.sigmoid(self.linear(self.norm(x))) return th.sigmoid(self.linear(self.norm(x)))
class UTransformer(nn.Module): class UTransformer(nn.Module):
"Universal Transformer(https://arxiv.org/pdf/1807.03819.pdf) with ACT(https://arxiv.org/pdf/1603.08983.pdf)." "Universal Transformer(https://arxiv.org/pdf/1807.03819.pdf) with ACT(https://arxiv.org/pdf/1603.08983.pdf)."
MAX_DEPTH = 8 MAX_DEPTH = 8
thres = 0.99 thres = 0.99
act_loss_weight = 0.01 act_loss_weight = 0.01
def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, time_enc, generator, h, d_k):
def __init__(
self,
encoder,
decoder,
src_embed,
tgt_embed,
pos_enc,
time_enc,
generator,
h,
d_k,
):
super(UTransformer, self).__init__() super(UTransformer, self).__init__()
self.encoder, self.decoder = encoder, decoder self.encoder, self.decoder = encoder, decoder
self.src_embed, self.tgt_embed = src_embed, tgt_embed self.src_embed, self.tgt_embed = src_embed, tgt_embed
self.pos_enc, self.time_enc = pos_enc, time_enc self.pos_enc, self.time_enc = pos_enc, time_enc
self.halt_enc = HaltingUnit(h * d_k) self.halt_enc = HaltingUnit(h * d_k)
...@@ -91,34 +114,45 @@ class UTransformer(nn.Module): ...@@ -91,34 +114,45 @@ class UTransformer(nn.Module):
self.stat = [0] * (self.MAX_DEPTH + 1) self.stat = [0] * (self.MAX_DEPTH + 1)
def step_forward(self, nodes): def step_forward(self, nodes):
x = nodes.data['x'] x = nodes.data["x"]
step = nodes.data['step'] step = nodes.data["step"]
pos = nodes.data['pos'] pos = nodes.data["pos"]
return {'x': self.pos_enc.dropout(x + self.pos_enc(pos.view(-1)) + self.time_enc(step.view(-1))), return {
'step': step + 1} "x": self.pos_enc.dropout(
x + self.pos_enc(pos.view(-1)) + self.time_enc(step.view(-1))
),
"step": step + 1,
}
def halt_and_accum(self, name, end=False): def halt_and_accum(self, name, end=False):
"field: 'enc' or 'dec'" "field: 'enc' or 'dec'"
halt = self.halt_enc if name == 'enc' else self.halt_dec halt = self.halt_enc if name == "enc" else self.halt_dec
thres = self.thres thres = self.thres
def func(nodes): def func(nodes):
p = halt(nodes.data['x']) p = halt(nodes.data["x"])
sum_p = nodes.data['sum_p'] + p sum_p = nodes.data["sum_p"] + p
active = (sum_p < thres) & (1 - end) active = (sum_p < thres) & (1 - end)
_continue = active.float() _continue = active.float()
r = nodes.data['r'] * (1 - _continue) + (1 - sum_p) * _continue r = nodes.data["r"] * (1 - _continue) + (1 - sum_p) * _continue
s = nodes.data['s'] + ((1 - _continue) * r + _continue * p) * nodes.data['x'] s = (
return {'p': p, 'sum_p': sum_p, 'r': r, 's': s, 'active': active} nodes.data["s"]
+ ((1 - _continue) * r + _continue * p) * nodes.data["x"]
)
return {"p": p, "sum_p": sum_p, "r": r, "s": s, "active": active}
return func return func
def propagate_attention(self, g, eids): def propagate_attention(self, g, eids):
# Compute attention score # Compute attention score
g.apply_edges(src_dot_dst('k', 'q', 'score'), eids) g.apply_edges(src_dot_dst("k", "q", "score"), eids)
g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids) g.apply_edges(scaled_exp("score", np.sqrt(self.d_k)), eids)
# Send weighted values to target nodes # Send weighted values to target nodes
g.send_and_recv(eids, g.send_and_recv(
[fn.u_mul_e('v', 'score', 'v'), fn.copy_e('score', 'score')], eids,
[fn.sum('v', 'wv'), fn.sum('score', 'z')]) [fn.u_mul_e("v", "score", "v"), fn.copy_e("score", "score")],
[fn.sum("v", "wv"), fn.sum("score", "z")],
)
def update_graph(self, g, eids, pre_pairs, post_pairs): def update_graph(self, g, eids, pre_pairs, post_pairs):
"Update the node states and edge states of the graph." "Update the node states and edge states of the graph."
...@@ -136,79 +170,128 @@ class UTransformer(nn.Module): ...@@ -136,79 +170,128 @@ class UTransformer(nn.Module):
nids, eids = graph.nids, graph.eids nids, eids = graph.nids, graph.eids
# embed & pos # embed & pos
g.nodes[nids['enc']].data['x'] = self.src_embed(graph.src[0]) g.nodes[nids["enc"]].data["x"] = self.src_embed(graph.src[0])
g.nodes[nids['dec']].data['x'] = self.tgt_embed(graph.tgt[0]) g.nodes[nids["dec"]].data["x"] = self.tgt_embed(graph.tgt[0])
g.nodes[nids['enc']].data['pos'] = graph.src[1] g.nodes[nids["enc"]].data["pos"] = graph.src[1]
g.nodes[nids['dec']].data['pos'] = graph.tgt[1] g.nodes[nids["dec"]].data["pos"] = graph.tgt[1]
# init step # init step
device = next(self.parameters()).device device = next(self.parameters()).device
g.ndata['s'] = th.zeros(N, self.h * self.d_k, dtype=th.float, device=device) # accumulated state g.ndata["s"] = th.zeros(
g.ndata['p'] = th.zeros(N, 1, dtype=th.float, device=device) # halting prob N, self.h * self.d_k, dtype=th.float, device=device
g.ndata['r'] = th.ones(N, 1, dtype=th.float, device=device) # remainder ) # accumulated state
g.ndata['sum_p'] = th.zeros(N, 1, dtype=th.float, device=device) # sum of pondering values g.ndata["p"] = th.zeros(
g.ndata['step'] = th.zeros(N, 1, dtype=th.long, device=device) # step N, 1, dtype=th.float, device=device
g.ndata['active'] = th.ones(N, 1, dtype=th.uint8, device=device) # active ) # halting prob
g.ndata["r"] = th.ones(N, 1, dtype=th.float, device=device) # remainder
g.ndata["sum_p"] = th.zeros(
N, 1, dtype=th.float, device=device
) # sum of pondering values
g.ndata["step"] = th.zeros(N, 1, dtype=th.long, device=device) # step
g.ndata["active"] = th.ones(
N, 1, dtype=th.uint8, device=device
) # active
for step in range(self.MAX_DEPTH): for step in range(self.MAX_DEPTH):
pre_func = self.encoder.pre_func('qkv') pre_func = self.encoder.pre_func("qkv")
post_func = self.encoder.post_func() post_func = self.encoder.post_func()
nodes = g.filter_nodes(lambda v: v.data['active'].view(-1), nids['enc']) nodes = g.filter_nodes(
if len(nodes) == 0: break lambda v: v.data["active"].view(-1), nids["enc"]
edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['ee']) )
if len(nodes) == 0:
break
edges = g.filter_edges(
lambda e: e.dst["active"].view(-1), eids["ee"]
)
end = step == self.MAX_DEPTH - 1 end = step == self.MAX_DEPTH - 1
self.update_graph(g, edges, self.update_graph(
[(self.step_forward, nodes), (pre_func, nodes)], g,
[(post_func, nodes), (self.halt_and_accum('enc', end), nodes)]) edges,
[(self.step_forward, nodes), (pre_func, nodes)],
[(post_func, nodes), (self.halt_and_accum("enc", end), nodes)],
)
g.nodes[nids['enc']].data['x'] = self.encoder.norm(g.nodes[nids['enc']].data['s']) g.nodes[nids["enc"]].data["x"] = self.encoder.norm(
g.nodes[nids["enc"]].data["s"]
)
for step in range(self.MAX_DEPTH): for step in range(self.MAX_DEPTH):
pre_func = self.decoder.pre_func('qkv') pre_func = self.decoder.pre_func("qkv")
post_func = self.decoder.post_func() post_func = self.decoder.post_func()
nodes = g.filter_nodes(lambda v: v.data['active'].view(-1), nids['dec']) nodes = g.filter_nodes(
if len(nodes) == 0: break lambda v: v.data["active"].view(-1), nids["dec"]
edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['dd']) )
self.update_graph(g, edges, if len(nodes) == 0:
[(self.step_forward, nodes), (pre_func, nodes)], break
[(post_func, nodes)]) edges = g.filter_edges(
lambda e: e.dst["active"].view(-1), eids["dd"]
pre_q = self.decoder.pre_func('q', 1) )
pre_kv = self.decoder.pre_func('kv', 1) self.update_graph(
g,
edges,
[(self.step_forward, nodes), (pre_func, nodes)],
[(post_func, nodes)],
)
pre_q = self.decoder.pre_func("q", 1)
pre_kv = self.decoder.pre_func("kv", 1)
post_func = self.decoder.post_func(1) post_func = self.decoder.post_func(1)
nodes_e = nids['enc'] nodes_e = nids["enc"]
edges = g.filter_edges(lambda e: e.dst['active'].view(-1), eids['ed']) edges = g.filter_edges(
lambda e: e.dst["active"].view(-1), eids["ed"]
)
end = step == self.MAX_DEPTH - 1 end = step == self.MAX_DEPTH - 1
self.update_graph(g, edges, self.update_graph(
[(pre_q, nodes), (pre_kv, nodes_e)], g,
[(post_func, nodes), (self.halt_and_accum('dec', end), nodes)]) edges,
[(pre_q, nodes), (pre_kv, nodes_e)],
[(post_func, nodes), (self.halt_and_accum("dec", end), nodes)],
)
g.nodes[nids['dec']].data['x'] = self.decoder.norm(g.nodes[nids['dec']].data['s']) g.nodes[nids["dec"]].data["x"] = self.decoder.norm(
act_loss = th.mean(g.ndata['r']) # ACT loss g.nodes[nids["dec"]].data["s"]
)
act_loss = th.mean(g.ndata["r"]) # ACT loss
self.stat[0] += N self.stat[0] += N
for step in range(1, self.MAX_DEPTH + 1): for step in range(1, self.MAX_DEPTH + 1):
self.stat[step] += th.sum(g.ndata['step'] >= step).item() self.stat[step] += th.sum(g.ndata["step"] >= step).item()
return self.generator(g.ndata['x'][nids['dec']]), act_loss * self.act_loss_weight return (
self.generator(g.ndata["x"][nids["dec"]]),
act_loss * self.act_loss_weight,
)
def infer(self, *args, **kwargs): def infer(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
def make_universal_model(src_vocab, tgt_vocab, dim_model=512, dim_ff=2048, h=8, dropout=0.1): def make_universal_model(
src_vocab, tgt_vocab, dim_model=512, dim_ff=2048, h=8, dropout=0.1
):
c = copy.deepcopy c = copy.deepcopy
attn = MultiHeadAttention(h, dim_model) attn = MultiHeadAttention(h, dim_model)
ff = PositionwiseFeedForward(dim_model, dim_ff) ff = PositionwiseFeedForward(dim_model, dim_ff)
pos_enc = PositionalEncoding(dim_model, dropout) pos_enc = PositionalEncoding(dim_model, dropout)
time_enc = PositionalEncoding(dim_model, dropout) time_enc = PositionalEncoding(dim_model, dropout)
encoder = UEncoder(EncoderLayer((dim_model), c(attn), c(ff), dropout)) encoder = UEncoder(EncoderLayer((dim_model), c(attn), c(ff), dropout))
decoder = UDecoder(DecoderLayer((dim_model), c(attn), c(attn), c(ff), dropout)) decoder = UDecoder(
DecoderLayer((dim_model), c(attn), c(attn), c(ff), dropout)
)
src_embed = Embeddings(src_vocab, dim_model) src_embed = Embeddings(src_vocab, dim_model)
tgt_embed = Embeddings(tgt_vocab, dim_model) tgt_embed = Embeddings(tgt_vocab, dim_model)
generator = Generator(dim_model, tgt_vocab) generator = Generator(dim_model, tgt_vocab)
model = UTransformer( model = UTransformer(
encoder, decoder, src_embed, tgt_embed, pos_enc, time_enc, generator, h, dim_model // h) encoder,
decoder,
src_embed,
tgt_embed,
pos_enc,
time_enc,
generator,
h,
dim_model // h,
)
# xavier init # xavier init
for p in model.parameters(): for p in model.parameters():
if p.dim() > 1: if p.dim() > 1:
......
...@@ -6,10 +6,12 @@ from .layers import * ...@@ -6,10 +6,12 @@ from .layers import *
from .functions import * from .functions import *
from .embedding import * from .embedding import *
import threading import threading
import torch as th
import dgl.function as fn import dgl.function as fn
import torch as th
import torch.nn.init as INIT import torch.nn.init as INIT
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, layer, N): def __init__(self, layer, N):
super(Encoder, self).__init__() super(Encoder, self).__init__()
...@@ -17,24 +19,29 @@ class Encoder(nn.Module): ...@@ -17,24 +19,29 @@ class Encoder(nn.Module):
self.layers = clones(layer, N) self.layers = clones(layer, N)
self.norm = LayerNorm(layer.size) self.norm = LayerNorm(layer.size)
def pre_func(self, i, fields='qkv'): def pre_func(self, i, fields="qkv"):
layer = self.layers[i] layer = self.layers[i]
def func(nodes): def func(nodes):
x = nodes.data['x'] x = nodes.data["x"]
norm_x = layer.sublayer[0].norm(x) norm_x = layer.sublayer[0].norm(x)
return layer.self_attn.get(norm_x, fields=fields) return layer.self_attn.get(norm_x, fields=fields)
return func return func
def post_func(self, i): def post_func(self, i):
layer = self.layers[i] layer = self.layers[i]
def func(nodes): def func(nodes):
x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z'] x, wv, z = nodes.data["x"], nodes.data["wv"], nodes.data["z"]
o = layer.self_attn.get_o(wv / z) o = layer.self_attn.get_o(wv / z)
x = x + layer.sublayer[0].dropout(o) x = x + layer.sublayer[0].dropout(o)
x = layer.sublayer[1](x, layer.feed_forward) x = layer.sublayer[1](x, layer.feed_forward)
return {'x': x if i < self.N - 1 else self.norm(x)} return {"x": x if i < self.N - 1 else self.norm(x)}
return func return func
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, layer, N): def __init__(self, layer, N):
super(Decoder, self).__init__() super(Decoder, self).__init__()
...@@ -42,32 +49,39 @@ class Decoder(nn.Module): ...@@ -42,32 +49,39 @@ class Decoder(nn.Module):
self.layers = clones(layer, N) self.layers = clones(layer, N)
self.norm = LayerNorm(layer.size) self.norm = LayerNorm(layer.size)
def pre_func(self, i, fields='qkv', l=0): def pre_func(self, i, fields="qkv", l=0):
layer = self.layers[i] layer = self.layers[i]
def func(nodes): def func(nodes):
x = nodes.data['x'] x = nodes.data["x"]
norm_x = layer.sublayer[l].norm(x) if fields.startswith('q') else x norm_x = layer.sublayer[l].norm(x) if fields.startswith("q") else x
if fields != 'qkv': if fields != "qkv":
return layer.src_attn.get(norm_x, fields) return layer.src_attn.get(norm_x, fields)
else: else:
return layer.self_attn.get(norm_x, fields) return layer.self_attn.get(norm_x, fields)
return func return func
def post_func(self, i, l=0): def post_func(self, i, l=0):
layer = self.layers[i] layer = self.layers[i]
def func(nodes): def func(nodes):
x, wv, z = nodes.data['x'], nodes.data['wv'], nodes.data['z'] x, wv, z = nodes.data["x"], nodes.data["wv"], nodes.data["z"]
o = layer.self_attn.get_o(wv / z) o = layer.self_attn.get_o(wv / z)
x = x + layer.sublayer[l].dropout(o) x = x + layer.sublayer[l].dropout(o)
if l == 1: if l == 1:
x = layer.sublayer[2](x, layer.feed_forward) x = layer.sublayer[2](x, layer.feed_forward)
return {'x': x if i < self.N - 1 else self.norm(x)} return {"x": x if i < self.N - 1 else self.norm(x)}
return func return func
class Transformer(nn.Module): class Transformer(nn.Module):
def __init__(self, encoder, decoder, src_embed, tgt_embed, pos_enc, generator, h, d_k): def __init__(
self, encoder, decoder, src_embed, tgt_embed, pos_enc, generator, h, d_k
):
super(Transformer, self).__init__() super(Transformer, self).__init__()
self.encoder, self.decoder = encoder, decoder self.encoder, self.decoder = encoder, decoder
self.src_embed, self.tgt_embed = src_embed, tgt_embed self.src_embed, self.tgt_embed = src_embed, tgt_embed
self.pos_enc = pos_enc self.pos_enc = pos_enc
self.generator = generator self.generator = generator
...@@ -76,11 +90,11 @@ class Transformer(nn.Module): ...@@ -76,11 +90,11 @@ class Transformer(nn.Module):
def propagate_attention(self, g, eids): def propagate_attention(self, g, eids):
# Compute attention score # Compute attention score
g.apply_edges(src_dot_dst('k', 'q', 'score'), eids) g.apply_edges(src_dot_dst("k", "q", "score"), eids)
g.apply_edges(scaled_exp('score', np.sqrt(self.d_k)), eids) g.apply_edges(scaled_exp("score", np.sqrt(self.d_k)), eids)
# Send weighted values to target nodes # Send weighted values to target nodes
g.send_and_recv(eids, fn.u_mul_e('v', 'score', 'v'), fn.sum('v', 'wv')) g.send_and_recv(eids, fn.u_mul_e("v", "score", "v"), fn.sum("v", "wv"))
g.send_and_recv(eids, fn.copy_e('score', 'score'), fn.sum('score', 'z')) g.send_and_recv(eids, fn.copy_e("score", "score"), fn.sum("score", "z"))
def update_graph(self, g, eids, pre_pairs, post_pairs): def update_graph(self, g, eids, pre_pairs, post_pairs):
"Update the node states and edge states of the graph." "Update the node states and edge states of the graph."
...@@ -98,27 +112,44 @@ class Transformer(nn.Module): ...@@ -98,27 +112,44 @@ class Transformer(nn.Module):
nids, eids = graph.nids, graph.eids nids, eids = graph.nids, graph.eids
# embed # embed
src_embed, src_pos = self.src_embed(graph.src[0]), self.pos_enc(graph.src[1]) src_embed, src_pos = self.src_embed(graph.src[0]), self.pos_enc(
tgt_embed, tgt_pos = self.tgt_embed(graph.tgt[0]), self.pos_enc(graph.tgt[1]) graph.src[1]
g.nodes[nids['enc']].data['x'] = self.pos_enc.dropout(src_embed + src_pos) )
g.nodes[nids['dec']].data['x'] = self.pos_enc.dropout(tgt_embed + tgt_pos) tgt_embed, tgt_pos = self.tgt_embed(graph.tgt[0]), self.pos_enc(
graph.tgt[1]
)
g.nodes[nids["enc"]].data["x"] = self.pos_enc.dropout(
src_embed + src_pos
)
g.nodes[nids["dec"]].data["x"] = self.pos_enc.dropout(
tgt_embed + tgt_pos
)
for i in range(self.encoder.N): for i in range(self.encoder.N):
pre_func = self.encoder.pre_func(i, 'qkv') pre_func = self.encoder.pre_func(i, "qkv")
post_func = self.encoder.post_func(i) post_func = self.encoder.post_func(i)
nodes, edges = nids['enc'], eids['ee'] nodes, edges = nids["enc"], eids["ee"]
self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)]) self.update_graph(
g, edges, [(pre_func, nodes)], [(post_func, nodes)]
)
for i in range(self.decoder.N): for i in range(self.decoder.N):
pre_func = self.decoder.pre_func(i, 'qkv') pre_func = self.decoder.pre_func(i, "qkv")
post_func = self.decoder.post_func(i) post_func = self.decoder.post_func(i)
nodes, edges = nids['dec'], eids['dd'] nodes, edges = nids["dec"], eids["dd"]
self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)]) self.update_graph(
pre_q = self.decoder.pre_func(i, 'q', 1) g, edges, [(pre_func, nodes)], [(post_func, nodes)]
pre_kv = self.decoder.pre_func(i, 'kv', 1) )
pre_q = self.decoder.pre_func(i, "q", 1)
pre_kv = self.decoder.pre_func(i, "kv", 1)
post_func = self.decoder.post_func(i, 1) post_func = self.decoder.post_func(i, 1)
nodes_e, edges = nids['enc'], eids['ed'] nodes_e, edges = nids["enc"], eids["ed"]
self.update_graph(g, edges, [(pre_q, nodes), (pre_kv, nodes_e)], [(post_func, nodes)]) self.update_graph(
g,
edges,
[(pre_q, nodes), (pre_kv, nodes_e)],
[(post_func, nodes)],
)
# visualize attention # visualize attention
""" """
...@@ -126,9 +157,10 @@ class Transformer(nn.Module): ...@@ -126,9 +157,10 @@ class Transformer(nn.Module):
self._register_att_map(g, graph.nid_arr['enc'][VIZ_IDX], graph.nid_arr['dec'][VIZ_IDX]) self._register_att_map(g, graph.nid_arr['enc'][VIZ_IDX], graph.nid_arr['dec'][VIZ_IDX])
""" """
return self.generator(g.ndata['x'][nids['dec']]) return self.generator(g.ndata["x"][nids["dec"]])
def infer(self, graph, max_len, eos_id, k, alpha=1.0): def infer(self, graph, max_len, eos_id, k, alpha=1.0):
''' """
This function implements Beam Search in DGL, which is required in inference phase. This function implements Beam Search in DGL, which is required in inference phase.
Length normalization is given by (5 + len) ^ alpha / 6 ^ alpha. Please refer to https://arxiv.org/pdf/1609.08144.pdf. Length normalization is given by (5 + len) ^ alpha / 6 ^ alpha. Please refer to https://arxiv.org/pdf/1609.08144.pdf.
args: args:
...@@ -138,7 +170,7 @@ class Transformer(nn.Module): ...@@ -138,7 +170,7 @@ class Transformer(nn.Module):
k: beam size k: beam size
return: return:
ret: a list of index array correspond to the input sequence specified by `graph``. ret: a list of index array correspond to the input sequence specified by `graph``.
''' """
g = graph.g g = graph.g
N, E = graph.n_nodes, graph.n_edges N, E = graph.n_nodes, graph.n_edges
nids, eids = graph.nids, graph.eids nids, eids = graph.nids, graph.eids
...@@ -146,21 +178,25 @@ class Transformer(nn.Module): ...@@ -146,21 +178,25 @@ class Transformer(nn.Module):
# embed & pos # embed & pos
src_embed = self.src_embed(graph.src[0]) src_embed = self.src_embed(graph.src[0])
src_pos = self.pos_enc(graph.src[1]) src_pos = self.pos_enc(graph.src[1])
g.nodes[nids['enc']].data['pos'] = graph.src[1] g.nodes[nids["enc"]].data["pos"] = graph.src[1]
g.nodes[nids['enc']].data['x'] = self.pos_enc.dropout(src_embed + src_pos) g.nodes[nids["enc"]].data["x"] = self.pos_enc.dropout(
src_embed + src_pos
)
tgt_pos = self.pos_enc(graph.tgt[1]) tgt_pos = self.pos_enc(graph.tgt[1])
g.nodes[nids['dec']].data['pos'] = graph.tgt[1] g.nodes[nids["dec"]].data["pos"] = graph.tgt[1]
# init mask # init mask
device = next(self.parameters()).device device = next(self.parameters()).device
g.ndata['mask'] = th.zeros(N, dtype=th.uint8, device=device) g.ndata["mask"] = th.zeros(N, dtype=th.uint8, device=device)
# encode # encode
for i in range(self.encoder.N): for i in range(self.encoder.N):
pre_func = self.encoder.pre_func(i, 'qkv') pre_func = self.encoder.pre_func(i, "qkv")
post_func = self.encoder.post_func(i) post_func = self.encoder.post_func(i)
nodes, edges = nids['enc'], eids['ee'] nodes, edges = nids["enc"], eids["ee"]
self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)]) self.update_graph(
g, edges, [(pre_func, nodes)], [(post_func, nodes)]
)
# decode # decode
log_prob = None log_prob = None
...@@ -168,36 +204,76 @@ class Transformer(nn.Module): ...@@ -168,36 +204,76 @@ class Transformer(nn.Module):
for step in range(1, max_len): for step in range(1, max_len):
y = y.view(-1) y = y.view(-1)
tgt_embed = self.tgt_embed(y) tgt_embed = self.tgt_embed(y)
g.ndata['x'][nids['dec']] = self.pos_enc.dropout(tgt_embed + tgt_pos) g.ndata["x"][nids["dec"]] = self.pos_enc.dropout(
edges_ed = g.filter_edges(lambda e: (e.dst['pos'] < step) & ~e.dst['mask'].bool(), eids['ed']) tgt_embed + tgt_pos
edges_dd = g.filter_edges(lambda e: (e.dst['pos'] < step) & ~e.dst['mask'].bool(), eids['dd']) )
nodes_d = g.filter_nodes(lambda v: (v.data['pos'] < step) & ~v.data['mask'].bool(), nids['dec']) edges_ed = g.filter_edges(
lambda e: (e.dst["pos"] < step) & ~e.dst["mask"].bool(),
eids["ed"],
)
edges_dd = g.filter_edges(
lambda e: (e.dst["pos"] < step) & ~e.dst["mask"].bool(),
eids["dd"],
)
nodes_d = g.filter_nodes(
lambda v: (v.data["pos"] < step) & ~v.data["mask"].bool(),
nids["dec"],
)
for i in range(self.decoder.N): for i in range(self.decoder.N):
pre_func, post_func = self.decoder.pre_func(i, 'qkv'), self.decoder.post_func(i) pre_func, post_func = self.decoder.pre_func(
i, "qkv"
), self.decoder.post_func(i)
nodes, edges = nodes_d, edges_dd nodes, edges = nodes_d, edges_dd
self.update_graph(g, edges, [(pre_func, nodes)], [(post_func, nodes)]) self.update_graph(
pre_q, pre_kv = self.decoder.pre_func(i, 'q', 1), self.decoder.pre_func(i, 'kv', 1) g, edges, [(pre_func, nodes)], [(post_func, nodes)]
)
pre_q, pre_kv = self.decoder.pre_func(
i, "q", 1
), self.decoder.pre_func(i, "kv", 1)
post_func = self.decoder.post_func(i, 1) post_func = self.decoder.post_func(i, 1)
nodes_e, nodes_d, edges = nids['enc'], nodes_d, edges_ed nodes_e, nodes_d, edges = nids["enc"], nodes_d, edges_ed
self.update_graph(g, edges, [(pre_q, nodes_d), (pre_kv, nodes_e)], [(post_func, nodes_d)]) self.update_graph(
g,
edges,
[(pre_q, nodes_d), (pre_kv, nodes_e)],
[(post_func, nodes_d)],
)
frontiers = g.filter_nodes(lambda v: v.data['pos'] == step - 1, nids['dec']) frontiers = g.filter_nodes(
out = self.generator(g.ndata['x'][frontiers]) lambda v: v.data["pos"] == step - 1, nids["dec"]
)
out = self.generator(g.ndata["x"][frontiers])
batch_size = frontiers.shape[0] // k batch_size = frontiers.shape[0] // k
vocab_size = out.shape[-1] vocab_size = out.shape[-1]
# Mask output for complete sequence # Mask output for complete sequence
one_hot = th.zeros(vocab_size).fill_(-1e9).to(device) one_hot = th.zeros(vocab_size).fill_(-1e9).to(device)
one_hot[eos_id] = 0 one_hot[eos_id] = 0
mask = g.ndata['mask'][frontiers].unsqueeze(-1).float() mask = g.ndata["mask"][frontiers].unsqueeze(-1).float()
out = out * (1 - mask) + one_hot.unsqueeze(0) * mask out = out * (1 - mask) + one_hot.unsqueeze(0) * mask
if log_prob is None: if log_prob is None:
log_prob, pos = out.view(batch_size, k, -1)[:, 0, :].topk(k, dim=-1) log_prob, pos = out.view(batch_size, k, -1)[:, 0, :].topk(
k, dim=-1
)
eos = th.zeros(batch_size, k).byte() eos = th.zeros(batch_size, k).byte()
else: else:
norm_old = eos.float().to(device) + (1 - eos.float().to(device)) * np.power((4. + step) / 6, alpha) norm_old = eos.float().to(device) + (
norm_new = eos.float().to(device) + (1 - eos.float().to(device)) * np.power((5. + step) / 6, alpha) 1 - eos.float().to(device)
log_prob, pos = ((out.view(batch_size, k, -1) + (log_prob * norm_old).unsqueeze(-1)) / norm_new.unsqueeze(-1)).view(batch_size, -1).topk(k, dim=-1) ) * np.power((4.0 + step) / 6, alpha)
norm_new = eos.float().to(device) + (
1 - eos.float().to(device)
) * np.power((5.0 + step) / 6, alpha)
log_prob, pos = (
(
(
out.view(batch_size, k, -1)
+ (log_prob * norm_old).unsqueeze(-1)
)
/ norm_new.unsqueeze(-1)
)
.view(batch_size, -1)
.topk(k, dim=-1)
)
_y = y.view(batch_size * k, -1) _y = y.view(batch_size * k, -1)
y = th.zeros_like(_y) y = th.zeros_like(_y)
...@@ -206,14 +282,16 @@ class Transformer(nn.Module): ...@@ -206,14 +282,16 @@ class Transformer(nn.Module):
for j in range(k): for j in range(k):
_j = pos[i, j].item() // vocab_size _j = pos[i, j].item() // vocab_size
token = pos[i, j].item() % vocab_size token = pos[i, j].item() % vocab_size
y[i*k+j, :] = _y[i*k+_j, :] y[i * k + j, :] = _y[i * k + _j, :]
y[i*k+j, step] = token y[i * k + j, step] = token
eos[i, j] = _eos[i, _j] | (token == eos_id) eos[i, j] = _eos[i, _j] | (token == eos_id)
if eos.all(): if eos.all():
break break
else: else:
g.ndata['mask'][nids['dec']] = eos.unsqueeze(-1).repeat(1, 1, max_len).view(-1).to(device) g.ndata["mask"][nids["dec"]] = (
eos.unsqueeze(-1).repeat(1, 1, max_len).view(-1).to(device)
)
return y.view(batch_size, k, -1)[:, 0, :].tolist() return y.view(batch_size, k, -1)[:, 0, :].tolist()
def _register_att_map(self, g, enc_ids, dec_ids): def _register_att_map(self, g, enc_ids, dec_ids):
...@@ -224,22 +302,42 @@ class Transformer(nn.Module): ...@@ -224,22 +302,42 @@ class Transformer(nn.Module):
] ]
def make_model(src_vocab, tgt_vocab, N=6, def make_model(
dim_model=512, dim_ff=2048, h=8, dropout=0.1, universal=False): src_vocab,
tgt_vocab,
N=6,
dim_model=512,
dim_ff=2048,
h=8,
dropout=0.1,
universal=False,
):
if universal: if universal:
return make_universal_model(src_vocab, tgt_vocab, dim_model, dim_ff, h, dropout) return make_universal_model(
src_vocab, tgt_vocab, dim_model, dim_ff, h, dropout
)
c = copy.deepcopy c = copy.deepcopy
attn = MultiHeadAttention(h, dim_model) attn = MultiHeadAttention(h, dim_model)
ff = PositionwiseFeedForward(dim_model, dim_ff) ff = PositionwiseFeedForward(dim_model, dim_ff)
pos_enc = PositionalEncoding(dim_model, dropout) pos_enc = PositionalEncoding(dim_model, dropout)
encoder = Encoder(EncoderLayer(dim_model, c(attn), c(ff), dropout), N) encoder = Encoder(EncoderLayer(dim_model, c(attn), c(ff), dropout), N)
decoder = Decoder(DecoderLayer(dim_model, c(attn), c(attn), c(ff), dropout), N) decoder = Decoder(
DecoderLayer(dim_model, c(attn), c(attn), c(ff), dropout), N
)
src_embed = Embeddings(src_vocab, dim_model) src_embed = Embeddings(src_vocab, dim_model)
tgt_embed = Embeddings(tgt_vocab, dim_model) tgt_embed = Embeddings(tgt_vocab, dim_model)
generator = Generator(dim_model, tgt_vocab) generator = Generator(dim_model, tgt_vocab)
model = Transformer( model = Transformer(
encoder, decoder, src_embed, tgt_embed, pos_enc, generator, h, dim_model // h) encoder,
decoder,
src_embed,
tgt_embed,
pos_enc,
generator,
h,
dim_model // h,
)
# xavier init # xavier init
for p in model.parameters(): for p in model.parameters():
if p.dim() > 1: if p.dim() > 1:
......
import os import os
import numpy as np
import torch as th
import networkx as nx
import matplotlib as mpl import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.animation as animation import matplotlib.animation as animation
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import torch as th
from networkx.algorithms import bipartite from networkx.algorithms import bipartite
def get_attention_map(g, src_nodes, dst_nodes, h): def get_attention_map(g, src_nodes, dst_nodes, h):
""" """
To visualize the attention score between two set of nodes. To visualize the attention score between two set of nodes.
...@@ -18,14 +20,15 @@ def get_attention_map(g, src_nodes, dst_nodes, h): ...@@ -18,14 +20,15 @@ def get_attention_map(g, src_nodes, dst_nodes, h):
if not g.has_edge_between(src, dst): if not g.has_edge_between(src, dst):
continue continue
eid = g.edge_ids(src, dst) eid = g.edge_ids(src, dst)
weight[i][j] = g.edata['score'][eid].squeeze(-1).cpu().detach() weight[i][j] = g.edata["score"][eid].squeeze(-1).cpu().detach()
weight = weight.transpose(0, 2) weight = weight.transpose(0, 2)
att = th.softmax(weight, -2) att = th.softmax(weight, -2)
return att.numpy() return att.numpy()
def draw_heatmap(array, input_seq, output_seq, dirname, name): def draw_heatmap(array, input_seq, output_seq, dirname, name):
dirname = os.path.join('log', dirname) dirname = os.path.join("log", dirname)
if not os.path.exists(dirname): if not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
...@@ -38,30 +41,37 @@ def draw_heatmap(array, input_seq, output_seq, dirname, name): ...@@ -38,30 +41,37 @@ def draw_heatmap(array, input_seq, output_seq, dirname, name):
axes[i, j].set_xticks(np.arange(len(output_seq))) axes[i, j].set_xticks(np.arange(len(output_seq)))
axes[i, j].set_yticklabels(input_seq, fontsize=4) axes[i, j].set_yticklabels(input_seq, fontsize=4)
axes[i, j].set_xticklabels(output_seq, fontsize=4) axes[i, j].set_xticklabels(output_seq, fontsize=4)
axes[i, j].set_title('head_{}'.format(cnt), fontsize=10) axes[i, j].set_title("head_{}".format(cnt), fontsize=10)
plt.setp(axes[i, j].get_xticklabels(), rotation=45, ha="right", plt.setp(
rotation_mode="anchor") axes[i, j].get_xticklabels(),
rotation=45,
ha="right",
rotation_mode="anchor",
)
cnt += 1 cnt += 1
fig.suptitle(name, fontsize=12) fig.suptitle(name, fontsize=12)
plt.tight_layout() plt.tight_layout()
plt.savefig(os.path.join(dirname, '{}.pdf'.format(name))) plt.savefig(os.path.join(dirname, "{}.pdf".format(name)))
plt.close() plt.close()
def draw_atts(maps, src, tgt, dirname, prefix): def draw_atts(maps, src, tgt, dirname, prefix):
''' """
maps[0]: encoder self-attention maps[0]: encoder self-attention
maps[1]: encoder-decoder attention maps[1]: encoder-decoder attention
maps[2]: decoder self-attention maps[2]: decoder self-attention
''' """
draw_heatmap(maps[0], src, src, dirname, '{}_enc_self_attn'.format(prefix)) draw_heatmap(maps[0], src, src, dirname, "{}_enc_self_attn".format(prefix))
draw_heatmap(maps[1], src, tgt, dirname, '{}_enc_dec_attn'.format(prefix)) draw_heatmap(maps[1], src, tgt, dirname, "{}_enc_dec_attn".format(prefix))
draw_heatmap(maps[2], tgt, tgt, dirname, '{}_dec_self_attn'.format(prefix)) draw_heatmap(maps[2], tgt, tgt, dirname, "{}_dec_self_attn".format(prefix))
mode2id = {'e2e': 0, 'e2d': 1, 'd2d': 2} mode2id = {"e2e": 0, "e2d": 1, "d2d": 2}
colorbar = None colorbar = None
def att_animation(maps_array, mode, src, tgt, head_id): def att_animation(maps_array, mode, src, tgt, head_id):
weights = [maps[mode2id[mode]][head_id] for maps in maps_array] weights = [maps[mode2id[mode]][head_id] for maps in maps_array]
fig, axes = plt.subplots(1, 2) fig, axes = plt.subplots(1, 2)
...@@ -71,75 +81,125 @@ def att_animation(maps_array, mode, src, tgt, head_id): ...@@ -71,75 +81,125 @@ def att_animation(maps_array, mode, src, tgt, head_id):
if colorbar: if colorbar:
colorbar.remove() colorbar.remove()
plt.cla() plt.cla()
axes[0].set_title('heatmap') axes[0].set_title("heatmap")
axes[0].set_yticks(np.arange(len(src))) axes[0].set_yticks(np.arange(len(src)))
axes[0].set_xticks(np.arange(len(tgt))) axes[0].set_xticks(np.arange(len(tgt)))
axes[0].set_yticklabels(src) axes[0].set_yticklabels(src)
axes[0].set_xticklabels(tgt) axes[0].set_xticklabels(tgt)
plt.setp(axes[0].get_xticklabels(), rotation=45, ha="right", plt.setp(
rotation_mode="anchor") axes[0].get_xticklabels(),
rotation=45,
fig.suptitle('epoch {}'.format(i)) ha="right",
rotation_mode="anchor",
)
fig.suptitle("epoch {}".format(i))
weight = weights[i].transpose(-1, -2) weight = weights[i].transpose(-1, -2)
heatmap = axes[0].pcolor(weight, vmin=0, vmax=1, cmap=plt.cm.Blues) heatmap = axes[0].pcolor(weight, vmin=0, vmax=1, cmap=plt.cm.Blues)
colorbar = plt.colorbar(heatmap, ax=axes[0], fraction=0.046, pad=0.04) colorbar = plt.colorbar(heatmap, ax=axes[0], fraction=0.046, pad=0.04)
axes[0].set_aspect('equal') axes[0].set_aspect("equal")
axes[1].axis("off") axes[1].axis("off")
graph_att_head(src, tgt, weight, axes[1], 'graph') graph_att_head(src, tgt, weight, axes[1], "graph")
ani = animation.FuncAnimation(
ani = animation.FuncAnimation(fig, weight_animate, frames=len(weights), interval=500, repeat_delay=2000) fig,
weight_animate,
frames=len(weights),
interval=500,
repeat_delay=2000,
)
return ani return ani
def graph_att_head(M, N, weight, ax, title): def graph_att_head(M, N, weight, ax, title):
"credit: Jinjing Zhou" "credit: Jinjing Zhou"
in_nodes=len(M) in_nodes = len(M)
out_nodes=len(N) out_nodes = len(N)
g = nx.bipartite.generators.complete_bipartite_graph(in_nodes,out_nodes) g = nx.bipartite.generators.complete_bipartite_graph(in_nodes, out_nodes)
X, Y = bipartite.sets(g) X, Y = bipartite.sets(g)
height_in = 10 height_in = 10
height_out = height_in height_out = height_in
height_in_y = np.linspace(0, height_in, in_nodes) height_in_y = np.linspace(0, height_in, in_nodes)
height_out_y = np.linspace((height_in - height_out) / 2, height_out, out_nodes) height_out_y = np.linspace(
(height_in - height_out) / 2, height_out, out_nodes
)
pos = dict() pos = dict()
pos.update((n, (1, i)) for i, n in zip(height_in_y, X)) # put nodes from X at x=1 pos.update(
pos.update((n, (3, i)) for i, n in zip(height_out_y, Y)) # put nodes from Y at x=2 (n, (1, i)) for i, n in zip(height_in_y, X)
ax.axis('off') ) # put nodes from X at x=1
ax.set_xlim(-1,4) pos.update(
(n, (3, i)) for i, n in zip(height_out_y, Y)
) # put nodes from Y at x=2
ax.axis("off")
ax.set_xlim(-1, 4)
ax.set_title(title) ax.set_title(title)
nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes), node_color='r', node_size=50, ax=ax) nx.draw_networkx_nodes(
nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes, in_nodes + out_nodes), node_color='b', node_size=50, ax=ax) g, pos, nodelist=range(in_nodes), node_color="r", node_size=50, ax=ax
)
nx.draw_networkx_nodes(
g,
pos,
nodelist=range(in_nodes, in_nodes + out_nodes),
node_color="b",
node_size=50,
ax=ax,
)
for edge in g.edges(): for edge in g.edges():
nx.draw_networkx_edges(g, pos, edgelist=[edge], width=weight[edge[0], edge[1] - in_nodes] * 1.5, ax=ax) nx.draw_networkx_edges(
nx.draw_networkx_labels(g, pos, {i:label + ' ' for i,label in enumerate(M)},horizontalalignment='right', font_size=8, ax=ax) g,
nx.draw_networkx_labels(g, pos, {i+in_nodes:' ' + label for i,label in enumerate(N)},horizontalalignment='left', font_size=8, ax=ax) pos,
edgelist=[edge],
width=weight[edge[0], edge[1] - in_nodes] * 1.5,
ax=ax,
)
nx.draw_networkx_labels(
g,
pos,
{i: label + " " for i, label in enumerate(M)},
horizontalalignment="right",
font_size=8,
ax=ax,
)
nx.draw_networkx_labels(
g,
pos,
{i + in_nodes: " " + label for i, label in enumerate(N)},
horizontalalignment="left",
font_size=8,
ax=ax,
)
import networkx as nx import networkx as nx
from matplotlib.patches import ConnectionStyle, FancyArrowPatch
from networkx.utils import is_string_like from networkx.utils import is_string_like
from matplotlib.patches import ConnectionStyle,FancyArrowPatch
"The following function was modified from the source code of networkx" "The following function was modified from the source code of networkx"
def draw_networkx_edges(G, pos,
edgelist=None,
width=1.0, def draw_networkx_edges(
edge_color='k', G,
style='solid', pos,
alpha=1.0, edgelist=None,
arrowstyle='-|>', width=1.0,
arrowsize=10, edge_color="k",
edge_cmap=None, style="solid",
edge_vmin=None, alpha=1.0,
edge_vmax=None, arrowstyle="-|>",
ax=None, arrowsize=10,
arrows=True, edge_cmap=None,
label=None, edge_vmin=None,
node_size=300, edge_vmax=None,
nodelist=None, ax=None,
node_shape="o", arrows=True,
connectionstyle='arc3', label=None,
**kwds): node_size=300,
nodelist=None,
node_shape="o",
connectionstyle="arc3",
**kwds
):
"""Draw the edges of the graph G. """Draw the edges of the graph G.
This draws only the edges of the graph G. This draws only the edges of the graph G.
...@@ -238,12 +298,12 @@ def draw_networkx_edges(G, pos, ...@@ -238,12 +298,12 @@ def draw_networkx_edges(G, pos,
""" """
try: try:
import matplotlib import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cbook as cb import matplotlib.cbook as cb
from matplotlib.colors import colorConverter, Colormap, Normalize import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from matplotlib.patches import FancyArrowPatch, ConnectionStyle
import numpy as np import numpy as np
from matplotlib.collections import LineCollection
from matplotlib.colors import colorConverter, Colormap, Normalize
from matplotlib.patches import ConnectionStyle, FancyArrowPatch
except ImportError: except ImportError:
raise ImportError("Matplotlib required for draw()") raise ImportError("Matplotlib required for draw()")
except RuntimeError: except RuntimeError:
...@@ -270,39 +330,44 @@ def draw_networkx_edges(G, pos, ...@@ -270,39 +330,44 @@ def draw_networkx_edges(G, pos,
else: else:
lw = width lw = width
if not is_string_like(edge_color) \ if (
and cb.iterable(edge_color) \ not is_string_like(edge_color)
and len(edge_color) == len(edge_pos): and cb.iterable(edge_color)
and len(edge_color) == len(edge_pos)
):
if np.alltrue([is_string_like(c) for c in edge_color]): if np.alltrue([is_string_like(c) for c in edge_color]):
# (should check ALL elements) # (should check ALL elements)
# list of color letters such as ['k','r','k',...] # list of color letters such as ['k','r','k',...]
edge_colors = tuple([colorConverter.to_rgba(c, alpha) edge_colors = tuple(
for c in edge_color]) [colorConverter.to_rgba(c, alpha) for c in edge_color]
)
elif np.alltrue([not is_string_like(c) for c in edge_color]): elif np.alltrue([not is_string_like(c) for c in edge_color]):
# If color specs are given as (rgb) or (rgba) tuples, we're OK # If color specs are given as (rgb) or (rgba) tuples, we're OK
if np.alltrue([cb.iterable(c) and len(c) in (3, 4) if np.alltrue(
for c in edge_color]): [cb.iterable(c) and len(c) in (3, 4) for c in edge_color]
):
edge_colors = tuple(edge_color) edge_colors = tuple(edge_color)
else: else:
# numbers (which are going to be mapped with a colormap) # numbers (which are going to be mapped with a colormap)
edge_colors = None edge_colors = None
else: else:
raise ValueError('edge_color must contain color names or numbers') raise ValueError("edge_color must contain color names or numbers")
else: else:
if is_string_like(edge_color) or len(edge_color) == 1: if is_string_like(edge_color) or len(edge_color) == 1:
edge_colors = (colorConverter.to_rgba(edge_color, alpha), ) edge_colors = (colorConverter.to_rgba(edge_color, alpha),)
else: else:
msg = 'edge_color must be a color or list of one color per edge' msg = "edge_color must be a color or list of one color per edge"
raise ValueError(msg) raise ValueError(msg)
if (not G.is_directed() or not arrows): if not G.is_directed() or not arrows:
edge_collection = LineCollection(edge_pos, edge_collection = LineCollection(
colors=edge_colors, edge_pos,
linewidths=lw, colors=edge_colors,
antialiaseds=(1,), linewidths=lw,
linestyle=style, antialiaseds=(1,),
transOffset=ax.transData, linestyle=style,
) transOffset=ax.transData,
)
edge_collection.set_zorder(1) # edges go behind nodes edge_collection.set_zorder(1) # edges go behind nodes
edge_collection.set_label(label) edge_collection.set_label(label)
...@@ -318,7 +383,7 @@ def draw_networkx_edges(G, pos, ...@@ -318,7 +383,7 @@ def draw_networkx_edges(G, pos,
if edge_colors is None: if edge_colors is None:
if edge_cmap is not None: if edge_cmap is not None:
assert(isinstance(edge_cmap, Colormap)) assert isinstance(edge_cmap, Colormap)
edge_collection.set_array(np.asarray(edge_color)) edge_collection.set_array(np.asarray(edge_color))
edge_collection.set_cmap(edge_cmap) edge_collection.set_cmap(edge_cmap)
if edge_vmin is not None or edge_vmax is not None: if edge_vmin is not None or edge_vmax is not None:
...@@ -346,7 +411,7 @@ def draw_networkx_edges(G, pos, ...@@ -346,7 +411,7 @@ def draw_networkx_edges(G, pos,
arrow_colors = edge_colors arrow_colors = edge_colors
if arrow_colors is None: if arrow_colors is None:
if edge_cmap is not None: if edge_cmap is not None:
assert(isinstance(edge_cmap, Colormap)) assert isinstance(edge_cmap, Colormap)
else: else:
edge_cmap = plt.get_cmap() # default matplotlib colormap edge_cmap = plt.get_cmap() # default matplotlib colormap
if edge_vmin is None: if edge_vmin is None:
...@@ -379,15 +444,18 @@ def draw_networkx_edges(G, pos, ...@@ -379,15 +444,18 @@ def draw_networkx_edges(G, pos,
line_width = lw[i] line_width = lw[i]
else: else:
line_width = lw[0] line_width = lw[0]
arrow = FancyArrowPatch((x1, y1), (x2, y2), arrow = FancyArrowPatch(
arrowstyle=arrowstyle, (x1, y1),
shrinkA=shrink_source, (x2, y2),
shrinkB=shrink_target, arrowstyle=arrowstyle,
mutation_scale=mutation_scale, shrinkA=shrink_source,
connectionstyle=connectionstyle, shrinkB=shrink_target,
color=arrow_color, mutation_scale=mutation_scale,
linewidth=line_width, connectionstyle=connectionstyle,
zorder=1) # arrows go behind nodes color=arrow_color,
linewidth=line_width,
zorder=1,
) # arrows go behind nodes
# There seems to be a bug in matplotlib to make collections of # There seems to be a bug in matplotlib to make collections of
# FancyArrowPatch instances. Until fixed, the patches are added # FancyArrowPatch instances. Until fixed, the patches are added
...@@ -403,7 +471,7 @@ def draw_networkx_edges(G, pos, ...@@ -403,7 +471,7 @@ def draw_networkx_edges(G, pos,
w = maxx - minx w = maxx - minx
h = maxy - miny h = maxy - miny
padx, pady = 0.05 * w, 0.05 * h padx, pady = 0.05 * w, 0.05 * h
corners = (minx - padx, miny - pady), (maxx + padx, maxy + pady) corners = (minx - padx, miny - pady), (maxx + padx, maxy + pady)
ax.update_datalim(corners) ax.update_datalim(corners)
ax.autoscale_view() ax.autoscale_view()
...@@ -412,44 +480,81 @@ def draw_networkx_edges(G, pos, ...@@ -412,44 +480,81 @@ def draw_networkx_edges(G, pos,
def draw_g(graph): def draw_g(graph):
g=graph.g.to_networkx() g = graph.g.to_networkx()
fig=plt.figure(figsize=(8,4),dpi=150) fig = plt.figure(figsize=(8, 4), dpi=150)
ax=fig.subplots() ax = fig.subplots()
ax.axis('off') ax.axis("off")
ax.set_ylim(-1,1.5) ax.set_ylim(-1, 1.5)
en_indx=graph.nids['enc'].tolist() en_indx = graph.nids["enc"].tolist()
de_indx=graph.nids['dec'].tolist() de_indx = graph.nids["dec"].tolist()
en_l={i:np.array([i,0]) for i in en_indx} en_l = {i: np.array([i, 0]) for i in en_indx}
de_l={i:np.array([i+2,1]) for i in de_indx} de_l = {i: np.array([i + 2, 1]) for i in de_indx}
en_de_s=[] en_de_s = []
for i in en_indx: for i in en_indx:
for j in de_indx: for j in de_indx:
en_de_s.append((i,j)) en_de_s.append((i, j))
g.add_edge(i,j) g.add_edge(i, j)
en_s=[] en_s = []
for i in en_indx: for i in en_indx:
for j in en_indx: for j in en_indx:
g.add_edge(i,j) g.add_edge(i, j)
en_s.append((i,j)) en_s.append((i, j))
de_s=[] de_s = []
for idx,i in enumerate(de_indx): for idx, i in enumerate(de_indx):
for j in de_indx[idx:]: for j in de_indx[idx:]:
g.add_edge(i,j) g.add_edge(i, j)
de_s.append((i,j)) de_s.append((i, j))
nx.draw_networkx_nodes(
nx.draw_networkx_nodes(g, en_l, nodelist=en_indx, node_color='r', node_size=60, ax=ax) g, en_l, nodelist=en_indx, node_color="r", node_size=60, ax=ax
nx.draw_networkx_nodes(g, de_l, nodelist=de_indx, node_color='r', node_size=60, ax=ax) )
draw_networkx_edges(g,en_l,edgelist=en_s, ax=ax,connectionstyle="arc3,rad=-0.3",width=0.5) nx.draw_networkx_nodes(
draw_networkx_edges(g,de_l,edgelist=de_s, ax=ax,connectionstyle="arc3,rad=-0.3",width=0.5) g, de_l, nodelist=de_indx, node_color="r", node_size=60, ax=ax
draw_networkx_edges(g,{**en_l,**de_l},edgelist=en_de_s,width=0.3, ax=ax) )
draw_networkx_edges(
g,
en_l,
edgelist=en_s,
ax=ax,
connectionstyle="arc3,rad=-0.3",
width=0.5,
)
draw_networkx_edges(
g,
de_l,
edgelist=de_s,
ax=ax,
connectionstyle="arc3,rad=-0.3",
width=0.5,
)
draw_networkx_edges(g, {**en_l, **de_l}, edgelist=en_de_s, width=0.3, ax=ax)
# ax.add_patch() # ax.add_patch()
ax.text(len(en_indx)+0.5,0,"Encoder", verticalalignment='center', horizontalalignment='left') ax.text(
len(en_indx) + 0.5,
ax.text(len(en_indx)+0.5,1,"Decoder", verticalalignment='center', horizontalalignment='right') 0,
delta=0.03 "Encoder",
for value in {**en_l,**de_l}.values(): verticalalignment="center",
x,y=value horizontalalignment="left",
ax.add_patch(FancyArrowPatch((x-delta,y+delta),(x-delta,y-delta),arrowstyle="->",mutation_scale=8,connectionstyle="arc3,rad=3")) )
ax.text(
len(en_indx) + 0.5,
1,
"Decoder",
verticalalignment="center",
horizontalalignment="right",
)
delta = 0.03
for value in {**en_l, **de_l}.values():
x, y = value
ax.add_patch(
FancyArrowPatch(
(x - delta, y + delta),
(x - delta, y - delta),
arrowstyle="->",
mutation_scale=8,
connectionstyle="arc3,rad=3",
)
)
plt.show(fig) plt.show(fig)
...@@ -2,17 +2,17 @@ import argparse ...@@ -2,17 +2,17 @@ import argparse
import collections import collections
import time import time
import dgl
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn.functional as F import torch.nn.functional as F
import torch.nn.init as INIT import torch.nn.init as INIT
import torch.optim as optim import torch.optim as optim
from dgl.data.tree import SSTDataset
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tree_lstm import TreeLSTM from tree_lstm import TreeLSTM
import dgl
from dgl.data.tree import SSTDataset
SSTBatch = collections.namedtuple( SSTBatch = collections.namedtuple(
"SSTBatch", ["graph", "mask", "wordid", "label"] "SSTBatch", ["graph", "mask", "wordid", "label"]
) )
......
...@@ -2,14 +2,16 @@ ...@@ -2,14 +2,16 @@
Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks Improved Semantic Representations From Tree-Structured Long Short-Term Memory Networks
https://arxiv.org/abs/1503.00075 https://arxiv.org/abs/1503.00075
""" """
import time
import itertools import itertools
import time
import dgl
import networkx as nx import networkx as nx
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
class TreeLSTMCell(nn.Module): class TreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size): def __init__(self, x_size, h_size):
...@@ -20,21 +22,22 @@ class TreeLSTMCell(nn.Module): ...@@ -20,21 +22,22 @@ class TreeLSTMCell(nn.Module):
self.U_f = nn.Linear(2 * h_size, 2 * h_size) self.U_f = nn.Linear(2 * h_size, 2 * h_size)
def message_func(self, edges): def message_func(self, edges):
return {'h': edges.src['h'], 'c': edges.src['c']} return {"h": edges.src["h"], "c": edges.src["c"]}
def reduce_func(self, nodes): def reduce_func(self, nodes):
h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1) h_cat = nodes.mailbox["h"].view(nodes.mailbox["h"].size(0), -1)
f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size()) f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox["h"].size())
c = th.sum(f * nodes.mailbox['c'], 1) c = th.sum(f * nodes.mailbox["c"], 1)
return {'iou': self.U_iou(h_cat), 'c': c} return {"iou": self.U_iou(h_cat), "c": c}
def apply_node_func(self, nodes): def apply_node_func(self, nodes):
iou = nodes.data['iou'] + self.b_iou iou = nodes.data["iou"] + self.b_iou
i, o, u = th.chunk(iou, 3, 1) i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u) i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
c = i * u + nodes.data['c'] c = i * u + nodes.data["c"]
h = o * th.tanh(c) h = o * th.tanh(c)
return {'h' : h, 'c' : c} return {"h": h, "c": c}
class ChildSumTreeLSTMCell(nn.Module): class ChildSumTreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size): def __init__(self, x_size, h_size):
...@@ -45,41 +48,44 @@ class ChildSumTreeLSTMCell(nn.Module): ...@@ -45,41 +48,44 @@ class ChildSumTreeLSTMCell(nn.Module):
self.U_f = nn.Linear(h_size, h_size) self.U_f = nn.Linear(h_size, h_size)
def message_func(self, edges): def message_func(self, edges):
return {'h': edges.src['h'], 'c': edges.src['c']} return {"h": edges.src["h"], "c": edges.src["c"]}
def reduce_func(self, nodes): def reduce_func(self, nodes):
h_tild = th.sum(nodes.mailbox['h'], 1) h_tild = th.sum(nodes.mailbox["h"], 1)
f = th.sigmoid(self.U_f(nodes.mailbox['h'])) f = th.sigmoid(self.U_f(nodes.mailbox["h"]))
c = th.sum(f * nodes.mailbox['c'], 1) c = th.sum(f * nodes.mailbox["c"], 1)
return {'iou': self.U_iou(h_tild), 'c': c} return {"iou": self.U_iou(h_tild), "c": c}
def apply_node_func(self, nodes): def apply_node_func(self, nodes):
iou = nodes.data['iou'] + self.b_iou iou = nodes.data["iou"] + self.b_iou
i, o, u = th.chunk(iou, 3, 1) i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u) i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
c = i * u + nodes.data['c'] c = i * u + nodes.data["c"]
h = o * th.tanh(c) h = o * th.tanh(c)
return {'h': h, 'c': c} return {"h": h, "c": c}
class TreeLSTM(nn.Module): class TreeLSTM(nn.Module):
def __init__(self, def __init__(
num_vocabs, self,
x_size, num_vocabs,
h_size, x_size,
num_classes, h_size,
dropout, num_classes,
cell_type='nary', dropout,
pretrained_emb=None): cell_type="nary",
pretrained_emb=None,
):
super(TreeLSTM, self).__init__() super(TreeLSTM, self).__init__()
self.x_size = x_size self.x_size = x_size
self.embedding = nn.Embedding(num_vocabs, x_size) self.embedding = nn.Embedding(num_vocabs, x_size)
if pretrained_emb is not None: if pretrained_emb is not None:
print('Using glove') print("Using glove")
self.embedding.weight.data.copy_(pretrained_emb) self.embedding.weight.data.copy_(pretrained_emb)
self.embedding.weight.requires_grad = True self.embedding.weight.requires_grad = True
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(h_size, num_classes) self.linear = nn.Linear(h_size, num_classes)
cell = TreeLSTMCell if cell_type == 'nary' else ChildSumTreeLSTMCell cell = TreeLSTMCell if cell_type == "nary" else ChildSumTreeLSTMCell
self.cell = cell(x_size, h_size) self.cell = cell(x_size, h_size)
def forward(self, batch, g, h, c): def forward(self, batch, g, h, c):
...@@ -101,12 +107,19 @@ class TreeLSTM(nn.Module): ...@@ -101,12 +107,19 @@ class TreeLSTM(nn.Module):
""" """
# feed embedding # feed embedding
embeds = self.embedding(batch.wordid * batch.mask) embeds = self.embedding(batch.wordid * batch.mask)
g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1) g.ndata["iou"] = self.cell.W_iou(
g.ndata['h'] = h self.dropout(embeds)
g.ndata['c'] = c ) * batch.mask.float().unsqueeze(-1)
g.ndata["h"] = h
g.ndata["c"] = c
# propagate # propagate
dgl.prop_nodes_topo(g, self.cell.message_func, self.cell.reduce_func, apply_node_func=self.cell.apply_node_func) dgl.prop_nodes_topo(
g,
self.cell.message_func,
self.cell.reduce_func,
apply_node_func=self.cell.apply_node_func,
)
# compute logits # compute logits
h = self.dropout(g.ndata.pop('h')) h = self.dropout(g.ndata.pop("h"))
logits = self.linear(h) logits = self.linear(h)
return logits return logits
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from train import device
from dgl.nn.pytorch import GraphConv from dgl.nn.pytorch import GraphConv
from train import device
class VGAEModel(nn.Module): class VGAEModel(nn.Module):
......
...@@ -2,11 +2,14 @@ import argparse ...@@ -2,11 +2,14 @@ import argparse
import os import os
import time import time
import dgl
import model import model
import numpy as np import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
from input_data import load_data from input_data import load_data
from preprocess import ( from preprocess import (
mask_test_edges, mask_test_edges,
...@@ -16,9 +19,6 @@ from preprocess import ( ...@@ -16,9 +19,6 @@ from preprocess import (
) )
from sklearn.metrics import average_precision_score, roc_auc_score from sklearn.metrics import average_precision_score, roc_auc_score
import dgl
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset
os.environ["KMP_DUPLICATE_LIB_OK"] = "True" os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
parser = argparse.ArgumentParser(description="Variant Graph Auto Encoder") parser = argparse.ArgumentParser(description="Variant Graph Auto Encoder")
......
import argparse import argparse
import time import time
import dgl
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import tqdm import tqdm
from torch.utils.data import DataLoader
import dgl
import dgl.function as fn
import dgl.nn.pytorch as dglnn
from dgl.data import RedditDataset from dgl.data import RedditDataset
from torch.utils.data import DataLoader
class SAGEConvWithCV(nn.Module): class SAGEConvWithCV(nn.Module):
......
...@@ -3,6 +3,10 @@ import math ...@@ -3,6 +3,10 @@ import math
import time import time
import traceback import traceback
import dgl
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import numpy as np import numpy as np
import torch as th import torch as th
import torch.multiprocessing as mp import torch.multiprocessing as mp
...@@ -10,14 +14,10 @@ import torch.nn as nn ...@@ -10,14 +14,10 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import tqdm import tqdm
from dgl.data import RedditDataset
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import dgl
import dgl.function as fn
import dgl.nn.pytorch as dglnn
from dgl.data import RedditDataset
class SAGEConvWithCV(nn.Module): class SAGEConvWithCV(nn.Module):
def __init__(self, in_feats, out_feats, activation): def __init__(self, in_feats, out_feats, activation):
......
...@@ -9,6 +9,7 @@ import torch.nn.functional as F ...@@ -9,6 +9,7 @@ import torch.nn.functional as F
from dgl.data import CoraGraphDataset from dgl.data import CoraGraphDataset
from torch.optim import Adam from torch.optim import Adam
############################################################################### ###############################################################################
# (HIGHLIGHT) Compute Label Propagation with Sparse Matrix API # (HIGHLIGHT) Compute Label Propagation with Sparse Matrix API
############################################################################### ###############################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment