Unverified Commit 23d09057 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4642)



* [Misc] Black auto fix.

* sort
Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent a9f2acf3
import torch.nn as nn
def GlorotOrthogonal(tensor, scale=2.0):
if tensor is not None:
nn.init.orthogonal_(tensor.data)
scale /= (tensor.size(-2) + tensor.size(-1)) * tensor.var()
tensor.data *= scale.sqrt()
\ No newline at end of file
tensor.data *= scale.sqrt()
import torch
import torch.nn as nn
from modules.initializers import GlorotOrthogonal
from modules.residual_layer import ResidualLayer
import dgl.function as fn
from modules.residual_layer import ResidualLayer
from modules.initializers import GlorotOrthogonal
class InteractionBlock(nn.Module):
def __init__(self,
emb_size,
num_radial,
num_spherical,
num_bilinear,
num_before_skip,
num_after_skip,
activation=None):
def __init__(
self,
emb_size,
num_radial,
num_spherical,
num_bilinear,
num_before_skip,
num_after_skip,
activation=None,
):
super(InteractionBlock, self).__init__()
self.activation = activation
# Transformations of Bessel and spherical basis representations
self.dense_rbf = nn.Linear(num_radial, emb_size, bias=False)
self.dense_sbf = nn.Linear(num_radial * num_spherical, num_bilinear, bias=False)
self.dense_sbf = nn.Linear(
num_radial * num_spherical, num_bilinear, bias=False
)
# Dense transformations of input messages
self.dense_ji = nn.Linear(emb_size, emb_size)
self.dense_kj = nn.Linear(emb_size, emb_size)
# Bilinear layer
bilin_initializer = torch.empty((emb_size, num_bilinear, emb_size)).normal_(mean=0, std=2 / emb_size)
bilin_initializer = torch.empty(
(emb_size, num_bilinear, emb_size)
).normal_(mean=0, std=2 / emb_size)
self.W_bilin = nn.Parameter(bilin_initializer)
# Residual layers before skip connection
self.layers_before_skip = nn.ModuleList([
ResidualLayer(emb_size, activation=activation) for _ in range(num_before_skip)
])
self.layers_before_skip = nn.ModuleList(
[
ResidualLayer(emb_size, activation=activation)
for _ in range(num_before_skip)
]
)
self.final_before_skip = nn.Linear(emb_size, emb_size)
# Residual layers after skip connection
self.layers_after_skip = nn.ModuleList([
ResidualLayer(emb_size, activation=activation) for _ in range(num_after_skip)
])
self.layers_after_skip = nn.ModuleList(
[
ResidualLayer(emb_size, activation=activation)
for _ in range(num_after_skip)
]
)
self.reset_params()
def reset_params(self):
GlorotOrthogonal(self.dense_rbf.weight)
GlorotOrthogonal(self.dense_sbf.weight)
......@@ -47,50 +60,52 @@ class InteractionBlock(nn.Module):
def edge_transfer(self, edges):
# Transform from Bessel basis to dence vector
rbf = self.dense_rbf(edges.data['rbf'])
rbf = self.dense_rbf(edges.data["rbf"])
# Initial transformation
x_ji = self.dense_ji(edges.data['m'])
x_kj = self.dense_kj(edges.data['m'])
x_ji = self.dense_ji(edges.data["m"])
x_kj = self.dense_kj(edges.data["m"])
if self.activation is not None:
x_ji = self.activation(x_ji)
x_kj = self.activation(x_kj)
# w: W * e_RBF \bigodot \sigma(W * m + b)
return {'x_kj': x_kj * rbf, 'x_ji': x_ji}
return {"x_kj": x_kj * rbf, "x_ji": x_ji}
def msg_func(self, edges):
sbf = self.dense_sbf(edges.data['sbf'])
sbf = self.dense_sbf(edges.data["sbf"])
# Apply bilinear layer to interactions and basis function activation
# [None, 8] * [128, 8, 128] * [None, 128] -> [None, 128]
x_kj = torch.einsum("wj,wl,ijl->wi", sbf, edges.src['x_kj'], self.W_bilin)
return {'x_kj': x_kj}
x_kj = torch.einsum(
"wj,wl,ijl->wi", sbf, edges.src["x_kj"], self.W_bilin
)
return {"x_kj": x_kj}
def forward(self, g, l_g):
g.apply_edges(self.edge_transfer)
# nodes correspond to edges and edges correspond to nodes in the original graphs
# node: d, rbf, o, rbf_env, x_kj, x_ji
for k, v in g.edata.items():
l_g.ndata[k] = v
l_g.update_all(self.msg_func, fn.sum('x_kj', 'm_update'))
l_g.update_all(self.msg_func, fn.sum("x_kj", "m_update"))
for k, v in l_g.ndata.items():
g.edata[k] = v
# Transformations before skip connection
g.edata['m_update'] = g.edata['m_update'] + g.edata['x_ji']
g.edata["m_update"] = g.edata["m_update"] + g.edata["x_ji"]
for layer in self.layers_before_skip:
g.edata['m_update'] = layer(g.edata['m_update'])
g.edata['m_update'] = self.final_before_skip(g.edata['m_update'])
g.edata["m_update"] = layer(g.edata["m_update"])
g.edata["m_update"] = self.final_before_skip(g.edata["m_update"])
if self.activation is not None:
g.edata['m_update'] = self.activation(g.edata['m_update'])
g.edata["m_update"] = self.activation(g.edata["m_update"])
# Skip connection
g.edata['m'] = g.edata['m'] + g.edata['m_update']
g.edata["m"] = g.edata["m"] + g.edata["m_update"]
# Transformations after skip connection
for layer in self.layers_after_skip:
g.edata['m'] = layer(g.edata['m'])
g.edata["m"] = layer(g.edata["m"])
return g
\ No newline at end of file
return g
import torch.nn as nn
from modules.initializers import GlorotOrthogonal
from modules.residual_layer import ResidualLayer
import dgl
import dgl.function as fn
from modules.residual_layer import ResidualLayer
from modules.initializers import GlorotOrthogonal
class InteractionPPBlock(nn.Module):
def __init__(self,
emb_size,
int_emb_size,
basis_emb_size,
num_radial,
num_spherical,
num_before_skip,
num_after_skip,
activation=None):
def __init__(
self,
emb_size,
int_emb_size,
basis_emb_size,
num_radial,
num_spherical,
num_before_skip,
num_after_skip,
activation=None,
):
super(InteractionPPBlock, self).__init__()
self.activation = activation
# Transformations of Bessel and spherical basis representations
self.dense_rbf1 = nn.Linear(num_radial, basis_emb_size, bias=False)
self.dense_rbf2 = nn.Linear(basis_emb_size, emb_size, bias=False)
self.dense_sbf1 = nn.Linear(num_radial * num_spherical, basis_emb_size, bias=False)
self.dense_sbf1 = nn.Linear(
num_radial * num_spherical, basis_emb_size, bias=False
)
self.dense_sbf2 = nn.Linear(basis_emb_size, int_emb_size, bias=False)
# Dense transformations of input messages
self.dense_ji = nn.Linear(emb_size, emb_size)
......@@ -30,17 +35,23 @@ class InteractionPPBlock(nn.Module):
self.down_projection = nn.Linear(emb_size, int_emb_size, bias=False)
self.up_projection = nn.Linear(int_emb_size, emb_size, bias=False)
# Residual layers before skip connection
self.layers_before_skip = nn.ModuleList([
ResidualLayer(emb_size, activation=activation) for _ in range(num_before_skip)
])
self.layers_before_skip = nn.ModuleList(
[
ResidualLayer(emb_size, activation=activation)
for _ in range(num_before_skip)
]
)
self.final_before_skip = nn.Linear(emb_size, emb_size)
# Residual layers after skip connection
self.layers_after_skip = nn.ModuleList([
ResidualLayer(emb_size, activation=activation) for _ in range(num_after_skip)
])
self.layers_after_skip = nn.ModuleList(
[
ResidualLayer(emb_size, activation=activation)
for _ in range(num_after_skip)
]
)
self.reset_params()
def reset_params(self):
GlorotOrthogonal(self.dense_rbf1.weight)
GlorotOrthogonal(self.dense_rbf2.weight)
......@@ -55,11 +66,11 @@ class InteractionPPBlock(nn.Module):
def edge_transfer(self, edges):
# Transform from Bessel basis to dense vector
rbf = self.dense_rbf1(edges.data['rbf'])
rbf = self.dense_rbf1(edges.data["rbf"])
rbf = self.dense_rbf2(rbf)
# Initial transformation
x_ji = self.dense_ji(edges.data['m'])
x_kj = self.dense_kj(edges.data['m'])
x_ji = self.dense_ji(edges.data["m"])
x_kj = self.dense_kj(edges.data["m"])
if self.activation is not None:
x_ji = self.activation(x_ji)
x_kj = self.activation(x_kj)
......@@ -67,41 +78,41 @@ class InteractionPPBlock(nn.Module):
x_kj = self.down_projection(x_kj * rbf)
if self.activation is not None:
x_kj = self.activation(x_kj)
return {'x_kj': x_kj, 'x_ji': x_ji}
return {"x_kj": x_kj, "x_ji": x_ji}
def msg_func(self, edges):
sbf = self.dense_sbf1(edges.data['sbf'])
sbf = self.dense_sbf1(edges.data["sbf"])
sbf = self.dense_sbf2(sbf)
x_kj = edges.src['x_kj'] * sbf
return {'x_kj': x_kj}
x_kj = edges.src["x_kj"] * sbf
return {"x_kj": x_kj}
def forward(self, g, l_g):
g.apply_edges(self.edge_transfer)
# nodes correspond to edges and edges correspond to nodes in the original graphs
# node: d, rbf, o, rbf_env, x_kj, x_ji
for k, v in g.edata.items():
l_g.ndata[k] = v
l_g_reverse = dgl.reverse(l_g, copy_edata=True)
l_g_reverse.update_all(self.msg_func, fn.sum('x_kj', 'm_update'))
l_g_reverse.update_all(self.msg_func, fn.sum("x_kj", "m_update"))
g.edata['m_update'] = self.up_projection(l_g_reverse.ndata['m_update'])
g.edata["m_update"] = self.up_projection(l_g_reverse.ndata["m_update"])
if self.activation is not None:
g.edata['m_update'] = self.activation(g.edata['m_update'])
g.edata["m_update"] = self.activation(g.edata["m_update"])
# Transformations before skip connection
g.edata['m_update'] = g.edata['m_update'] + g.edata['x_ji']
g.edata["m_update"] = g.edata["m_update"] + g.edata["x_ji"]
for layer in self.layers_before_skip:
g.edata['m_update'] = layer(g.edata['m_update'])
g.edata['m_update'] = self.final_before_skip(g.edata['m_update'])
g.edata["m_update"] = layer(g.edata["m_update"])
g.edata["m_update"] = self.final_before_skip(g.edata["m_update"])
if self.activation is not None:
g.edata['m_update'] = self.activation(g.edata['m_update'])
g.edata["m_update"] = self.activation(g.edata["m_update"])
# Skip connection
g.edata['m'] = g.edata['m'] + g.edata['m_update']
g.edata["m"] = g.edata["m"] + g.edata["m_update"]
# Transformations after skip connection
for layer in self.layers_after_skip:
g.edata['m'] = layer(g.edata['m'])
g.edata["m"] = layer(g.edata["m"])
return g
\ No newline at end of file
return g
import torch.nn as nn
from modules.initializers import GlorotOrthogonal
import dgl
import dgl.function as fn
from modules.initializers import GlorotOrthogonal
class OutputBlock(nn.Module):
def __init__(self,
emb_size,
num_radial,
num_dense,
num_targets,
activation=None,
output_init=nn.init.zeros_):
def __init__(
self,
emb_size,
num_radial,
num_dense,
num_targets,
activation=None,
output_init=nn.init.zeros_,
):
super(OutputBlock, self).__init__()
self.activation = activation
self.output_init = output_init
self.dense_rbf = nn.Linear(num_radial, emb_size, bias=False)
self.dense_layers = nn.ModuleList([
nn.Linear(emb_size, emb_size) for _ in range(num_dense)
])
self.dense_layers = nn.ModuleList(
[nn.Linear(emb_size, emb_size) for _ in range(num_dense)]
)
self.dense_final = nn.Linear(emb_size, num_targets, bias=False)
self.reset_params()
def reset_params(self):
GlorotOrthogonal(self.dense_rbf.weight)
for layer in self.dense_layers:
......@@ -31,12 +34,12 @@ class OutputBlock(nn.Module):
def forward(self, g):
with g.local_scope():
g.edata['tmp'] = g.edata['m'] * self.dense_rbf(g.edata['rbf'])
g.update_all(fn.copy_e('tmp', 'x'), fn.sum('x', 't'))
g.edata["tmp"] = g.edata["m"] * self.dense_rbf(g.edata["rbf"])
g.update_all(fn.copy_e("tmp", "x"), fn.sum("x", "t"))
for layer in self.dense_layers:
g.ndata['t'] = layer(g.ndata['t'])
g.ndata["t"] = layer(g.ndata["t"])
if self.activation is not None:
g.ndata['t'] = self.activation(g.ndata['t'])
g.ndata['t'] = self.dense_final(g.ndata['t'])
return dgl.readout_nodes(g, 't')
\ No newline at end of file
g.ndata["t"] = self.activation(g.ndata["t"])
g.ndata["t"] = self.dense_final(g.ndata["t"])
return dgl.readout_nodes(g, "t")
import torch.nn as nn
from modules.initializers import GlorotOrthogonal
import dgl
import dgl.function as fn
from modules.initializers import GlorotOrthogonal
class OutputPPBlock(nn.Module):
def __init__(self,
emb_size,
out_emb_size,
num_radial,
num_dense,
num_targets,
activation=None,
output_init=nn.init.zeros_,
extensive=True):
def __init__(
self,
emb_size,
out_emb_size,
num_radial,
num_dense,
num_targets,
activation=None,
output_init=nn.init.zeros_,
extensive=True,
):
super(OutputPPBlock, self).__init__()
self.activation = activation
......@@ -21,12 +24,12 @@ class OutputPPBlock(nn.Module):
self.extensive = extensive
self.dense_rbf = nn.Linear(num_radial, emb_size, bias=False)
self.up_projection = nn.Linear(emb_size, out_emb_size, bias=False)
self.dense_layers = nn.ModuleList([
nn.Linear(out_emb_size, out_emb_size) for _ in range(num_dense)
])
self.dense_layers = nn.ModuleList(
[nn.Linear(out_emb_size, out_emb_size) for _ in range(num_dense)]
)
self.dense_final = nn.Linear(out_emb_size, num_targets, bias=False)
self.reset_params()
def reset_params(self):
GlorotOrthogonal(self.dense_rbf.weight)
GlorotOrthogonal(self.up_projection.weight)
......@@ -36,14 +39,16 @@ class OutputPPBlock(nn.Module):
def forward(self, g):
with g.local_scope():
g.edata['tmp'] = g.edata['m'] * self.dense_rbf(g.edata['rbf'])
g.edata["tmp"] = g.edata["m"] * self.dense_rbf(g.edata["rbf"])
g_reverse = dgl.reverse(g, copy_edata=True)
g_reverse.update_all(fn.copy_e('tmp', 'x'), fn.sum('x', 't'))
g.ndata['t'] = self.up_projection(g_reverse.ndata['t'])
g_reverse.update_all(fn.copy_e("tmp", "x"), fn.sum("x", "t"))
g.ndata["t"] = self.up_projection(g_reverse.ndata["t"])
for layer in self.dense_layers:
g.ndata['t'] = layer(g.ndata['t'])
g.ndata["t"] = layer(g.ndata["t"])
if self.activation is not None:
g.ndata['t'] = self.activation(g.ndata['t'])
g.ndata['t'] = self.dense_final(g.ndata['t'])
return dgl.readout_nodes(g, 't', op='sum' if self.extensive else 'mean')
\ No newline at end of file
g.ndata["t"] = self.activation(g.ndata["t"])
g.ndata["t"] = self.dense_final(g.ndata["t"])
return dgl.readout_nodes(
g, "t", op="sum" if self.extensive else "mean"
)
import torch.nn as nn
from modules.initializers import GlorotOrthogonal
class ResidualLayer(nn.Module):
def __init__(self, units, activation=None):
super(ResidualLayer, self).__init__()
......@@ -9,9 +9,9 @@ class ResidualLayer(nn.Module):
self.activation = activation
self.dense_1 = nn.Linear(units, units)
self.dense_2 = nn.Linear(units, units)
self.reset_params()
def reset_params(self):
GlorotOrthogonal(self.dense_1.weight)
nn.init.zeros_(self.dense_1.bias)
......@@ -25,4 +25,4 @@ class ResidualLayer(nn.Module):
x = self.dense_2(x)
if self.activation is not None:
x = self.activation(x)
return inputs + x
\ No newline at end of file
return inputs + x
import sympy as sym
import torch
import torch.nn as nn
from modules.basis_utils import bessel_basis, real_sph_harm
from modules.envelope import Envelope
class SphericalBasisLayer(nn.Module):
def __init__(self,
num_spherical,
num_radial,
cutoff,
envelope_exponent=5):
def __init__(self, num_spherical, num_radial, cutoff, envelope_exponent=5):
super(SphericalBasisLayer, self).__init__()
assert num_radial <= 64
......@@ -20,26 +16,38 @@ class SphericalBasisLayer(nn.Module):
self.envelope = Envelope(envelope_exponent)
# retrieve formulas
self.bessel_formulas = bessel_basis(num_spherical, num_radial) # x, [num_spherical, num_radial] sympy functions
self.sph_harm_formulas = real_sph_harm(num_spherical) # theta, [num_spherical, ] sympy functions
self.bessel_formulas = bessel_basis(
num_spherical, num_radial
) # x, [num_spherical, num_radial] sympy functions
self.sph_harm_formulas = real_sph_harm(
num_spherical
) # theta, [num_spherical, ] sympy functions
self.sph_funcs = []
self.bessel_funcs = []
# convert to torch functions
x = sym.symbols('x')
theta = sym.symbols('theta')
modules = {'sin': torch.sin, 'cos': torch.cos}
x = sym.symbols("x")
theta = sym.symbols("theta")
modules = {"sin": torch.sin, "cos": torch.cos}
for i in range(num_spherical):
if i == 0:
first_sph = sym.lambdify([theta], self.sph_harm_formulas[i][0], modules)(0)
self.sph_funcs.append(lambda tensor: torch.zeros_like(tensor) + first_sph)
first_sph = sym.lambdify(
[theta], self.sph_harm_formulas[i][0], modules
)(0)
self.sph_funcs.append(
lambda tensor: torch.zeros_like(tensor) + first_sph
)
else:
self.sph_funcs.append(sym.lambdify([theta], self.sph_harm_formulas[i][0], modules))
self.sph_funcs.append(
sym.lambdify([theta], self.sph_harm_formulas[i][0], modules)
)
for j in range(num_radial):
self.bessel_funcs.append(sym.lambdify([x], self.bessel_formulas[i][j], modules))
self.bessel_funcs.append(
sym.lambdify([x], self.bessel_formulas[i][j], modules)
)
def get_bessel_funcs(self):
return self.bessel_funcs
def get_sph_funcs(self):
return self.sph_funcs
\ No newline at end of file
return self.sph_funcs
"""QM9 dataset for graph property prediction (regression)."""
import os
import numpy as np
import scipy.sparse as sp
import torch
import dgl
from tqdm import trange
import dgl
from dgl.convert import graph as dgl_graph
from dgl.data import QM9Dataset
from dgl.data.utils import load_graphs, save_graphs
from dgl.convert import graph as dgl_graph
class QM9(QM9Dataset):
r"""QM9 dataset for graph property prediction (regression)
......@@ -16,11 +18,11 @@ class QM9(QM9Dataset):
This dataset consists of 130,831 molecules with 12 regression targets.
Nodes correspond to atoms and edges correspond to bonds.
Reference:
Reference:
- `"Quantum-Machine.org" <http://quantum-machine.org/datasets/>`_
- `"Directional Message Passing for Molecular Graphs" <https://arxiv.org/abs/2003.03123>`_
Statistics:
- Number of graphs: 130,831
......@@ -53,7 +55,7 @@ class QM9(QM9Dataset):
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
| Cv | :math:`c_{\textrm{v}}` | Heat capavity at 298.15K | :math:`\frac{\textrm{cal}}{\textrm{mol K}}` |
+--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+
Parameters
----------
label_keys: list
......@@ -75,12 +77,12 @@ class QM9(QM9Dataset):
----------
num_labels : int
Number of labels for each graph, i.e. number of prediction tasks
Raises
------
UserWarning
If the raw data is changed in the remote server by the author.
Examples
--------
>>> data = QM9Dataset(label_keys=['mu', 'gap'], cutoff=5.0)
......@@ -95,70 +97,94 @@ class QM9(QM9Dataset):
>>>
"""
def __init__(self,
label_keys,
edge_funcs=None,
cutoff=5.0,
raw_dir=None,
force_reload=False,
verbose=False):
def __init__(
self,
label_keys,
edge_funcs=None,
cutoff=5.0,
raw_dir=None,
force_reload=False,
verbose=False,
):
self.edge_funcs = edge_funcs
self._keys = ['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv']
super(QM9, self).__init__(label_keys=label_keys,
cutoff=cutoff,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose)
self._keys = [
"mu",
"alpha",
"homo",
"lumo",
"gap",
"r2",
"zpve",
"U0",
"U",
"H",
"G",
"Cv",
]
super(QM9, self).__init__(
label_keys=label_keys,
cutoff=cutoff,
raw_dir=raw_dir,
force_reload=force_reload,
verbose=verbose,
)
def has_cache(self):
""" step 1, if True, goto step 5; else goto download(step 2), then step 3"""
graph_path = f'{self.save_path}/dgl_graph.bin'
line_graph_path = f'{self.save_path}/dgl_line_graph.bin'
"""step 1, if True, goto step 5; else goto download(step 2), then step 3"""
graph_path = f"{self.save_path}/dgl_graph.bin"
line_graph_path = f"{self.save_path}/dgl_line_graph.bin"
return os.path.exists(graph_path) and os.path.exists(line_graph_path)
def process(self):
""" step 3 """
npz_path = f'{self.raw_dir}/qm9_eV.npz'
"""step 3"""
npz_path = f"{self.raw_dir}/qm9_eV.npz"
data_dict = np.load(npz_path, allow_pickle=True)
# data_dict['N'] contains the number of atoms in each molecule,
# data_dict['R'] consists of the atomic coordinates,
# data_dict['Z'] consists of the atomic numbers.
# Atomic properties (Z and R) of all molecules are concatenated as single tensors,
# so you need this value to select the correct atoms for each molecule.
self.N = data_dict['N']
self.R = data_dict['R']
self.Z = data_dict['Z']
self.N = data_dict["N"]
self.R = data_dict["R"]
self.Z = data_dict["Z"]
self.N_cumsum = np.concatenate([[0], np.cumsum(self.N)])
# graph labels
self.label_dict = {}
for k in self._keys:
self.label_dict[k] = torch.tensor(data_dict[k], dtype=torch.float32)
self.label = torch.stack([self.label_dict[key] for key in self.label_keys], dim=1)
self.label = torch.stack(
[self.label_dict[key] for key in self.label_keys], dim=1
)
# graphs & features
self.graphs, self.line_graphs = self._load_graph()
def _load_graph(self):
num_graphs = self.label.shape[0]
graphs = []
line_graphs = []
for idx in trange(num_graphs):
n_atoms = self.N[idx]
# get all the atomic coordinates of the idx-th molecular graph
R = self.R[self.N_cumsum[idx]:self.N_cumsum[idx + 1]]
R = self.R[self.N_cumsum[idx] : self.N_cumsum[idx + 1]]
# calculate the distance between all atoms
dist = np.linalg.norm(R[:, None, :] - R[None, :, :], axis=-1)
# keep all edges that don't exceed the cutoff and delete self-loops
adj = sp.csr_matrix(dist <= self.cutoff) - sp.eye(n_atoms, dtype=np.bool)
adj = sp.csr_matrix(dist <= self.cutoff) - sp.eye(
n_atoms, dtype=np.bool
)
adj = adj.tocoo()
u, v = torch.tensor(adj.row), torch.tensor(adj.col)
g = dgl_graph((u, v))
g.ndata['R'] = torch.tensor(R, dtype=torch.float32)
g.ndata['Z'] = torch.tensor(self.Z[self.N_cumsum[idx]:self.N_cumsum[idx + 1]], dtype=torch.long)
g.ndata["R"] = torch.tensor(R, dtype=torch.float32)
g.ndata["Z"] = torch.tensor(
self.Z[self.N_cumsum[idx] : self.N_cumsum[idx + 1]],
dtype=torch.long,
)
# add user-defined features
if self.edge_funcs is not None:
for func in self.edge_funcs:
......@@ -167,32 +193,34 @@ class QM9(QM9Dataset):
graphs.append(g)
l_g = dgl.line_graph(g, backtracking=False)
line_graphs.append(l_g)
return graphs, line_graphs
def save(self):
""" step 4 """
graph_path = f'{self.save_path}/dgl_graph.bin'
line_graph_path = f'{self.save_path}/dgl_line_graph.bin'
"""step 4"""
graph_path = f"{self.save_path}/dgl_graph.bin"
line_graph_path = f"{self.save_path}/dgl_line_graph.bin"
save_graphs(str(graph_path), self.graphs, self.label_dict)
save_graphs(str(line_graph_path), self.line_graphs)
def load(self):
""" step 5 """
graph_path = f'{self.save_path}/dgl_graph.bin'
line_graph_path = f'{self.save_path}/dgl_line_graph.bin'
"""step 5"""
graph_path = f"{self.save_path}/dgl_graph.bin"
line_graph_path = f"{self.save_path}/dgl_line_graph.bin"
self.graphs, label_dict = load_graphs(graph_path)
self.line_graphs, _ = load_graphs(line_graph_path)
self.label = torch.stack([label_dict[key] for key in self.label_keys], dim=1)
self.label = torch.stack(
[label_dict[key] for key in self.label_keys], dim=1
)
def __getitem__(self, idx):
r""" Get graph and label by index
r"""Get graph and label by index
Parameters
----------
idx : int
Item index
Returns
-------
dgl.DGLGraph
......
import os
import ssl
from six.moves import urllib
import torch
import numpy as np
import torch
from six.moves import urllib
from torch.utils.data import DataLoader, Dataset
import dgl
from torch.utils.data import Dataset, DataLoader
def download_file(dataset):
print("Start Downloading data: {}".format(dataset))
url = "https://s3.us-west-2.amazonaws.com/dgl-data/dataset/{}".format(
dataset)
dataset
)
print("Start Downloading File....")
context = ssl._create_unverified_context()
data = urllib.request.urlopen(url, context=context)
......@@ -20,13 +23,13 @@ def download_file(dataset):
class SnapShotDataset(Dataset):
def __init__(self, path, npz_file):
if not os.path.exists(path+'/'+npz_file):
if not os.path.exists(path + "/" + npz_file):
if not os.path.exists(path):
os.mkdir(path)
download_file(npz_file)
zipfile = np.load(path+'/'+npz_file)
self.x = zipfile['x']
self.y = zipfile['y']
zipfile = np.load(path + "/" + npz_file)
self.x = zipfile["x"]
self.y = zipfile["y"]
def __len__(self):
return len(self.x)
......@@ -39,54 +42,52 @@ class SnapShotDataset(Dataset):
def METR_LAGraphDataset():
if not os.path.exists('data/graph_la.bin'):
if not os.path.exists('data'):
os.mkdir('data')
download_file('graph_la.bin')
g, _ = dgl.load_graphs('data/graph_la.bin')
if not os.path.exists("data/graph_la.bin"):
if not os.path.exists("data"):
os.mkdir("data")
download_file("graph_la.bin")
g, _ = dgl.load_graphs("data/graph_la.bin")
return g[0]
class METR_LATrainDataset(SnapShotDataset):
def __init__(self):
super(METR_LATrainDataset, self).__init__('data', 'metr_la_train.npz')
super(METR_LATrainDataset, self).__init__("data", "metr_la_train.npz")
self.mean = self.x[..., 0].mean()
self.std = self.x[..., 0].std()
class METR_LATestDataset(SnapShotDataset):
def __init__(self):
super(METR_LATestDataset, self).__init__('data', 'metr_la_test.npz')
super(METR_LATestDataset, self).__init__("data", "metr_la_test.npz")
class METR_LAValidDataset(SnapShotDataset):
def __init__(self):
super(METR_LAValidDataset, self).__init__('data', 'metr_la_valid.npz')
super(METR_LAValidDataset, self).__init__("data", "metr_la_valid.npz")
def PEMS_BAYGraphDataset():
if not os.path.exists('data/graph_bay.bin'):
if not os.path.exists('data'):
os.mkdir('data')
download_file('graph_bay.bin')
g, _ = dgl.load_graphs('data/graph_bay.bin')
if not os.path.exists("data/graph_bay.bin"):
if not os.path.exists("data"):
os.mkdir("data")
download_file("graph_bay.bin")
g, _ = dgl.load_graphs("data/graph_bay.bin")
return g[0]
class PEMS_BAYTrainDataset(SnapShotDataset):
def __init__(self):
super(PEMS_BAYTrainDataset, self).__init__(
'data', 'pems_bay_train.npz')
super(PEMS_BAYTrainDataset, self).__init__("data", "pems_bay_train.npz")
self.mean = self.x[..., 0].mean()
self.std = self.x[..., 0].std()
class PEMS_BAYTestDataset(SnapShotDataset):
def __init__(self):
super(PEMS_BAYTestDataset, self).__init__('data', 'pems_bay_test.npz')
super(PEMS_BAYTestDataset, self).__init__("data", "pems_bay_test.npz")
class PEMS_BAYValidDataset(SnapShotDataset):
def __init__(self):
super(PEMS_BAYValidDataset, self).__init__(
'data', 'pems_bay_valid.npz')
super(PEMS_BAYValidDataset, self).__init__("data", "pems_bay_valid.npz")
......@@ -2,13 +2,14 @@ import numpy as np
import scipy.sparse as sparse
import torch
import torch.nn as nn
import dgl
from dgl.base import DGLError
import dgl.function as fn
from dgl.base import DGLError
class DiffConv(nn.Module):
'''DiffConv is the implementation of diffusion convolution from paper DCRNN
"""DiffConv is the implementation of diffusion convolution from paper DCRNN
It will compute multiple diffusion matrix and perform multiple diffusion conv on it,
this layer can be used for traffic prediction, pedamic model.
Parameter
......@@ -25,20 +26,23 @@ class DiffConv(nn.Module):
dir : str [both/in/out]
direction of diffusion convolution
From paper default both direction
'''
"""
def __init__(self, in_feats, out_feats, k, in_graph_list, out_graph_list, dir='both'):
def __init__(
self, in_feats, out_feats, k, in_graph_list, out_graph_list, dir="both"
):
super(DiffConv, self).__init__()
self.in_feats = in_feats
self.out_feats = out_feats
self.k = k
self.dir = dir
self.num_graphs = self.k-1 if self.dir == 'both' else 2*self.k-2
self.num_graphs = self.k - 1 if self.dir == "both" else 2 * self.k - 2
self.project_fcs = nn.ModuleList()
for i in range(self.num_graphs):
self.project_fcs.append(
nn.Linear(self.in_feats, self.out_feats, bias=False))
self.merger = nn.Parameter(torch.randn(self.num_graphs+1))
nn.Linear(self.in_feats, self.out_feats, bias=False)
)
self.merger = nn.Parameter(torch.randn(self.num_graphs + 1))
self.in_graph_list = in_graph_list
self.out_graph_list = out_graph_list
......@@ -48,62 +52,68 @@ class DiffConv(nn.Module):
out_graph_list = []
in_graph_list = []
wadj, ind, outd = DiffConv.get_weight_matrix(g)
adj = sparse.coo_matrix(wadj/outd.cpu().numpy())
outg = dgl.from_scipy(adj, eweight_name='weight').to(device)
outg.edata['weight'] = outg.edata['weight'].float().to(device)
adj = sparse.coo_matrix(wadj / outd.cpu().numpy())
outg = dgl.from_scipy(adj, eweight_name="weight").to(device)
outg.edata["weight"] = outg.edata["weight"].float().to(device)
out_graph_list.append(outg)
for i in range(k-1):
out_graph_list.append(DiffConv.diffuse(
out_graph_list[-1], wadj, outd))
adj = sparse.coo_matrix(wadj.T/ind.cpu().numpy())
ing = dgl.from_scipy(adj, eweight_name='weight').to(device)
ing.edata['weight'] = ing.edata['weight'].float().to(device)
for i in range(k - 1):
out_graph_list.append(
DiffConv.diffuse(out_graph_list[-1], wadj, outd)
)
adj = sparse.coo_matrix(wadj.T / ind.cpu().numpy())
ing = dgl.from_scipy(adj, eweight_name="weight").to(device)
ing.edata["weight"] = ing.edata["weight"].float().to(device)
in_graph_list.append(ing)
for i in range(k-1):
in_graph_list.append(DiffConv.diffuse(
in_graph_list[-1], wadj.T, ind))
for i in range(k - 1):
in_graph_list.append(
DiffConv.diffuse(in_graph_list[-1], wadj.T, ind)
)
return out_graph_list, in_graph_list
@staticmethod
def get_weight_matrix(g):
adj = g.adj(scipy_fmt='coo')
adj = g.adj(scipy_fmt="coo")
ind = g.in_degrees()
outd = g.out_degrees()
weight = g.edata['weight']
weight = g.edata["weight"]
adj.data = weight.cpu().numpy()
return adj, ind, outd
@staticmethod
def diffuse(progress_g, weighted_adj, degree):
device = progress_g.device
progress_adj = progress_g.adj(scipy_fmt='coo')
progress_adj.data = progress_g.edata['weight'].cpu().numpy()
ret_adj = sparse.coo_matrix(progress_adj@(
weighted_adj/degree.cpu().numpy()))
ret_graph = dgl.from_scipy(ret_adj, eweight_name='weight').to(device)
ret_graph.edata['weight'] = ret_graph.edata['weight'].float().to(
device)
progress_adj = progress_g.adj(scipy_fmt="coo")
progress_adj.data = progress_g.edata["weight"].cpu().numpy()
ret_adj = sparse.coo_matrix(
progress_adj @ (weighted_adj / degree.cpu().numpy())
)
ret_graph = dgl.from_scipy(ret_adj, eweight_name="weight").to(device)
ret_graph.edata["weight"] = ret_graph.edata["weight"].float().to(device)
return ret_graph
def forward(self, g, x):
feat_list = []
if self.dir == 'both':
graph_list = self.in_graph_list+self.out_graph_list
elif self.dir == 'in':
if self.dir == "both":
graph_list = self.in_graph_list + self.out_graph_list
elif self.dir == "in":
graph_list = self.in_graph_list
elif self.dir == 'out':
elif self.dir == "out":
graph_list = self.out_graph_list
for i in range(self.num_graphs):
g = graph_list[i]
with g.local_scope():
g.ndata['n'] = self.project_fcs[i](x)
g.update_all(fn.u_mul_e('n', 'weight', 'e'),
fn.sum('e', 'feat'))
feat_list.append(g.ndata['feat'])
g.ndata["n"] = self.project_fcs[i](x)
g.update_all(
fn.u_mul_e("n", "weight", "e"), fn.sum("e", "feat")
)
feat_list.append(g.ndata["feat"])
# Each feat has shape [N,q_feats]
feat_list.append(self.project_fcs[-1](x))
feat_list = torch.cat(feat_list).view(
len(feat_list), -1, self.out_feats)
ret = (self.merger*feat_list.permute(1, 2, 0)).permute(2, 0, 1).mean(0)
len(feat_list), -1, self.out_feats
)
ret = (
(self.merger * feat_list.permute(1, 2, 0)).permute(2, 0, 1).mean(0)
)
return ret
import numpy as np
import torch
import torch.nn as nn
import dgl
import dgl.function as fn
import dgl.nn as dglnn
from dgl.base import DGLError
import dgl.function as fn
from dgl.nn.functional import edge_softmax
class WeightedGATConv(dglnn.GATConv):
'''
"""
This model inherit from dgl GATConv for traffic prediction task,
it add edge weight when aggregating the node feature.
'''
"""
def forward(self, graph, feat, get_attention=False):
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
if isinstance(feat, tuple):
h_src = self.feat_drop(feat[0])
h_dst = self.feat_drop(feat[1])
if not hasattr(self, 'fc_src'):
feat_src = self.fc(
h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc(
h_dst).view(-1, self._num_heads, self._out_feats)
if not hasattr(self, "fc_src"):
feat_src = self.fc(h_src).view(
-1, self._num_heads, self._out_feats
)
feat_dst = self.fc(h_dst).view(
-1, self._num_heads, self._out_feats
)
else:
feat_src = self.fc_src(
h_src).view(-1, self._num_heads, self._out_feats)
feat_dst = self.fc_dst(
h_dst).view(-1, self._num_heads, self._out_feats)
feat_src = self.fc_src(h_src).view(
-1, self._num_heads, self._out_feats
)
feat_dst = self.fc_dst(h_dst).view(
-1, self._num_heads, self._out_feats
)
else:
h_src = h_dst = self.feat_drop(feat)
feat_src = feat_dst = self.fc(h_src).view(
-1, self._num_heads, self._out_feats)
-1, self._num_heads, self._out_feats
)
if graph.is_block:
feat_dst = feat_src[:graph.number_of_dst_nodes()]
feat_dst = feat_src[: graph.number_of_dst_nodes()]
# NOTE: GAT paper uses "first concatenation then linear projection"
# to compute attention scores, while ours is "first projection then
# addition", the two approaches are mathematically equivalent:
......@@ -59,37 +67,38 @@ class WeightedGATConv(dglnn.GATConv):
# which further speeds up computation and saves memory footprint.
el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
graph.srcdata.update({'ft': feat_src, 'el': el})
graph.dstdata.update({'er': er})
graph.srcdata.update({"ft": feat_src, "el": el})
graph.dstdata.update({"er": er})
# compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
e = self.leaky_relu(graph.edata.pop('e'))
graph.apply_edges(fn.u_add_v("el", "er", "e"))
e = self.leaky_relu(graph.edata.pop("e"))
# compute softmax
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
# compute weighted attention
graph.edata['a'] = (graph.edata['a'].permute(
1, 2, 0)*graph.edata['weight']).permute(2, 0, 1)
graph.edata["a"] = (
graph.edata["a"].permute(1, 2, 0) * graph.edata["weight"]
).permute(2, 0, 1)
# message passing
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
fn.sum('m', 'ft'))
rst = graph.dstdata['ft']
graph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft"))
rst = graph.dstdata["ft"]
# residual
if self.res_fc is not None:
resval = self.res_fc(h_dst).view(
h_dst.shape[0], -1, self._out_feats)
h_dst.shape[0], -1, self._out_feats
)
rst = rst + resval
# activation
if self.activation:
rst = self.activation(rst)
if get_attention:
return rst, graph.edata['a']
return rst, graph.edata["a"]
else:
return rst
class GatedGAT(nn.Module):
'''Gated Graph Attention module, it is a general purpose
"""Gated Graph Attention module, it is a general purpose
graph attention module proposed in paper GaAN. The paper use
it for traffic prediction task
Parameter
......@@ -105,7 +114,7 @@ class GatedGAT(nn.Module):
num_heads : int
number of head for multihead attention
'''
"""
def __init__(self, in_feats, out_feats, map_feats, num_heads):
super(GatedGAT, self).__init__()
......@@ -113,28 +122,32 @@ class GatedGAT(nn.Module):
self.out_feats = out_feats
self.map_feats = map_feats
self.num_heads = num_heads
self.gatlayer = WeightedGATConv(self.in_feats,
self.out_feats,
self.num_heads)
self.gatlayer = WeightedGATConv(
self.in_feats, self.out_feats, self.num_heads
)
self.gate_fn = nn.Linear(
2*self.in_feats+self.map_feats, self.num_heads)
2 * self.in_feats + self.map_feats, self.num_heads
)
self.gate_m = nn.Linear(self.in_feats, self.map_feats)
self.merger_layer = nn.Linear(
self.in_feats+self.out_feats, self.out_feats)
self.in_feats + self.out_feats, self.out_feats
)
def forward(self, g, x):
with g.local_scope():
g.ndata['x'] = x
g.ndata['z'] = self.gate_m(x)
g.update_all(fn.copy_u('x', 'x'), fn.mean('x', 'mean_z'))
g.update_all(fn.copy_u('z', 'z'), fn.max('z', 'max_z'))
nft = torch.cat([g.ndata['x'], g.ndata['max_z'],
g.ndata['mean_z']], dim=1)
g.ndata["x"] = x
g.ndata["z"] = self.gate_m(x)
g.update_all(fn.copy_u("x", "x"), fn.mean("x", "mean_z"))
g.update_all(fn.copy_u("z", "z"), fn.max("z", "max_z"))
nft = torch.cat(
[g.ndata["x"], g.ndata["max_z"], g.ndata["mean_z"]], dim=1
)
gate = self.gate_fn(nft).sigmoid()
attn_out = self.gatlayer(g, x)
node_num = g.num_nodes()
gated_out = ((gate.view(-1)*attn_out.view(-1, self.out_feats).T).T).view(
node_num, self.num_heads, self.out_feats)
gated_out = (
(gate.view(-1) * attn_out.view(-1, self.out_feats).T).T
).view(node_num, self.num_heads, self.out_feats)
gated_out = gated_out.mean(1)
merge = self.merger_layer(torch.cat([x, gated_out], dim=1))
return merge
......@@ -2,15 +2,16 @@ import numpy as np
import scipy.sparse as sparse
import torch
import torch.nn as nn
import dgl
import dgl.function as fn
import dgl.nn as dglnn
from dgl.base import DGLError
import dgl.function as fn
from dgl.nn.functional import edge_softmax
class GraphGRUCell(nn.Module):
'''Graph GRU unit which can use any message passing
"""Graph GRU unit which can use any message passing
net to replace the linear layer in the original GRU
Parameter
==========
......@@ -22,7 +23,7 @@ class GraphGRUCell(nn.Module):
net : torch.nn.Module
message passing network
'''
"""
def __init__(self, in_feats, out_feats, net):
super(GraphGRUCell, self).__init__()
......@@ -30,28 +31,27 @@ class GraphGRUCell(nn.Module):
self.out_feats = out_feats
self.dir = dir
# net can be any GNN model
self.r_net = net(in_feats+out_feats, out_feats)
self.u_net = net(in_feats+out_feats, out_feats)
self.c_net = net(in_feats+out_feats, out_feats)
self.r_net = net(in_feats + out_feats, out_feats)
self.u_net = net(in_feats + out_feats, out_feats)
self.c_net = net(in_feats + out_feats, out_feats)
# Manually add bias Bias
self.r_bias = nn.Parameter(torch.rand(out_feats))
self.u_bias = nn.Parameter(torch.rand(out_feats))
self.c_bias = nn.Parameter(torch.rand(out_feats))
def forward(self, g, x, h):
r = torch.sigmoid(self.r_net(
g, torch.cat([x, h], dim=1)) + self.r_bias)
u = torch.sigmoid(self.u_net(
g, torch.cat([x, h], dim=1)) + self.u_bias)
h_ = r*h
c = torch.sigmoid(self.c_net(
g, torch.cat([x, h_], dim=1)) + self.c_bias)
new_h = u*h + (1-u)*c
r = torch.sigmoid(self.r_net(g, torch.cat([x, h], dim=1)) + self.r_bias)
u = torch.sigmoid(self.u_net(g, torch.cat([x, h], dim=1)) + self.u_bias)
h_ = r * h
c = torch.sigmoid(
self.c_net(g, torch.cat([x, h_], dim=1)) + self.c_bias
)
new_h = u * h + (1 - u) * c
return new_h
class StackedEncoder(nn.Module):
'''One step encoder unit for hidden representation generation
"""One step encoder unit for hidden representation generation
it can stack multiple vertical layers to increase the depth.
Parameter
......@@ -67,7 +67,7 @@ class StackedEncoder(nn.Module):
net : torch.nn.Module
message passing network for graph computation
'''
"""
def __init__(self, in_feats, out_feats, num_layers, net):
super(StackedEncoder, self).__init__()
......@@ -78,11 +78,13 @@ class StackedEncoder(nn.Module):
self.layers = nn.ModuleList()
if self.num_layers <= 0:
raise DGLError("Layer Number must be greater than 0! ")
self.layers.append(GraphGRUCell(
self.in_feats, self.out_feats, self.net))
for _ in range(self.num_layers-1):
self.layers.append(GraphGRUCell(
self.out_feats, self.out_feats, self.net))
self.layers.append(
GraphGRUCell(self.in_feats, self.out_feats, self.net)
)
for _ in range(self.num_layers - 1):
self.layers.append(
GraphGRUCell(self.out_feats, self.out_feats, self.net)
)
# hidden_states should be a list which for different layer
def forward(self, g, x, hidden_states):
......@@ -94,7 +96,7 @@ class StackedEncoder(nn.Module):
class StackedDecoder(nn.Module):
'''One step decoder unit for hidden representation generation
"""One step decoder unit for hidden representation generation
it can stack multiple vertical layers to increase the depth.
Parameter
......@@ -113,7 +115,7 @@ class StackedDecoder(nn.Module):
net : torch.nn.Module
message passing network for graph computation
'''
"""
def __init__(self, in_feats, hid_feats, out_feats, num_layers, net):
super(StackedDecoder, self).__init__()
......@@ -127,9 +129,10 @@ class StackedDecoder(nn.Module):
if self.num_layers <= 0:
raise DGLError("Layer Number must be greater than 0!")
self.layers.append(GraphGRUCell(self.in_feats, self.hid_feats, net))
for _ in range(self.num_layers-1):
self.layers.append(GraphGRUCell(
self.hid_feats, self.hid_feats, net))
for _ in range(self.num_layers - 1):
self.layers.append(
GraphGRUCell(self.hid_feats, self.hid_feats, net)
)
def forward(self, g, x, hidden_states):
hiddens = []
......@@ -141,7 +144,7 @@ class StackedDecoder(nn.Module):
class GraphRNN(nn.Module):
'''Graph Sequence to sequence prediction framework
"""Graph Sequence to sequence prediction framework
Support multiple backbone GNN. Mainly used for traffic prediction.
Parameter
......@@ -163,15 +166,11 @@ class GraphRNN(nn.Module):
decay_steps : int
number of steps for the teacher forcing probability to decay
'''
def __init__(self,
in_feats,
out_feats,
seq_len,
num_layers,
net,
decay_steps):
"""
def __init__(
self, in_feats, out_feats, seq_len, num_layers, net, decay_steps
):
super(GraphRNN, self).__init__()
self.in_feats = in_feats
self.out_feats = out_feats
......@@ -180,24 +179,30 @@ class GraphRNN(nn.Module):
self.net = net
self.decay_steps = decay_steps
self.encoder = StackedEncoder(self.in_feats,
self.out_feats,
self.num_layers,
self.net)
self.encoder = StackedEncoder(
self.in_feats, self.out_feats, self.num_layers, self.net
)
self.decoder = StackedDecoder(
self.in_feats,
self.out_feats,
self.in_feats,
self.num_layers,
self.net,
)
self.decoder = StackedDecoder(self.in_feats,
self.out_feats,
self.in_feats,
self.num_layers,
self.net)
# Threshold For Teacher Forcing
def compute_thresh(self, batch_cnt):
return self.decay_steps/(self.decay_steps + np.exp(batch_cnt / self.decay_steps))
return self.decay_steps / (
self.decay_steps + np.exp(batch_cnt / self.decay_steps)
)
def encode(self, g, inputs, device):
hidden_states = [torch.zeros(g.num_nodes(), self.out_feats).to(
device) for _ in range(self.num_layers)]
hidden_states = [
torch.zeros(g.num_nodes(), self.out_feats).to(device)
for _ in range(self.num_layers)
]
for i in range(self.seq_len):
_, hidden_states = self.encoder(g, inputs[i], hidden_states)
......@@ -207,9 +212,13 @@ class GraphRNN(nn.Module):
outputs = []
inputs = torch.zeros(g.num_nodes(), self.in_feats).to(device)
for i in range(self.seq_len):
if np.random.random() < self.compute_thresh(batch_cnt) and self.training:
if (
np.random.random() < self.compute_thresh(batch_cnt)
and self.training
):
inputs, hidden_states = self.decoder(
g, teacher_states[i], hidden_states)
g, teacher_states[i], hidden_states
)
else:
inputs, hidden_states = self.decoder(g, inputs, hidden_states)
outputs.append(inputs)
......
from functools import partial
import argparse
from functools import partial
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import dgl
from model import GraphRNN
from dataloading import (METR_LAGraphDataset, METR_LATestDataset,
METR_LATrainDataset, METR_LAValidDataset,
PEMS_BAYGraphDataset, PEMS_BAYTestDataset,
PEMS_BAYTrainDataset, PEMS_BAYValidDataset)
from dcrnn import DiffConv
from gaan import GatedGAT
from dataloading import METR_LAGraphDataset, METR_LATrainDataset,\
METR_LATestDataset, METR_LAValidDataset,\
PEMS_BAYGraphDataset, PEMS_BAYTrainDataset,\
PEMS_BAYValidDataset, PEMS_BAYTestDataset
from utils import NormalizationLayer, masked_mae_loss, get_learning_rate
from model import GraphRNN
from torch.utils.data import DataLoader
from utils import NormalizationLayer, get_learning_rate, masked_mae_loss
import dgl
batch_cnt = [0]
def train(model, graph, dataloader, optimizer, scheduler, normalizer, loss_fn, device, args):
def train(
model,
graph,
dataloader,
optimizer,
scheduler,
normalizer,
loss_fn,
device,
args,
):
total_loss = []
graph = graph.to(device)
model.train()
......@@ -27,29 +39,37 @@ def train(model, graph, dataloader, optimizer, scheduler, normalizer, loss_fn, d
# Padding: Since the diffusion graph is precmputed we need to pad the batch so that
# each batch have same batch size
if x.shape[0] != batch_size:
x_buff = torch.zeros(
batch_size, x.shape[1], x.shape[2], x.shape[3])
y_buff = torch.zeros(
batch_size, x.shape[1], x.shape[2], x.shape[3])
x_buff[:x.shape[0], :, :, :] = x
x_buff[x.shape[0]:, :, :,
:] = x[-1].repeat(batch_size-x.shape[0], 1, 1, 1)
y_buff[:x.shape[0], :, :, :] = y
y_buff[x.shape[0]:, :, :,
:] = y[-1].repeat(batch_size-x.shape[0], 1, 1, 1)
x_buff = torch.zeros(batch_size, x.shape[1], x.shape[2], x.shape[3])
y_buff = torch.zeros(batch_size, x.shape[1], x.shape[2], x.shape[3])
x_buff[: x.shape[0], :, :, :] = x
x_buff[x.shape[0] :, :, :, :] = x[-1].repeat(
batch_size - x.shape[0], 1, 1, 1
)
y_buff[: x.shape[0], :, :, :] = y
y_buff[x.shape[0] :, :, :, :] = y[-1].repeat(
batch_size - x.shape[0], 1, 1, 1
)
x = x_buff
y = y_buff
# Permute the dimension for shaping
x = x.permute(1, 0, 2, 3)
y = y.permute(1, 0, 2, 3)
x_norm = normalizer.normalize(x).reshape(
x.shape[0], -1, x.shape[3]).float().to(device)
y_norm = normalizer.normalize(y).reshape(
x.shape[0], -1, x.shape[3]).float().to(device)
x_norm = (
normalizer.normalize(x)
.reshape(x.shape[0], -1, x.shape[3])
.float()
.to(device)
)
y_norm = (
normalizer.normalize(y)
.reshape(x.shape[0], -1, x.shape[3])
.float()
.to(device)
)
y = y.reshape(y.shape[0], -1, y.shape[3]).float().to(device)
batch_graph = dgl.batch([graph]*batch_size)
batch_graph = dgl.batch([graph] * batch_size)
output = model(batch_graph, x_norm, y_norm, batch_cnt[0], device)
# Denormalization for loss compute
y_pred = normalizer.denormalize(output)
......@@ -74,29 +94,37 @@ def eval(model, graph, dataloader, normalizer, loss_fn, device, args):
# Padding: Since the diffusion graph is precmputed we need to pad the batch so that
# each batch have same batch size
if x.shape[0] != batch_size:
x_buff = torch.zeros(
batch_size, x.shape[1], x.shape[2], x.shape[3])
y_buff = torch.zeros(
batch_size, x.shape[1], x.shape[2], x.shape[3])
x_buff[:x.shape[0], :, :, :] = x
x_buff[x.shape[0]:, :, :,
:] = x[-1].repeat(batch_size-x.shape[0], 1, 1, 1)
y_buff[:x.shape[0], :, :, :] = y
y_buff[x.shape[0]:, :, :,
:] = y[-1].repeat(batch_size-x.shape[0], 1, 1, 1)
x_buff = torch.zeros(batch_size, x.shape[1], x.shape[2], x.shape[3])
y_buff = torch.zeros(batch_size, x.shape[1], x.shape[2], x.shape[3])
x_buff[: x.shape[0], :, :, :] = x
x_buff[x.shape[0] :, :, :, :] = x[-1].repeat(
batch_size - x.shape[0], 1, 1, 1
)
y_buff[: x.shape[0], :, :, :] = y
y_buff[x.shape[0] :, :, :, :] = y[-1].repeat(
batch_size - x.shape[0], 1, 1, 1
)
x = x_buff
y = y_buff
# Permute the order of dimension
x = x.permute(1, 0, 2, 3)
y = y.permute(1, 0, 2, 3)
x_norm = normalizer.normalize(x).reshape(
x.shape[0], -1, x.shape[3]).float().to(device)
y_norm = normalizer.normalize(y).reshape(
x.shape[0], -1, x.shape[3]).float().to(device)
x_norm = (
normalizer.normalize(x)
.reshape(x.shape[0], -1, x.shape[3])
.float()
.to(device)
)
y_norm = (
normalizer.normalize(y)
.reshape(x.shape[0], -1, x.shape[3])
.float()
.to(device)
)
y = y.reshape(x.shape[0], -1, x.shape[3]).to(device)
batch_graph = dgl.batch([graph]*batch_size)
batch_graph = dgl.batch([graph] * batch_size)
output = model(batch_graph, x_norm, y_norm, i, device)
y_pred = normalizer.denormalize(output)
loss = loss_fn(y_pred, y)
......@@ -107,70 +135,124 @@ def eval(model, graph, dataloader, normalizer, loss_fn, device, args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Define the arguments
parser.add_argument('--batch_size', type=int, default=64,
help="Size of batch for minibatch Training")
parser.add_argument('--num_workers', type=int, default=0,
help="Number of workers for parallel dataloading")
parser.add_argument('--model', type=str, default='dcrnn',
help="WHich model to use DCRNN vs GaAN")
parser.add_argument('--gpu', type=int, default=-1,
help="GPU indexm -1 for CPU training")
parser.add_argument('--diffsteps', type=int, default=2,
help="Step of constructing the diffusiob matrix")
parser.add_argument('--num_heads', type=int, default=2,
help="Number of multiattention head")
parser.add_argument('--decay_steps', type=int, default=2000,
help="Teacher forcing probability decay ratio")
parser.add_argument('--lr', type=float, default=0.01,
help="Initial learning rate")
parser.add_argument('--minimum_lr', type=float, default=2e-6,
help="Lower bound of learning rate")
parser.add_argument('--dataset', type=str, default='LA',
help="dataset LA for METR_LA; BAY for PEMS_BAY")
parser.add_argument('--epochs', type=int, default=100,
help="Number of epoches for training")
parser.add_argument('--max_grad_norm', type=float, default=5.0,
help="Maximum gradient norm for update parameters")
parser.add_argument(
"--batch_size",
type=int,
default=64,
help="Size of batch for minibatch Training",
)
parser.add_argument(
"--num_workers",
type=int,
default=0,
help="Number of workers for parallel dataloading",
)
parser.add_argument(
"--model",
type=str,
default="dcrnn",
help="WHich model to use DCRNN vs GaAN",
)
parser.add_argument(
"--gpu", type=int, default=-1, help="GPU indexm -1 for CPU training"
)
parser.add_argument(
"--diffsteps",
type=int,
default=2,
help="Step of constructing the diffusiob matrix",
)
parser.add_argument(
"--num_heads", type=int, default=2, help="Number of multiattention head"
)
parser.add_argument(
"--decay_steps",
type=int,
default=2000,
help="Teacher forcing probability decay ratio",
)
parser.add_argument(
"--lr", type=float, default=0.01, help="Initial learning rate"
)
parser.add_argument(
"--minimum_lr",
type=float,
default=2e-6,
help="Lower bound of learning rate",
)
parser.add_argument(
"--dataset",
type=str,
default="LA",
help="dataset LA for METR_LA; BAY for PEMS_BAY",
)
parser.add_argument(
"--epochs", type=int, default=100, help="Number of epoches for training"
)
parser.add_argument(
"--max_grad_norm",
type=float,
default=5.0,
help="Maximum gradient norm for update parameters",
)
args = parser.parse_args()
# Load the datasets
if args.dataset == 'LA':
if args.dataset == "LA":
g = METR_LAGraphDataset()
train_data = METR_LATrainDataset()
test_data = METR_LATestDataset()
valid_data = METR_LAValidDataset()
elif args.dataset == 'BAY':
elif args.dataset == "BAY":
g = PEMS_BAYGraphDataset()
train_data = PEMS_BAYTrainDataset()
test_data = PEMS_BAYTestDataset()
valid_data = PEMS_BAYValidDataset()
if args.gpu == -1:
device = torch.device('cpu')
device = torch.device("cpu")
else:
device = torch.device('cuda:{}'.format(args.gpu))
device = torch.device("cuda:{}".format(args.gpu))
train_loader = DataLoader(
train_data, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True)
train_data,
batch_size=args.batch_size,
num_workers=args.num_workers,
shuffle=True,
)
valid_loader = DataLoader(
valid_data, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True)
valid_data,
batch_size=args.batch_size,
num_workers=args.num_workers,
shuffle=True,
)
test_loader = DataLoader(
test_data, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=True)
test_data,
batch_size=args.batch_size,
num_workers=args.num_workers,
shuffle=True,
)
normalizer = NormalizationLayer(train_data.mean, train_data.std)
if args.model == 'dcrnn':
batch_g = dgl.batch([g]*args.batch_size).to(device)
if args.model == "dcrnn":
batch_g = dgl.batch([g] * args.batch_size).to(device)
out_gs, in_gs = DiffConv.attach_graph(batch_g, args.diffsteps)
net = partial(DiffConv, k=args.diffsteps,
in_graph_list=in_gs, out_graph_list=out_gs)
elif args.model == 'gaan':
net = partial(
DiffConv,
k=args.diffsteps,
in_graph_list=in_gs,
out_graph_list=out_gs,
)
elif args.model == "gaan":
net = partial(GatedGAT, map_feats=64, num_heads=args.num_heads)
dcrnn = GraphRNN(in_feats=2,
out_feats=64,
seq_len=12,
num_layers=2,
net=net,
decay_steps=args.decay_steps).to(device)
dcrnn = GraphRNN(
in_feats=2,
out_feats=64,
seq_len=12,
num_layers=2,
net=net,
decay_steps=args.decay_steps,
).to(device)
optimizer = torch.optim.Adam(dcrnn.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
......@@ -178,13 +260,25 @@ if __name__ == "__main__":
loss_fn = masked_mae_loss
for e in range(args.epochs):
train_loss = train(dcrnn, g, train_loader, optimizer, scheduler,
normalizer, loss_fn, device, args)
valid_loss = eval(dcrnn, g, valid_loader,
normalizer, loss_fn, device, args)
test_loss = eval(dcrnn, g, test_loader,
normalizer, loss_fn, device, args)
print("Epoch: {} Train Loss: {} Valid Loss: {} Test Loss: {}".format(e,
train_loss,
valid_loss,
test_loss))
train_loss = train(
dcrnn,
g,
train_loader,
optimizer,
scheduler,
normalizer,
loss_fn,
device,
args,
)
valid_loss = eval(
dcrnn, g, valid_loader, normalizer, loss_fn, device, args
)
test_loss = eval(
dcrnn, g, test_loader, normalizer, loss_fn, device, args
)
print(
"Epoch: {} Train Loss: {} Valid Loss: {} Test Loss: {}".format(
e, train_loss, valid_loss, test_loss
)
)
import dgl
import scipy.sparse as sparse
import numpy as np
import torch.nn as nn
import scipy.sparse as sparse
import torch
import torch.nn as nn
import dgl
class NormalizationLayer(nn.Module):
......@@ -12,10 +13,10 @@ class NormalizationLayer(nn.Module):
# Here we shall expect mean and std be scaler
def normalize(self, x):
return (x-self.mean)/self.std
return (x - self.mean) / self.std
def denormalize(self, x):
return x*self.std + self.mean
return x * self.std + self.mean
def masked_mae_loss(y_pred, y_true):
......@@ -30,4 +31,4 @@ def masked_mae_loss(y_pred, y_true):
def get_learning_rate(optimizer):
for param in optimizer.param_groups:
return param['lr']
return param["lr"]
import torch
import numpy as np
import math
from itertools import product
import numpy as np
import pandas as pd
import torch
import dgl
from dgl.data import DGLDataset
from itertools import product
class EEGGraphDataset(DGLDataset):
""" Build graph, treat all nodes as the same type
Parameters
----------
x: edge weights of 8-node complete graph
There are 1 x 64 edges
y: labels (diseased/healthy)
num_nodes: the number of nodes of the graph. In our case, it is 8.
indices: Patient level indices. They are used to generate edge weights.
Output
------
a complete 8-node DGLGraph with node features and edge weights
"""Build graph, treat all nodes as the same type
Parameters
----------
x: edge weights of 8-node complete graph
There are 1 x 64 edges
y: labels (diseased/healthy)
num_nodes: the number of nodes of the graph. In our case, it is 8.
indices: Patient level indices. They are used to generate edge weights.
Output
------
a complete 8-node DGLGraph with node features and edge weights
"""
def __init__(self, x, y, num_nodes, indices):
# CAUTION - x and labels are memory-mapped, used as if they are in RAM.
self.x = x
......@@ -37,27 +40,24 @@ class EEGGraphDataset(DGLDataset):
"P7-P3",
"P8-P4",
"O1-P3",
"O2-P4"
"O2-P4",
]
# in the 10-10 system, in between the 2 10-20 electrodes in ch_names, used for calculating edge weights
# Note: "01" is for "P03", and "02" is for "P04."
self.ref_names = [
"F5",
"F6",
"C5",
"C6",
"P5",
"P6",
"O1",
"O2"
]
self.ref_names = ["F5", "F6", "C5", "C6", "P5", "P6", "O1", "O2"]
# edge indices source to target - 2 x E = 2 x 64
# fully connected undirected graph so 8*8=64 edges
self.node_ids = range(len(self.ch_names))
self.edge_index = torch.tensor([[a, b] for a, b in product(self.node_ids, self.node_ids)],
dtype=torch.long).t().contiguous()
self.edge_index = (
torch.tensor(
[[a, b] for a, b in product(self.node_ids, self.node_ids)],
dtype=torch.long,
)
.t()
.contiguous()
)
# edge attributes - E x 1
# only the spatial distance between electrodes for now - standardize between 0 and 1
......@@ -68,18 +68,22 @@ class EEGGraphDataset(DGLDataset):
# sensor distances don't depend on window ID
def get_sensor_distances(self):
coords_1010 = pd.read_csv("standard_1010.tsv.txt", sep='\t')
coords_1010 = pd.read_csv("standard_1010.tsv.txt", sep="\t")
num_edges = self.edge_index.shape[1]
distances = []
for edge_idx in range(num_edges):
sensor1_idx = self.edge_index[0, edge_idx]
sensor2_idx = self.edge_index[1, edge_idx]
dist = self.get_geodesic_distance(sensor1_idx, sensor2_idx, coords_1010)
dist = self.get_geodesic_distance(
sensor1_idx, sensor2_idx, coords_1010
)
distances.append(dist)
assert len(distances) == num_edges
return distances
def get_geodesic_distance(self, montage_sensor1_idx, montage_sensor2_idx, coords_1010):
def get_geodesic_distance(
self, montage_sensor1_idx, montage_sensor2_idx, coords_1010
):
# get the reference sensor in the 10-10 system for the current montage pair in 10-20 system
ref_sensor1 = self.ref_names[montage_sensor1_idx]
......@@ -96,7 +100,9 @@ class EEGGraphDataset(DGLDataset):
# https://math.stackexchange.com/questions/1304169/distance-between-two-points-on-a-sphere
r = 1 # since coords are on unit sphere
# rounding is for numerical stability, domain is [-1, 1]
dist = r * math.acos(round(((x1 * x2) + (y1 * y2) + (z1 * z2)) / (r ** 2), 2))
dist = r * math.acos(
round(((x1 * x2) + (y1 * y2) + (z1 * z2)) / (r**2), 2)
)
return dist
# returns size of dataset = number of indices
......@@ -123,19 +129,23 @@ class EEGGraphDataset(DGLDataset):
edge_weights = torch.tensor(edge_weights) # trucated to integer
# create 8-node complete graph
src = [[0 for i in range(self.num_nodes)] for j in range(self.num_nodes)]
src = [
[0 for i in range(self.num_nodes)] for j in range(self.num_nodes)
]
for i in range(len(src)):
for j in range(len(src[i])):
src[i][j] = i
src = np.array(src).flatten()
det = [[i for i in range(self.num_nodes)] for j in range(self.num_nodes)]
det = [
[i for i in range(self.num_nodes)] for j in range(self.num_nodes)
]
det = np.array(det).flatten()
u, v = (torch.tensor(src), torch.tensor(det))
g = dgl.graph((u, v))
# add node features and edge features
g.ndata['x'] = node_features
g.edata['edge_weights'] = edge_weights
g.ndata["x"] = node_features
g.edata["edge_weights"] = edge_weights
return g, torch.tensor(idx), torch.tensor(self.labels[idx])
import torch.nn as nn
import torch.nn.functional as function
from dgl.nn import GraphConv, SumPooling
from torch.nn import BatchNorm1d
from dgl.nn import GraphConv, SumPooling
class EEGGraphConvNet(nn.Module):
""" EEGGraph Convolution Net
Parameters
----------
num_feats: the number of features per node. In our case, it is 6.
"""EEGGraph Convolution Net
Parameters
----------
num_feats: the number of features per node. In our case, it is 6.
"""
def __init__(self, num_feats):
super(EEGGraphConvNet, self).__init__()
......@@ -17,7 +19,9 @@ class EEGGraphConvNet(nn.Module):
self.conv2 = GraphConv(16, 32)
self.conv3 = GraphConv(32, 64)
self.conv4 = GraphConv(64, 50)
self.conv4_bn = BatchNorm1d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.conv4_bn = BatchNorm1d(
50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
)
self.fc_block1 = nn.Linear(50, 30)
self.fc_block2 = nn.Linear(30, 10)
......@@ -31,8 +35,8 @@ class EEGGraphConvNet(nn.Module):
self.sumpool = SumPooling()
def forward(self, g, return_graph_embedding=False):
x = g.ndata['x']
edge_weight = g.edata['edge_weights']
x = g.ndata["x"]
edge_weight = g.edata["edge_weights"]
x = self.conv1(g, x, edge_weight=edge_weight)
x = function.leaky_relu(x, negative_slope=0.01)
......
import argparse
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from joblib import load
from EEGGraphDataset import EEGGraphDataset
from dgl.dataloading import GraphDataLoader
from torch.utils.data import WeightedRandomSampler
from sklearn.metrics import roc_auc_score
from sklearn.metrics import balanced_accuracy_score
from joblib import load
from sklearn import preprocessing
from sklearn.metrics import balanced_accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_split
from torch.utils.data import WeightedRandomSampler
from dgl.dataloading import GraphDataLoader
if __name__ == "__main__":
# argparse commandline args
parser = argparse.ArgumentParser(description='Execute training pipeline on a given train/val subjects')
parser.add_argument('--num_feats', type=int, default=6, help='Number of features per node for the graph')
parser.add_argument('--num_nodes', type=int, default=8, help='Number of nodes in the graph')
parser.add_argument('--gpu_idx', type=int, default=0,
help='index of GPU device that should be used for this run, defaults to 0.')
parser.add_argument('--num_epochs', type=int, default=40, help='Number of epochs used to train')
parser.add_argument('--exp_name', type=str, default='default', help='Name for the test.')
parser.add_argument('--batch_size', type=int, default=512, help='Batch Size. Default is 512.')
parser.add_argument('--model', type=str, default='shallow',
help='type shallow to use shallow_EEGGraphDataset; '
'type deep to use deep_EEGGraphDataset. Default is shallow')
parser = argparse.ArgumentParser(
description="Execute training pipeline on a given train/val subjects"
)
parser.add_argument(
"--num_feats",
type=int,
default=6,
help="Number of features per node for the graph",
)
parser.add_argument(
"--num_nodes", type=int, default=8, help="Number of nodes in the graph"
)
parser.add_argument(
"--gpu_idx",
type=int,
default=0,
help="index of GPU device that should be used for this run, defaults to 0.",
)
parser.add_argument(
"--num_epochs",
type=int,
default=40,
help="Number of epochs used to train",
)
parser.add_argument(
"--exp_name", type=str, default="default", help="Name for the test."
)
parser.add_argument(
"--batch_size",
type=int,
default=512,
help="Batch Size. Default is 512.",
)
parser.add_argument(
"--model",
type=str,
default="shallow",
help="type shallow to use shallow_EEGGraphDataset; "
"type deep to use deep_EEGGraphDataset. Default is shallow",
)
args = parser.parse_args()
# choose model
if args.model == 'shallow':
if args.model == "shallow":
from shallow_EEGGraphConvNet import EEGGraphConvNet
if args.model == 'deep':
if args.model == "deep":
from deep_EEGGraphConvNet import EEGGraphConvNet
# set the random seed so that we can reproduce the results
......@@ -41,9 +70,11 @@ if __name__ == "__main__":
# use GPU when available
_GPU_IDX = args.gpu_idx
_DEVICE = torch.device(f'cuda:{_GPU_IDX}' if torch.cuda.is_available() else 'cpu')
_DEVICE = torch.device(
f"cuda:{_GPU_IDX}" if torch.cuda.is_available() else "cpu"
)
torch.cuda.set_device(_DEVICE)
print(f' Using device: {_DEVICE} {torch.cuda.get_device_name(_DEVICE)}')
print(f" Using device: {_DEVICE} {torch.cuda.get_device_name(_DEVICE)}")
# load patient level indices
_DATASET_INDEX = pd.read_csv("master_metadata_index.csv")
......@@ -58,10 +89,10 @@ if __name__ == "__main__":
num_feats = args.num_feats
# set up input and targets from files
memmap_x = f'psd_features_data_X'
memmap_y = f'labels_y'
x = load(memmap_x, mmap_mode='r')
y = load(memmap_y, mmap_mode='r')
memmap_x = f"psd_features_data_X"
memmap_y = f"labels_y"
x = load(memmap_x, mmap_mode="r")
y = load(memmap_y, mmap_mode="r")
# normalize psd features data
normd_x = []
......@@ -79,22 +110,33 @@ if __name__ == "__main__":
print(f"Unique labels 0/1 mapping: {label_mapping}")
# split the dataset to train and test. The ratio of test is 0.3.
train_and_val_subjects, heldout_subjects = train_test_split(all_subjects, test_size=0.3, random_state=42)
train_and_val_subjects, heldout_subjects = train_test_split(
all_subjects, test_size=0.3, random_state=42
)
# split the dataset using patient indices
train_window_indices = _DATASET_INDEX.index[
_DATASET_INDEX["patient_ID"].astype("str").isin(train_and_val_subjects)].tolist()
_DATASET_INDEX["patient_ID"].astype("str").isin(train_and_val_subjects)
].tolist()
heldout_test_window_indices = _DATASET_INDEX.index[
_DATASET_INDEX["patient_ID"].astype("str").isin(heldout_subjects)].tolist()
_DATASET_INDEX["patient_ID"].astype("str").isin(heldout_subjects)
].tolist()
# define model, optimizer, scheduler
model = EEGGraphConvNet(num_feats)
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[i * 10 for i in range(1, 26)], gamma=0.1)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[i * 10 for i in range(1, 26)], gamma=0.1
)
model = model.to(_DEVICE).double()
num_trainable_params = np.sum([np.prod(p.size()) if p.requires_grad else 0 for p in model.parameters()])
num_trainable_params = np.sum(
[
np.prod(p.size()) if p.requires_grad else 0
for p in model.parameters()
]
)
# Dataloader========================================================================================================
......@@ -109,7 +151,8 @@ if __name__ == "__main__":
# sampler needs to come up with training set size number of samples
weighted_sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(train_window_indices), replacement=True
num_samples=len(train_window_indices),
replacement=True,
)
# train data loader
......@@ -118,17 +161,20 @@ if __name__ == "__main__":
)
train_loader = GraphDataLoader(
dataset=train_dataset, batch_size=_BATCH_SIZE,
dataset=train_dataset,
batch_size=_BATCH_SIZE,
sampler=weighted_sampler,
num_workers=NUM_WORKERS,
pin_memory=True
pin_memory=True,
)
# this loader is used without weighted sampling, to evaluate metrics on full training set after each epoch
train_metrics_loader = GraphDataLoader(
dataset=train_dataset, batch_size=_BATCH_SIZE,
shuffle=False, num_workers=NUM_WORKERS,
pin_memory=True
dataset=train_dataset,
batch_size=_BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
pin_memory=True,
)
# test data loader
......@@ -137,9 +183,11 @@ if __name__ == "__main__":
)
test_loader = GraphDataLoader(
dataset=test_dataset, batch_size=_BATCH_SIZE,
shuffle=False, num_workers=NUM_WORKERS,
pin_memory=True
dataset=test_dataset,
batch_size=_BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
pin_memory=True,
)
auroc_train_history = []
......@@ -173,7 +221,7 @@ if __name__ == "__main__":
# update learning rate
scheduler.step()
# evaluate model after each epoch for train-metric data============================================================
# evaluate model after each epoch for train-metric data============================================================
model.eval()
with torch.no_grad():
y_probs_train = torch.empty(0, 2).to(_DEVICE)
......@@ -194,10 +242,12 @@ if __name__ == "__main__":
y_true_train += y_batch.cpu().numpy().tolist()
# returning prob distribution over target classes, take softmax over the 1st dimension
y_probs_train = nn.functional.softmax(y_probs_train, dim=1).cpu().numpy()
y_probs_train = (
nn.functional.softmax(y_probs_train, dim=1).cpu().numpy()
)
y_true_train = np.array(y_true_train)
# evaluate model after each epoch for validation data ==============================================================
# evaluate model after each epoch for validation data ==============================================================
y_probs_test = torch.empty(0, 2).to(_DEVICE)
y_true_test, minibatch_loss, y_pred_test = [], [], []
......@@ -217,32 +267,54 @@ if __name__ == "__main__":
y_true_test += y_batch.cpu().numpy().tolist()
# returning prob distribution over target classes, take softmax over the 1st dimension
y_probs_test = torch.nn.functional.softmax(y_probs_test, dim=1).cpu().numpy()
y_probs_test = (
torch.nn.functional.softmax(y_probs_test, dim=1).cpu().numpy()
)
y_true_test = np.array(y_true_test)
# record training auroc and testing auroc
auroc_train_history.append(roc_auc_score(y_true_train, y_probs_train[:, 1]))
auroc_test_history.append(roc_auc_score(y_true_test, y_probs_test[:, 1]))
auroc_train_history.append(
roc_auc_score(y_true_train, y_probs_train[:, 1])
)
auroc_test_history.append(
roc_auc_score(y_true_test, y_probs_test[:, 1])
)
# record training balanced accuracy and testing balanced accuracy
balACC_train_history.append(balanced_accuracy_score(y_true_train, y_pred_train))
balACC_test_history.append(balanced_accuracy_score(y_true_test, y_pred_test))
balACC_train_history.append(
balanced_accuracy_score(y_true_train, y_pred_train)
)
balACC_test_history.append(
balanced_accuracy_score(y_true_test, y_pred_test)
)
# LOSS - epoch loss is defined as mean of minibatch losses within epoch
loss_train_history.append(np.mean(train_loss))
loss_test_history.append(np.mean(minibatch_loss))
# print the metrics
print("Train loss: {}, test loss: {}".format(loss_train_history[-1], loss_test_history[-1]))
print("Train AUC: {}, test AUC: {}".format(auroc_train_history[-1], auroc_test_history[-1]))
print("Train Bal.ACC: {}, test Bal.ACC: {}".format(balACC_train_history[-1], balACC_test_history[-1]))
print(
"Train loss: {}, test loss: {}".format(
loss_train_history[-1], loss_test_history[-1]
)
)
print(
"Train AUC: {}, test AUC: {}".format(
auroc_train_history[-1], auroc_test_history[-1]
)
)
print(
"Train Bal.ACC: {}, test Bal.ACC: {}".format(
balACC_train_history[-1], balACC_test_history[-1]
)
)
# save model from each epoch====================================================================================
state = {
'epochs': _NUM_EPOCHS,
'experiment_name': _EXPERIMENT_NAME,
'model_description': str(model),
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict()
"epochs": _NUM_EPOCHS,
"experiment_name": _EXPERIMENT_NAME,
"model_description": str(model),
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
torch.save(state, f"{_EXPERIMENT_NAME}_Epoch_{epoch}.ckpt")
import torch.nn as nn
import torch.nn.functional as function
from dgl.nn import GraphConv, SumPooling
class EEGGraphConvNet(nn.Module):
""" EEGGraph Convolution Net
Parameters
----------
num_feats: the number of features per node. In our case, it is 6.
"""EEGGraph Convolution Net
Parameters
----------
num_feats: the number of features per node. In our case, it is 6.
"""
def __init__(self, num_feats):
super(EEGGraphConvNet, self).__init__()
self.conv1 = GraphConv(num_feats, 32)
self.conv2 = GraphConv(32, 20)
self.conv2_bn = nn.BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
self.conv2_bn = nn.BatchNorm1d(
20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
)
self.fc_block1 = nn.Linear(20, 10)
self.fc_block2 = nn.Linear(10, 2)
......@@ -23,11 +27,13 @@ class EEGGraphConvNet(nn.Module):
self.fc_block2.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1))
def forward(self, g, return_graph_embedding=False):
x = g.ndata['x']
edge_weight = g.edata['edge_weights']
x = g.ndata["x"]
edge_weight = g.edata["edge_weights"]
x = function.leaky_relu(self.conv1(g, x, edge_weight=edge_weight))
x = function.leaky_relu(self.conv2_bn(self.conv2(g, x, edge_weight=edge_weight)))
x = function.leaky_relu(
self.conv2_bn(self.conv2(g, x, edge_weight=edge_weight))
)
# NOTE: this takes node-level features/"embeddings"
# and aggregates to graph-level - use for graph-level classification
......
import dgl
import torch as th
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn import metrics
import utils
from model import EGES
from sampler import Sampler
from sklearn import metrics
from torch.utils.data import DataLoader
import dgl
def train(args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates):
sampler = Sampler(
train_g,
args.walk_length,
args.num_walks,
args.window_size,
args.num_negative
train_g,
args.walk_length,
args.num_walks,
args.window_size,
args.num_negative,
)
# for each node in the graph, we sample pos and neg
# pairs for it, and feed these sampled pairs into the model.
......@@ -25,7 +25,7 @@ def train(args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates):
# this is the batch_size of input nodes
batch_size=args.batch_size,
shuffle=True,
collate_fn=lambda x: sampler.sample(x, sku_info)
collate_fn=lambda x: sampler.sample(x, sku_info),
)
model = EGES(args.dim, num_skus, num_brands, num_shops, num_cates)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
......@@ -45,8 +45,11 @@ def train(args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates):
epoch_total_loss += loss.item()
if step % args.log_every == 0:
print('Epoch {:05d} | Step {:05d} | Step Loss {:.4f} | Epoch Avg Loss: {:.4f}'.format(
epoch, step, loss.item(), epoch_total_loss / (step + 1)))
print(
"Epoch {:05d} | Step {:05d} | Step Loss {:.4f} | Epoch Avg Loss: {:.4f}".format(
epoch, step, loss.item(), epoch_total_loss / (step + 1)
)
)
eval(model, test_g, sku_info)
......@@ -77,15 +80,14 @@ if __name__ == "__main__":
valid_sku_raw_ids = utils.get_valid_sku_set(args.item_info_data)
g, sku_encoder, sku_decoder = utils.construct_graph(
args.action_data,
args.session_interval_sec,
valid_sku_raw_ids
args.action_data, args.session_interval_sec, valid_sku_raw_ids
)
train_g, test_g = utils.split_train_test_graph(g)
sku_info_encoder, sku_info_decoder, sku_info = \
utils.encode_sku_fields(args.item_info_data, sku_encoder, sku_decoder)
sku_info_encoder, sku_info_decoder, sku_info = utils.encode_sku_fields(
args.item_info_data, sku_encoder, sku_decoder
)
num_skus = len(sku_encoder)
num_brands = len(sku_info_encoder["brand"])
......@@ -93,8 +95,11 @@ if __name__ == "__main__":
num_cates = len(sku_info_encoder["cate"])
print(
"Num skus: {}, num brands: {}, num shops: {}, num cates: {}".\
format(num_skus, num_brands, num_shops, num_cates)
"Num skus: {}, num brands: {}, num shops: {}, num cates: {}".format(
num_skus, num_brands, num_shops, num_cates
)
)
model = train(args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates)
model = train(
args, train_g, sku_info, num_skus, num_brands, num_shops, num_cates
)
......@@ -18,12 +18,12 @@ class EGES(th.nn.Module):
# srcs: sku_id, brand_id, shop_id, cate_id
srcs = self.query_node_embed(srcs)
dsts = self.query_node_embed(dsts)
return srcs, dsts
def query_node_embed(self, nodes):
"""
@nodes: tensor of shape (batch_size, num_side_info)
@nodes: tensor of shape (batch_size, num_side_info)
"""
batch_size = nodes.shape[0]
# query side info weights, (batch_size, 4)
......@@ -33,21 +33,31 @@ class EGES(th.nn.Module):
side_info_weights_sum = []
for i in range(4):
# weights for i-th side info, (batch_size, ) -> (batch_size, 1)
i_th_side_info_weights = side_info_weights[:, i].view((batch_size, 1))
i_th_side_info_weights = side_info_weights[:, i].view(
(batch_size, 1)
)
# batch of i-th side info embedding * its weight, (batch_size, dim)
side_info_weighted_embeds_sum.append(i_th_side_info_weights * self.embeds[i](nodes[:, i]))
side_info_weighted_embeds_sum.append(
i_th_side_info_weights * self.embeds[i](nodes[:, i])
)
side_info_weights_sum.append(i_th_side_info_weights)
# stack: (batch_size, 4, dim), sum: (batch_size, dim)
side_info_weighted_embeds_sum = th.sum(th.stack(side_info_weighted_embeds_sum, axis=1), axis=1)
side_info_weighted_embeds_sum = th.sum(
th.stack(side_info_weighted_embeds_sum, axis=1), axis=1
)
# stack: (batch_size, 4), sum: (batch_size, )
side_info_weights_sum = th.sum(th.stack(side_info_weights_sum, axis=1), axis=1)
side_info_weights_sum = th.sum(
th.stack(side_info_weights_sum, axis=1), axis=1
)
# (batch_size, dim)
H = side_info_weighted_embeds_sum / side_info_weights_sum
return H
return H
def loss(self, srcs, dsts, labels):
dots = th.sigmoid(th.sum(srcs * dsts, axis=1))
dots = th.clamp(dots, min=1e-7, max=1 - 1e-7)
return th.mean(- (labels * th.log(dots) + (1 - labels) * th.log(1 - dots)))
return th.mean(
-(labels * th.log(dots) + (1 - labels) * th.log(1 - dots))
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment