Unverified Commit dce89919 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Auto-reformat multiple python folders. (#5325)



* auto-reformat

* lintrunner

---------
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent ab812179
...@@ -17,6 +17,8 @@ Tree-LSTM in DGL ...@@ -17,6 +17,8 @@ Tree-LSTM in DGL
""" """
import os
############################################################################## ##############################################################################
# #
# In this tutorial, you learn to use Tree-LSTM networks for sentiment analysis. # In this tutorial, you learn to use Tree-LSTM networks for sentiment analysis.
...@@ -58,30 +60,32 @@ Tree-LSTM in DGL ...@@ -58,30 +60,32 @@ Tree-LSTM in DGL
from collections import namedtuple from collections import namedtuple
import os os.environ["DGLBACKEND"] = "pytorch"
os.environ['DGLBACKEND'] = 'pytorch'
import dgl import dgl
from dgl.data.tree import SSTDataset from dgl.data.tree import SSTDataset
SSTBatch = namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label']) SSTBatch = namedtuple("SSTBatch", ["graph", "mask", "wordid", "label"])
# Each sample in the dataset is a constituency tree. The leaf nodes # Each sample in the dataset is a constituency tree. The leaf nodes
# represent words. The word is an int value stored in the "x" field. # represent words. The word is an int value stored in the "x" field.
# The non-leaf nodes have a special word PAD_WORD. The sentiment # The non-leaf nodes have a special word PAD_WORD. The sentiment
# label is stored in the "y" feature field. # label is stored in the "y" feature field.
trainset = SSTDataset(mode='tiny') # the "tiny" set has only five trees trainset = SSTDataset(mode="tiny") # the "tiny" set has only five trees
tiny_sst = [tr for tr in trainset] tiny_sst = [tr for tr in trainset]
num_vocabs = trainset.vocab_size num_vocabs = trainset.vocab_size
num_classes = trainset.num_classes num_classes = trainset.num_classes
vocab = trainset.vocab # vocabulary dict: key -> id vocab = trainset.vocab # vocabulary dict: key -> id
inv_vocab = {v: k for k, v in vocab.items()} # inverted vocabulary dict: id -> word inv_vocab = {
v: k for k, v in vocab.items()
} # inverted vocabulary dict: id -> word
a_tree = tiny_sst[0] a_tree = tiny_sst[0]
for token in a_tree.ndata['x'].tolist(): for token in a_tree.ndata["x"].tolist():
if token != trainset.PAD_WORD: if token != trainset.PAD_WORD:
print(inv_vocab[token], end=" ") print(inv_vocab[token], end=" ")
import matplotlib.pyplot as plt
############################################################################## ##############################################################################
# Step 1: Batching # Step 1: Batching
...@@ -92,16 +96,24 @@ for token in a_tree.ndata['x'].tolist(): ...@@ -92,16 +96,24 @@ for token in a_tree.ndata['x'].tolist():
# #
import networkx as nx import networkx as nx
import matplotlib.pyplot as plt
graph = dgl.batch(tiny_sst) graph = dgl.batch(tiny_sst)
def plot_tree(g): def plot_tree(g):
# this plot requires pygraphviz package # this plot requires pygraphviz package
pos = nx.nx_agraph.graphviz_layout(g, prog='dot') pos = nx.nx_agraph.graphviz_layout(g, prog="dot")
nx.draw(g, pos, with_labels=False, node_size=10, nx.draw(
node_color=[[.5, .5, .5]], arrowsize=4) g,
pos,
with_labels=False,
node_size=10,
node_color=[[0.5, 0.5, 0.5]],
arrowsize=4,
)
plt.show() plt.show()
plot_tree(graph.to_networkx()) plot_tree(graph.to_networkx())
################################################################################# #################################################################################
...@@ -173,6 +185,7 @@ plot_tree(graph.to_networkx()) ...@@ -173,6 +185,7 @@ plot_tree(graph.to_networkx())
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
class TreeLSTMCell(nn.Module): class TreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size): def __init__(self, x_size, h_size):
super(TreeLSTMCell, self).__init__() super(TreeLSTMCell, self).__init__()
...@@ -182,27 +195,28 @@ class TreeLSTMCell(nn.Module): ...@@ -182,27 +195,28 @@ class TreeLSTMCell(nn.Module):
self.U_f = nn.Linear(2 * h_size, 2 * h_size) self.U_f = nn.Linear(2 * h_size, 2 * h_size)
def message_func(self, edges): def message_func(self, edges):
return {'h': edges.src['h'], 'c': edges.src['c']} return {"h": edges.src["h"], "c": edges.src["c"]}
def reduce_func(self, nodes): def reduce_func(self, nodes):
# concatenate h_jl for equation (1), (2), (3), (4) # concatenate h_jl for equation (1), (2), (3), (4)
h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1) h_cat = nodes.mailbox["h"].view(nodes.mailbox["h"].size(0), -1)
# equation (2) # equation (2)
f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size()) f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox["h"].size())
# second term of equation (5) # second term of equation (5)
c = th.sum(f * nodes.mailbox['c'], 1) c = th.sum(f * nodes.mailbox["c"], 1)
return {'iou': self.U_iou(h_cat), 'c': c} return {"iou": self.U_iou(h_cat), "c": c}
def apply_node_func(self, nodes): def apply_node_func(self, nodes):
# equation (1), (3), (4) # equation (1), (3), (4)
iou = nodes.data['iou'] + self.b_iou iou = nodes.data["iou"] + self.b_iou
i, o, u = th.chunk(iou, 3, 1) i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u) i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
# equation (5) # equation (5)
c = i * u + nodes.data['c'] c = i * u + nodes.data["c"]
# equation (6) # equation (6)
h = o * th.tanh(c) h = o * th.tanh(c)
return {'h' : h, 'c' : c} return {"h": h, "c": c}
############################################################################## ##############################################################################
# Step 3: Define traversal # Step 3: Define traversal
...@@ -228,12 +242,12 @@ class TreeLSTMCell(nn.Module): ...@@ -228,12 +242,12 @@ class TreeLSTMCell(nn.Module):
# to heterogenous graph # to heterogenous graph
trv_a_tree = dgl.graph(a_tree.edges()) trv_a_tree = dgl.graph(a_tree.edges())
print('Traversing one tree:') print("Traversing one tree:")
print(dgl.topological_nodes_generator(trv_a_tree)) print(dgl.topological_nodes_generator(trv_a_tree))
# to heterogenous graph # to heterogenous graph
trv_graph = dgl.graph(graph.edges()) trv_graph = dgl.graph(graph.edges())
print('Traversing many trees at the same time:') print("Traversing many trees at the same time:")
print(dgl.topological_nodes_generator(trv_graph)) print(dgl.topological_nodes_generator(trv_graph))
############################################################################## ##############################################################################
...@@ -242,11 +256,13 @@ print(dgl.topological_nodes_generator(trv_graph)) ...@@ -242,11 +256,13 @@ print(dgl.topological_nodes_generator(trv_graph))
import dgl.function as fn import dgl.function as fn
import torch as th import torch as th
trv_graph.ndata['a'] = th.ones(graph.number_of_nodes(), 1) trv_graph.ndata["a"] = th.ones(graph.number_of_nodes(), 1)
traversal_order = dgl.topological_nodes_generator(trv_graph) traversal_order = dgl.topological_nodes_generator(trv_graph)
trv_graph.prop_nodes(traversal_order, trv_graph.prop_nodes(
message_func=fn.copy_u('a', 'a'), traversal_order,
reduce_func=fn.sum('a', 'a')) message_func=fn.copy_u("a", "a"),
reduce_func=fn.sum("a", "a"),
)
# the following is a syntax sugar that does the same # the following is a syntax sugar that does the same
# dgl.prop_nodes_topo(graph) # dgl.prop_nodes_topo(graph)
...@@ -265,19 +281,22 @@ trv_graph.prop_nodes(traversal_order, ...@@ -265,19 +281,22 @@ trv_graph.prop_nodes(traversal_order,
# Here is the complete code that specifies the ``Tree-LSTM`` class. # Here is the complete code that specifies the ``Tree-LSTM`` class.
# #
class TreeLSTM(nn.Module): class TreeLSTM(nn.Module):
def __init__(self, def __init__(
self,
num_vocabs, num_vocabs,
x_size, x_size,
h_size, h_size,
num_classes, num_classes,
dropout, dropout,
pretrained_emb=None): pretrained_emb=None,
):
super(TreeLSTM, self).__init__() super(TreeLSTM, self).__init__()
self.x_size = x_size self.x_size = x_size
self.embedding = nn.Embedding(num_vocabs, x_size) self.embedding = nn.Embedding(num_vocabs, x_size)
if pretrained_emb is not None: if pretrained_emb is not None:
print('Using glove') print("Using glove")
self.embedding.weight.data.copy_(pretrained_emb) self.embedding.weight.data.copy_(pretrained_emb)
self.embedding.weight.requires_grad = True self.embedding.weight.requires_grad = True
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
...@@ -306,19 +325,26 @@ class TreeLSTM(nn.Module): ...@@ -306,19 +325,26 @@ class TreeLSTM(nn.Module):
g = dgl.graph(g.edges()) g = dgl.graph(g.edges())
# feed embedding # feed embedding
embeds = self.embedding(batch.wordid * batch.mask) embeds = self.embedding(batch.wordid * batch.mask)
g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1) g.ndata["iou"] = self.cell.W_iou(
g.ndata['h'] = h self.dropout(embeds)
g.ndata['c'] = c ) * batch.mask.float().unsqueeze(-1)
g.ndata["h"] = h
g.ndata["c"] = c
# propagate # propagate
dgl.prop_nodes_topo(g, dgl.prop_nodes_topo(
g,
message_func=self.cell.message_func, message_func=self.cell.message_func,
reduce_func=self.cell.reduce_func, reduce_func=self.cell.reduce_func,
apply_node_func=self.cell.apply_node_func) apply_node_func=self.cell.apply_node_func,
)
# compute logits # compute logits
h = self.dropout(g.ndata.pop('h')) h = self.dropout(g.ndata.pop("h"))
logits = self.linear(h) logits = self.linear(h)
return logits return logits
import torch.nn.functional as F
############################################################################## ##############################################################################
# Main Loop # Main Loop
# --------- # ---------
...@@ -327,9 +353,8 @@ class TreeLSTM(nn.Module): ...@@ -327,9 +353,8 @@ class TreeLSTM(nn.Module):
# #
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torch.nn.functional as F
device = th.device('cpu') device = th.device("cpu")
# hyper parameters # hyper parameters
x_size = 256 x_size = 256
h_size = 256 h_size = 256
...@@ -339,32 +364,37 @@ weight_decay = 1e-4 ...@@ -339,32 +364,37 @@ weight_decay = 1e-4
epochs = 10 epochs = 10
# create the model # create the model
model = TreeLSTM(trainset.vocab_size, model = TreeLSTM(
x_size, trainset.vocab_size, x_size, h_size, trainset.num_classes, dropout
h_size, )
trainset.num_classes,
dropout)
print(model) print(model)
# create the optimizer # create the optimizer
optimizer = th.optim.Adagrad(model.parameters(), optimizer = th.optim.Adagrad(
lr=lr, model.parameters(), lr=lr, weight_decay=weight_decay
weight_decay=weight_decay) )
def batcher(dev): def batcher(dev):
def batcher_dev(batch): def batcher_dev(batch):
batch_trees = dgl.batch(batch) batch_trees = dgl.batch(batch)
return SSTBatch(graph=batch_trees, return SSTBatch(
mask=batch_trees.ndata['mask'].to(device), graph=batch_trees,
wordid=batch_trees.ndata['x'].to(device), mask=batch_trees.ndata["mask"].to(device),
label=batch_trees.ndata['y'].to(device)) wordid=batch_trees.ndata["x"].to(device),
label=batch_trees.ndata["y"].to(device),
)
return batcher_dev return batcher_dev
train_loader = DataLoader(dataset=tiny_sst,
train_loader = DataLoader(
dataset=tiny_sst,
batch_size=5, batch_size=5,
collate_fn=batcher(device), collate_fn=batcher(device),
shuffle=False, shuffle=False,
num_workers=0) num_workers=0,
)
# training loop # training loop
for epoch in range(epochs): for epoch in range(epochs):
...@@ -375,15 +405,17 @@ for epoch in range(epochs): ...@@ -375,15 +405,17 @@ for epoch in range(epochs):
c = th.zeros((n, h_size)) c = th.zeros((n, h_size))
logits = model(batch, h, c) logits = model(batch, h, c)
logp = F.log_softmax(logits, 1) logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label, reduction='sum') loss = F.nll_loss(logp, batch.label, reduction="sum")
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
pred = th.argmax(logits, 1) pred = th.argmax(logits, 1)
acc = float(th.sum(th.eq(batch.label, pred))) / len(batch.label) acc = float(th.sum(th.eq(batch.label, pred))) / len(batch.label)
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} |".format( print(
epoch, step, loss.item(), acc)) "Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} |".format(
epoch, step, loss.item(), acc
)
)
############################################################################## ##############################################################################
# To train the model on a full dataset with different settings (such as CPU or GPU), # To train the model on a full dataset with different settings (such as CPU or GPU),
# refer to the `PyTorch example <https://github.com/dmlc/dgl/tree/master/examples/pytorch/tree_lstm>`__. # refer to the `PyTorch example <https://github.com/dmlc/dgl/tree/master/examples/pytorch/tree_lstm>`__.
......
...@@ -51,7 +51,8 @@ Generative Models of Graphs ...@@ -51,7 +51,8 @@ Generative Models of Graphs
# #
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import dgl import dgl
g = dgl.DGLGraph() g = dgl.DGLGraph()
...@@ -116,6 +117,7 @@ g.add_edges([2, 0], [0, 2]) # Add edges (2, 0), (0, 2) ...@@ -116,6 +117,7 @@ g.add_edges([2, 0], [0, 2]) # Add edges (2, 0), (0, 2)
# with DGMG is implemented in DGL. # with DGMG is implemented in DGL.
# #
def forward_inference(self): def forward_inference(self):
stop = self.add_node_and_update() stop = self.add_node_and_update()
while (not stop) and (self.g.number_of_nodes() < self.v_max + 1): while (not stop) and (self.g.number_of_nodes() < self.v_max + 1):
...@@ -126,9 +128,9 @@ def forward_inference(self): ...@@ -126,9 +128,9 @@ def forward_inference(self):
num_trials += 1 num_trials += 1
to_add_edge = self.add_edge_or_not() to_add_edge = self.add_edge_or_not()
stop = self.add_node_and_update() stop = self.add_node_and_update()
return self.g return self.g
####################################################################################### #######################################################################################
# Assume you have a pre-trained model for generating cycles of nodes 10-20. # Assume you have a pre-trained model for generating cycles of nodes 10-20.
# How does it generate a cycle on-the-fly during inference? Use the code below # How does it generate a cycle on-the-fly during inference? Use the code below
...@@ -203,6 +205,7 @@ def forward_inference(self): ...@@ -203,6 +205,7 @@ def forward_inference(self):
# -\log p(a_{1},\cdots,a_{T})=-\sum_{t=1}^{T}\log p(a_{t}|a_{1},\cdots, a_{t-1}).\\ # -\log p(a_{1},\cdots,a_{T})=-\sum_{t=1}^{T}\log p(a_{t}|a_{1},\cdots, a_{t-1}).\\
# #
def forward_train(self, actions): def forward_train(self, actions):
""" """
- actions: list - actions: list
...@@ -225,9 +228,9 @@ def forward_train(self, actions): ...@@ -225,9 +228,9 @@ def forward_train(self, actions):
self.choose_dest_and_update(a=actions[self.action_step]) self.choose_dest_and_update(a=actions[self.action_step])
to_add_edge = self.add_edge_or_not(a=actions[self.action_step]) to_add_edge = self.add_edge_or_not(a=actions[self.action_step])
stop = self.add_node_and_update(a=actions[self.action_step]) stop = self.add_node_and_update(a=actions[self.action_step])
return self.get_log_prob() return self.get_log_prob()
####################################################################################### #######################################################################################
# The key difference between ``forward_train`` and ``forward_inference`` is # The key difference between ``forward_train`` and ``forward_inference`` is
# that the training process takes oracle actions as input and returns log # that the training process takes oracle actions as input and returns log
...@@ -295,6 +298,7 @@ class DGMGSkeleton(nn.Module): ...@@ -295,6 +298,7 @@ class DGMGSkeleton(nn.Module):
else: else:
return self.forward_inference() return self.forward_inference()
####################################################################################### #######################################################################################
# Encoding a dynamic graph # Encoding a dynamic graph
# `````````````````````````` # ``````````````````````````
...@@ -338,20 +342,20 @@ class GraphEmbed(nn.Module): ...@@ -338,20 +342,20 @@ class GraphEmbed(nn.Module):
# Embed graphs # Embed graphs
self.node_gating = nn.Sequential( self.node_gating = nn.Sequential(
nn.Linear(node_hidden_size, 1), nn.Linear(node_hidden_size, 1), nn.Sigmoid()
nn.Sigmoid()
) )
self.node_to_graph = nn.Linear(node_hidden_size, self.node_to_graph = nn.Linear(node_hidden_size, self.graph_hidden_size)
self.graph_hidden_size)
def forward(self, g): def forward(self, g):
if g.number_of_nodes() == 0: if g.number_of_nodes() == 0:
return torch.zeros(1, self.graph_hidden_size) return torch.zeros(1, self.graph_hidden_size)
else: else:
# Node features are stored as hv in ndata. # Node features are stored as hv in ndata.
hvs = g.ndata['hv'] hvs = g.ndata["hv"]
return (self.node_gating(hvs) * return (self.node_gating(hvs) * self.node_to_graph(hvs)).sum(
self.node_to_graph(hvs)).sum(0, keepdim=True) 0, keepdim=True
)
####################################################################################### #######################################################################################
# Update node embeddings via graph propagation # Update node embeddings via graph propagation
...@@ -395,6 +399,7 @@ class GraphEmbed(nn.Module): ...@@ -395,6 +399,7 @@ class GraphEmbed(nn.Module):
from functools import partial from functools import partial
class GraphProp(nn.Module): class GraphProp(nn.Module):
def __init__(self, num_prop_rounds, node_hidden_size): def __init__(self, num_prop_rounds, node_hidden_size):
super(GraphProp, self).__init__() super(GraphProp, self).__init__()
...@@ -410,39 +415,43 @@ class GraphProp(nn.Module): ...@@ -410,39 +415,43 @@ class GraphProp(nn.Module):
for t in range(num_prop_rounds): for t in range(num_prop_rounds):
# input being [hv, hu, xuv] # input being [hv, hu, xuv]
message_funcs.append(nn.Linear(2 * node_hidden_size + 1, message_funcs.append(
self.node_activation_hidden_size)) nn.Linear(
2 * node_hidden_size + 1, self.node_activation_hidden_size
)
)
self.reduce_funcs.append(partial(self.dgmg_reduce, round=t)) self.reduce_funcs.append(partial(self.dgmg_reduce, round=t))
node_update_funcs.append( node_update_funcs.append(
nn.GRUCell(self.node_activation_hidden_size, nn.GRUCell(self.node_activation_hidden_size, node_hidden_size)
node_hidden_size)) )
self.message_funcs = nn.ModuleList(message_funcs) self.message_funcs = nn.ModuleList(message_funcs)
self.node_update_funcs = nn.ModuleList(node_update_funcs) self.node_update_funcs = nn.ModuleList(node_update_funcs)
def dgmg_msg(self, edges): def dgmg_msg(self, edges):
"""For an edge u->v, return concat([h_u, x_uv])""" """For an edge u->v, return concat([h_u, x_uv])"""
return {'m': torch.cat([edges.src['hv'], return {"m": torch.cat([edges.src["hv"], edges.data["he"]], dim=1)}
edges.data['he']],
dim=1)}
def dgmg_reduce(self, nodes, round): def dgmg_reduce(self, nodes, round):
hv_old = nodes.data['hv'] hv_old = nodes.data["hv"]
m = nodes.mailbox['m'] m = nodes.mailbox["m"]
message = torch.cat([ message = torch.cat(
hv_old.unsqueeze(1).expand(-1, m.size(1), -1), m], dim=2) [hv_old.unsqueeze(1).expand(-1, m.size(1), -1), m], dim=2
)
node_activation = (self.message_funcs[round](message)).sum(1) node_activation = (self.message_funcs[round](message)).sum(1)
return {'a': node_activation} return {"a": node_activation}
def forward(self, g): def forward(self, g):
if g.number_of_edges() > 0: if g.number_of_edges() > 0:
for t in range(self.num_prop_rounds): for t in range(self.num_prop_rounds):
g.update_all(message_func=self.dgmg_msg, g.update_all(
reduce_func=self.reduce_funcs[t]) message_func=self.dgmg_msg, reduce_func=self.reduce_funcs[t]
g.ndata['hv'] = self.node_update_funcs[t]( )
g.ndata['a'], g.ndata['hv']) g.ndata["hv"] = self.node_update_funcs[t](
g.ndata["a"], g.ndata["hv"]
)
####################################################################################### #######################################################################################
# Actions # Actions
...@@ -475,6 +484,7 @@ class GraphProp(nn.Module): ...@@ -475,6 +484,7 @@ class GraphProp(nn.Module):
import torch.nn.functional as F import torch.nn.functional as F
from torch.distributions import Bernoulli from torch.distributions import Bernoulli
def bernoulli_action_log_prob(logit, action): def bernoulli_action_log_prob(logit, action):
"""Calculate the log p of an action with respect to a Bernoulli """Calculate the log p of an action with respect to a Bernoulli
distribution. Use logit rather than prob for numerical stability.""" distribution. Use logit rather than prob for numerical stability."""
...@@ -483,20 +493,22 @@ def bernoulli_action_log_prob(logit, action): ...@@ -483,20 +493,22 @@ def bernoulli_action_log_prob(logit, action):
else: else:
return F.logsigmoid(logit) return F.logsigmoid(logit)
class AddNode(nn.Module): class AddNode(nn.Module):
def __init__(self, graph_embed_func, node_hidden_size): def __init__(self, graph_embed_func, node_hidden_size):
super(AddNode, self).__init__() super(AddNode, self).__init__()
self.graph_op = {'embed': graph_embed_func} self.graph_op = {"embed": graph_embed_func}
self.stop = 1 self.stop = 1
self.add_node = nn.Linear(graph_embed_func.graph_hidden_size, 1) self.add_node = nn.Linear(graph_embed_func.graph_hidden_size, 1)
# If to add a node, initialize its hv # If to add a node, initialize its hv
self.node_type_embed = nn.Embedding(1, node_hidden_size) self.node_type_embed = nn.Embedding(1, node_hidden_size)
self.initialize_hv = nn.Linear(node_hidden_size + \ self.initialize_hv = nn.Linear(
graph_embed_func.graph_hidden_size, node_hidden_size + graph_embed_func.graph_hidden_size,
node_hidden_size) node_hidden_size,
)
self.init_node_activation = torch.zeros(1, 2 * node_hidden_size) self.init_node_activation = torch.zeros(1, 2 * node_hidden_size)
...@@ -504,17 +516,22 @@ class AddNode(nn.Module): ...@@ -504,17 +516,22 @@ class AddNode(nn.Module):
"""Whenver a node is added, initialize its representation.""" """Whenver a node is added, initialize its representation."""
num_nodes = g.number_of_nodes() num_nodes = g.number_of_nodes()
hv_init = self.initialize_hv( hv_init = self.initialize_hv(
torch.cat([ torch.cat(
[
self.node_type_embed(torch.LongTensor([node_type])), self.node_type_embed(torch.LongTensor([node_type])),
graph_embed], dim=1)) graph_embed,
g.nodes[num_nodes - 1].data['hv'] = hv_init ],
g.nodes[num_nodes - 1].data['a'] = self.init_node_activation dim=1,
)
)
g.nodes[num_nodes - 1].data["hv"] = hv_init
g.nodes[num_nodes - 1].data["a"] = self.init_node_activation
def prepare_training(self): def prepare_training(self):
self.log_prob = [] self.log_prob = []
def forward(self, g, action=None): def forward(self, g, action=None):
graph_embed = self.graph_op['embed'](g) graph_embed = self.graph_op["embed"](g)
logit = self.add_node(graph_embed) logit = self.add_node(graph_embed)
prob = torch.sigmoid(logit) prob = torch.sigmoid(logit)
...@@ -526,14 +543,13 @@ class AddNode(nn.Module): ...@@ -526,14 +543,13 @@ class AddNode(nn.Module):
if not stop: if not stop:
g.add_nodes(1) g.add_nodes(1)
self._initialize_node_repr(g, action, graph_embed) self._initialize_node_repr(g, action, graph_embed)
if self.training: if self.training:
sample_log_prob = bernoulli_action_log_prob(logit, action) sample_log_prob = bernoulli_action_log_prob(logit, action)
self.log_prob.append(sample_log_prob) self.log_prob.append(sample_log_prob)
return stop return stop
####################################################################################### #######################################################################################
# Action 2: Add edges # Action 2: Add edges
# '''''''''''''''''''''''''' # ''''''''''''''''''''''''''
...@@ -550,23 +566,24 @@ class AddNode(nn.Module): ...@@ -550,23 +566,24 @@ class AddNode(nn.Module):
# whether to add a new edge starting from :math:`v`. # whether to add a new edge starting from :math:`v`.
# #
class AddEdge(nn.Module): class AddEdge(nn.Module):
def __init__(self, graph_embed_func, node_hidden_size): def __init__(self, graph_embed_func, node_hidden_size):
super(AddEdge, self).__init__() super(AddEdge, self).__init__()
self.graph_op = {'embed': graph_embed_func} self.graph_op = {"embed": graph_embed_func}
self.add_edge = nn.Linear(graph_embed_func.graph_hidden_size + \ self.add_edge = nn.Linear(
node_hidden_size, 1) graph_embed_func.graph_hidden_size + node_hidden_size, 1
)
def prepare_training(self): def prepare_training(self):
self.log_prob = [] self.log_prob = []
def forward(self, g, action=None): def forward(self, g, action=None):
graph_embed = self.graph_op['embed'](g) graph_embed = self.graph_op["embed"](g)
src_embed = g.nodes[g.number_of_nodes() - 1].data['hv'] src_embed = g.nodes[g.number_of_nodes() - 1].data["hv"]
logit = self.add_edge(torch.cat( logit = self.add_edge(torch.cat([graph_embed, src_embed], dim=1))
[graph_embed, src_embed], dim=1))
prob = torch.sigmoid(logit) prob = torch.sigmoid(logit)
if self.training: if self.training:
...@@ -574,10 +591,10 @@ class AddEdge(nn.Module): ...@@ -574,10 +591,10 @@ class AddEdge(nn.Module):
self.log_prob.append(sample_log_prob) self.log_prob.append(sample_log_prob)
else: else:
action = Bernoulli(prob).sample().item() action = Bernoulli(prob).sample().item()
to_add_edge = bool(action == 0) to_add_edge = bool(action == 0)
return to_add_edge return to_add_edge
####################################################################################### #######################################################################################
# Action 3: Choose a destination # Action 3: Choose a destination
# ''''''''''''''''''''''''''''''''' # '''''''''''''''''''''''''''''''''
...@@ -595,11 +612,12 @@ class AddEdge(nn.Module): ...@@ -595,11 +612,12 @@ class AddEdge(nn.Module):
from torch.distributions import Categorical from torch.distributions import Categorical
class ChooseDestAndUpdate(nn.Module): class ChooseDestAndUpdate(nn.Module):
def __init__(self, graph_prop_func, node_hidden_size): def __init__(self, graph_prop_func, node_hidden_size):
super(ChooseDestAndUpdate, self).__init__() super(ChooseDestAndUpdate, self).__init__()
self.graph_op = {'prop': graph_prop_func} self.graph_op = {"prop": graph_prop_func}
self.choose_dest = nn.Linear(2 * node_hidden_size, 1) self.choose_dest = nn.Linear(2 * node_hidden_size, 1)
def _initialize_edge_repr(self, g, src_list, dest_list): def _initialize_edge_repr(self, g, src_list, dest_list):
...@@ -607,7 +625,7 @@ class ChooseDestAndUpdate(nn.Module): ...@@ -607,7 +625,7 @@ class ChooseDestAndUpdate(nn.Module):
# For multiple edge types, use a one-hot representation # For multiple edge types, use a one-hot representation
# or an embedding module. # or an embedding module.
edge_repr = torch.ones(len(src_list), 1) edge_repr = torch.ones(len(src_list), 1)
g.edges[src_list, dest_list].data['he'] = edge_repr g.edges[src_list, dest_list].data["he"] = edge_repr
def prepare_training(self): def prepare_training(self):
self.log_prob = [] self.log_prob = []
...@@ -616,17 +634,16 @@ class ChooseDestAndUpdate(nn.Module): ...@@ -616,17 +634,16 @@ class ChooseDestAndUpdate(nn.Module):
src = g.number_of_nodes() - 1 src = g.number_of_nodes() - 1
possible_dests = range(src) possible_dests = range(src)
src_embed_expand = g.nodes[src].data['hv'].expand(src, -1) src_embed_expand = g.nodes[src].data["hv"].expand(src, -1)
possible_dests_embed = g.nodes[possible_dests].data['hv'] possible_dests_embed = g.nodes[possible_dests].data["hv"]
dests_scores = self.choose_dest( dests_scores = self.choose_dest(
torch.cat([possible_dests_embed, torch.cat([possible_dests_embed, src_embed_expand], dim=1)
src_embed_expand], dim=1)).view(1, -1) ).view(1, -1)
dests_probs = F.softmax(dests_scores, dim=1) dests_probs = F.softmax(dests_scores, dim=1)
if not self.training: if not self.training:
dest = Categorical(dests_probs).sample().item() dest = Categorical(dests_probs).sample().item()
if not g.has_edges_between(src, dest): if not g.has_edges_between(src, dest):
# For undirected graphs, add edges for both directions # For undirected graphs, add edges for both directions
# so that you can perform graph propagation. # so that you can perform graph propagation.
...@@ -636,12 +653,13 @@ class ChooseDestAndUpdate(nn.Module): ...@@ -636,12 +653,13 @@ class ChooseDestAndUpdate(nn.Module):
g.add_edges(src_list, dest_list) g.add_edges(src_list, dest_list)
self._initialize_edge_repr(g, src_list, dest_list) self._initialize_edge_repr(g, src_list, dest_list)
self.graph_op['prop'](g) self.graph_op["prop"](g)
if self.training: if self.training:
if dests_probs.nelement() > 1: if dests_probs.nelement() > 1:
self.log_prob.append( self.log_prob.append(
F.log_softmax(dests_scores, dim=1)[:, dest: dest + 1]) F.log_softmax(dests_scores, dim=1)[:, dest : dest + 1]
)
####################################################################################### #######################################################################################
# Putting it together # Putting it together
...@@ -650,25 +668,23 @@ class ChooseDestAndUpdate(nn.Module): ...@@ -650,25 +668,23 @@ class ChooseDestAndUpdate(nn.Module):
# You are now ready to have a complete implementation of the model class. # You are now ready to have a complete implementation of the model class.
# #
class DGMG(DGMGSkeleton): class DGMG(DGMGSkeleton):
def __init__(self, v_max, node_hidden_size, def __init__(self, v_max, node_hidden_size, num_prop_rounds):
num_prop_rounds):
super(DGMG, self).__init__(v_max) super(DGMG, self).__init__(v_max)
# Graph embedding module # Graph embedding module
self.graph_embed = GraphEmbed(node_hidden_size) self.graph_embed = GraphEmbed(node_hidden_size)
# Graph propagation module # Graph propagation module
self.graph_prop = GraphProp(num_prop_rounds, self.graph_prop = GraphProp(num_prop_rounds, node_hidden_size)
node_hidden_size)
# Actions # Actions
self.add_node_agent = AddNode( self.add_node_agent = AddNode(self.graph_embed, node_hidden_size)
self.graph_embed, node_hidden_size) self.add_edge_agent = AddEdge(self.graph_embed, node_hidden_size)
self.add_edge_agent = AddEdge(
self.graph_embed, node_hidden_size)
self.choose_dest_agent = ChooseDestAndUpdate( self.choose_dest_agent = ChooseDestAndUpdate(
self.graph_prop, node_hidden_size) self.graph_prop, node_hidden_size
)
# Forward functions # Forward functions
self.forward_train = partial(forward_train, self=self) self.forward_train = partial(forward_train, self=self)
...@@ -711,6 +727,7 @@ class DGMG(DGMGSkeleton): ...@@ -711,6 +727,7 @@ class DGMG(DGMGSkeleton):
choose_dest_log_p = torch.cat(self.choose_dest_agent.log_prob).sum() choose_dest_log_p = torch.cat(self.choose_dest_agent.log_prob).sum()
return add_node_log_p + add_edge_log_p + choose_dest_log_p return add_node_log_p + add_edge_log_p + choose_dest_log_p
####################################################################################### #######################################################################################
# Below is an animation where a graph is generated on the fly # Below is an animation where a graph is generated on the fly
# after every 10 batches of training for the first 400 batches. You # after every 10 batches of training for the first 400 batches. You
...@@ -725,11 +742,14 @@ class DGMG(DGMGSkeleton): ...@@ -725,11 +742,14 @@ class DGMG(DGMGSkeleton):
import torch.utils.model_zoo as model_zoo import torch.utils.model_zoo as model_zoo
# Download a pre-trained model state dict for generating cycles with 10-20 nodes. # Download a pre-trained model state dict for generating cycles with 10-20 nodes.
state_dict = model_zoo.load_url('https://data.dgl.ai/model/dgmg_cycles-5a0c40be.pth') state_dict = model_zoo.load_url(
"https://data.dgl.ai/model/dgmg_cycles-5a0c40be.pth"
)
model = DGMG(v_max=20, node_hidden_size=16, num_prop_rounds=2) model = DGMG(v_max=20, node_hidden_size=16, num_prop_rounds=2)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
model.eval() model.eval()
def is_valid(g): def is_valid(g):
# Check if g is a cycle having 10-20 nodes. # Check if g is a cycle having 10-20 nodes.
def _get_previous(i, v_max): def _get_previous(i, v_max):
...@@ -748,28 +768,24 @@ def is_valid(g): ...@@ -748,28 +768,24 @@ def is_valid(g):
if size < 10 or size > 20: if size < 10 or size > 20:
return False return False
for node in range(size): for node in range(size):
neighbors = g.successors(node) neighbors = g.successors(node)
if len(neighbors) != 2: if len(neighbors) != 2:
return False return False
if _get_previous(node, size - 1) not in neighbors: if _get_previous(node, size - 1) not in neighbors:
return False return False
if _get_next(node, size - 1) not in neighbors: if _get_next(node, size - 1) not in neighbors:
return False return False
return True return True
num_valid = 0 num_valid = 0
for i in range(100): for i in range(100):
g = model() g = model()
num_valid += is_valid(g) num_valid += is_valid(g)
del model del model
print('Among 100 graphs generated, {}% are valid.'.format(num_valid)) print("Among 100 graphs generated, {}% are valid.".format(num_valid))
####################################################################################### #######################################################################################
# For the complete implementation, see the `DGL DGMG example # For the complete implementation, see the `DGL DGMG example
......
...@@ -68,15 +68,15 @@ offers a different perspective. The tutorial describes how to implement a Capsul ...@@ -68,15 +68,15 @@ offers a different perspective. The tutorial describes how to implement a Capsul
# Here's how we set up the graph and initialize node and edge features. # Here's how we set up the graph and initialize node and edge features.
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
def init_graph(in_nodes, out_nodes, f_size): def init_graph(in_nodes, out_nodes, f_size):
u = np.repeat(np.arange(in_nodes), out_nodes) u = np.repeat(np.arange(in_nodes), out_nodes)
...@@ -186,7 +186,6 @@ for i in range(10): ...@@ -186,7 +186,6 @@ for i in range(10):
entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=1) entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=1)
entropy_list.append(entropy.data.numpy()) entropy_list.append(entropy.data.numpy())
dist_list.append(dist_matrix.data.numpy()) dist_list.append(dist_matrix.data.numpy())
stds = np.std(entropy_list, axis=1) stds = np.std(entropy_list, axis=1)
means = np.mean(entropy_list, axis=1) means = np.mean(entropy_list, axis=1)
plt.errorbar(np.arange(len(entropy_list)), means, stds, marker="o") plt.errorbar(np.arange(len(entropy_list)), means, stds, marker="o")
......
...@@ -71,7 +71,8 @@ process ID, which should be an integer from `0` to `world_size - 1`. ...@@ -71,7 +71,8 @@ process ID, which should be an integer from `0` to `world_size - 1`.
""" """
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import torch.distributed as dist import torch.distributed as dist
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment