Unverified Commit dce89919 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Auto-reformat multiple python folders. (#5325)



* auto-reformat

* lintrunner

---------
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent ab812179
......@@ -17,6 +17,8 @@ Tree-LSTM in DGL
"""
import os
##############################################################################
#
# In this tutorial, you learn to use Tree-LSTM networks for sentiment analysis.
......@@ -58,30 +60,32 @@ Tree-LSTM in DGL
from collections import namedtuple
import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import dgl
from dgl.data.tree import SSTDataset
SSTBatch = namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])
SSTBatch = namedtuple("SSTBatch", ["graph", "mask", "wordid", "label"])
# Each sample in the dataset is a constituency tree. The leaf nodes
# represent words. The word is an int value stored in the "x" field.
# The non-leaf nodes have a special word PAD_WORD. The sentiment
# label is stored in the "y" feature field.
trainset = SSTDataset(mode='tiny') # the "tiny" set has only five trees
trainset = SSTDataset(mode="tiny") # the "tiny" set has only five trees
tiny_sst = [tr for tr in trainset]
num_vocabs = trainset.vocab_size
num_classes = trainset.num_classes
vocab = trainset.vocab # vocabulary dict: key -> id
inv_vocab = {v: k for k, v in vocab.items()} # inverted vocabulary dict: id -> word
inv_vocab = {
v: k for k, v in vocab.items()
} # inverted vocabulary dict: id -> word
a_tree = tiny_sst[0]
for token in a_tree.ndata['x'].tolist():
for token in a_tree.ndata["x"].tolist():
if token != trainset.PAD_WORD:
print(inv_vocab[token], end=" ")
import matplotlib.pyplot as plt
##############################################################################
# Step 1: Batching
......@@ -92,16 +96,24 @@ for token in a_tree.ndata['x'].tolist():
#
import networkx as nx
import matplotlib.pyplot as plt
graph = dgl.batch(tiny_sst)
def plot_tree(g):
# this plot requires pygraphviz package
pos = nx.nx_agraph.graphviz_layout(g, prog='dot')
nx.draw(g, pos, with_labels=False, node_size=10,
node_color=[[.5, .5, .5]], arrowsize=4)
pos = nx.nx_agraph.graphviz_layout(g, prog="dot")
nx.draw(
g,
pos,
with_labels=False,
node_size=10,
node_color=[[0.5, 0.5, 0.5]],
arrowsize=4,
)
plt.show()
plot_tree(graph.to_networkx())
#################################################################################
......@@ -173,6 +185,7 @@ plot_tree(graph.to_networkx())
import torch as th
import torch.nn as nn
class TreeLSTMCell(nn.Module):
def __init__(self, x_size, h_size):
super(TreeLSTMCell, self).__init__()
......@@ -182,27 +195,28 @@ class TreeLSTMCell(nn.Module):
self.U_f = nn.Linear(2 * h_size, 2 * h_size)
def message_func(self, edges):
return {'h': edges.src['h'], 'c': edges.src['c']}
return {"h": edges.src["h"], "c": edges.src["c"]}
def reduce_func(self, nodes):
# concatenate h_jl for equation (1), (2), (3), (4)
h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1)
h_cat = nodes.mailbox["h"].view(nodes.mailbox["h"].size(0), -1)
# equation (2)
f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size())
f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox["h"].size())
# second term of equation (5)
c = th.sum(f * nodes.mailbox['c'], 1)
return {'iou': self.U_iou(h_cat), 'c': c}
c = th.sum(f * nodes.mailbox["c"], 1)
return {"iou": self.U_iou(h_cat), "c": c}
def apply_node_func(self, nodes):
# equation (1), (3), (4)
iou = nodes.data['iou'] + self.b_iou
iou = nodes.data["iou"] + self.b_iou
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
# equation (5)
c = i * u + nodes.data['c']
c = i * u + nodes.data["c"]
# equation (6)
h = o * th.tanh(c)
return {'h' : h, 'c' : c}
return {"h": h, "c": c}
##############################################################################
# Step 3: Define traversal
......@@ -228,12 +242,12 @@ class TreeLSTMCell(nn.Module):
# to heterogenous graph
trv_a_tree = dgl.graph(a_tree.edges())
print('Traversing one tree:')
print("Traversing one tree:")
print(dgl.topological_nodes_generator(trv_a_tree))
# to heterogenous graph
trv_graph = dgl.graph(graph.edges())
print('Traversing many trees at the same time:')
print("Traversing many trees at the same time:")
print(dgl.topological_nodes_generator(trv_graph))
##############################################################################
......@@ -242,11 +256,13 @@ print(dgl.topological_nodes_generator(trv_graph))
import dgl.function as fn
import torch as th
trv_graph.ndata['a'] = th.ones(graph.number_of_nodes(), 1)
trv_graph.ndata["a"] = th.ones(graph.number_of_nodes(), 1)
traversal_order = dgl.topological_nodes_generator(trv_graph)
trv_graph.prop_nodes(traversal_order,
message_func=fn.copy_u('a', 'a'),
reduce_func=fn.sum('a', 'a'))
trv_graph.prop_nodes(
traversal_order,
message_func=fn.copy_u("a", "a"),
reduce_func=fn.sum("a", "a"),
)
# the following is a syntax sugar that does the same
# dgl.prop_nodes_topo(graph)
......@@ -265,19 +281,22 @@ trv_graph.prop_nodes(traversal_order,
# Here is the complete code that specifies the ``Tree-LSTM`` class.
#
class TreeLSTM(nn.Module):
def __init__(self,
def __init__(
self,
num_vocabs,
x_size,
h_size,
num_classes,
dropout,
pretrained_emb=None):
pretrained_emb=None,
):
super(TreeLSTM, self).__init__()
self.x_size = x_size
self.embedding = nn.Embedding(num_vocabs, x_size)
if pretrained_emb is not None:
print('Using glove')
print("Using glove")
self.embedding.weight.data.copy_(pretrained_emb)
self.embedding.weight.requires_grad = True
self.dropout = nn.Dropout(dropout)
......@@ -306,19 +325,26 @@ class TreeLSTM(nn.Module):
g = dgl.graph(g.edges())
# feed embedding
embeds = self.embedding(batch.wordid * batch.mask)
g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1)
g.ndata['h'] = h
g.ndata['c'] = c
g.ndata["iou"] = self.cell.W_iou(
self.dropout(embeds)
) * batch.mask.float().unsqueeze(-1)
g.ndata["h"] = h
g.ndata["c"] = c
# propagate
dgl.prop_nodes_topo(g,
dgl.prop_nodes_topo(
g,
message_func=self.cell.message_func,
reduce_func=self.cell.reduce_func,
apply_node_func=self.cell.apply_node_func)
apply_node_func=self.cell.apply_node_func,
)
# compute logits
h = self.dropout(g.ndata.pop('h'))
h = self.dropout(g.ndata.pop("h"))
logits = self.linear(h)
return logits
import torch.nn.functional as F
##############################################################################
# Main Loop
# ---------
......@@ -327,9 +353,8 @@ class TreeLSTM(nn.Module):
#
from torch.utils.data import DataLoader
import torch.nn.functional as F
device = th.device('cpu')
device = th.device("cpu")
# hyper parameters
x_size = 256
h_size = 256
......@@ -339,32 +364,37 @@ weight_decay = 1e-4
epochs = 10
# create the model
model = TreeLSTM(trainset.vocab_size,
x_size,
h_size,
trainset.num_classes,
dropout)
model = TreeLSTM(
trainset.vocab_size, x_size, h_size, trainset.num_classes, dropout
)
print(model)
# create the optimizer
optimizer = th.optim.Adagrad(model.parameters(),
lr=lr,
weight_decay=weight_decay)
optimizer = th.optim.Adagrad(
model.parameters(), lr=lr, weight_decay=weight_decay
)
def batcher(dev):
def batcher_dev(batch):
batch_trees = dgl.batch(batch)
return SSTBatch(graph=batch_trees,
mask=batch_trees.ndata['mask'].to(device),
wordid=batch_trees.ndata['x'].to(device),
label=batch_trees.ndata['y'].to(device))
return SSTBatch(
graph=batch_trees,
mask=batch_trees.ndata["mask"].to(device),
wordid=batch_trees.ndata["x"].to(device),
label=batch_trees.ndata["y"].to(device),
)
return batcher_dev
train_loader = DataLoader(dataset=tiny_sst,
train_loader = DataLoader(
dataset=tiny_sst,
batch_size=5,
collate_fn=batcher(device),
shuffle=False,
num_workers=0)
num_workers=0,
)
# training loop
for epoch in range(epochs):
......@@ -375,15 +405,17 @@ for epoch in range(epochs):
c = th.zeros((n, h_size))
logits = model(batch, h, c)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label, reduction='sum')
loss = F.nll_loss(logp, batch.label, reduction="sum")
optimizer.zero_grad()
loss.backward()
optimizer.step()
pred = th.argmax(logits, 1)
acc = float(th.sum(th.eq(batch.label, pred))) / len(batch.label)
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} |".format(
epoch, step, loss.item(), acc))
print(
"Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} |".format(
epoch, step, loss.item(), acc
)
)
##############################################################################
# To train the model on a full dataset with different settings (such as CPU or GPU),
# refer to the `PyTorch example <https://github.com/dmlc/dgl/tree/master/examples/pytorch/tree_lstm>`__.
......
......@@ -51,7 +51,8 @@ Generative Models of Graphs
#
import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import dgl
g = dgl.DGLGraph()
......@@ -116,6 +117,7 @@ g.add_edges([2, 0], [0, 2]) # Add edges (2, 0), (0, 2)
# with DGMG is implemented in DGL.
#
def forward_inference(self):
stop = self.add_node_and_update()
while (not stop) and (self.g.number_of_nodes() < self.v_max + 1):
......@@ -126,9 +128,9 @@ def forward_inference(self):
num_trials += 1
to_add_edge = self.add_edge_or_not()
stop = self.add_node_and_update()
return self.g
#######################################################################################
# Assume you have a pre-trained model for generating cycles of nodes 10-20.
# How does it generate a cycle on-the-fly during inference? Use the code below
......@@ -203,6 +205,7 @@ def forward_inference(self):
# -\log p(a_{1},\cdots,a_{T})=-\sum_{t=1}^{T}\log p(a_{t}|a_{1},\cdots, a_{t-1}).\\
#
def forward_train(self, actions):
"""
- actions: list
......@@ -225,9 +228,9 @@ def forward_train(self, actions):
self.choose_dest_and_update(a=actions[self.action_step])
to_add_edge = self.add_edge_or_not(a=actions[self.action_step])
stop = self.add_node_and_update(a=actions[self.action_step])
return self.get_log_prob()
#######################################################################################
# The key difference between ``forward_train`` and ``forward_inference`` is
# that the training process takes oracle actions as input and returns log
......@@ -295,6 +298,7 @@ class DGMGSkeleton(nn.Module):
else:
return self.forward_inference()
#######################################################################################
# Encoding a dynamic graph
# ``````````````````````````
......@@ -338,20 +342,20 @@ class GraphEmbed(nn.Module):
# Embed graphs
self.node_gating = nn.Sequential(
nn.Linear(node_hidden_size, 1),
nn.Sigmoid()
nn.Linear(node_hidden_size, 1), nn.Sigmoid()
)
self.node_to_graph = nn.Linear(node_hidden_size,
self.graph_hidden_size)
self.node_to_graph = nn.Linear(node_hidden_size, self.graph_hidden_size)
def forward(self, g):
if g.number_of_nodes() == 0:
return torch.zeros(1, self.graph_hidden_size)
else:
# Node features are stored as hv in ndata.
hvs = g.ndata['hv']
return (self.node_gating(hvs) *
self.node_to_graph(hvs)).sum(0, keepdim=True)
hvs = g.ndata["hv"]
return (self.node_gating(hvs) * self.node_to_graph(hvs)).sum(
0, keepdim=True
)
#######################################################################################
# Update node embeddings via graph propagation
......@@ -395,6 +399,7 @@ class GraphEmbed(nn.Module):
from functools import partial
class GraphProp(nn.Module):
def __init__(self, num_prop_rounds, node_hidden_size):
super(GraphProp, self).__init__()
......@@ -410,39 +415,43 @@ class GraphProp(nn.Module):
for t in range(num_prop_rounds):
# input being [hv, hu, xuv]
message_funcs.append(nn.Linear(2 * node_hidden_size + 1,
self.node_activation_hidden_size))
message_funcs.append(
nn.Linear(
2 * node_hidden_size + 1, self.node_activation_hidden_size
)
)
self.reduce_funcs.append(partial(self.dgmg_reduce, round=t))
node_update_funcs.append(
nn.GRUCell(self.node_activation_hidden_size,
node_hidden_size))
nn.GRUCell(self.node_activation_hidden_size, node_hidden_size)
)
self.message_funcs = nn.ModuleList(message_funcs)
self.node_update_funcs = nn.ModuleList(node_update_funcs)
def dgmg_msg(self, edges):
"""For an edge u->v, return concat([h_u, x_uv])"""
return {'m': torch.cat([edges.src['hv'],
edges.data['he']],
dim=1)}
return {"m": torch.cat([edges.src["hv"], edges.data["he"]], dim=1)}
def dgmg_reduce(self, nodes, round):
hv_old = nodes.data['hv']
m = nodes.mailbox['m']
message = torch.cat([
hv_old.unsqueeze(1).expand(-1, m.size(1), -1), m], dim=2)
hv_old = nodes.data["hv"]
m = nodes.mailbox["m"]
message = torch.cat(
[hv_old.unsqueeze(1).expand(-1, m.size(1), -1), m], dim=2
)
node_activation = (self.message_funcs[round](message)).sum(1)
return {'a': node_activation}
return {"a": node_activation}
def forward(self, g):
if g.number_of_edges() > 0:
for t in range(self.num_prop_rounds):
g.update_all(message_func=self.dgmg_msg,
reduce_func=self.reduce_funcs[t])
g.ndata['hv'] = self.node_update_funcs[t](
g.ndata['a'], g.ndata['hv'])
g.update_all(
message_func=self.dgmg_msg, reduce_func=self.reduce_funcs[t]
)
g.ndata["hv"] = self.node_update_funcs[t](
g.ndata["a"], g.ndata["hv"]
)
#######################################################################################
# Actions
......@@ -475,6 +484,7 @@ class GraphProp(nn.Module):
import torch.nn.functional as F
from torch.distributions import Bernoulli
def bernoulli_action_log_prob(logit, action):
"""Calculate the log p of an action with respect to a Bernoulli
distribution. Use logit rather than prob for numerical stability."""
......@@ -483,20 +493,22 @@ def bernoulli_action_log_prob(logit, action):
else:
return F.logsigmoid(logit)
class AddNode(nn.Module):
def __init__(self, graph_embed_func, node_hidden_size):
super(AddNode, self).__init__()
self.graph_op = {'embed': graph_embed_func}
self.graph_op = {"embed": graph_embed_func}
self.stop = 1
self.add_node = nn.Linear(graph_embed_func.graph_hidden_size, 1)
# If to add a node, initialize its hv
self.node_type_embed = nn.Embedding(1, node_hidden_size)
self.initialize_hv = nn.Linear(node_hidden_size + \
graph_embed_func.graph_hidden_size,
node_hidden_size)
self.initialize_hv = nn.Linear(
node_hidden_size + graph_embed_func.graph_hidden_size,
node_hidden_size,
)
self.init_node_activation = torch.zeros(1, 2 * node_hidden_size)
......@@ -504,17 +516,22 @@ class AddNode(nn.Module):
"""Whenver a node is added, initialize its representation."""
num_nodes = g.number_of_nodes()
hv_init = self.initialize_hv(
torch.cat([
torch.cat(
[
self.node_type_embed(torch.LongTensor([node_type])),
graph_embed], dim=1))
g.nodes[num_nodes - 1].data['hv'] = hv_init
g.nodes[num_nodes - 1].data['a'] = self.init_node_activation
graph_embed,
],
dim=1,
)
)
g.nodes[num_nodes - 1].data["hv"] = hv_init
g.nodes[num_nodes - 1].data["a"] = self.init_node_activation
def prepare_training(self):
self.log_prob = []
def forward(self, g, action=None):
graph_embed = self.graph_op['embed'](g)
graph_embed = self.graph_op["embed"](g)
logit = self.add_node(graph_embed)
prob = torch.sigmoid(logit)
......@@ -526,14 +543,13 @@ class AddNode(nn.Module):
if not stop:
g.add_nodes(1)
self._initialize_node_repr(g, action, graph_embed)
if self.training:
sample_log_prob = bernoulli_action_log_prob(logit, action)
self.log_prob.append(sample_log_prob)
return stop
#######################################################################################
# Action 2: Add edges
# ''''''''''''''''''''''''''
......@@ -550,23 +566,24 @@ class AddNode(nn.Module):
# whether to add a new edge starting from :math:`v`.
#
class AddEdge(nn.Module):
def __init__(self, graph_embed_func, node_hidden_size):
super(AddEdge, self).__init__()
self.graph_op = {'embed': graph_embed_func}
self.add_edge = nn.Linear(graph_embed_func.graph_hidden_size + \
node_hidden_size, 1)
self.graph_op = {"embed": graph_embed_func}
self.add_edge = nn.Linear(
graph_embed_func.graph_hidden_size + node_hidden_size, 1
)
def prepare_training(self):
self.log_prob = []
def forward(self, g, action=None):
graph_embed = self.graph_op['embed'](g)
src_embed = g.nodes[g.number_of_nodes() - 1].data['hv']
graph_embed = self.graph_op["embed"](g)
src_embed = g.nodes[g.number_of_nodes() - 1].data["hv"]
logit = self.add_edge(torch.cat(
[graph_embed, src_embed], dim=1))
logit = self.add_edge(torch.cat([graph_embed, src_embed], dim=1))
prob = torch.sigmoid(logit)
if self.training:
......@@ -574,10 +591,10 @@ class AddEdge(nn.Module):
self.log_prob.append(sample_log_prob)
else:
action = Bernoulli(prob).sample().item()
to_add_edge = bool(action == 0)
return to_add_edge
#######################################################################################
# Action 3: Choose a destination
# '''''''''''''''''''''''''''''''''
......@@ -595,11 +612,12 @@ class AddEdge(nn.Module):
from torch.distributions import Categorical
class ChooseDestAndUpdate(nn.Module):
def __init__(self, graph_prop_func, node_hidden_size):
super(ChooseDestAndUpdate, self).__init__()
self.graph_op = {'prop': graph_prop_func}
self.graph_op = {"prop": graph_prop_func}
self.choose_dest = nn.Linear(2 * node_hidden_size, 1)
def _initialize_edge_repr(self, g, src_list, dest_list):
......@@ -607,7 +625,7 @@ class ChooseDestAndUpdate(nn.Module):
# For multiple edge types, use a one-hot representation
# or an embedding module.
edge_repr = torch.ones(len(src_list), 1)
g.edges[src_list, dest_list].data['he'] = edge_repr
g.edges[src_list, dest_list].data["he"] = edge_repr
def prepare_training(self):
self.log_prob = []
......@@ -616,17 +634,16 @@ class ChooseDestAndUpdate(nn.Module):
src = g.number_of_nodes() - 1
possible_dests = range(src)
src_embed_expand = g.nodes[src].data['hv'].expand(src, -1)
possible_dests_embed = g.nodes[possible_dests].data['hv']
src_embed_expand = g.nodes[src].data["hv"].expand(src, -1)
possible_dests_embed = g.nodes[possible_dests].data["hv"]
dests_scores = self.choose_dest(
torch.cat([possible_dests_embed,
src_embed_expand], dim=1)).view(1, -1)
torch.cat([possible_dests_embed, src_embed_expand], dim=1)
).view(1, -1)
dests_probs = F.softmax(dests_scores, dim=1)
if not self.training:
dest = Categorical(dests_probs).sample().item()
if not g.has_edges_between(src, dest):
# For undirected graphs, add edges for both directions
# so that you can perform graph propagation.
......@@ -636,12 +653,13 @@ class ChooseDestAndUpdate(nn.Module):
g.add_edges(src_list, dest_list)
self._initialize_edge_repr(g, src_list, dest_list)
self.graph_op['prop'](g)
self.graph_op["prop"](g)
if self.training:
if dests_probs.nelement() > 1:
self.log_prob.append(
F.log_softmax(dests_scores, dim=1)[:, dest: dest + 1])
F.log_softmax(dests_scores, dim=1)[:, dest : dest + 1]
)
#######################################################################################
# Putting it together
......@@ -650,25 +668,23 @@ class ChooseDestAndUpdate(nn.Module):
# You are now ready to have a complete implementation of the model class.
#
class DGMG(DGMGSkeleton):
def __init__(self, v_max, node_hidden_size,
num_prop_rounds):
def __init__(self, v_max, node_hidden_size, num_prop_rounds):
super(DGMG, self).__init__(v_max)
# Graph embedding module
self.graph_embed = GraphEmbed(node_hidden_size)
# Graph propagation module
self.graph_prop = GraphProp(num_prop_rounds,
node_hidden_size)
self.graph_prop = GraphProp(num_prop_rounds, node_hidden_size)
# Actions
self.add_node_agent = AddNode(
self.graph_embed, node_hidden_size)
self.add_edge_agent = AddEdge(
self.graph_embed, node_hidden_size)
self.add_node_agent = AddNode(self.graph_embed, node_hidden_size)
self.add_edge_agent = AddEdge(self.graph_embed, node_hidden_size)
self.choose_dest_agent = ChooseDestAndUpdate(
self.graph_prop, node_hidden_size)
self.graph_prop, node_hidden_size
)
# Forward functions
self.forward_train = partial(forward_train, self=self)
......@@ -711,6 +727,7 @@ class DGMG(DGMGSkeleton):
choose_dest_log_p = torch.cat(self.choose_dest_agent.log_prob).sum()
return add_node_log_p + add_edge_log_p + choose_dest_log_p
#######################################################################################
# Below is an animation where a graph is generated on the fly
# after every 10 batches of training for the first 400 batches. You
......@@ -725,11 +742,14 @@ class DGMG(DGMGSkeleton):
import torch.utils.model_zoo as model_zoo
# Download a pre-trained model state dict for generating cycles with 10-20 nodes.
state_dict = model_zoo.load_url('https://data.dgl.ai/model/dgmg_cycles-5a0c40be.pth')
state_dict = model_zoo.load_url(
"https://data.dgl.ai/model/dgmg_cycles-5a0c40be.pth"
)
model = DGMG(v_max=20, node_hidden_size=16, num_prop_rounds=2)
model.load_state_dict(state_dict)
model.eval()
def is_valid(g):
# Check if g is a cycle having 10-20 nodes.
def _get_previous(i, v_max):
......@@ -748,28 +768,24 @@ def is_valid(g):
if size < 10 or size > 20:
return False
for node in range(size):
neighbors = g.successors(node)
if len(neighbors) != 2:
return False
if _get_previous(node, size - 1) not in neighbors:
return False
if _get_next(node, size - 1) not in neighbors:
return False
return True
num_valid = 0
for i in range(100):
g = model()
num_valid += is_valid(g)
del model
print('Among 100 graphs generated, {}% are valid.'.format(num_valid))
print("Among 100 graphs generated, {}% are valid.".format(num_valid))
#######################################################################################
# For the complete implementation, see the `DGL DGMG example
......
......@@ -68,15 +68,15 @@ offers a different perspective. The tutorial describes how to implement a Capsul
# Here's how we set up the graph and initialize node and edge features.
import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import matplotlib.pyplot as plt
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import dgl
def init_graph(in_nodes, out_nodes, f_size):
u = np.repeat(np.arange(in_nodes), out_nodes)
......@@ -186,7 +186,6 @@ for i in range(10):
entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=1)
entropy_list.append(entropy.data.numpy())
dist_list.append(dist_matrix.data.numpy())
stds = np.std(entropy_list, axis=1)
means = np.mean(entropy_list, axis=1)
plt.errorbar(np.arange(len(entropy_list)), means, stds, marker="o")
......
......@@ -71,7 +71,8 @@ process ID, which should be an integer from `0` to `world_size - 1`.
"""
import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import torch.distributed as dist
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment