Unverified Commit 4d3c01d6 authored by Mufei Li's avatar Mufei Li Committed by GitHub
Browse files

[Bug Fix] Fix the case when reverse_edge is False for citation graphs (#3840)



* Update citation_graph.py

* Update

* Update

* Update
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 71157b05
...@@ -61,7 +61,7 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -61,7 +61,7 @@ class CitationGraphDataset(DGLBuiltinDataset):
} }
def __init__(self, name, raw_dir=None, force_reload=False, def __init__(self, name, raw_dir=None, force_reload=False,
verbose=True, reverse_edge=True, transform=None, verbose=True, reverse_edge=True, transform=None,
reorder=False): reorder=False):
assert name.lower() in ['cora', 'citeseer', 'pubmed'] assert name.lower() in ['cora', 'citeseer', 'pubmed']
...@@ -122,8 +122,12 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -122,8 +122,12 @@ class CitationGraphDataset(DGLBuiltinDataset):
if self.reverse_edge: if self.reverse_edge:
graph = nx.DiGraph(nx.from_dict_of_lists(graph)) graph = nx.DiGraph(nx.from_dict_of_lists(graph))
g = from_networkx(graph)
else: else:
graph = nx.Graph(nx.from_dict_of_lists(graph)) graph = nx.Graph(nx.from_dict_of_lists(graph))
edges = list(graph.edges())
u, v = map(list, zip(*edges))
g = dgl_graph((u, v))
onehot_labels = np.vstack((ally, ty)) onehot_labels = np.vstack((ally, ty))
onehot_labels[test_idx_reorder, :] = onehot_labels[test_idx_range, :] onehot_labels[test_idx_reorder, :] = onehot_labels[test_idx_range, :]
...@@ -137,9 +141,6 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -137,9 +141,6 @@ class CitationGraphDataset(DGLBuiltinDataset):
val_mask = generate_mask_tensor(_sample_mask(idx_val, labels.shape[0])) val_mask = generate_mask_tensor(_sample_mask(idx_val, labels.shape[0]))
test_mask = generate_mask_tensor(_sample_mask(idx_test, labels.shape[0])) test_mask = generate_mask_tensor(_sample_mask(idx_test, labels.shape[0]))
self._graph = graph
g = from_networkx(graph)
g.ndata['train_mask'] = train_mask g.ndata['train_mask'] = train_mask
g.ndata['val_mask'] = val_mask g.ndata['val_mask'] = val_mask
g.ndata['test_mask'] = test_mask g.ndata['test_mask'] = test_mask
...@@ -204,7 +205,6 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -204,7 +205,6 @@ class CitationGraphDataset(DGLBuiltinDataset):
graph.ndata.pop('feat') graph.ndata.pop('feat')
graph.ndata.pop('label') graph.ndata.pop('label')
graph = to_networkx(graph) graph = to_networkx(graph)
self._graph = nx.DiGraph(graph)
self._num_classes = info['num_classes'] self._num_classes = info['num_classes']
self._g.ndata['train_mask'] = generate_mask_tensor(F.asnumpy(self._g.ndata['train_mask'])) self._g.ndata['train_mask'] = generate_mask_tensor(F.asnumpy(self._g.ndata['train_mask']))
...@@ -250,10 +250,6 @@ class CitationGraphDataset(DGLBuiltinDataset): ...@@ -250,10 +250,6 @@ class CitationGraphDataset(DGLBuiltinDataset):
""" Citation graph is used in many examples """ Citation graph is used in many examples
We preserve these properties for compatability. We preserve these properties for compatability.
""" """
@property
def graph(self):
deprecate_property('dataset.graph', 'dataset[0]')
return self._graph
@property @property
def train_mask(self): def train_mask(self):
......
...@@ -17,24 +17,24 @@ Line Graph Neural Network ...@@ -17,24 +17,24 @@ Line Graph Neural Network
""" """
########################################################################################### ###########################################################################################
# #
# In this tutorial, you learn how to solve community detection tasks by implementing a line # In this tutorial, you learn how to solve community detection tasks by implementing a line
# graph neural network (LGNN). Community detection, or graph clustering, consists of partitioning # graph neural network (LGNN). Community detection, or graph clustering, consists of partitioning
# the vertices in a graph into clusters in which nodes are more similar to # the vertices in a graph into clusters in which nodes are more similar to
# one another. # one another.
# #
# In the :doc:`Graph convolutinal network tutorial <1_gcn>`, you learned how to classify the nodes of an input # In the :doc:`Graph convolutinal network tutorial <1_gcn>`, you learned how to classify the nodes of an input
# graph in a semi-supervised setting. You used a graph convolutional neural network (GCN) # graph in a semi-supervised setting. You used a graph convolutional neural network (GCN)
# as an embedding mechanism for graph features. # as an embedding mechanism for graph features.
# #
# To generalize a graph neural network (GNN) into supervised community detection, a line-graph based # To generalize a graph neural network (GNN) into supervised community detection, a line-graph based
# variation of GNN is introduced in the research paper # variation of GNN is introduced in the research paper
# `Supervised Community Detection with Line Graph Neural Networks <https://arxiv.org/abs/1705.08415>`__. # `Supervised Community Detection with Line Graph Neural Networks <https://arxiv.org/abs/1705.08415>`__.
# One of the highlights of the model is # One of the highlights of the model is
# to augment the straightforward GNN architecture so that it operates on # to augment the straightforward GNN architecture so that it operates on
# a line graph of edge adjacencies, defined with a non-backtracking operator. # a line graph of edge adjacencies, defined with a non-backtracking operator.
# #
# A line graph neural network (LGNN) shows how DGL can implement an advanced graph algorithm by # A line graph neural network (LGNN) shows how DGL can implement an advanced graph algorithm by
# mixing basic tensor operations, sparse-matrix multiplication, and message- # mixing basic tensor operations, sparse-matrix multiplication, and message-
# passing APIs. # passing APIs.
# #
...@@ -65,13 +65,13 @@ Line Graph Neural Network ...@@ -65,13 +65,13 @@ Line Graph Neural Network
# #
# Cora dataset # Cora dataset
# ~~~~~ # ~~~~~
# To be consistent with the GCN tutorial, # To be consistent with the GCN tutorial,
# you use the `Cora dataset <https://linqs.soe.ucsc.edu/data>`__ # you use the `Cora dataset <https://linqs.soe.ucsc.edu/data>`__
# to illustrate a simple community detection task. Cora is a scientific publication dataset, # to illustrate a simple community detection task. Cora is a scientific publication dataset,
# with 2708 papers belonging to seven # with 2708 papers belonging to seven
# different machine learning fields. Here, you formulate Cora as a # different machine learning fields. Here, you formulate Cora as a
# directed graph, with each node being a paper, and each edge being a # directed graph, with each node being a paper, and each edge being a
# citation link (A->B means A cites B). Here is a visualization of the whole # citation link (A->B means A cites B). Here is a visualization of the whole
# Cora dataset. # Cora dataset.
# #
# .. figure:: https://i.imgur.com/X404Byc.png # .. figure:: https://i.imgur.com/X404Byc.png
...@@ -96,7 +96,7 @@ from dgl.data import citation_graph as citegrh ...@@ -96,7 +96,7 @@ from dgl.data import citation_graph as citegrh
data = citegrh.load_cora() data = citegrh.load_cora()
G = dgl.DGLGraph(data.graph) G = data[0]
labels = th.tensor(data.labels) labels = th.tensor(data.labels)
# find all the nodes labeled with class 0 # find all the nodes labeled with class 0
...@@ -113,7 +113,7 @@ print('Intra-class edges percent: %.4f' % (len(intra_src) / len(src_labels))) ...@@ -113,7 +113,7 @@ print('Intra-class edges percent: %.4f' % (len(intra_src) / len(src_labels)))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Without loss of generality, in this tutorial you limit the scope of the # Without loss of generality, in this tutorial you limit the scope of the
# task to binary community detection. # task to binary community detection.
# #
# .. note:: # .. note::
# #
# To create a practice binary-community dataset from Cora, first extract # To create a practice binary-community dataset from Cora, first extract
...@@ -177,7 +177,7 @@ visualize(label1, nx_G1) ...@@ -177,7 +177,7 @@ visualize(label1, nx_G1)
# community assignment :math:`\{A, A, A, B\}`, with each node's label # community assignment :math:`\{A, A, A, B\}`, with each node's label
# :math:`l \in \{0,1\}`,The group of all possible permutations # :math:`l \in \{0,1\}`,The group of all possible permutations
# :math:`S_c = \{\{0,0,0,1\}, \{1,1,1,0\}\}`. # :math:`S_c = \{\{0,0,0,1\}, \{1,1,1,0\}\}`.
# #
# Line graph neural network key ideas # Line graph neural network key ideas
# ------------------------------------ # ------------------------------------
# An key innovation in this topic is the use of a line graph. # An key innovation in this topic is the use of a line graph.
...@@ -193,7 +193,7 @@ visualize(label1, nx_G1) ...@@ -193,7 +193,7 @@ visualize(label1, nx_G1)
# Specifically, a line-graph :math:`L(G)` turns an edge of the original graph `G` # Specifically, a line-graph :math:`L(G)` turns an edge of the original graph `G`
# into a node. This is illustrated with the graph below (taken from the # into a node. This is illustrated with the graph below (taken from the
# research paper). # research paper).
# #
# .. figure:: https://i.imgur.com/4WO5jEm.png # .. figure:: https://i.imgur.com/4WO5jEm.png
# :alt: lg # :alt: lg
# :align: center # :align: center
...@@ -206,11 +206,11 @@ visualize(label1, nx_G1) ...@@ -206,11 +206,11 @@ visualize(label1, nx_G1)
# connect two edges? Here, we use the following connection rule: # connect two edges? Here, we use the following connection rule:
# #
# Two nodes :math:`v^{l}_{A}`, :math:`v^{l}_{B}` in `lg` are connected if # Two nodes :math:`v^{l}_{A}`, :math:`v^{l}_{B}` in `lg` are connected if
# the corresponding two edges :math:`e_{A}, e_{B}` in `g` share one and only # the corresponding two edges :math:`e_{A}, e_{B}` in `g` share one and only
# one node: # one node:
# :math:`e_{A}`'s destination node is :math:`e_{B}`'s source node # :math:`e_{A}`'s destination node is :math:`e_{B}`'s source node
# (:math:`j`). # (:math:`j`).
# #
# .. note:: # .. note::
# #
# Mathematically, this definition corresponds to a notion called non-backtracking # Mathematically, this definition corresponds to a notion called non-backtracking
...@@ -228,7 +228,7 @@ visualize(label1, nx_G1) ...@@ -228,7 +228,7 @@ visualize(label1, nx_G1)
# LGNN chains together a series of line graph neural network layers. The graph # LGNN chains together a series of line graph neural network layers. The graph
# representation :math:`x` and its line graph companion :math:`y` evolve with # representation :math:`x` and its line graph companion :math:`y` evolve with
# the dataflow as follows. # the dataflow as follows.
# #
# .. figure:: https://i.imgur.com/bZGGIGp.png # .. figure:: https://i.imgur.com/bZGGIGp.png
# :alt: alg # :alt: alg
# :align: center # :align: center
...@@ -265,9 +265,9 @@ visualize(label1, nx_G1) ...@@ -265,9 +265,9 @@ visualize(label1, nx_G1)
# #
# Implement LGNN in DGL # Implement LGNN in DGL
# --------------------- # ---------------------
# Even though the equations in the previous section might seem intimidating, # Even though the equations in the previous section might seem intimidating,
# it helps to understand the following information before you implement the LGNN. # it helps to understand the following information before you implement the LGNN.
# #
# The two equations are symmetric and can be implemented as two instances # The two equations are symmetric and can be implemented as two instances
# of the same class with different parameters. # of the same class with different parameters.
# The first equation operates on graph representation :math:`x`, # The first equation operates on graph representation :math:`x`,
...@@ -295,7 +295,7 @@ visualize(label1, nx_G1) ...@@ -295,7 +295,7 @@ visualize(label1, nx_G1)
# Each of the terms are performed again with different # Each of the terms are performed again with different
# parameters, and without the nonlinearity after the sum. # parameters, and without the nonlinearity after the sum.
# Therefore, :math:`f` could be written as: # Therefore, :math:`f` could be written as:
# #
# .. math:: # .. math::
# \begin{split} # \begin{split}
# f(x^{(k)},y^{(k)}) = {}\rho[&\text{prev}(x^{(k-1)}) + \text{deg}(x^{(k-1)}) +\text{radius}(x^{k-1}) # f(x^{(k)},y^{(k)}) = {}\rho[&\text{prev}(x^{(k-1)}) + \text{deg}(x^{(k-1)}) +\text{radius}(x^{k-1})
...@@ -304,18 +304,18 @@ visualize(label1, nx_G1) ...@@ -304,18 +304,18 @@ visualize(label1, nx_G1)
# \end{split} # \end{split}
# #
# Two equations are chained-up in the following order: # Two equations are chained-up in the following order:
# #
# .. math:: # .. math::
# \begin{split} # \begin{split}
# x^{(k+1)} = {}& f(x^{(k)}, y^{(k)})\\ # x^{(k+1)} = {}& f(x^{(k)}, y^{(k)})\\
# y^{(k+1)} = {}& f(y^{(k)}, x^{(k+1)}) # y^{(k+1)} = {}& f(y^{(k)}, x^{(k+1)})
# \end{split} # \end{split}
# #
# Keep in mind the listed observations in this overview and proceed to implementation. # Keep in mind the listed observations in this overview and proceed to implementation.
# An important point is that you use different strategies for the noted terms. # An important point is that you use different strategies for the noted terms.
# #
# .. note:: # .. note::
# You can understand :math:`\{Pm, Pd\}` more thoroughly with this explanation. # You can understand :math:`\{Pm, Pd\}` more thoroughly with this explanation.
# Roughly speaking, there is a relationship between how :math:`g` and # Roughly speaking, there is a relationship between how :math:`g` and
# :math:`lg` (the line graph) work together with loopy brief propagation. # :math:`lg` (the line graph) work together with loopy brief propagation.
# Here, you implement :math:`\{Pm, Pd\}` as a SciPy COO sparse matrix in the dataset, # Here, you implement :math:`\{Pm, Pd\}` as a SciPy COO sparse matrix in the dataset,
...@@ -329,21 +329,21 @@ visualize(label1, nx_G1) ...@@ -329,21 +329,21 @@ visualize(label1, nx_G1)
# multiplication. Write them as PyTorch tensor operations. # multiplication. Write them as PyTorch tensor operations.
# #
# In ``__init__``, you define the projection variables. # In ``__init__``, you define the projection variables.
# #
# :: # ::
# #
# self.linear_prev = nn.Linear(in_feats, out_feats) # self.linear_prev = nn.Linear(in_feats, out_feats)
# self.linear_deg = nn.Linear(in_feats, out_feats) # self.linear_deg = nn.Linear(in_feats, out_feats)
# #
# #
# In ``forward()``, :math:`\text{prev}` and :math:`\text{deg}` are the same # In ``forward()``, :math:`\text{prev}` and :math:`\text{deg}` are the same
# as any other PyTorch tensor operations. # as any other PyTorch tensor operations.
# #
# :: # ::
# #
# prev_proj = self.linear_prev(feat_a) # prev_proj = self.linear_prev(feat_a)
# deg_proj = self.linear_deg(deg * feat_a) # deg_proj = self.linear_deg(deg * feat_a)
# #
# Implementing :math:`\text{radius}` as message passing in DGL # Implementing :math:`\text{radius}` as message passing in DGL
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# As discussed in GCN tutorial, you can formulate one adjacency operator as # As discussed in GCN tutorial, you can formulate one adjacency operator as
...@@ -355,14 +355,14 @@ visualize(label1, nx_G1) ...@@ -355,14 +355,14 @@ visualize(label1, nx_G1)
# #
# In ``__init__``, define the projection variables used in each # In ``__init__``, define the projection variables used in each
# :math:`2^j` steps of message passing. # :math:`2^j` steps of message passing.
# #
# :: # ::
# #
# self.linear_radius = nn.ModuleList( # self.linear_radius = nn.ModuleList(
# [nn.Linear(in_feats, out_feats) for i in range(radius)]) # [nn.Linear(in_feats, out_feats) for i in range(radius)])
# #
# In ``__forward__``, use following function ``aggregate_radius()`` to # In ``__forward__``, use following function ``aggregate_radius()`` to
# gather data from multiple hops. This can be seen in the following code. # gather data from multiple hops. This can be seen in the following code.
# Note that the ``update_all`` is called multiple times. # Note that the ``update_all`` is called multiple times.
# Return a list containing features gathered from multiple radius. # Return a list containing features gathered from multiple radius.
...@@ -389,32 +389,32 @@ def aggregate_radius(radius, g, z): ...@@ -389,32 +389,32 @@ def aggregate_radius(radius, g, z):
# and implement :math:`\text{fuse}` as a sparse matrix multiplication. # and implement :math:`\text{fuse}` as a sparse matrix multiplication.
# #
# in ``__forward__``: # in ``__forward__``:
# #
# :: # ::
# #
# fuse = self.linear_fuse(th.mm(pm_pd, feat_b)) # fuse = self.linear_fuse(th.mm(pm_pd, feat_b))
# #
# Completing :math:`f(x, y)` # Completing :math:`f(x, y)`
# ~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~
# Finally, the following shows how to sum up all the terms together, pass it to skip connection, and # Finally, the following shows how to sum up all the terms together, pass it to skip connection, and
# batch norm. # batch norm.
# #
# :: # ::
# #
# result = prev_proj + deg_proj + radius_proj + fuse # result = prev_proj + deg_proj + radius_proj + fuse
# #
# Pass result to skip connection. # Pass result to skip connection.
# #
# :: # ::
# #
# result = th.cat([result[:, :n], F.relu(result[:, n:])], 1) # result = th.cat([result[:, :n], F.relu(result[:, n:])], 1)
# #
# Then pass the result to batch norm. # Then pass the result to batch norm.
# #
# :: # ::
# #
# result = self.bn(result) #Batch Normalization. # result = self.bn(result) #Batch Normalization.
# #
# #
# Here is the complete code for one LGNN layer's abstraction :math:`f(x,y)` # Here is the complete code for one LGNN layer's abstraction :math:`f(x,y)`
class LGNNCore(nn.Module): class LGNNCore(nn.Module):
...@@ -460,7 +460,7 @@ class LGNNCore(nn.Module): ...@@ -460,7 +460,7 @@ class LGNNCore(nn.Module):
# Chain-up LGNN abstractions as an LGNN layer # Chain-up LGNN abstractions as an LGNN layer
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# To implement: # To implement:
# #
# .. math:: # .. math::
# \begin{split} # \begin{split}
# x^{(k+1)} = {}& f(x^{(k)}, y^{(k)})\\ # x^{(k+1)} = {}& f(x^{(k)}, y^{(k)})\\
...@@ -518,7 +518,7 @@ training_loader = DataLoader(train_set, ...@@ -518,7 +518,7 @@ training_loader = DataLoader(train_set,
# array in ``numpy.ndarray``. Generate the line graph by using this command: # array in ``numpy.ndarray``. Generate the line graph by using this command:
# #
# :: # ::
# #
# lg = g.line_graph(backtracking=False) # lg = g.line_graph(backtracking=False)
# #
# Note that ``backtracking=False`` is required to correctly simulate non-backtracking # Note that ``backtracking=False`` is required to correctly simulate non-backtracking
...@@ -547,7 +547,7 @@ for i in range(20): ...@@ -547,7 +547,7 @@ for i in range(20):
# Create torch tensors # Create torch tensors
pmpd = sparse2th(pmpd) pmpd = sparse2th(pmpd)
label = th.from_numpy(label) label = th.from_numpy(label)
# Forward # Forward
z = model(g, lg, pmpd) z = model(g, lg, pmpd)
...@@ -594,7 +594,7 @@ visualize(label1, nx_G1) ...@@ -594,7 +594,7 @@ visualize(label1, nx_G1)
######################################### #########################################
# Here is an animation to better understand the process. (40 epochs) # Here is an animation to better understand the process. (40 epochs)
# #
# .. figure:: https://i.imgur.com/KDUyE1S.gif # .. figure:: https://i.imgur.com/KDUyE1S.gif
# :alt: lgnn-anim # :alt: lgnn-anim
# #
# Batching graphs for parallelism # Batching graphs for parallelism
...@@ -619,5 +619,5 @@ def collate_fn(batch): ...@@ -619,5 +619,5 @@ def collate_fn(batch):
return batched_graphs, batched_pmpds, batched_labels return batched_graphs, batched_pmpds, batched_labels
###################################################################################### ######################################################################################
# You can find the complete code on Github at # You can find the complete code on Github at
# `Community Detection with Graph Neural Networks (CDGNN) <https://github.com/dmlc/dgl/tree/master/examples/pytorch/line_graph>`_. # `Community Detection with Graph Neural Networks (CDGNN) <https://github.com/dmlc/dgl/tree/master/examples/pytorch/line_graph>`_.
...@@ -16,8 +16,8 @@ Understand Graph Attention Network ...@@ -16,8 +16,8 @@ Understand Graph Attention Network
efficiency. For recommended implementation, please refer to the `official efficiency. For recommended implementation, please refer to the `official
examples <https://github.com/dmlc/dgl/tree/master/examples>`_. examples <https://github.com/dmlc/dgl/tree/master/examples>`_.
In this tutorial, you learn about a graph attention network (GAT) and how it can be In this tutorial, you learn about a graph attention network (GAT) and how it can be
implemented in PyTorch. You can also learn to visualize and understand what the attention implemented in PyTorch. You can also learn to visualize and understand what the attention
mechanism has learned. mechanism has learned.
The research described in the paper `Graph Convolutional Network (GCN) <https://arxiv.org/abs/1609.02907>`_, The research described in the paper `Graph Convolutional Network (GCN) <https://arxiv.org/abs/1609.02907>`_,
...@@ -93,7 +93,7 @@ structure-free normalization, in the style of attention. ...@@ -93,7 +93,7 @@ structure-free normalization, in the style of attention.
# scaled by the attention scores. # scaled by the attention scores.
# #
# There are other details from the paper, such as dropout and skip connections. # There are other details from the paper, such as dropout and skip connections.
# For the purpose of simplicity, those details are left out of this tutorial. To see more details, # For the purpose of simplicity, those details are left out of this tutorial. To see more details,
# download the `full example <https://github.com/dmlc/dgl/blob/master/examples/pytorch/gat/gat.py>`_. # download the `full example <https://github.com/dmlc/dgl/blob/master/examples/pytorch/gat/gat.py>`_.
# In its essence, GAT is just a different aggregation function with attention # In its essence, GAT is just a different aggregation function with attention
# over features of neighbors, instead of a simple mean aggregation. # over features of neighbors, instead of a simple mean aggregation.
...@@ -111,7 +111,7 @@ from dgl.nn.pytorch import GATConv ...@@ -111,7 +111,7 @@ from dgl.nn.pytorch import GATConv
# jump to the `Put everything together`_ for training and visualization results. # jump to the `Put everything together`_ for training and visualization results.
# #
# To begin, you can get an overall impression about how a ``GATLayer`` module is # To begin, you can get an overall impression about how a ``GATLayer`` module is
# implemented in DGL. In this section, the four equations above are broken down # implemented in DGL. In this section, the four equations above are broken down
# one at a time. # one at a time.
# #
# .. note:: # .. note::
...@@ -306,7 +306,7 @@ def load_cora_data(): ...@@ -306,7 +306,7 @@ def load_cora_data():
features = torch.FloatTensor(data.features) features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels) labels = torch.LongTensor(data.labels)
mask = torch.BoolTensor(data.train_mask) mask = torch.BoolTensor(data.train_mask)
g = DGLGraph(data.graph) g = data[0]
return g, features, labels, mask return g, features, labels, mask
############################################################################## ##############################################################################
...@@ -355,7 +355,7 @@ for epoch in range(30): ...@@ -355,7 +355,7 @@ for epoch in range(30):
# ^^^^ # ^^^^
# #
# The following table summarizes the model performance on Cora that is reported in # The following table summarizes the model performance on Cora that is reported in
# `the GAT paper <https://arxiv.org/pdf/1710.10903.pdf>`_ and obtained with DGL # `the GAT paper <https://arxiv.org/pdf/1710.10903.pdf>`_ and obtained with DGL
# implementations. # implementations.
# #
# .. list-table:: # .. list-table::
...@@ -460,15 +460,15 @@ for epoch in range(30): ...@@ -460,15 +460,15 @@ for epoch in range(30):
# .. note:: # .. note::
# #
# Below is the calculation process of F1 score: # Below is the calculation process of F1 score:
# #
# .. math:: # .. math::
# #
# precision=\frac{\sum_{t=1}^{n}TP_{t}}{\sum_{t=1}^{n}(TP_{t} +FP_{t})} # precision=\frac{\sum_{t=1}^{n}TP_{t}}{\sum_{t=1}^{n}(TP_{t} +FP_{t})}
# #
# recall=\frac{\sum_{t=1}^{n}TP_{t}}{\sum_{t=1}^{n}(TP_{t} +FN_{t})} # recall=\frac{\sum_{t=1}^{n}TP_{t}}{\sum_{t=1}^{n}(TP_{t} +FN_{t})}
# #
# F1_{micro}=2\frac{precision*recall}{precision+recall} # F1_{micro}=2\frac{precision*recall}{precision+recall}
# #
# * :math:`TP_{t}` represents for number of nodes that both have and are predicted to have label :math:`t` # * :math:`TP_{t}` represents for number of nodes that both have and are predicted to have label :math:`t`
# * :math:`FP_{t}` represents for number of nodes that do not have but are predicted to have label :math:`t` # * :math:`FP_{t}` represents for number of nodes that do not have but are predicted to have label :math:`t`
# * :math:`FN_{t}` represents for number of output classes labeled as :math:`t` but predicted as others. # * :math:`FN_{t}` represents for number of output classes labeled as :math:`t` but predicted as others.
...@@ -498,7 +498,7 @@ for epoch in range(30): ...@@ -498,7 +498,7 @@ for epoch in range(30):
# #
# |image7| # |image7|
# #
# Again, comparing with uniform distribution: # Again, comparing with uniform distribution:
# #
# .. image:: https://data.dgl.ai/tutorial/gat/ppi-uniform-hist.png # .. image:: https://data.dgl.ai/tutorial/gat/ppi-uniform-hist.png
# :width: 250px # :width: 250px
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment