Unverified Commit 23d09057 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4642)



* [Misc] Black auto fix.

* sort
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent a9f2acf3
import torch.nn as nn import torch.nn as nn
def GlorotOrthogonal(tensor, scale=2.0): def GlorotOrthogonal(tensor, scale=2.0):
if tensor is not None: if tensor is not None:
nn.init.orthogonal_(tensor.data) nn.init.orthogonal_(tensor.data)
scale /= (tensor.size(-2) + tensor.size(-1)) * tensor.var() scale /= (tensor.size(-2) + tensor.size(-1)) * tensor.var()
tensor.data *= scale.sqrt() tensor.data *= scale.sqrt()
\ No newline at end of file
import torch import torch
import torch.nn as nn import torch.nn as nn
from modules.initializers import GlorotOrthogonal
from modules.residual_layer import ResidualLayer
import dgl.function as fn import dgl.function as fn
from modules.residual_layer import ResidualLayer
from modules.initializers import GlorotOrthogonal
class InteractionBlock(nn.Module): class InteractionBlock(nn.Module):
def __init__(self, def __init__(
emb_size, self,
num_radial, emb_size,
num_spherical, num_radial,
num_bilinear, num_spherical,
num_before_skip, num_bilinear,
num_after_skip, num_before_skip,
activation=None): num_after_skip,
activation=None,
):
super(InteractionBlock, self).__init__() super(InteractionBlock, self).__init__()
self.activation = activation self.activation = activation
# Transformations of Bessel and spherical basis representations # Transformations of Bessel and spherical basis representations
self.dense_rbf = nn.Linear(num_radial, emb_size, bias=False) self.dense_rbf = nn.Linear(num_radial, emb_size, bias=False)
self.dense_sbf = nn.Linear(num_radial * num_spherical, num_bilinear, bias=False) self.dense_sbf = nn.Linear(
num_radial * num_spherical, num_bilinear, bias=False
)
# Dense transformations of input messages # Dense transformations of input messages
self.dense_ji = nn.Linear(emb_size, emb_size) self.dense_ji = nn.Linear(emb_size, emb_size)
self.dense_kj = nn.Linear(emb_size, emb_size) self.dense_kj = nn.Linear(emb_size, emb_size)
# Bilinear layer # Bilinear layer
bilin_initializer = torch.empty((emb_size, num_bilinear, emb_size)).normal_(mean=0, std=2 / emb_size) bilin_initializer = torch.empty(
(emb_size, num_bilinear, emb_size)
).normal_(mean=0, std=2 / emb_size)
self.W_bilin = nn.Parameter(bilin_initializer) self.W_bilin = nn.Parameter(bilin_initializer)
# Residual layers before skip connection # Residual layers before skip connection
self.layers_before_skip = nn.ModuleList([ self.layers_before_skip = nn.ModuleList(
ResidualLayer(emb_size, activation=activation) for _ in range(num_before_skip) [
]) ResidualLayer(emb_size, activation=activation)
for _ in range(num_before_skip)
]
)
self.final_before_skip = nn.Linear(emb_size, emb_size) self.final_before_skip = nn.Linear(emb_size, emb_size)
# Residual layers after skip connection # Residual layers after skip connection
self.layers_after_skip = nn.ModuleList([ self.layers_after_skip = nn.ModuleList(
ResidualLayer(emb_size, activation=activation) for _ in range(num_after_skip) [
]) ResidualLayer(emb_size, activation=activation)
for _ in range(num_after_skip)
]
)
self.reset_params() self.reset_params()
def reset_params(self): def reset_params(self):
GlorotOrthogonal(self.dense_rbf.weight) GlorotOrthogonal(self.dense_rbf.weight)
GlorotOrthogonal(self.dense_sbf.weight) GlorotOrthogonal(self.dense_sbf.weight)
...@@ -47,50 +60,52 @@ class InteractionBlock(nn.Module): ...@@ -47,50 +60,52 @@ class InteractionBlock(nn.Module):
def edge_transfer(self, edges): def edge_transfer(self, edges):
# Transform from Bessel basis to dence vector # Transform from Bessel basis to dence vector
rbf = self.dense_rbf(edges.data['rbf']) rbf = self.dense_rbf(edges.data["rbf"])
# Initial transformation # Initial transformation
x_ji = self.dense_ji(edges.data['m']) x_ji = self.dense_ji(edges.data["m"])
x_kj = self.dense_kj(edges.data['m']) x_kj = self.dense_kj(edges.data["m"])
if self.activation is not None: if self.activation is not None:
x_ji = self.activation(x_ji) x_ji = self.activation(x_ji)
x_kj = self.activation(x_kj) x_kj = self.activation(x_kj)
# w: W * e_RBF \bigodot \sigma(W * m + b) # w: W * e_RBF \bigodot \sigma(W * m + b)
return {'x_kj': x_kj * rbf, 'x_ji': x_ji} return {"x_kj": x_kj * rbf, "x_ji": x_ji}
def msg_func(self, edges): def msg_func(self, edges):
sbf = self.dense_sbf(edges.data['sbf']) sbf = self.dense_sbf(edges.data["sbf"])
# Apply bilinear layer to interactions and basis function activation # Apply bilinear layer to interactions and basis function activation
# [None, 8] * [128, 8, 128] * [None, 128] -> [None, 128] # [None, 8] * [128, 8, 128] * [None, 128] -> [None, 128]
x_kj = torch.einsum("wj,wl,ijl->wi", sbf, edges.src['x_kj'], self.W_bilin) x_kj = torch.einsum(
return {'x_kj': x_kj} "wj,wl,ijl->wi", sbf, edges.src["x_kj"], self.W_bilin
)
return {"x_kj": x_kj}
def forward(self, g, l_g): def forward(self, g, l_g):
g.apply_edges(self.edge_transfer) g.apply_edges(self.edge_transfer)
# nodes correspond to edges and edges correspond to nodes in the original graphs # nodes correspond to edges and edges correspond to nodes in the original graphs
# node: d, rbf, o, rbf_env, x_kj, x_ji # node: d, rbf, o, rbf_env, x_kj, x_ji
for k, v in g.edata.items(): for k, v in g.edata.items():
l_g.ndata[k] = v l_g.ndata[k] = v
l_g.update_all(self.msg_func, fn.sum('x_kj', 'm_update')) l_g.update_all(self.msg_func, fn.sum("x_kj", "m_update"))
for k, v in l_g.ndata.items(): for k, v in l_g.ndata.items():
g.edata[k] = v g.edata[k] = v
# Transformations before skip connection # Transformations before skip connection
g.edata['m_update'] = g.edata['m_update'] + g.edata['x_ji'] g.edata["m_update"] = g.edata["m_update"] + g.edata["x_ji"]
for layer in self.layers_before_skip: for layer in self.layers_before_skip:
g.edata['m_update'] = layer(g.edata['m_update']) g.edata["m_update"] = layer(g.edata["m_update"])
g.edata['m_update'] = self.final_before_skip(g.edata['m_update']) g.edata["m_update"] = self.final_before_skip(g.edata["m_update"])
if self.activation is not None: if self.activation is not None:
g.edata['m_update'] = self.activation(g.edata['m_update']) g.edata["m_update"] = self.activation(g.edata["m_update"])
# Skip connection # Skip connection
g.edata['m'] = g.edata['m'] + g.edata['m_update'] g.edata["m"] = g.edata["m"] + g.edata["m_update"]
# Transformations after skip connection # Transformations after skip connection
for layer in self.layers_after_skip: for layer in self.layers_after_skip:
g.edata['m'] = layer(g.edata['m']) g.edata["m"] = layer(g.edata["m"])
return g return g
\ No newline at end of file
import torch.nn as nn import torch.nn as nn
from modules.initializers import GlorotOrthogonal
from modules.residual_layer import ResidualLayer
import dgl import dgl
import dgl.function as fn import dgl.function as fn
from modules.residual_layer import ResidualLayer
from modules.initializers import GlorotOrthogonal
class InteractionPPBlock(nn.Module): class InteractionPPBlock(nn.Module):
def __init__(self, def __init__(
emb_size, self,
int_emb_size, emb_size,
basis_emb_size, int_emb_size,
num_radial, basis_emb_size,
num_spherical, num_radial,
num_before_skip, num_spherical,
num_after_skip, num_before_skip,
activation=None): num_after_skip,
activation=None,
):
super(InteractionPPBlock, self).__init__() super(InteractionPPBlock, self).__init__()
self.activation = activation self.activation = activation
# Transformations of Bessel and spherical basis representations # Transformations of Bessel and spherical basis representations
self.dense_rbf1 = nn.Linear(num_radial, basis_emb_size, bias=False) self.dense_rbf1 = nn.Linear(num_radial, basis_emb_size, bias=False)
self.dense_rbf2 = nn.Linear(basis_emb_size, emb_size, bias=False) self.dense_rbf2 = nn.Linear(basis_emb_size, emb_size, bias=False)
self.dense_sbf1 = nn.Linear(num_radial * num_spherical, basis_emb_size, bias=False) self.dense_sbf1 = nn.Linear(
num_radial * num_spherical, basis_emb_size, bias=False
)
self.dense_sbf2 = nn.Linear(basis_emb_size, int_emb_size, bias=False) self.dense_sbf2 = nn.Linear(basis_emb_size, int_emb_size, bias=False)
# Dense transformations of input messages # Dense transformations of input messages
self.dense_ji = nn.Linear(emb_size, emb_size) self.dense_ji = nn.Linear(emb_size, emb_size)
...@@ -30,17 +35,23 @@ class InteractionPPBlock(nn.Module): ...@@ -30,17 +35,23 @@ class InteractionPPBlock(nn.Module):
self.down_projection = nn.Linear(emb_size, int_emb_size, bias=False) self.down_projection = nn.Linear(emb_size, int_emb_size, bias=False)
self.up_projection = nn.Linear(int_emb_size, emb_size, bias=False) self.up_projection = nn.Linear(int_emb_size, emb_size, bias=False)
# Residual layers before skip connection # Residual layers before skip connection
self.layers_before_skip = nn.ModuleList([ self.layers_before_skip = nn.ModuleList(
ResidualLayer(emb_size, activation=activation) for _ in range(num_before_skip) [
]) ResidualLayer(emb_size, activation=activation)
for _ in range(num_before_skip)
]
)
self.final_before_skip = nn.Linear(emb_size, emb_size) self.final_before_skip = nn.Linear(emb_size, emb_size)
# Residual layers after skip connection # Residual layers after skip connection
self.layers_after_skip = nn.ModuleList([ self.layers_after_skip = nn.ModuleList(
ResidualLayer(emb_size, activation=activation) for _ in range(num_after_skip) [
]) ResidualLayer(emb_size, activation=activation)
for _ in range(num_after_skip)
]
)
self.reset_params() self.reset_params()
def reset_params(self): def reset_params(self):
GlorotOrthogonal(self.dense_rbf1.weight) GlorotOrthogonal(self.dense_rbf1.weight)
GlorotOrthogonal(self.dense_rbf2.weight) GlorotOrthogonal(self.dense_rbf2.weight)
...@@ -55,11 +66,11 @@ class InteractionPPBlock(nn.Module): ...@@ -55,11 +66,11 @@ class InteractionPPBlock(nn.Module):
def edge_transfer(self, edges): def edge_transfer(self, edges):
# Transform from Bessel basis to dense vector # Transform from Bessel basis to dense vector
rbf = self.dense_rbf1(edges.data['rbf']) rbf = self.dense_rbf1(edges.data["rbf"])
rbf = self.dense_rbf2(rbf) rbf = self.dense_rbf2(rbf)
# Initial transformation # Initial transformation
x_ji = self.dense_ji(edges.data['m']) x_ji = self.dense_ji(edges.data["m"])
x_kj = self.dense_kj(edges.data['m']) x_kj = self.dense_kj(edges.data["m"])
if self.activation is not None: if self.activation is not None:
x_ji = self.activation(x_ji) x_ji = self.activation(x_ji)
x_kj = self.activation(x_kj) x_kj = self.activation(x_kj)
...@@ -67,41 +78,41 @@ class InteractionPPBlock(nn.Module): ...@@ -67,41 +78,41 @@ class InteractionPPBlock(nn.Module):
x_kj = self.down_projection(x_kj * rbf) x_kj = self.down_projection(x_kj * rbf)
if self.activation is not None: if self.activation is not None:
x_kj = self.activation(x_kj) x_kj = self.activation(x_kj)
return {'x_kj': x_kj, 'x_ji': x_ji} return {"x_kj": x_kj, "x_ji": x_ji}
def msg_func(self, edges): def msg_func(self, edges):
sbf = self.dense_sbf1(edges.data['sbf']) sbf = self.dense_sbf1(edges.data["sbf"])
sbf = self.dense_sbf2(sbf) sbf = self.dense_sbf2(sbf)
x_kj = edges.src['x_kj'] * sbf x_kj = edges.src["x_kj"] * sbf
return {'x_kj': x_kj} return {"x_kj": x_kj}
def forward(self, g, l_g): def forward(self, g, l_g):
g.apply_edges(self.edge_transfer) g.apply_edges(self.edge_transfer)
# nodes correspond to edges and edges correspond to nodes in the original graphs # nodes correspond to edges and edges correspond to nodes in the original graphs
# node: d, rbf, o, rbf_env, x_kj, x_ji # node: d, rbf, o, rbf_env, x_kj, x_ji
for k, v in g.edata.items(): for k, v in g.edata.items():
l_g.ndata[k] = v l_g.ndata[k] = v
l_g_reverse = dgl.reverse(l_g, copy_edata=True) l_g_reverse = dgl.reverse(l_g, copy_edata=True)
l_g_reverse.update_all(self.msg_func, fn.sum('x_kj', 'm_update')) l_g_reverse.update_all(self.msg_func, fn.sum("x_kj", "m_update"))
g.edata['m_update'] = self.up_projection(l_g_reverse.ndata['m_update']) g.edata["m_update"] = self.up_projection(l_g_reverse.ndata["m_update"])
if self.activation is not None: if self.activation is not None:
g.edata['m_update'] = self.activation(g.edata['m_update']) g.edata["m_update"] = self.activation(g.edata["m_update"])
# Transformations before skip connection # Transformations before skip connection
g.edata['m_update'] = g.edata['m_update'] + g.edata['x_ji'] g.edata["m_update"] = g.edata["m_update"] + g.edata["x_ji"]
for layer in self.layers_before_skip: for layer in self.layers_before_skip:
g.edata['m_update'] = layer(g.edata['m_update']) g.edata["m_update"] = layer(g.edata["m_update"])
g.edata['m_update'] = self.final_before_skip(g.edata['m_update']) g.edata["m_update"] = self.final_before_skip(g.edata["m_update"])
if self.activation is not None: if self.activation is not None:
g.edata['m_update'] = self.activation(g.edata['m_update']) g.edata["m_update"] = self.activation(g.edata["m_update"])
# Skip connection # Skip connection
g.edata['m'] = g.edata['m'] + g.edata['m_update'] g.edata["m"] = g.edata["m"] + g.edata["m_update"]
# Transformations after skip connection # Transformations after skip connection
for layer in self.layers_after_skip: for layer in self.layers_after_skip:
g.edata['m'] = layer(g.edata['m']) g.edata["m"] = layer(g.edata["m"])
return g return g
\ No newline at end of file
import torch.nn as nn import torch.nn as nn
from modules.initializers import GlorotOrthogonal
import dgl import dgl
import dgl.function as fn import dgl.function as fn
from modules.initializers import GlorotOrthogonal
class OutputBlock(nn.Module): class OutputBlock(nn.Module):
def __init__(self, def __init__(
emb_size, self,
num_radial, emb_size,
num_dense, num_radial,
num_targets, num_dense,
activation=None, num_targets,
output_init=nn.init.zeros_): activation=None,
output_init=nn.init.zeros_,
):
super(OutputBlock, self).__init__() super(OutputBlock, self).__init__()
self.activation = activation self.activation = activation
self.output_init = output_init self.output_init = output_init
self.dense_rbf = nn.Linear(num_radial, emb_size, bias=False) self.dense_rbf = nn.Linear(num_radial, emb_size, bias=False)
self.dense_layers = nn.ModuleList([ self.dense_layers = nn.ModuleList(
nn.Linear(emb_size, emb_size) for _ in range(num_dense) [nn.Linear(emb_size, emb_size) for _ in range(num_dense)]
]) )
self.dense_final = nn.Linear(emb_size, num_targets, bias=False) self.dense_final = nn.Linear(emb_size, num_targets, bias=False)
self.reset_params() self.reset_params()
def reset_params(self): def reset_params(self):
GlorotOrthogonal(self.dense_rbf.weight) GlorotOrthogonal(self.dense_rbf.weight)
for layer in self.dense_layers: for layer in self.dense_layers:
...@@ -31,12 +34,12 @@ class OutputBlock(nn.Module): ...@@ -31,12 +34,12 @@ class OutputBlock(nn.Module):
def forward(self, g): def forward(self, g):
with g.local_scope(): with g.local_scope():
g.edata['tmp'] = g.edata['m'] * self.dense_rbf(g.edata['rbf']) g.edata["tmp"] = g.edata["m"] * self.dense_rbf(g.edata["rbf"])
g.update_all(fn.copy_e('tmp', 'x'), fn.sum('x', 't')) g.update_all(fn.copy_e("tmp", "x"), fn.sum("x", "t"))
for layer in self.dense_layers: for layer in self.dense_layers:
g.ndata['t'] = layer(g.ndata['t']) g.ndata["t"] = layer(g.ndata["t"])
if self.activation is not None: if self.activation is not None:
g.ndata['t'] = self.activation(g.ndata['t']) g.ndata["t"] = self.activation(g.ndata["t"])
g.ndata['t'] = self.dense_final(g.ndata['t']) g.ndata["t"] = self.dense_final(g.ndata["t"])
return dgl.readout_nodes(g, 't') return dgl.readout_nodes(g, "t")
\ No newline at end of file
import torch.nn as nn import torch.nn as nn
from modules.initializers import GlorotOrthogonal
import dgl import dgl
import dgl.function as fn import dgl.function as fn
from modules.initializers import GlorotOrthogonal
class OutputPPBlock(nn.Module): class OutputPPBlock(nn.Module):
def __init__(self, def __init__(
emb_size, self,
out_emb_size, emb_size,
num_radial, out_emb_size,
num_dense, num_radial,
num_targets, num_dense,
activation=None, num_targets,
output_init=nn.init.zeros_, activation=None,
extensive=True): output_init=nn.init.zeros_,
extensive=True,
):
super(OutputPPBlock, self).__init__() super(OutputPPBlock, self).__init__()
self.activation = activation self.activation = activation
...@@ -21,12 +24,12 @@ class OutputPPBlock(nn.Module): ...@@ -21,12 +24,12 @@ class OutputPPBlock(nn.Module):
self.extensive = extensive self.extensive = extensive
self.dense_rbf = nn.Linear(num_radial, emb_size, bias=False) self.dense_rbf = nn.Linear(num_radial, emb_size, bias=False)
self.up_projection = nn.Linear(emb_size, out_emb_size, bias=False) self.up_projection = nn.Linear(emb_size, out_emb_size, bias=False)
self.dense_layers = nn.ModuleList([ self.dense_layers = nn.ModuleList(
nn.Linear(out_emb_size, out_emb_size) for _ in range(num_dense) [nn.Linear(out_emb_size, out_emb_size) for _ in range(num_dense)]
]) )
self.dense_final = nn.Linear(out_emb_size, num_targets, bias=False) self.dense_final = nn.Linear(out_emb_size, num_targets, bias=False)
self.reset_params() self.reset_params()
def reset_params(self): def reset_params(self):
GlorotOrthogonal(self.dense_rbf.weight) GlorotOrthogonal(self.dense_rbf.weight)
GlorotOrthogonal(self.up_projection.weight) GlorotOrthogonal(self.up_projection.weight)
...@@ -36,14 +39,16 @@ class OutputPPBlock(nn.Module): ...@@ -36,14 +39,16 @@ class OutputPPBlock(nn.Module):
def forward(self, g): def forward(self, g):
with g.local_scope(): with g.local_scope():
g.edata['tmp'] = g.edata['m'] * self.dense_rbf(g.edata['rbf']) g.edata["tmp"] = g.edata["m"] * self.dense_rbf(g.edata["rbf"])
g_reverse = dgl.reverse(g, copy_edata=True) g_reverse = dgl.reverse(g, copy_edata=True)
g_reverse.update_all(fn.copy_e('tmp', 'x'), fn.sum('x', 't')) g_reverse.update_all(fn.copy_e("tmp", "x"), fn.sum("x", "t"))
g.ndata['t'] = self.up_projection(g_reverse.ndata['t']) g.ndata["t"] = self.up_projection(g_reverse.ndata["t"])
for layer in self.dense_layers: for layer in self.dense_layers:
g.ndata['t'] = layer(g.ndata['t']) g.ndata["t"] = layer(g.ndata["t"])
if self.activation is not None: if self.activation is not None:
g.ndata['t'] = self.activation(g.ndata['t']) g.ndata["t"] = self.activation(g.ndata["t"])
g.ndata['t'] = self.dense_final(g.ndata['t']) g.ndata["t"] = self.dense_final(g.ndata["t"])
return dgl.readout_nodes(g, 't', op='sum' if self.extensive else 'mean') return dgl.readout_nodes(
\ No newline at end of file g, "t", op="sum" if self.extensive else "mean"
)
import torch.nn as nn import torch.nn as nn
from modules.initializers import GlorotOrthogonal from modules.initializers import GlorotOrthogonal
class ResidualLayer(nn.Module): class ResidualLayer(nn.Module):
def __init__(self, units, activation=None): def __init__(self, units, activation=None):
super(ResidualLayer, self).__init__() super(ResidualLayer, self).__init__()
...@@ -9,9 +9,9 @@ class ResidualLayer(nn.Module): ...@@ -9,9 +9,9 @@ class ResidualLayer(nn.Module):
self.activation = activation self.activation = activation
self.dense_1 = nn.Linear(units, units) self.dense_1 = nn.Linear(units, units)
self.dense_2 = nn.Linear(units, units) self.dense_2 = nn.Linear(units, units)
self.reset_params() self.reset_params()
def reset_params(self): def reset_params(self):
GlorotOrthogonal(self.dense_1.weight) GlorotOrthogonal(self.dense_1.weight)
nn.init.zeros_(self.dense_1.bias) nn.init.zeros_(self.dense_1.bias)
...@@ -25,4 +25,4 @@ class ResidualLayer(nn.Module): ...@@ -25,4 +25,4 @@ class ResidualLayer(nn.Module):
x = self.dense_2(x) x = self.dense_2(x)
if self.activation is not None: if self.activation is not None:
x = self.activation(x) x = self.activation(x)
return inputs + x return inputs + x
\ No newline at end of file
import sympy as sym import sympy as sym
import torch import torch
import torch.nn as nn import torch.nn as nn
from modules.basis_utils import bessel_basis, real_sph_harm from modules.basis_utils import bessel_basis, real_sph_harm
from modules.envelope import Envelope from modules.envelope import Envelope
class SphericalBasisLayer(nn.Module): class SphericalBasisLayer(nn.Module):
def __init__(self, def __init__(self, num_spherical, num_radial, cutoff, envelope_exponent=5):
num_spherical,
num_radial,
cutoff,
envelope_exponent=5):
super(SphericalBasisLayer, self).__init__() super(SphericalBasisLayer, self).__init__()
assert num_radial <= 64 assert num_radial <= 64
...@@ -20,26 +16,38 @@ class SphericalBasisLayer(nn.Module): ...@@ -20,26 +16,38 @@ class SphericalBasisLayer(nn.Module):
self.envelope = Envelope(envelope_exponent) self.envelope = Envelope(envelope_exponent)
# retrieve formulas # retrieve formulas
self.bessel_formulas = bessel_basis(num_spherical, num_radial) # x, [num_spherical, num_radial] sympy functions self.bessel_formulas = bessel_basis(
self.sph_harm_formulas = real_sph_harm(num_spherical) # theta, [num_spherical, ] sympy functions num_spherical, num_radial
) # x, [num_spherical, num_radial] sympy functions
self.sph_harm_formulas = real_sph_harm(
num_spherical
) # theta, [num_spherical, ] sympy functions
self.sph_funcs = [] self.sph_funcs = []
self.bessel_funcs = [] self.bessel_funcs = []
# convert to torch functions # convert to torch functions
x = sym.symbols('x') x = sym.symbols("x")
theta = sym.symbols('theta') theta = sym.symbols("theta")
modules = {'sin': torch.sin, 'cos': torch.cos} modules = {"sin": torch.sin, "cos": torch.cos}
for i in range(num_spherical): for i in range(num_spherical):
if i == 0: if i == 0:
first_sph = sym.lambdify([theta], self.sph_harm_formulas[i][0], modules)(0) first_sph = sym.lambdify(
self.sph_funcs.append(lambda tensor: torch.zeros_like(tensor) + first_sph) [theta], self.sph_harm_formulas[i][0], modules
)(0)
self.sph_funcs.append(
lambda tensor: torch.zeros_like(tensor) + first_sph
)
else: else:
self.sph_funcs.append(sym.lambdify([theta], self.sph_harm_formulas[i][0], modules)) self.sph_funcs.append(
sym.lambdify([theta], self.sph_harm_formulas[i][0], modules)
)
for j in range(num_radial): for j in range(num_radial):
self.bessel_funcs.append(sym.lambdify([x], self.bessel_formulas[i][j], modules)) self.bessel_funcs.append(
sym.lambdify([x], self.bessel_formulas[i][j], modules)
)
def get_bessel_funcs(self): def get_bessel_funcs(self):
return self.bessel_funcs return self.bessel_funcs
def get_sph_funcs(self): def get_sph_funcs(self):
return self.sph_funcs return self.sph_funcs
\ No newline at end of file
"""QM9 dataset for graph property prediction (regression).""" """QM9 dataset for graph property prediction (regression)."""
import os import os
import numpy as np import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
import torch import torch
import dgl
from tqdm import trange from tqdm import trange
import dgl
from dgl.convert import graph as dgl_graph
from dgl.data import QM9Dataset from dgl.data import QM9Dataset
from dgl.data.utils import load_graphs, save_graphs from dgl.data.utils import load_graphs, save_graphs
from dgl.convert import graph as dgl_graph
class QM9(QM9Dataset): class QM9(QM9Dataset):
r"""QM9 dataset for graph property prediction (regression) r"""QM9 dataset for graph property prediction (regression)
...@@ -16,11 +18,11 @@ class QM9(QM9Dataset): ...@@ -16,11 +18,11 @@ class QM9(QM9Dataset):
This dataset consists of 130,831 molecules with 12 regression targets. This dataset consists of 130,831 molecules with 12 regression targets.
Nodes correspond to atoms and edges correspond to bonds. Nodes correspond to atoms and edges correspond to bonds.
Reference: Reference:
- `"Quantum-Machine.org" <http://quantum-machine.org/datasets/>`_ - `"Quantum-Machine.org" <http://quantum-machine.org/datasets/>`_
- `"Directional Message Passing for Molecular Graphs" <https://arxiv.org/abs/2003.03123>`_ - `"Directional Message Passing for Molecular Graphs" <https://arxiv.org/abs/2003.03123>`_
Statistics: Statistics:
- Number of graphs: 130,831 - Number of graphs: 130,831
...@@ -53,7 +55,7 @@ class QM9(QM9Dataset): ...@@ -53,7 +55,7 @@ class QM9(QM9Dataset):
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| Cv | :math:`c_{\textrm{v}}` | Heat capavity at 298.15K | :math:`\frac{\textrm{cal}}{\textrm{mol K}}` | | Cv | :math:`c_{\textrm{v}}` | Heat capavity at 298.15K | :math:`\frac{\textrm{cal}}{\textrm{mol K}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
Parameters Parameters
---------- ----------
label_keys: list label_keys: list
...@@ -75,12 +77,12 @@ class QM9(QM9Dataset): ...@@ -75,12 +77,12 @@ class QM9(QM9Dataset):
---------- ----------
num_labels : int num_labels : int
Number of labels for each graph, i.e. number of prediction tasks Number of labels for each graph, i.e. number of prediction tasks
Raises Raises
------ ------
UserWarning UserWarning
If the raw data is changed in the remote server by the author. If the raw data is changed in the remote server by the author.
Examples Examples
-------- --------
>>> data = QM9Dataset(label_keys=['mu', 'gap'], cutoff=5.0) >>> data = QM9Dataset(label_keys=['mu', 'gap'], cutoff=5.0)
...@@ -95,70 +97,94 @@ class QM9(QM9Dataset): ...@@ -95,70 +97,94 @@ class QM9(QM9Dataset):
>>> >>>
""" """
def __init__(self, def __init__(
label_keys, self,
edge_funcs=None, label_keys,
cutoff=5.0, edge_funcs=None,
raw_dir=None, cutoff=5.0,
force_reload=False, raw_dir=None,
verbose=False): force_reload=False,
verbose=False,
):
self.edge_funcs = edge_funcs self.edge_funcs = edge_funcs
self._keys = ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv'] self._keys = [
"mu",
super(QM9, self).__init__(label_keys=label_keys, "alpha",
cutoff=cutoff, "homo",
raw_dir=raw_dir, "lumo",
force_reload=force_reload, "gap",
verbose=verbose) "r2",
"zpve",
"U0",
"U",
"H",
"G",
"Cv",
]
super(QM9, self).__init__(
label_keys=label_keys,
cutoff=cutoff,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
)
def has_cache(self): def has_cache(self):
""" step 1, if True, goto step 5; else goto download(step 2), then step 3""" """step 1, if True, goto step 5; else goto download(step 2), then step 3"""
graph_path = f'{self.save_path}/dgl_graph.bin' graph_path = f"{self.save_path}/dgl_graph.bin"
line_graph_path = f'{self.save_path}/dgl_line_graph.bin' line_graph_path = f"{self.save_path}/dgl_line_graph.bin"
return os.path.exists(graph_path) and os.path.exists(line_graph_path) return os.path.exists(graph_path) and os.path.exists(line_graph_path)
def process(self): def process(self):
""" step 3 """ """step 3"""
npz_path = f'{self.raw_dir}/qm9_eV.npz' npz_path = f"{self.raw_dir}/qm9_eV.npz"
data_dict = np.load(npz_path, allow_pickle=True) data_dict = np.load(npz_path, allow_pickle=True)
# data_dict['N'] contains the number of atoms in each molecule, # data_dict['N'] contains the number of atoms in each molecule,
# data_dict['R'] consists of the atomic coordinates, # data_dict['R'] consists of the atomic coordinates,
# data_dict['Z'] consists of the atomic numbers. # data_dict['Z'] consists of the atomic numbers.
# Atomic properties (Z and R) of all molecules are concatenated as single tensors, # Atomic properties (Z and R) of all molecules are concatenated as single tensors,
# so you need this value to select the correct atoms for each molecule. # so you need this value to select the correct atoms for each molecule.
self.N = data_dict['N'] self.N = data_dict["N"]
self.R = data_dict['R'] self.R = data_dict["R"]
self.Z = data_dict['Z'] self.Z = data_dict["Z"]
self.N_cumsum = np.concatenate([[0], np.cumsum(self.N)]) self.N_cumsum = np.concatenate([[0], np.cumsum(self.N)])
# graph labels # graph labels
self.label_dict = {} self.label_dict = {}
for k in self._keys: for k in self._keys:
self.label_dict[k] = torch.tensor(data_dict[k], dtype=torch.float32) self.label_dict[k] = torch.tensor(data_dict[k], dtype=torch.float32)
self.label = torch.stack([self.label_dict[key] for key in self.label_keys], dim=1) self.label = torch.stack(
[self.label_dict[key] for key in self.label_keys], dim=1
)
# graphs & features # graphs & features
self.graphs, self.line_graphs = self._load_graph() self.graphs, self.line_graphs = self._load_graph()
def _load_graph(self): def _load_graph(self):
num_graphs = self.label.shape[0] num_graphs = self.label.shape[0]
graphs = [] graphs = []
line_graphs = [] line_graphs = []
for idx in trange(num_graphs): for idx in trange(num_graphs):
n_atoms = self.N[idx] n_atoms = self.N[idx]
# get all the atomic coordinates of the idx-th molecular graph # get all the atomic coordinates of the idx-th molecular graph
R = self.R[self.N_cumsum[idx]:self.N_cumsum[idx + 1]] R = self.R[self.N_cumsum[idx] : self.N_cumsum[idx + 1]]
# calculate the distance between all atoms # calculate the distance between all atoms
dist = np.linalg.norm(R[:, None, :] - R[None, :, :], axis=-1) dist = np.linalg.norm(R[:, None, :] - R[None, :, :], axis=-1)
# keep all edges that don't exceed the cutoff and delete self-loops # keep all edges that don't exceed the cutoff and delete self-loops
adj = sp.csr_matrix(dist <= self.cutoff) - sp.eye(n_atoms, dtype=np.bool) adj = sp.csr_matrix(dist <= self.cutoff) - sp.eye(
n_atoms, dtype=np.bool
)
adj = adj.tocoo() adj = adj.tocoo()
u, v = torch.tensor(adj.row), torch.tensor(adj.col) u, v = torch.tensor(adj.row), torch.tensor(adj.col)
g = dgl_graph((u, v)) g = dgl_graph((u, v))
g.ndata['R'] = torch.tensor(R, dtype=torch.float32) g.ndata["R"] = torch.tensor(R, dtype=torch.float32)
g.ndata['Z'] = torch.tensor(self.Z[self.N_cumsum[idx]:self.N_cumsum[idx + 1]], dtype=torch.long) g.ndata["Z"] = torch.tensor(
self.Z[self.N_cumsum[idx] : self.N_cumsum[idx + 1]],
dtype=torch.long,
)
# add user-defined features # add user-defined features
if self.edge_funcs is not None: if self.edge_funcs is not None:
for func in self.edge_funcs: for func in self.edge_funcs:
...@@ -167,32 +193,34 @@ class QM9(QM9Dataset): ...@@ -167,32 +193,34 @@ class QM9(QM9Dataset):
graphs.append(g) graphs.append(g)
l_g = dgl.line_graph(g, backtracking=False) l_g = dgl.line_graph(g, backtracking=False)
line_graphs.append(l_g) line_graphs.append(l_g)
return graphs, line_graphs return graphs, line_graphs
def save(self): def save(self):
""" step 4 """ """step 4"""
graph_path = f'{self.save_path}/dgl_graph.bin' graph_path = f"{self.save_path}/dgl_graph.bin"
line_graph_path = f'{self.save_path}/dgl_line_graph.bin' line_graph_path = f"{self.save_path}/dgl_line_graph.bin"
save_graphs(str(graph_path), self.graphs, self.label_dict) save_graphs(str(graph_path), self.graphs, self.label_dict)
save_graphs(str(line_graph_path), self.line_graphs) save_graphs(str(line_graph_path), self.line_graphs)
def load(self): def load(self):
""" step 5 """ """step 5"""
graph_path = f'{self.save_path}/dgl_graph.bin' graph_path = f"{self.save_path}/dgl_graph.bin"
line_graph_path = f'{self.save_path}/dgl_line_graph.bin' line_graph_path = f"{self.save_path}/dgl_line_graph.bin"
self.graphs, label_dict = load_graphs(graph_path) self.graphs, label_dict = load_graphs(graph_path)
self.line_graphs, _ = load_graphs(line_graph_path) self.line_graphs, _ = load_graphs(line_graph_path)
self.label = torch.stack([label_dict[key] for key in self.label_keys], dim=1) self.label = torch.stack(
[label_dict[key] for key in self.label_keys], dim=1
)
def __getitem__(self, idx): def __getitem__(self, idx):
r""" Get graph and label by index r"""Get graph and label by index
Parameters Parameters
---------- ----------
idx : int idx : int
Item index Item index
Returns Returns
------- -------
dgl.DGLGraph dgl.DGLGraph
......
import os import os
import ssl import ssl
from six.moves import urllib
import torch
import numpy as np import numpy as np
import torch
from six.moves import urllib
from torch.utils.data import DataLoader, Dataset
import dgl import dgl
from torch.utils.data import Dataset, DataLoader
def download_file(dataset): def download_file(dataset):
print("Start Downloading data: {}".format(dataset)) print("Start Downloading data: {}".format(dataset))
url = "https://s3.us-west-2.amazonaws.com/dgl-data/dataset/{}".format( url = "https://s3.us-west-2.amazonaws.com/dgl-data/dataset/{}".format(
dataset) dataset
)
print("Start Downloading File....") print("Start Downloading File....")
context = ssl._create_unverified_context() context = ssl._create_unverified_context()
data = urllib.request.urlopen(url, context=context) data = urllib.request.urlopen(url, context=context)
...@@ -20,13 +23,13 @@ def download_file(dataset): ...@@ -20,13 +23,13 @@ def download_file(dataset):
class SnapShotDataset(Dataset): class SnapShotDataset(Dataset):
def __init__(self, path, npz_file): def __init__(self, path, npz_file):
if not os.path.exists(path+'/'+npz_file): if not os.path.exists(path + "/" + npz_file):
if not os.path.exists(path): if not os.path.exists(path):
os.mkdir(path) os.mkdir(path)
download_file(npz_file) download_file(npz_file)
zipfile = np.load(path+'/'+npz_file) zipfile = np.load(path + "/" + npz_file)
self.x = zipfile['x'] self.x = zipfile["x"]
self.y = zipfile['y'] self.y = zipfile["y"]
def __len__(self): def __len__(self):
return len(self.x) return len(self.x)
...@@ -39,54 +42,52 @@ class SnapShotDataset(Dataset): ...@@ -39,54 +42,52 @@ class SnapShotDataset(Dataset):
def METR_LAGraphDataset(): def METR_LAGraphDataset():
if not os.path.exists('data/graph_la.bin'): if not os.path.exists("data/graph_la.bin"):
if not os.path.exists('data'): if not os.path.exists("data"):
os.mkdir('data') os.mkdir("data")
download_file('graph_la.bin') download_file("graph_la.bin")
g, _ = dgl.load_graphs('data/graph_la.bin') g, _ = dgl.load_graphs("data/graph_la.bin")
return g[0] return g[0]
class METR_LATrainDataset(SnapShotDataset): class METR_LATrainDataset(SnapShotDataset):
def __init__(self): def __init__(self):
super(METR_LATrainDataset, self).__init__('data', 'metr_la_train.npz') super(METR_LATrainDataset, self).__init__("data", "metr_la_train.npz")
self.mean = self.x[..., 0].mean() self.mean = self.x[..., 0].mean()
self.std = self.x[..., 0].std() self.std = self.x[..., 0].std()
class METR_LATestDataset(SnapShotDataset): class METR_LATestDataset(SnapShotDataset):
def __init__(self): def __init__(self):
super(METR_LATestDataset, self).__init__('data', 'metr_la_test.npz') super(METR_LATestDataset, self).__init__("data", "metr_la_test.npz")
class METR_LAValidDataset(SnapShotDataset): class METR_LAValidDataset(SnapShotDataset):
def __init__(self): def __init__(self):
super(METR_LAValidDataset, self).__init__('data', 'metr_la_valid.npz') super(METR_LAValidDataset, self).__init__("data", "metr_la_valid.npz")
def PEMS_BAYGraphDataset(): def PEMS_BAYGraphDataset():
if not os.path.exists('data/graph_bay.bin'): if not os.path.exists("data/graph_bay.bin"):
if not os.path.exists('data'): if not os.path.exists("data"):
os.mkdir('data') os.mkdir("data")
download_file('graph_bay.bin') download_file("graph_bay.bin")
g, _ = dgl.load_graphs('data/graph_bay.bin') g, _ = dgl.load_graphs("data/graph_bay.bin")
return g[0] return g[0]
class PEMS_BAYTrainDataset(SnapShotDataset): class PEMS_BAYTrainDataset(SnapShotDataset):
def __init__(self): def __init__(self):
super(PEMS_BAYTrainDataset, self).__init__( super(PEMS_BAYTrainDataset, self).__init__("data", "pems_bay_train.npz")
'data', 'pems_bay_train.npz')
self.mean = self.x[..., 0].mean() self.mean = self.x[..., 0].mean()
self.std = self.x[..., 0].std() self.std = self.x[..., 0].std()
class PEMS_BAYTestDataset(SnapShotDataset): class PEMS_BAYTestDataset(SnapShotDataset):
def __init__(self): def __init__(self):
super(PEMS_BAYTestDataset, self).__init__('data', 'pems_bay_test.npz') super(PEMS_BAYTestDataset, self).__init__("data", "pems_bay_test.npz")
class PEMS_BAYValidDataset(SnapShotDataset): class PEMS_BAYValidDataset(SnapShotDataset):
def __init__(self): def __init__(self):
super(PEMS_BAYValidDataset, self).__init__( super(PEMS_BAYValidDataset, self).__init__("data", "pems_bay_valid.npz")
'data', 'pems_bay_valid.npz')
...@@ -2,13 +2,14 @@ import numpy as np ...@@ -2,13 +2,14 @@ import numpy as np
import scipy.sparse as sparse import scipy.sparse as sparse
import torch import torch
import torch.nn as nn import torch.nn as nn
import dgl import dgl
from dgl.base import DGLError
import dgl.function as fn import dgl.function as fn
from dgl.base import DGLError
class DiffConv(nn.Module): class DiffConv(nn.Module):
'''DiffConv is the implementation of diffusion convolution from paper DCRNN """DiffConv is the implementation of diffusion convolution from paper DCRNN
It will compute multiple diffusion matrix and perform multiple diffusion conv on it, It will compute multiple diffusion matrix and perform multiple diffusion conv on it,
this layer can be used for traffic prediction, pedamic model. this layer can be used for traffic prediction, pedamic model.
Parameter Parameter
...@@ -25,20 +26,23 @@ class DiffConv(nn.Module): ...@@ -25,20 +26,23 @@ class DiffConv(nn.Module):
dir : str [both/in/out] dir : str [both/in/out]
direction of diffusion convolution direction of diffusion convolution
From paper default both direction From paper default both direction
''' """
def __init__(self, in_feats, out_feats, k, in_graph_list, out_graph_list, dir='both'): def __init__(
self, in_feats, out_feats, k, in_graph_list, out_graph_list, dir="both"
):
super(DiffConv, self).__init__() super(DiffConv, self).__init__()
self.in_feats = in_feats self.in_feats = in_feats
self.out_feats = out_feats self.out_feats = out_feats
self.k = k self.k = k
self.dir = dir self.dir = dir
self.num_graphs = self.k-1 if self.dir == 'both' else 2*self.k-2 self.num_graphs = self.k - 1 if self.dir == "both" else 2 * self.k - 2
self.project_fcs = nn.ModuleList() self.project_fcs = nn.ModuleList()
for i in range(self.num_graphs): for i in range(self.num_graphs):
self.project_fcs.append( self.project_fcs.append(
nn.Linear(self.in_feats, self.out_feats, bias=False)) nn.Linear(self.in_feats, self.out_feats, bias=False)
self.merger = nn.Parameter(torch.randn(self.num_graphs+1)) )
self.merger = nn.Parameter(torch.randn(self.num_graphs + 1))
self.in_graph_list = in_graph_list self.in_graph_list = in_graph_list
self.out_graph_list = out_graph_list self.out_graph_list = out_graph_list
...@@ -48,62 +52,68 @@ class DiffConv(nn.Module): ...@@ -48,62 +52,68 @@ class DiffConv(nn.Module):
out_graph_list = [] out_graph_list = []
in_graph_list = [] in_graph_list = []
wadj, ind, outd = DiffConv.get_weight_matrix(g) wadj, ind, outd = DiffConv.get_weight_matrix(g)
adj = sparse.coo_matrix(wadj/outd.cpu().numpy()) adj = sparse.coo_matrix(wadj / outd.cpu().numpy())
outg = dgl.from_scipy(adj, eweight_name='weight').to(device) outg = dgl.from_scipy(adj, eweight_name="weight").to(device)
outg.edata['weight'] = outg.edata['weight'].float().to(device) outg.edata["weight"] = outg.edata["weight"].float().to(device)
out_graph_list.append(outg) out_graph_list.append(outg)
for i in range(k-1): for i in range(k - 1):
out_graph_list.append(DiffConv.diffuse( out_graph_list.append(
out_graph_list[-1], wadj, outd)) DiffConv.diffuse(out_graph_list[-1], wadj, outd)
adj = sparse.coo_matrix(wadj.T/ind.cpu().numpy()) )
ing = dgl.from_scipy(adj, eweight_name='weight').to(device) adj = sparse.coo_matrix(wadj.T / ind.cpu().numpy())
ing.edata['weight'] = ing.edata['weight'].float().to(device) ing = dgl.from_scipy(adj, eweight_name="weight").to(device)
ing.edata["weight"] = ing.edata["weight"].float().to(device)
in_graph_list.append(ing) in_graph_list.append(ing)
for i in range(k-1): for i in range(k - 1):
in_graph_list.append(DiffConv.diffuse( in_graph_list.append(
in_graph_list[-1], wadj.T, ind)) DiffConv.diffuse(in_graph_list[-1], wadj.T, ind)
)
return out_graph_list, in_graph_list return out_graph_list, in_graph_list
@staticmethod @staticmethod
def get_weight_matrix(g): def get_weight_matrix(g):
adj = g.adj(scipy_fmt='coo') adj = g.adj(scipy_fmt="coo")
ind = g.in_degrees() ind = g.in_degrees()
outd = g.out_degrees() outd = g.out_degrees()
weight = g.edata['weight'] weight = g.edata["weight"]
adj.data = weight.cpu().numpy() adj.data = weight.cpu().numpy()
return adj, ind, outd return adj, ind, outd
@staticmethod @staticmethod
def diffuse(progress_g, weighted_adj, degree): def diffuse(progress_g, weighted_adj, degree):
device = progress_g.device device = progress_g.device
progress_adj = progress_g.adj(scipy_fmt='coo') progress_adj = progress_g.adj(scipy_fmt="coo")
progress_adj.data = progress_g.edata['weight'].cpu().numpy() progress_adj.data = progress_g.edata["weight"].cpu().numpy()
ret_adj = sparse.coo_matrix(progress_adj@( ret_adj = sparse.coo_matrix(
weighted_adj/degree.cpu().numpy())) progress_adj @ (weighted_adj / degree.cpu().numpy())
ret_graph = dgl.from_scipy(ret_adj, eweight_name='weight').to(device) )
ret_graph.edata['weight'] = ret_graph.edata['weight'].float().to( ret_graph = dgl.from_scipy(ret_adj, eweight_name="weight").to(device)
device) ret_graph.edata["weight"] = ret_graph.edata["weight"].float().to(device)
return ret_graph return ret_graph
def forward(self, g, x): def forward(self, g, x):
feat_list = [] feat_list = []
if self.dir == 'both': if self.dir == "both":
graph_list = self.in_graph_list+self.out_graph_list graph_list = self.in_graph_list + self.out_graph_list
elif self.dir == 'in': elif self.dir == "in":
graph_list = self.in_graph_list graph_list = self.in_graph_list
elif self.dir == 'out': elif self.dir == "out":
graph_list = self.out_graph_list graph_list = self.out_graph_list
for i in range(self.num_graphs): for i in range(self.num_graphs):
g = graph_list[i] g = graph_list[i]
with g.local_scope(): with g.local_scope():
g.ndata['n'] = self.project_fcs[i](x) g.ndata["n"] = self.project_fcs[i](x)
g.update_all(fn.u_mul_e('n', 'weight', 'e'), g.update_all(
fn.sum('e', 'feat')) fn.u_mul_e("n", "weight", "e"), fn.sum("e", "feat")
feat_list.append(g.ndata['feat']) )
feat_list.append(g.ndata["feat"])
# Each feat has shape [N,q_feats] # Each feat has shape [N,q_feats]
feat_list.append(self.project_fcs[-1](x)) feat_list.append(self.project_fcs[-1](x))
feat_list = torch.cat(feat_list).view( feat_list = torch.cat(feat_list).view(
len(feat_list), -1, self.out_feats) len(feat_list), -1, self.out_feats
ret = (self.merger*feat_list.permute(1, 2, 0)).permute(2, 0, 1).mean(0) )
ret = (
(self.merger * feat_list.permute(1, 2, 0)).permute(2, 0, 1).mean(0)
)
return ret return ret
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import dgl import dgl
import dgl.function as fn
import dgl.nn as dglnn import dgl.nn as dglnn
from dgl.base import DGLError from dgl.base import DGLError
import dgl.function as fn
from dgl.nn.functional import edge_softmax from dgl.nn.functional import edge_softmax
class WeightedGATConv(dglnn.GATConv): class WeightedGATConv(dglnn.GATConv):
''' """
This model inherit from dgl GATConv for traffic prediction task, This model inherit from dgl GATConv for traffic prediction task,
it add edge weight when aggregating the node feature. it add edge weight when aggregating the node feature.
''' """
def forward(self, graph, feat, get_attention=False): def forward(self, graph, feat, get_attention=False):
with graph.local_scope(): with graph.local_scope():
if not self._allow_zero_in_degree: if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any(): if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, ' raise DGLError(
'output for those nodes will be invalid. ' "There are 0-in-degree nodes in the graph, "
'This is harmful for some applications, ' "output for those nodes will be invalid. "
'causing silent performance regression. ' "This is harmful for some applications, "
'Adding self-loop on the input graph by ' "causing silent performance regression. "
'calling `g = dgl.add_self_loop(g)` will resolve ' "Adding self-loop on the input graph by "
'the issue. Setting ``allow_zero_in_degree`` ' "calling `g = dgl.add_self_loop(g)` will resolve "
'to be `True` when constructing this module will ' "the issue. Setting ``allow_zero_in_degree`` "
'suppress the check and let the code run.') "to be `True` when constructing this module will "
"suppress the check and let the code run."
)
if isinstance(feat, tuple): if isinstance(feat, tuple):
h_src = self.feat_drop(feat[0]) h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1]) h_dst = self.feat_drop(feat[1])
if not hasattr(self, 'fc_src'): if not hasattr(self, "fc_src"):
feat_src = self.fc( feat_src = self.fc(h_src).view(
h_src).view(-1, self._num_heads, self._out_feats) -1, self._num_heads, self._out_feats
feat_dst = self.fc( )
h_dst).view(-1, self._num_heads, self._out_feats) feat_dst = self.fc(h_dst).view(
-1, self._num_heads, self._out_feats
)
else: else:
feat_src = self.fc_src( feat_src = self.fc_src(h_src).view(
h_src).view(-1, self._num_heads, self._out_feats) -1, self._num_heads, self._out_feats
feat_dst = self.fc_dst( )
h_dst).view(-1, self._num_heads, self._out_feats) feat_dst = self.fc_dst(h_dst).view(
-1, self._num_heads, self._out_feats
)
else: else:
h_src = h_dst = self.feat_drop(feat) h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src).view( feat_src = feat_dst = self.fc(h_src).view(
-1, self._num_heads, self._out_feats) -1, self._num_heads, self._out_feats
)
if graph.is_block: if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()] feat_dst = feat_src[: graph.number_of_dst_nodes()]
# NOTE: GAT paper uses "first concatenation then linear projection" # NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then # to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent: # addition", the two approaches are mathematically equivalent:
...@@ -59,37 +67,38 @@ class WeightedGATConv(dglnn.GATConv): ...@@ -59,37 +67,38 @@ class WeightedGATConv(dglnn.GATConv):
# which further speeds up computation and saves memory footprint. # which further speeds up computation and saves memory footprint.
el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1) el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1) er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
graph.srcdata.update({'ft': feat_src, 'el': el}) graph.srcdata.update({"ft": feat_src, "el": el})
graph.dstdata.update({'er': er}) graph.dstdata.update({"er": er})
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively. # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v('el', 'er', 'e')) graph.apply_edges(fn.u_add_v("el", "er", "e"))
e = self.leaky_relu(graph.edata.pop('e')) e = self.leaky_relu(graph.edata.pop("e"))
# compute softmax # compute softmax
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
# compute weighted attention # compute weighted attention
graph.edata['a'] = (graph.edata['a'].permute( graph.edata["a"] = (
1, 2, 0)*graph.edata['weight']).permute(2, 0, 1) graph.edata["a"].permute(1, 2, 0) * graph.edata["weight"]
).permute(2, 0, 1)
# message passing # message passing
graph.update_all(fn.u_mul_e('ft', 'a', 'm'), graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft"))
fn.sum('m', 'ft')) rst = graph.dstdata["ft"]
rst = graph.dstdata['ft']
# residual # residual
if self.res_fc is not None: if self.res_fc is not None:
resval = self.res_fc(h_dst).view( resval = self.res_fc(h_dst).view(
h_dst.shape[0], -1, self._out_feats) h_dst.shape[0], -1, self._out_feats
)
rst = rst + resval rst = rst + resval
# activation # activation
if self.activation: if self.activation:
rst = self.activation(rst) rst = self.activation(rst)
if get_attention: if get_attention:
return rst, graph.edata['a'] return rst, graph.edata["a"]
else: else:
return rst return rst
class GatedGAT(nn.Module): class GatedGAT(nn.Module):
'''Gated Graph Attention module, it is a general purpose """Gated Graph Attention module, it is a general purpose
graph attention module proposed in paper GaAN. The paper use graph attention module proposed in paper GaAN. The paper use
it for traffic prediction task it for traffic prediction task
Parameter Parameter
...@@ -105,7 +114,7 @@ class GatedGAT(nn.Module): ...@@ -105,7 +114,7 @@ class GatedGAT(nn.Module):
num_heads : int num_heads : int
number of head for multihead attention number of head for multihead attention
''' """
def __init__(self, in_feats, out_feats, map_feats, num_heads): def __init__(self, in_feats, out_feats, map_feats, num_heads):
super(GatedGAT, self).__init__() super(GatedGAT, self).__init__()
...@@ -113,28 +122,32 @@ class GatedGAT(nn.Module): ...@@ -113,28 +122,32 @@ class GatedGAT(nn.Module):
self.out_feats = out_feats self.out_feats = out_feats
self.map_feats = map_feats self.map_feats = map_feats
self.num_heads = num_heads self.num_heads = num_heads
self.gatlayer = WeightedGATConv(self.in_feats, self.gatlayer = WeightedGATConv(
self.out_feats, self.in_feats, self.out_feats, self.num_heads
self.num_heads) )
self.gate_fn = nn.Linear( self.gate_fn = nn.Linear(
2*self.in_feats+self.map_feats, self.num_heads) 2 * self.in_feats + self.map_feats, self.num_heads
)
self.gate_m = nn.Linear(self.in_feats, self.map_feats) self.gate_m = nn.Linear(self.in_feats, self.map_feats)
self.merger_layer = nn.Linear( self.merger_layer = nn.Linear(
self.in_feats+self.out_feats, self.out_feats) self.in_feats + self.out_feats, self.out_feats
)
def forward(self, g, x): def forward(self, g, x):
with g.local_scope(): with g.local_scope():
g.ndata['x'] = x g.ndata["x"] = x
g.ndata['z'] = self.gate_m(x) g.ndata["z"] = self.gate_m(x)
g.update_all(fn.copy_u('x', 'x'), fn.mean('x', 'mean_z')) g.update_all(fn.copy_u("x", "x"), fn.mean("x", "mean_z"))
g.update_all(fn.copy_u('z', 'z'), fn.max('z', 'max_z')) g.update_all(fn.copy_u("z", "z"), fn.max("z", "max_z"))
nft = torch.cat([g.ndata['x'], g.ndata['max_z'], nft = torch.cat(
g.ndata['mean_z']], dim=1) [g.ndata["x"], g.ndata["max_z"], g.ndata["mean_z"]], dim=1
)
gate = self.gate_fn(nft).sigmoid() gate = self.gate_fn(nft).sigmoid()
attn_out = self.gatlayer(g, x) attn_out = self.gatlayer(g, x)
node_num = g.num_nodes() node_num = g.num_nodes()
gated_out = ((gate.view(-1)*attn_out.view(-1, self.out_feats).T).T).view( gated_out = (
node_num, self.num_heads, self.out_feats) (gate.view(-1) * attn_out.view(-1, self.out_feats).T).T
).view(node_num, self.num_heads, self.out_feats)
gated_out = gated_out.mean(1) gated_out = gated_out.mean(1)
merge = self.merger_layer(torch.cat([x, gated_out], dim=1)) merge = self.merger_layer(torch.cat([x, gated_out], dim=1))
return merge return merge
...@@ -2,15 +2,16 @@ import numpy as np ...@@ -2,15 +2,16 @@ import numpy as np
import scipy.sparse as sparse import scipy.sparse as sparse
import torch import torch
import torch.nn as nn import torch.nn as nn
import dgl import dgl
import dgl.function as fn
import dgl.nn as dglnn import dgl.nn as dglnn
from dgl.base import DGLError from dgl.base import DGLError
import dgl.function as fn
from dgl.nn.functional import edge_softmax from dgl.nn.functional import edge_softmax
class GraphGRUCell(nn.Module): class GraphGRUCell(nn.Module):
'''Graph GRU unit which can use any message passing """Graph GRU unit which can use any message passing
net to replace the linear layer in the original GRU net to replace the linear layer in the original GRU
Parameter Parameter
========== ==========
...@@ -22,7 +23,7 @@ class GraphGRUCell(nn.Module): ...@@ -22,7 +23,7 @@ class GraphGRUCell(nn.Module):
net : torch.nn.Module net : torch.nn.Module
message passing network message passing network
''' """
def __init__(self, in_feats, out_feats, net): def __init__(self, in_feats, out_feats, net):
super(GraphGRUCell, self).__init__() super(GraphGRUCell, self).__init__()
...@@ -30,28 +31,27 @@ class GraphGRUCell(nn.Module): ...@@ -30,28 +31,27 @@ class GraphGRUCell(nn.Module):
self.out_feats = out_feats self.out_feats = out_feats
self.dir = dir self.dir = dir
# net can be any GNN model # net can be any GNN model
self.r_net = net(in_feats+out_feats, out_feats) self.r_net = net(in_feats + out_feats, out_feats)
self.u_net = net(in_feats+out_feats, out_feats) self.u_net = net(in_feats + out_feats, out_feats)
self.c_net = net(in_feats+out_feats, out_feats) self.c_net = net(in_feats + out_feats, out_feats)
# Manually add bias Bias # Manually add bias Bias
self.r_bias = nn.Parameter(torch.rand(out_feats)) self.r_bias = nn.Parameter(torch.rand(out_feats))
self.u_bias = nn.Parameter(torch.rand(out_feats)) self.u_bias = nn.Parameter(torch.rand(out_feats))
self.c_bias = nn.Parameter(torch.rand(out_feats)) self.c_bias = nn.Parameter(torch.rand(out_feats))
def forward(self, g, x, h): def forward(self, g, x, h):
r = torch.sigmoid(self.r_net( r = torch.sigmoid(self.r_net(g, torch.cat([x, h], dim=1)) + self.r_bias)
g, torch.cat([x, h], dim=1)) + self.r_bias) u = torch.sigmoid(self.u_net(g, torch.cat([x, h], dim=1)) + self.u_bias)
u = torch.sigmoid(self.u_net( h_ = r * h
g, torch.cat([x, h], dim=1)) + self.u_bias) c = torch.sigmoid(
h_ = r*h self.c_net(g, torch.cat([x, h_], dim=1)) + self.c_bias
c = torch.sigmoid(self.c_net( )
g, torch.cat([x, h_], dim=1)) + self.c_bias) new_h = u * h + (1 - u) * c
new_h = u*h + (1-u)*c
return new_h return new_h
class StackedEncoder(nn.Module): class StackedEncoder(nn.Module):
'''One step encoder unit for hidden representation generation """One step encoder unit for hidden representation generation
it can stack multiple vertical layers to increase the depth. it can stack multiple vertical layers to increase the depth.
Parameter Parameter
...@@ -67,7 +67,7 @@ class StackedEncoder(nn.Module): ...@@ -67,7 +67,7 @@ class StackedEncoder(nn.Module):
net : torch.nn.Module net : torch.nn.Module
message passing network for graph computation message passing network for graph computation
''' """
def __init__(self, in_feats, out_feats, num_layers, net): def __init__(self, in_feats, out_feats, num_layers, net):
super(StackedEncoder, self).__init__() super(StackedEncoder, self).__init__()
...@@ -78,11 +78,13 @@ class StackedEncoder(nn.Module): ...@@ -78,11 +78,13 @@ class StackedEncoder(nn.Module):
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
if self.num_layers <= 0: if self.num_layers <= 0:
raise DGLError("Layer Number must be greater than 0! ") raise DGLError("Layer Number must be greater than 0! ")
self.layers.append(GraphGRUCell( self.layers.append(
self.in_feats, self.out_feats, self.net)) GraphGRUCell(self.in_feats, self.out_feats, self.net)
for _ in range(self.num_layers-1): )
self.layers.append(GraphGRUCell( for _ in range(self.num_layers - 1):
self.out_feats, self.out_feats, self.net)) self.layers.append(
GraphGRUCell(self.out_feats, self.out_feats, self.net)
)
# hidden_states should be a list which for different layer # hidden_states should be a list which for different layer
def forward(self, g, x, hidden_states): def forward(self, g, x, hidden_states):
...@@ -94,7 +96,7 @@ class StackedEncoder(nn.Module): ...@@ -94,7 +96,7 @@ class StackedEncoder(nn.Module):
class StackedDecoder(nn.Module): class StackedDecoder(nn.Module):
'''One step decoder unit for hidden representation generation """One step decoder unit for hidden representation generation
it can stack multiple vertical layers to increase the depth. it can stack multiple vertical layers to increase the depth.
Parameter Parameter
...@@ -113,7 +115,7 @@ class StackedDecoder(nn.Module): ...@@ -113,7 +115,7 @@ class StackedDecoder(nn.Module):
net : torch.nn.Module net : torch.nn.Module
message passing network for graph computation message passing network for graph computation
''' """
def __init__(self, in_feats, hid_feats, out_feats, num_layers, net): def __init__(self, in_feats, hid_feats, out_feats, num_layers, net):
super(StackedDecoder, self).__init__() super(StackedDecoder, self).__init__()
...@@ -127,9 +129,10 @@ class StackedDecoder(nn.Module): ...@@ -127,9 +129,10 @@ class StackedDecoder(nn.Module):
if self.num_layers <= 0: if self.num_layers <= 0:
raise DGLError("Layer Number must be greater than 0!") raise DGLError("Layer Number must be greater than 0!")
self.layers.append(GraphGRUCell(self.in_feats, self.hid_feats, net)) self.layers.append(GraphGRUCell(self.in_feats, self.hid_feats, net))
for _ in range(self.num_layers-1): for _ in range(self.num_layers - 1):
self.layers.append(GraphGRUCell( self.layers.append(
self.hid_feats, self.hid_feats, net)) GraphGRUCell(self.hid_feats, self.hid_feats, net)
)
def forward(self, g, x, hidden_states): def forward(self, g, x, hidden_states):
hiddens = [] hiddens = []
...@@ -141,7 +144,7 @@ class StackedDecoder(nn.Module): ...@@ -141,7 +144,7 @@ class StackedDecoder(nn.Module):
class GraphRNN(nn.Module): class GraphRNN(nn.Module):
'''Graph Sequence to sequence prediction framework """Graph Sequence to sequence prediction framework
Support multiple backbone GNN. Mainly used for traffic prediction. Support multiple backbone GNN. Mainly used for traffic prediction.
Parameter Parameter
...@@ -163,15 +166,11 @@ class GraphRNN(nn.Module): ...@@ -163,15 +166,11 @@ class GraphRNN(nn.Module):
decay_steps : int decay_steps : int
number of steps for the teacher forcing probability to decay number of steps for the teacher forcing probability to decay
''' """
def __init__(self, def __init__(
in_feats, self, in_feats, out_feats, seq_len, num_layers, net, decay_steps
out_feats, ):
seq_len,
num_layers,
net,
decay_steps):
super(GraphRNN, self).__init__() super(GraphRNN, self).__init__()
self.in_feats = in_feats self.in_feats = in_feats
self.out_feats = out_feats self.out_feats = out_feats
...@@ -180,24 +179,30 @@ class GraphRNN(nn.Module): ...@@ -180,24 +179,30 @@ class GraphRNN(nn.Module):
self.net = net self.net = net
self.decay_steps = decay_steps self.decay_steps = decay_steps
self.encoder = StackedEncoder(self.in_feats, self.encoder = StackedEncoder(
self.out_feats, self.in_feats, self.out_feats, self.num_layers, self.net
self.num_layers, )
self.net)
self.decoder = StackedDecoder(
self.in_feats,
self.out_feats,
self.in_feats,
self.num_layers,
self.net,
)
self.decoder = StackedDecoder(self.in_feats,
self.out_feats,
self.in_feats,
self.num_layers,
self.net)
# Threshold For Teacher Forcing # Threshold For Teacher Forcing
def compute_thresh(self, batch_cnt): def compute_thresh(self, batch_cnt):
return self.decay_steps/(self.decay_steps + np.exp(batch_cnt / self.decay_steps)) return self.decay_steps / (
self.decay_steps + np.exp(batch_cnt / self.decay_steps)
)
def encode(self, g, inputs, device): def encode(self, g, inputs, device):
hidden_states = [torch.zeros(g.num_nodes(), self.out_feats).to( hidden_states = [
device) for _ in range(self.num_layers)] torch.zeros(g.num_nodes(), self.out_feats).to(device)
for _ in range(self.num_layers)
]
for i in range(self.seq_len): for i in range(self.seq_len):
_, hidden_states = self.encoder(g, inputs[i], hidden_states) _, hidden_states = self.encoder(g, inputs[i], hidden_states)
...@@ -207,9 +212,13 @@ class GraphRNN(nn.Module): ...@@ -207,9 +212,13 @@ class GraphRNN(nn.Module):
outputs = [] outputs = []
inputs = torch.zeros(g.num_nodes(), self.in_feats).to(device) inputs = torch.zeros(g.num_nodes(), self.in_feats).to(device)
for i in range(self.seq_len): for i in range(self.seq_len):
if np.random.random() < self.compute_thresh(batch_cnt) and self.training: if (
np.random.random() < self.compute_thresh(batch_cnt)
and self.training
):
inputs, hidden_states = self.decoder( inputs, hidden_states = self.decoder(
g, teacher_states[i], hidden_states) g, teacher_states[i], hidden_states
)
else: else:
inputs, hidden_states = self.decoder(g, inputs, hidden_states) inputs, hidden_states = self.decoder(g, inputs, hidden_states)
outputs.append(inputs) outputs.append(inputs)
......
from functools import partial
import argparse import argparse
from functools import partial
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader from dataloading import (METR_LAGraphDataset, METR_LATestDataset,
import dgl METR_LATrainDataset, METR_LAValidDataset,
from model import GraphRNN PEMS_BAYGraphDataset, PEMS_BAYTestDataset,
PEMS_BAYTrainDataset, PEMS_BAYValidDataset)
from dcrnn import DiffConv from dcrnn import DiffConv
from gaan import GatedGAT from gaan import GatedGAT
from dataloading import METR_LAGraphDataset, METR_LATrainDataset,\ from model import GraphRNN
METR_LATestDataset, METR_LAValidDataset,\ from torch.utils.data import DataLoader
PEMS_BAYGraphDataset, PEMS_BAYTrainDataset,\ from utils import NormalizationLayer, get_learning_rate, masked_mae_loss
PEMS_BAYValidDataset, PEMS_BAYTestDataset
from utils import NormalizationLayer, masked_mae_loss, get_learning_rate import dgl
batch_cnt = [0] batch_cnt = [0]
def train(model, graph, dataloader, optimizer, scheduler, normalizer, loss_fn, device, args): def train(
model,
graph,
dataloader,
optimizer,
scheduler,
normalizer,
loss_fn,
device,
args,
):
total_loss = [] total_loss = []
graph = graph.to(device) graph = graph.to(device)
model.train() model.train()
...@@ -27,29 +39,37 @@ def train(model, graph, dataloader, optimizer, scheduler, normalizer, loss_fn, d ...@@ -27,29 +39,37 @@ def train(model, graph, dataloader, optimizer, scheduler, normalizer, loss_fn, d
# Padding: Since the diffusion graph is precmputed we need to pad the batch so that # Padding: Since the diffusion graph is precmputed we need to pad the batch so that
# each batch have same batch size # each batch have same batch size
if x.shape[0] != batch_size: if x.shape[0] != batch_size:
x_buff = torch.zeros( x_buff = torch.zeros(batch_size, x.shape[1], x.shape[2], x.shape[3])
batch_size, x.shape[1], x.shape[2], x.shape[3]) y_buff = torch.zeros(batch_size, x.shape[1], x.shape[2], x.shape[3])
y_buff = torch.zeros( x_buff[: x.shape[0], :, :, :] = x
batch_size, x.shape[1], x.shape[2], x.shape[3]) x_buff[x.shape[0] :, :, :, :] = x[-1].repeat(
x_buff[:x.shape[0], :, :, :] = x batch_size - x.shape[0], 1, 1, 1
x_buff[x.shape[0]:, :, :, )
:] = x[-1].repeat(batch_size-x.shape[0], 1, 1, 1) y_buff[: x.shape[0], :, :, :] = y
y_buff[:x.shape[0], :, :, :] = y y_buff[x.shape[0] :, :, :, :] = y[-1].repeat(
y_buff[x.shape[0]:, :, :, batch_size - x.shape[0], 1, 1, 1
:] = y[-1].repeat(batch_size-x.shape[0], 1, 1, 1) )
x = x_buff x = x_buff
y = y_buff y = y_buff
# Permute the dimension for shaping # Permute the dimension for shaping
x = x.permute(1, 0, 2, 3) x = x.permute(1, 0, 2, 3)
y = y.permute(1, 0, 2, 3) y = y.permute(1, 0, 2, 3)
x_norm = normalizer.normalize(x).reshape( x_norm = (
x.shape[0], -1, x.shape[3]).float().to(device) normalizer.normalize(x)
y_norm = normalizer.normalize(y).reshape( .reshape(x.shape[0], -1, x.shape[3])
x.shape[0], -1, x.shape[3]).float().to(device) .float()
.to(device)
)
y_norm = (
normalizer.normalize(y)
.reshape(x.shape[0], -1, x.shape[3])
.float()
.to(device)
)
y = y.reshape(y.shape[0], -1, y.shape[3]).float().to(device) y = y.reshape(y.shape[0], -1, y.shape[3]).float().to(device)
batch_graph = dgl.batch([graph]*batch_size) batch_graph = dgl.batch([graph] * batch_size)
output = model(batch_graph, x_norm, y_norm, batch_cnt[0], device) output = model(batch_graph, x_norm, y_norm, batch_cnt[0], device)
# Denormalization for loss compute # Denormalization for loss compute
y_pred = normalizer.denormalize(output) y_pred = normalizer.denormalize(output)
...@@ -74,29 +94,37 @@ def eval(model, graph, dataloader, normalizer, loss_fn, device, args): ...@@ -74,29 +94,37 @@ def eval(model, graph, dataloader, normalizer, loss_fn, device, args):
# Padding: Since the diffusion graph is precmputed we need to pad the batch so that # Padding: Since the diffusion graph is precmputed we need to pad the batch so that
# each batch have same batch size # each batch have same batch size
if x.shape[0] != batch_size: if x.shape[0] != batch_size:
x_buff = torch.zeros( x_buff = torch.zeros(batch_size, x.shape[1], x.shape[2], x.shape[3])
batch_size, x.shape[1], x.shape[2], x.shape[3]) y_buff = torch.zeros(batch_size, x.shape[1], x.shape[2], x.shape[3])
y_buff = torch.zeros( x_buff[: x.shape[0], :, :, :] = x
batch_size, x.shape[1], x.shape[2], x.shape[3]) x_buff[x.shape[0] :, :, :, :] = x[-1].repeat(
x_buff[:x.shape[0], :, :, :] = x batch_size - x.shape[0], 1, 1, 1
x_buff[x.shape[0]:, :, :, )
:] = x[-1].repeat(batch_size-x.shape[0], 1, 1, 1) y_buff[: x.shape[0], :, :, :] = y
y_buff[:x.shape[0], :, :, :] = y y_buff[x.shape[0] :, :, :, :] = y[-1].repeat(
y_buff[x.shape[0]:, :, :, batch_size - x.shape[0], 1, 1, 1
:] = y[-1].repeat(batch_size-x.shape[0], 1, 1, 1) )
x = x_buff x = x_buff
y = y_buff y = y_buff
# Permute the order of dimension # Permute the order of dimension
x = x.permute(1, 0, 2, 3) x = x.permute(1, 0, 2, 3)
y = y.permute(1, 0, 2, 3) y = y.permute(1, 0, 2, 3)
x_norm = normalizer.normalize(x).reshape( x_norm = (
x.shape[0], -1, x.shape[3]).float().to(device) normalizer.normalize(x)
y_norm = normalizer.normalize(y).reshape( .reshape(x.shape[0], -1, x.shape[3])
x.shape[0], -1, x.shape[3]).float().to(device) .float()
.to(device)
)
y_norm = (
normalizer.normalize(y)
.reshape(x.shape[0], -1, x.shape[3])
.float()
.to(device)
)
y = y.reshape(x.shape[0], -1, x.shape[3]).to(device) y = y.reshape(x.shape[0], -1, x.shape[3]).to(device)
batch_graph = dgl.batch([graph]*batch_size) batch_graph = dgl.batch([graph] * batch_size)
output = model(batch_graph, x_norm, y_norm, i, device) output = model(batch_graph, x_norm, y_norm, i, device)
y_pred = normalizer.denormalize(output) y_pred = normalizer.denormalize(output)
loss = loss_fn(y_pred, y) loss = loss_fn(y_pred, y)
...@@ -107,70 +135,124 @@ def eval(model, graph, dataloader, normalizer, loss_fn, device, args): ...@@ -107,70 +135,124 @@ def eval(model, graph, dataloader, normalizer, loss_fn, device, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# Define the arguments # Define the arguments
parser.add_argument('--batch_size', type=int, default=64, parser.add_argument(
help="Size of batch for minibatch Training") "--batch_size",
parser.add_argument('--num_workers', type=int, default=0, type=int,
help="Number of workers for parallel dataloading") default=64,
parser.add_argument('--model', type=str, default='dcrnn', help="Size of batch for minibatch Training",
help="WHich model to use DCRNN vs GaAN") )
parser.add_argument('--gpu', type=int, default=-1, parser.add_argument(
help="GPU indexm -1 for CPU training") "--num_workers",
parser.add_argument('--diffsteps', type=int, default=2, type=int,
help="Step of constructing the diffusiob matrix") default=0,
parser.add_argument('--num_heads', type=int, default=2, help="Number of workers for parallel dataloading",
help="Number of multiattention head") )
parser.add_argument('--decay_steps', type=int, default=2000, parser.add_argument(
help="Teacher forcing probability decay ratio") "--model",
parser.add_argument('--lr', type=float, default=0.01, type=str,
help="Initial learning rate") default="dcrnn",
parser.add_argument('--minimum_lr', type=float, default=2e-6, help="WHich model to use DCRNN vs GaAN",
help="Lower bound of learning rate") )
parser.add_argument('--dataset', type=str, default='LA', parser.add_argument(
help="dataset LA for METR_LA; BAY for PEMS_BAY") "--gpu", type=int, default=-1, help="GPU indexm -1 for CPU training"
parser.add_argument('--epochs', type=int, default=100, )
help="Number of epoches for training") parser.add_argument(
parser.add_argument('--max_grad_norm', type=float, default=5.0, "--diffsteps",
help="Maximum gradient norm for update parameters") type=int,
default=2,
help="Step of constructing the diffusiob matrix",
)
parser.add_argument(
"--num_heads", type=int, default=2, help="Number of multiattention head"
)
parser.add_argument(
"--decay_steps",
type=int,
default=2000,
help="Teacher forcing probability decay ratio",
)
parser.add_argument(
"--lr", type=float, default=0.01, help="Initial learning rate"
)
parser.add_argument(
"--minimum_lr",
type=float,
default=2e-6,
help="Lower bound of learning rate",
)
parser.add_argument(
"--dataset",
type=str,
default="LA",
help="dataset LA for METR_LA; BAY for PEMS_BAY",
)
parser.add_argument(
"--epochs", type=int, default=100, help="Number of epoches for training"
)
parser.add_argument(
"--max_grad_norm",
type=float,
default=5.0,
help="Maximum gradient norm for update parameters",
)
args = parser.parse_args() args = parser.parse_args()
# Load the datasets # Load the datasets
if args.dataset == 'LA': if args.dataset == "LA":
g = METR_LAGraphDataset() g = METR_LAGraphDataset()
train_data = METR_LATrainDataset() train_data = METR_LATrainDataset()
test_data = METR_LATestDataset() test_data = METR_LATestDataset()
valid_data = METR_LAValidDataset() valid_data = METR_LAValidDataset()
elif args.dataset == 'BAY': elif args.dataset == "BAY":
g = PEMS_BAYGraphDataset() g = PEMS_BAYGraphDataset()
train_data = PEMS_BAYTrainDataset() train_data = PEMS_BAYTrainDataset()
test_data = PEMS_BAYTestDataset() test_data = PEMS_BAYTestDataset()
valid_data = PEMS_BAYValidDataset() valid_data = PEMS_BAYValidDataset()
if args.gpu == -1: if args.gpu == -1:
device = torch.device('cpu') device = torch.device("cpu")
else: else:
device = torch.device('cuda:{}'.format(args.gpu)) device = torch.device("cuda:{}".format(args.gpu))
train_loader = DataLoader( train_loader = DataLoader(
train_data, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True) train_data,
batch_size=args.batch_size,
num_workers=args.num_workers,
shuffle=True,
)
valid_loader = DataLoader( valid_loader = DataLoader(
valid_data, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True) valid_data,
batch_size=args.batch_size,
num_workers=args.num_workers,
shuffle=True,
)
test_loader = DataLoader( test_loader = DataLoader(
test_data, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True) test_data,
batch_size=args.batch_size,
num_workers=args.num_workers,
shuffle=True,
)
normalizer = NormalizationLayer(train_data.mean, train_data.std) normalizer = NormalizationLayer(train_data.mean, train_data.std)
if args.model == 'dcrnn': if args.model == "dcrnn":
batch_g = dgl.batch([g]*args.batch_size).to(device) batch_g = dgl.batch([g] * args.batch_size).to(device)
out_gs, in_gs = DiffConv.attach_graph(batch_g, args.diffsteps) out_gs, in_gs = DiffConv.attach_graph(batch_g, args.diffsteps)
net = partial(DiffConv, k=args.diffsteps, net = partial(
in_graph_list=in_gs, out_graph_list=out_gs) DiffConv,
elif args.model == 'gaan': k=args.diffsteps,
in_graph_list=in_gs,
out_graph_list=out_gs,
)
elif args.model == "gaan":
net = partial(GatedGAT, map_feats=64, num_heads=args.num_heads) net = partial(GatedGAT, map_feats=64, num_heads=args.num_heads)
dcrnn = GraphRNN(in_feats=2, dcrnn = GraphRNN(
out_feats=64, in_feats=2,
seq_len=12, out_feats=64,
num_layers=2, seq_len=12,
net=net, num_layers=2,
decay_steps=args.decay_steps).to(device) net=net,
decay_steps=args.decay_steps,
).to(device)
optimizer = torch.optim.Adam(dcrnn.parameters(), lr=args.lr) optimizer = torch.optim.Adam(dcrnn.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
...@@ -178,13 +260,25 @@ if __name__ == "__main__": ...@@ -178,13 +260,25 @@ if __name__ == "__main__":
loss_fn = masked_mae_loss loss_fn = masked_mae_loss
for e in range(args.epochs): for e in range(args.epochs):
train_loss = train(dcrnn, g, train_loader, optimizer, scheduler, train_loss = train(
normalizer, loss_fn, device, args) dcrnn,
valid_loss = eval(dcrnn, g, valid_loader, g,
normalizer, loss_fn, device, args) train_loader,
test_loss = eval(dcrnn, g, test_loader, optimizer,
normalizer, loss_fn, device, args) scheduler,
print("Epoch: {} Train Loss: {} Valid Loss: {} Test Loss: {}".format(e, normalizer,
train_loss, loss_fn,
valid_loss, device,
test_loss)) args,
)
valid_loss = eval(
dcrnn, g, valid_loader, normalizer, loss_fn, device, args
)
test_loss = eval(
dcrnn, g, test_loader, normalizer, loss_fn, device, args
)
print(
"Epoch: {} Train Loss: {} Valid Loss: {} Test Loss: {}".format(
e, train_loss, valid_loss, test_loss
)
)
import dgl
import scipy.sparse as sparse
import numpy as np import numpy as np
import torch.nn as nn import scipy.sparse as sparse
import torch import torch
import torch.nn as nn
import dgl
class NormalizationLayer(nn.Module): class NormalizationLayer(nn.Module):
...@@ -12,10 +13,10 @@ class NormalizationLayer(nn.Module): ...@@ -12,10 +13,10 @@ class NormalizationLayer(nn.Module):
# Here we shall expect mean and std be scaler # Here we shall expect mean and std be scaler
def normalize(self, x): def normalize(self, x):
return (x-self.mean)/self.std return (x - self.mean) / self.std
def denormalize(self, x): def denormalize(self, x):
return x*self.std + self.mean return x * self.std + self.mean
def masked_mae_loss(y_pred, y_true): def masked_mae_loss(y_pred, y_true):
...@@ -30,4 +31,4 @@ def masked_mae_loss(y_pred, y_true): ...@@ -30,4 +31,4 @@ def masked_mae_loss(y_pred, y_true):
def get_learning_rate(optimizer): def get_learning_rate(optimizer):
for param in optimizer.param_groups: for param in optimizer.param_groups:
return param['lr'] return param["lr"]
import torch
import numpy as np
import math import math
from itertools import product
import numpy as np
import pandas as pd import pandas as pd
import torch
import dgl import dgl
from dgl.data import DGLDataset from dgl.data import DGLDataset
from itertools import product
class EEGGraphDataset(DGLDataset): class EEGGraphDataset(DGLDataset):
""" Build graph, treat all nodes as the same type """Build graph, treat all nodes as the same type
Parameters Parameters
---------- ----------
x: edge weights of 8-node complete graph x: edge weights of 8-node complete graph
There are 1 x 64 edges There are 1 x 64 edges
y: labels (diseased/healthy) y: labels (diseased/healthy)
num_nodes: the number of nodes of the graph. In our case, it is 8. num_nodes: the number of nodes of the graph. In our case, it is 8.
indices: Patient level indices. They are used to generate edge weights. indices: Patient level indices. They are used to generate edge weights.
Output Output
------ ------
a complete 8-node DGLGraph with node features and edge weights a complete 8-node DGLGraph with node features and edge weights
""" """
def __init__(self, x, y, num_nodes, indices): def __init__(self, x, y, num_nodes, indices):
# CAUTION - x and labels are memory-mapped, used as if they are in RAM. # CAUTION - x and labels are memory-mapped, used as if they are in RAM.
self.x = x self.x = x
...@@ -37,27 +40,24 @@ class EEGGraphDataset(DGLDataset): ...@@ -37,27 +40,24 @@ class EEGGraphDataset(DGLDataset):
"P7-P3", "P7-P3",
"P8-P4", "P8-P4",
"O1-P3", "O1-P3",
"O2-P4" "O2-P4",
] ]
# in the 10-10 system, in between the 2 10-20 electrodes in ch_names, used for calculating edge weights # in the 10-10 system, in between the 2 10-20 electrodes in ch_names, used for calculating edge weights
# Note: "01" is for "P03", and "02" is for "P04." # Note: "01" is for "P03", and "02" is for "P04."
self.ref_names = [ self.ref_names = ["F5", "F6", "C5", "C6", "P5", "P6", "O1", "O2"]
"F5",
"F6",
"C5",
"C6",
"P5",
"P6",
"O1",
"O2"
]
# edge indices source to target - 2 x E = 2 x 64 # edge indices source to target - 2 x E = 2 x 64
# fully connected undirected graph so 8*8=64 edges # fully connected undirected graph so 8*8=64 edges
self.node_ids = range(len(self.ch_names)) self.node_ids = range(len(self.ch_names))
self.edge_index = torch.tensor([[a, b] for a, b in product(self.node_ids, self.node_ids)], self.edge_index = (
dtype=torch.long).t().contiguous() torch.tensor(
[[a, b] for a, b in product(self.node_ids, self.node_ids)],
dtype=torch.long,
)
.t()
.contiguous()
)
# edge attributes - E x 1 # edge attributes - E x 1
# only the spatial distance between electrodes for now - standardize between 0 and 1 # only the spatial distance between electrodes for now - standardize between 0 and 1
...@@ -68,18 +68,22 @@ class EEGGraphDataset(DGLDataset): ...@@ -68,18 +68,22 @@ class EEGGraphDataset(DGLDataset):
# sensor distances don't depend on window ID # sensor distances don't depend on window ID
def get_sensor_distances(self): def get_sensor_distances(self):
coords_1010 = pd.read_csv("standard_1010.tsv.txt", sep='\t') coords_1010 = pd.read_csv("standard_1010.tsv.txt", sep="\t")
num_edges = self.edge_index.shape[1] num_edges = self.edge_index.shape[1]
distances = [] distances = []
for edge_idx in range(num_edges): for edge_idx in range(num_edges):
sensor1_idx = self.edge_index[0, edge_idx] sensor1_idx = self.edge_index[0, edge_idx]
sensor2_idx = self.edge_index[1, edge_idx] sensor2_idx = self.edge_index[1, edge_idx]
dist = self.get_geodesic_distance(sensor1_idx, sensor2_idx, coords_1010) dist = self.get_geodesic_distance(
sensor1_idx, sensor2_idx, coords_1010
)
distances.append(dist) distances.append(dist)
assert len(distances) == num_edges assert len(distances) == num_edges
return distances return distances
def get_geodesic_distance(self, montage_sensor1_idx, montage_sensor2_idx, coords_1010): def get_geodesic_distance(
self, montage_sensor1_idx, montage_sensor2_idx, coords_1010
):
# get the reference sensor in the 10-10 system for the current montage pair in 10-20 system # get the reference sensor in the 10-10 system for the current montage pair in 10-20 system
ref_sensor1 = self.ref_names[montage_sensor1_idx] ref_sensor1 = self.ref_names[montage_sensor1_idx]
...@@ -96,7 +100,9 @@ class EEGGraphDataset(DGLDataset): ...@@ -96,7 +100,9 @@ class EEGGraphDataset(DGLDataset):
# https://math.stackexchange.com/questions/1304169/distance-between-two-points-on-a-sphere # https://math.stackexchange.com/questions/1304169/distance-between-two-points-on-a-sphere
r = 1 # since coords are on unit sphere r = 1 # since coords are on unit sphere
# rounding is for numerical stability, domain is [-1, 1] # rounding is for numerical stability, domain is [-1, 1]
dist = r * math.acos(round(((x1 * x2) + (y1 * y2) + (z1 * z2)) / (r ** 2), 2)) dist = r * math.acos(
round(((x1 * x2) + (y1 * y2) + (z1 * z2)) / (r**2), 2)
)
return dist return dist
# returns size of dataset = number of indices # returns size of dataset = number of indices
...@@ -123,19 +129,23 @@ class EEGGraphDataset(DGLDataset): ...@@ -123,19 +129,23 @@ class EEGGraphDataset(DGLDataset):
edge_weights = torch.tensor(edge_weights) # trucated to integer edge_weights = torch.tensor(edge_weights) # trucated to integer
# create 8-node complete graph # create 8-node complete graph
src = [[0 for i in range(self.num_nodes)] for j in range(self.num_nodes)] src = [
[0 for i in range(self.num_nodes)] for j in range(self.num_nodes)
]
for i in range(len(src)): for i in range(len(src)):
for j in range(len(src[i])): for j in range(len(src[i])):
src[i][j] = i src[i][j] = i
src = np.array(src).flatten() src = np.array(src).flatten()
det = [[i for i in range(self.num_nodes)] for j in range(self.num_nodes)] det = [
[i for i in range(self.num_nodes)] for j in range(self.num_nodes)
]
det = np.array(det).flatten() det = np.array(det).flatten()
u, v = (torch.tensor(src), torch.tensor(det)) u, v = (torch.tensor(src), torch.tensor(det))
g = dgl.graph((u, v)) g = dgl.graph((u, v))
# add node features and edge features # add node features and edge features
g.ndata['x'] = node_features g.ndata["x"] = node_features
g.edata['edge_weights'] = edge_weights g.edata["edge_weights"] = edge_weights
return g, torch.tensor(idx), torch.tensor(self.labels[idx]) return g, torch.tensor(idx), torch.tensor(self.labels[idx])
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as function import torch.nn.functional as function
from dgl.nn import GraphConv, SumPooling
from torch.nn import BatchNorm1d from torch.nn import BatchNorm1d
from dgl.nn import GraphConv, SumPooling
class EEGGraphConvNet(nn.Module): class EEGGraphConvNet(nn.Module):
""" EEGGraph Convolution Net """EEGGraph Convolution Net
Parameters Parameters
---------- ----------
num_feats: the number of features per node. In our case, it is 6. num_feats: the number of features per node. In our case, it is 6.
""" """
def __init__(self, num_feats): def __init__(self, num_feats):
super(EEGGraphConvNet, self).__init__() super(EEGGraphConvNet, self).__init__()
...@@ -17,7 +19,9 @@ class EEGGraphConvNet(nn.Module): ...@@ -17,7 +19,9 @@ class EEGGraphConvNet(nn.Module):
self.conv2 = GraphConv(16, 32) self.conv2 = GraphConv(16, 32)
self.conv3 = GraphConv(32, 64) self.conv3 = GraphConv(32, 64)
self.conv4 = GraphConv(64, 50) self.conv4 = GraphConv(64, 50)
self.conv4_bn = BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) self.conv4_bn = BatchNorm1d(
50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
)
self.fc_block1 = nn.Linear(50, 30) self.fc_block1 = nn.Linear(50, 30)
self.fc_block2 = nn.Linear(30, 10) self.fc_block2 = nn.Linear(30, 10)
...@@ -31,8 +35,8 @@ class EEGGraphConvNet(nn.Module): ...@@ -31,8 +35,8 @@ class EEGGraphConvNet(nn.Module):
self.sumpool = SumPooling() self.sumpool = SumPooling()
def forward(self, g, return_graph_embedding=False): def forward(self, g, return_graph_embedding=False):
x = g.ndata['x'] x = g.ndata["x"]
edge_weight = g.edata['edge_weights'] edge_weight = g.edata["edge_weights"]
x = self.conv1(g, x, edge_weight=edge_weight) x = self.conv1(g, x, edge_weight=edge_weight)
x = function.leaky_relu(x, negative_slope=0.01) x = function.leaky_relu(x, negative_slope=0.01)
......
import argparse import argparse
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import torch import torch
import torch.nn as nn import torch.nn as nn
from sklearn.model_selection import train_test_split
from joblib import load
from EEGGraphDataset import EEGGraphDataset from EEGGraphDataset import EEGGraphDataset
from dgl.dataloading import GraphDataLoader from joblib import load
from torch.utils.data import WeightedRandomSampler
from sklearn.metrics import roc_auc_score
from sklearn.metrics import balanced_accuracy_score
from sklearn import preprocessing from sklearn import preprocessing
from sklearn.metrics import balanced_accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_split
from torch.utils.data import WeightedRandomSampler
from dgl.dataloading import GraphDataLoader
if __name__ == "__main__": if __name__ == "__main__":
# argparse commandline args # argparse commandline args
parser = argparse.ArgumentParser(description='Execute training pipeline on a given train/val subjects') parser = argparse.ArgumentParser(
parser.add_argument('--num_feats', type=int, default=6, help='Number of features per node for the graph') description="Execute training pipeline on a given train/val subjects"
parser.add_argument('--num_nodes', type=int, default=8, help='Number of nodes in the graph') )
parser.add_argument('--gpu_idx', type=int, default=0, parser.add_argument(
help='index of GPU device that should be used for this run, defaults to 0.') "--num_feats",
parser.add_argument('--num_epochs', type=int, default=40, help='Number of epochs used to train') type=int,
parser.add_argument('--exp_name', type=str, default='default', help='Name for the test.') default=6,
parser.add_argument('--batch_size', type=int, default=512, help='Batch Size. Default is 512.') help="Number of features per node for the graph",
parser.add_argument('--model', type=str, default='shallow', )
help='type shallow to use shallow_EEGGraphDataset; ' parser.add_argument(
'type deep to use deep_EEGGraphDataset. Default is shallow') "--num_nodes", type=int, default=8, help="Number of nodes in the graph"
)
parser.add_argument(
"--gpu_idx",
type=int,
default=0,
help="index of GPU device that should be used for this run, defaults to 0.",
)
parser.add_argument(
"--num_epochs",
type=int,
default=40,
help="Number of epochs used to train",
)
parser.add_argument(
"--exp_name", type=str, default="default", help="Name for the test."
)
parser.add_argument(
"--batch_size",
type=int,
default=512,
help="Batch Size. Default is 512.",
)
parser.add_argument(
"--model",
type=str,
default="shallow",
help="type shallow to use shallow_EEGGraphDataset; "
"type deep to use deep_EEGGraphDataset. Default is shallow",
)
args = parser.parse_args() args = parser.parse_args()
# choose model # choose model
if args.model == 'shallow': if args.model == "shallow":
from shallow_EEGGraphConvNet import EEGGraphConvNet from shallow_EEGGraphConvNet import EEGGraphConvNet
if args.model == 'deep': if args.model == "deep":
from deep_EEGGraphConvNet import EEGGraphConvNet from deep_EEGGraphConvNet import EEGGraphConvNet
# set the random seed so that we can reproduce the results # set the random seed so that we can reproduce the results
...@@ -41,9 +70,11 @@ if __name__ == "__main__": ...@@ -41,9 +70,11 @@ if __name__ == "__main__":
# use GPU when available # use GPU when available
_GPU_IDX = args.gpu_idx _GPU_IDX = args.gpu_idx
_DEVICE = torch.device(f'cuda:{_GPU_IDX}' if torch.cuda.is_available() else 'cpu') _DEVICE = torch.device(
f"cuda:{_GPU_IDX}" if torch.cuda.is_available() else "cpu"
)
torch.cuda.set_device(_DEVICE) torch.cuda.set_device(_DEVICE)
print(f' Using device: {_DEVICE} {torch.cuda.get_device_name(_DEVICE)}') print(f" Using device: {_DEVICE} {torch.cuda.get_device_name(_DEVICE)}")
# load patient level indices # load patient level indices
_DATASET_INDEX = pd.read_csv("master_metadata_index.csv") _DATASET_INDEX = pd.read_csv("master_metadata_index.csv")
...@@ -58,10 +89,10 @@ if __name__ == "__main__": ...@@ -58,10 +89,10 @@ if __name__ == "__main__":
num_feats = args.num_feats num_feats = args.num_feats
# set up input and targets from files # set up input and targets from files
memmap_x = f'psd_features_data_X' memmap_x = f"psd_features_data_X"
memmap_y = f'labels_y' memmap_y = f"labels_y"
x = load(memmap_x, mmap_mode='r') x = load(memmap_x, mmap_mode="r")
y = load(memmap_y, mmap_mode='r') y = load(memmap_y, mmap_mode="r")
# normalize psd features data # normalize psd features data
normd_x = [] normd_x = []
...@@ -79,22 +110,33 @@ if __name__ == "__main__": ...@@ -79,22 +110,33 @@ if __name__ == "__main__":
print(f"Unique labels 0/1 mapping: {label_mapping}") print(f"Unique labels 0/1 mapping: {label_mapping}")
# split the dataset to train and test. The ratio of test is 0.3. # split the dataset to train and test. The ratio of test is 0.3.
train_and_val_subjects, heldout_subjects = train_test_split(all_subjects, test_size=0.3, random_state=42) train_and_val_subjects, heldout_subjects = train_test_split(
all_subjects, test_size=0.3, random_state=42
)
# split the dataset using patient indices # split the dataset using patient indices
train_window_indices = _DATASET_INDEX.index[ train_window_indices = _DATASET_INDEX.index[
_DATASET_INDEX["patient_ID"].astype("str").isin(train_and_val_subjects)].tolist() _DATASET_INDEX["patient_ID"].astype("str").isin(train_and_val_subjects)
].tolist()
heldout_test_window_indices = _DATASET_INDEX.index[ heldout_test_window_indices = _DATASET_INDEX.index[
_DATASET_INDEX["patient_ID"].astype("str").isin(heldout_subjects)].tolist() _DATASET_INDEX["patient_ID"].astype("str").isin(heldout_subjects)
].tolist()
# define model, optimizer, scheduler # define model, optimizer, scheduler
model = EEGGraphConvNet(num_feats) model = EEGGraphConvNet(num_feats)
loss_function = nn.CrossEntropyLoss() loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[i * 10 for i in range(1, 26)], gamma=0.1) scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[i * 10 for i in range(1, 26)], gamma=0.1
)
model = model.to(_DEVICE).double() model = model.to(_DEVICE).double()
num_trainable_params = np.sum([np.prod(p.size()) if p.requires_grad else 0 for p in model.parameters()]) num_trainable_params = np.sum(
[
np.prod(p.size()) if p.requires_grad else 0
for p in model.parameters()
]
)
# Dataloader======================================================================================================== # Dataloader========================================================================================================
...@@ -109,7 +151,8 @@ if __name__ == "__main__": ...@@ -109,7 +151,8 @@ if __name__ == "__main__":
# sampler needs to come up with training set size number of samples # sampler needs to come up with training set size number of samples
weighted_sampler = WeightedRandomSampler( weighted_sampler = WeightedRandomSampler(
weights=sample_weights, weights=sample_weights,
num_samples=len(train_window_indices), replacement=True num_samples=len(train_window_indices),
replacement=True,
) )
# train data loader # train data loader
...@@ -118,17 +161,20 @@ if __name__ == "__main__": ...@@ -118,17 +161,20 @@ if __name__ == "__main__":
) )
train_loader = GraphDataLoader( train_loader = GraphDataLoader(
dataset=train_dataset, batch_size=_BATCH_SIZE, dataset=train_dataset,
batch_size=_BATCH_SIZE,
sampler=weighted_sampler, sampler=weighted_sampler,
num_workers=NUM_WORKERS, num_workers=NUM_WORKERS,
pin_memory=True pin_memory=True,
) )
# this loader is used without weighted sampling, to evaluate metrics on full training set after each epoch # this loader is used without weighted sampling, to evaluate metrics on full training set after each epoch
train_metrics_loader = GraphDataLoader( train_metrics_loader = GraphDataLoader(
dataset=train_dataset, batch_size=_BATCH_SIZE, dataset=train_dataset,
shuffle=False, num_workers=NUM_WORKERS, batch_size=_BATCH_SIZE,
pin_memory=True shuffle=False,
num_workers=NUM_WORKERS,
pin_memory=True,
) )
# test data loader # test data loader
...@@ -137,9 +183,11 @@ if __name__ == "__main__": ...@@ -137,9 +183,11 @@ if __name__ == "__main__":
) )
test_loader = GraphDataLoader( test_loader = GraphDataLoader(
dataset=test_dataset, batch_size=_BATCH_SIZE, dataset=test_dataset,
shuffle=False, num_workers=NUM_WORKERS, batch_size=_BATCH_SIZE,
pin_memory=True shuffle=False,
num_workers=NUM_WORKERS,
pin_memory=True,
) )
auroc_train_history = [] auroc_train_history = []
...@@ -173,7 +221,7 @@ if __name__ == "__main__": ...@@ -173,7 +221,7 @@ if __name__ == "__main__":
# update learning rate # update learning rate
scheduler.step() scheduler.step()
# evaluate model after each epoch for train-metric data============================================================ # evaluate model after each epoch for train-metric data============================================================
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
y_probs_train = torch.empty(0, 2).to(_DEVICE) y_probs_train = torch.empty(0, 2).to(_DEVICE)
...@@ -194,10 +242,12 @@ if __name__ == "__main__": ...@@ -194,10 +242,12 @@ if __name__ == "__main__":
y_true_train += y_batch.cpu().numpy().tolist() y_true_train += y_batch.cpu().numpy().tolist()
# returning prob distribution over target classes, take softmax over the 1st dimension # returning prob distribution over target classes, take softmax over the 1st dimension
y_probs_train = nn.functional.softmax(y_probs_train, dim=1).cpu().numpy() y_probs_train = (
nn.functional.softmax(y_probs_train, dim=1).cpu().numpy()
)
y_true_train = np.array(y_true_train) y_true_train = np.array(y_true_train)
# evaluate model after each epoch for validation data ============================================================== # evaluate model after each epoch for validation data ==============================================================
y_probs_test = torch.empty(0, 2).to(_DEVICE) y_probs_test = torch.empty(0, 2).to(_DEVICE)
y_true_test, minibatch_loss, y_pred_test = [], [], [] y_true_test, minibatch_loss, y_pred_test = [], [], []
...@@ -217,32 +267,54 @@ if __name__ == "__main__": ...@@ -217,32 +267,54 @@ if __name__ == "__main__":
y_true_test += y_batch.cpu().numpy().tolist() y_true_test += y_batch.cpu().numpy().tolist()
# returning prob distribution over target classes, take softmax over the 1st dimension # returning prob distribution over target classes, take softmax over the 1st dimension
y_probs_test = torch.nn.functional.softmax(y_probs_test, dim=1).cpu().numpy() y_probs_test = (
torch.nn.functional.softmax(y_probs_test, dim=1).cpu().numpy()
)
y_true_test = np.array(y_true_test) y_true_test = np.array(y_true_test)
# record training auroc and testing auroc # record training auroc and testing auroc
auroc_train_history.append(roc_auc_score(y_true_train, y_probs_train[:, 1])) auroc_train_history.append(
auroc_test_history.append(roc_auc_score(y_true_test, y_probs_test[:, 1])) roc_auc_score(y_true_train, y_probs_train[:, 1])
)
auroc_test_history.append(
roc_auc_score(y_true_test, y_probs_test[:, 1])
)
# record training balanced accuracy and testing balanced accuracy # record training balanced accuracy and testing balanced accuracy
balACC_train_history.append(balanced_accuracy_score(y_true_train, y_pred_train)) balACC_train_history.append(
balACC_test_history.append(balanced_accuracy_score(y_true_test, y_pred_test)) balanced_accuracy_score(y_true_train, y_pred_train)
)
balACC_test_history.append(
balanced_accuracy_score(y_true_test, y_pred_test)
)
# LOSS - epoch loss is defined as mean of minibatch losses within epoch # LOSS - epoch loss is defined as mean of minibatch losses within epoch
loss_train_history.append(np.mean(train_loss)) loss_train_history.append(np.mean(train_loss))
loss_test_history.append(np.mean(minibatch_loss)) loss_test_history.append(np.mean(minibatch_loss))
# print the metrics # print the metrics
print("Train loss: {}, test loss: {}".format(loss_train_history[-1], loss_test_history[-1])) print(
print("Train AUC: {}, test AUC: {}".format(auroc_train_history[-1], auroc_test_history[-1])) "Train loss: {}, test loss: {}".format(
print("Train Bal.ACC: {}, test Bal.ACC: {}".format(balACC_train_history[-1], balACC_test_history[-1])) loss_train_history[-1], loss_test_history[-1]
)
)
print(
"Train AUC: {}, test AUC: {}".format(
auroc_train_history[-1], auroc_test_history[-1]
)
)
print(
"Train Bal.ACC: {}, test Bal.ACC: {}".format(
balACC_train_history[-1], balACC_test_history[-1]
)
)
# save model from each epoch==================================================================================== # save model from each epoch====================================================================================
state = { state = {
'epochs': _NUM_EPOCHS, "epochs": _NUM_EPOCHS,
'experiment_name': _EXPERIMENT_NAME, "experiment_name": _EXPERIMENT_NAME,
'model_description': str(model), "model_description": str(model),
'state_dict': model.state_dict(), "state_dict": model.state_dict(),
'optimizer': optimizer.state_dict() "optimizer": optimizer.state_dict(),
} }
torch.save(state, f"{_EXPERIMENT_NAME}_Epoch_{epoch}.ckpt") torch.save(state, f"{_EXPERIMENT_NAME}_Epoch_{epoch}.ckpt")
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as function import torch.nn.functional as function
from dgl.nn import GraphConv, SumPooling from dgl.nn import GraphConv, SumPooling
class EEGGraphConvNet(nn.Module): class EEGGraphConvNet(nn.Module):
""" EEGGraph Convolution Net """EEGGraph Convolution Net
Parameters Parameters
---------- ----------
num_feats: the number of features per node. In our case, it is 6. num_feats: the number of features per node. In our case, it is 6.
""" """
def __init__(self, num_feats): def __init__(self, num_feats):
super(EEGGraphConvNet, self).__init__() super(EEGGraphConvNet, self).__init__()
self.conv1 = GraphConv(num_feats, 32) self.conv1 = GraphConv(num_feats, 32)
self.conv2 = GraphConv(32, 20) self.conv2 = GraphConv(32, 20)
self.conv2_bn = nn.BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) self.conv2_bn = nn.BatchNorm1d(
20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
)
self.fc_block1 = nn.Linear(20, 10) self.fc_block1 = nn.Linear(20, 10)
self.fc_block2 = nn.Linear(10, 2) self.fc_block2 = nn.Linear(10, 2)
...@@ -23,11 +27,13 @@ class EEGGraphConvNet(nn.Module): ...@@ -23,11 +27,13 @@ class EEGGraphConvNet(nn.Module):
self.fc_block2.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1)) self.fc_block2.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1))
def forward(self, g, return_graph_embedding=False): def forward(self, g, return_graph_embedding=False):
x = g.ndata['x'] x = g.ndata["x"]
edge_weight = g.edata['edge_weights'] edge_weight = g.edata["edge_weights"]
x = function.leaky_relu(self.conv1(g, x, edge_weight=edge_weight)) x = function.leaky_relu(self.conv1(g, x, edge_weight=edge_weight))
x = function.leaky_relu(self.conv2_bn(self.conv2(g, x, edge_weight=edge_weight))) x = function.leaky_relu(
self.conv2_bn(self.conv2(g, x, edge_weight=edge_weight))
)
# NOTE: this takes node-level features/"embeddings" # NOTE: this takes node-level features/"embeddings"
# and aggregates to graph-level - use for graph-level classification # and aggregates to graph-level - use for graph-level classification
......
import dgl
import torch as th import torch as th
import torch.optim as optim import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn import metrics
import utils import utils
from model import EGES from model import EGES
from sampler import Sampler from sampler import Sampler
from sklearn import metrics
from torch.utils.data import DataLoader
import dgl
def train(args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates): def train(args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates):
sampler = Sampler( sampler = Sampler(
train_g, train_g,
args.walk_length, args.walk_length,
args.num_walks, args.num_walks,
args.window_size, args.window_size,
args.num_negative args.num_negative,
) )
# for each node in the graph, we sample pos and neg # for each node in the graph, we sample pos and neg
# pairs for it, and feed these sampled pairs into the model. # pairs for it, and feed these sampled pairs into the model.
...@@ -25,7 +25,7 @@ def train(args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates): ...@@ -25,7 +25,7 @@ def train(args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates):
# this is the batch_size of input nodes # this is the batch_size of input nodes
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
collate_fn=lambda x: sampler.sample(x, sku_info) collate_fn=lambda x: sampler.sample(x, sku_info),
) )
model = EGES(args.dim, num_skus, num_brands, num_shops, num_cates) model = EGES(args.dim, num_skus, num_brands, num_shops, num_cates)
optimizer = optim.Adam(model.parameters(), lr=args.lr) optimizer = optim.Adam(model.parameters(), lr=args.lr)
...@@ -45,8 +45,11 @@ def train(args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates): ...@@ -45,8 +45,11 @@ def train(args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates):
epoch_total_loss += loss.item() epoch_total_loss += loss.item()
if step % args.log_every == 0: if step % args.log_every == 0:
print('Epoch {:05d} | Step {:05d} | Step Loss {:.4f} | Epoch Avg Loss: {:.4f}'.format( print(
epoch, step, loss.item(), epoch_total_loss / (step + 1))) "Epoch {:05d} | Step {:05d} | Step Loss {:.4f} | Epoch Avg Loss: {:.4f}".format(
epoch, step, loss.item(), epoch_total_loss / (step + 1)
)
)
eval(model, test_g, sku_info) eval(model, test_g, sku_info)
...@@ -77,15 +80,14 @@ if __name__ == "__main__": ...@@ -77,15 +80,14 @@ if __name__ == "__main__":
valid_sku_raw_ids = utils.get_valid_sku_set(args.item_info_data) valid_sku_raw_ids = utils.get_valid_sku_set(args.item_info_data)
g, sku_encoder, sku_decoder = utils.construct_graph( g, sku_encoder, sku_decoder = utils.construct_graph(
args.action_data, args.action_data, args.session_interval_sec, valid_sku_raw_ids
args.session_interval_sec,
valid_sku_raw_ids
) )
train_g, test_g = utils.split_train_test_graph(g) train_g, test_g = utils.split_train_test_graph(g)
sku_info_encoder, sku_info_decoder, sku_info = \ sku_info_encoder, sku_info_decoder, sku_info = utils.encode_sku_fields(
utils.encode_sku_fields(args.item_info_data, sku_encoder, sku_decoder) args.item_info_data, sku_encoder, sku_decoder
)
num_skus = len(sku_encoder) num_skus = len(sku_encoder)
num_brands = len(sku_info_encoder["brand"]) num_brands = len(sku_info_encoder["brand"])
...@@ -93,8 +95,11 @@ if __name__ == "__main__": ...@@ -93,8 +95,11 @@ if __name__ == "__main__":
num_cates = len(sku_info_encoder["cate"]) num_cates = len(sku_info_encoder["cate"])
print( print(
"Num skus: {}, num brands: {}, num shops: {}, num cates: {}".\ "Num skus: {}, num brands: {}, num shops: {}, num cates: {}".format(
format(num_skus, num_brands, num_shops, num_cates) num_skus, num_brands, num_shops, num_cates
)
) )
model = train(args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates) model = train(
args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates
)
...@@ -18,12 +18,12 @@ class EGES(th.nn.Module): ...@@ -18,12 +18,12 @@ class EGES(th.nn.Module):
# srcs: sku_id, brand_id, shop_id, cate_id # srcs: sku_id, brand_id, shop_id, cate_id
srcs = self.query_node_embed(srcs) srcs = self.query_node_embed(srcs)
dsts = self.query_node_embed(dsts) dsts = self.query_node_embed(dsts)
return srcs, dsts return srcs, dsts
def query_node_embed(self, nodes): def query_node_embed(self, nodes):
""" """
@nodes: tensor of shape (batch_size, num_side_info) @nodes: tensor of shape (batch_size, num_side_info)
""" """
batch_size = nodes.shape[0] batch_size = nodes.shape[0]
# query side info weights, (batch_size, 4) # query side info weights, (batch_size, 4)
...@@ -33,21 +33,31 @@ class EGES(th.nn.Module): ...@@ -33,21 +33,31 @@ class EGES(th.nn.Module):
side_info_weights_sum = [] side_info_weights_sum = []
for i in range(4): for i in range(4):
# weights for i-th side info, (batch_size, ) -> (batch_size, 1) # weights for i-th side info, (batch_size, ) -> (batch_size, 1)
i_th_side_info_weights = side_info_weights[:, i].view((batch_size, 1)) i_th_side_info_weights = side_info_weights[:, i].view(
(batch_size, 1)
)
# batch of i-th side info embedding * its weight, (batch_size, dim) # batch of i-th side info embedding * its weight, (batch_size, dim)
side_info_weighted_embeds_sum.append(i_th_side_info_weights * self.embeds[i](nodes[:, i])) side_info_weighted_embeds_sum.append(
i_th_side_info_weights * self.embeds[i](nodes[:, i])
)
side_info_weights_sum.append(i_th_side_info_weights) side_info_weights_sum.append(i_th_side_info_weights)
# stack: (batch_size, 4, dim), sum: (batch_size, dim) # stack: (batch_size, 4, dim), sum: (batch_size, dim)
side_info_weighted_embeds_sum = th.sum(th.stack(side_info_weighted_embeds_sum, axis=1), axis=1) side_info_weighted_embeds_sum = th.sum(
th.stack(side_info_weighted_embeds_sum, axis=1), axis=1
)
# stack: (batch_size, 4), sum: (batch_size, ) # stack: (batch_size, 4), sum: (batch_size, )
side_info_weights_sum = th.sum(th.stack(side_info_weights_sum, axis=1), axis=1) side_info_weights_sum = th.sum(
th.stack(side_info_weights_sum, axis=1), axis=1
)
# (batch_size, dim) # (batch_size, dim)
H = side_info_weighted_embeds_sum / side_info_weights_sum H = side_info_weighted_embeds_sum / side_info_weights_sum
return H return H
def loss(self, srcs, dsts, labels): def loss(self, srcs, dsts, labels):
dots = th.sigmoid(th.sum(srcs * dsts, axis=1)) dots = th.sigmoid(th.sum(srcs * dsts, axis=1))
dots = th.clamp(dots, min=1e-7, max=1 - 1e-7) dots = th.clamp(dots, min=1e-7, max=1 - 1e-7)
return th.mean(- (labels * th.log(dots) + (1 - labels) * th.log(1 - dots))) return th.mean(
-(labels * th.log(dots) + (1 - labels) * th.log(1 - dots))
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment