"docs/source/vscode:/vscode.git/clone" did not exist on "b35388020d48ef027fb0375268aa03dcbe9acb3e"
Unverified Commit 5798ee8d authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Performance] Add a warning for ChebConv (#3099)



* add a warning for chebconv

* fix and docstrings

* update bgnn

* fix
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 183d29de
......@@ -61,7 +61,12 @@ class GNNModelDGL(torch.nn.Module):
h = self.l1(graph, h)
h = self.l2(graph, h)
logits = self.lin2(h)
elif self.name in ['gcn', 'cheb']:
elif self.name == 'che3b':
lambda_max = dgl.laplacian_lambda_max(graph)
h = self.drop(h)
h = self.l1(graph, h, lambda_max)
logits = self.l2(graph, h, lambda_max)
elif self.name == 'gcn':
h = self.drop(h)
h = self.l1(graph, h)
logits = self.l2(graph, h)
......
......@@ -5,7 +5,8 @@ import mxnet as mx
from mxnet import nd
from mxnet.gluon import nn
from .... import laplacian_lambda_max, broadcast_nodes, function as fn
from ....base import dgl_warning
from .... import broadcast_nodes, function as fn
class ChebConv(nn.Block):
......@@ -50,7 +51,6 @@ class ChebConv(nn.Block):
>>> import numpy as np
>>> import mxnet as mx
>>> from dgl.nn import ChebConv
>>>
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> feat = mx.nd.ones((6, 10))
>>> conv = ChebConv(10, 2, 2)
......@@ -106,8 +106,9 @@ class ChebConv(nn.Block):
A list(tensor) with length :math:`B`, stores the largest eigenvalue
of the normalized laplacian of each individual graph in ``graph``,
where :math:`B` is the batch size of the input graph. Default: None.
If None, this method would compute the list by calling
``dgl.laplacian_lambda_max``.
If None, this method would set the default value to 2.
One can use :func:`dgl.laplacian_lambda_max` to compute this value.
Returns
-------
......@@ -119,8 +120,13 @@ class ChebConv(nn.Block):
degs = graph.in_degrees().astype('float32')
norm = mx.nd.power(mx.nd.clip(degs, a_min=1, a_max=float("inf")), -0.5)
norm = norm.expand_dims(-1).as_in_context(feat.context)
if lambda_max is None:
lambda_max = laplacian_lambda_max(graph)
dgl_warning(
"lambda_max is not provided, using default value of 2. "
"Please use dgl.laplacian_lambda_max to compute the eigenvalues.")
lambda_max = [2] * graph.batch_size
if isinstance(lambda_max, list):
lambda_max = nd.array(lambda_max).as_in_context(feat.context)
if lambda_max.ndim == 1:
......
......@@ -5,7 +5,7 @@ from torch import nn
import torch.nn.functional as F
from ....base import dgl_warning
from .... import laplacian_lambda_max, broadcast_nodes, function as fn
from .... import broadcast_nodes, function as fn
class ChebConv(nn.Module):
......@@ -95,8 +95,9 @@ class ChebConv(nn.Module):
A list(tensor) with length :math:`B`, stores the largest eigenvalue
of the normalized laplacian of each individual graph in ``graph``,
where :math:`B` is the batch size of the input graph. Default: None.
If None, this method would compute the list by calling
``dgl.laplacian_lambda_max``.
If None, this method would set the default value to 2.
One can use :func:`dgl.laplacian_lambda_max` to compute this value.
Returns
-------
......@@ -115,17 +116,13 @@ class ChebConv(nn.Module):
min=1), -0.5).unsqueeze(-1).to(feat.device)
if lambda_max is None:
try:
lambda_max = laplacian_lambda_max(graph)
except BaseException:
# if the largest eigenvalue is not found
dgl_warning(
"Largest eigonvalue not found, using default value 2 for lambda_max",
RuntimeWarning)
lambda_max = th.Tensor(2).to(feat.device)
"lambda_max is not provided, using default value of 2. "
"Please use dgl.laplacian_lambda_max to compute the eigenvalues.")
lambda_max = [2] * graph.batch_size
if isinstance(lambda_max, list):
lambda_max = th.Tensor(lambda_max).to(feat.device)
lambda_max = th.Tensor(lambda_max).to(feat)
if lambda_max.dim() == 1:
lambda_max = lambda_max.unsqueeze(-1) # (B,) to (B, 1)
......
......@@ -5,7 +5,7 @@ from tensorflow.keras import layers
import numpy as np
from ....base import dgl_warning
from .... import laplacian_lambda_max, broadcast_nodes, function as fn
from .... import broadcast_nodes, function as fn
class ChebConv(layers.Layer):
......@@ -50,13 +50,12 @@ class ChebConv(layers.Layer):
>>> import numpy as np
>>> import tensorflow as tf
>>> from dgl.nn import ChebConv
>>>
>>> with tf.device("CPU:0"):
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
>>> feat = tf.ones((6, 10))
>>> conv = ChebConv(10, 2, 2)
>>> res = conv(g, feat)
>>> res
... g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
... feat = tf.ones((6, 10))
... conv = ChebConv(10, 2, 2)
... res = conv(g, feat)
... res
<tf.Tensor: shape=(6, 2), dtype=float32, numpy=
array([[ 0.6163, -0.1809],
[ 0.6163, -0.1809],
......@@ -97,8 +96,9 @@ class ChebConv(layers.Layer):
A list(tensor) with length :math:`B`, stores the largest eigenvalue
of the normalized laplacian of each individual graph in ``graph``,
where :math:`B` is the batch size of the input graph. Default: None.
If None, this method would compute the list by calling
``dgl.laplacian_lambda_max``.
If None, this method would set the default value to 2.
One can use :func:`dgl.laplacian_lambda_max` to compute this value.
Returns
-------
......@@ -117,15 +117,12 @@ class ChebConv(layers.Layer):
clip_value_min=1,
clip_value_max=np.inf)
D_invsqrt = tf.expand_dims(tf.pow(in_degrees, -0.5), axis=-1)
if lambda_max is None:
try:
lambda_max = laplacian_lambda_max(graph)
except BaseException:
# if the largest eigenvalue is not found
dgl_warning(
"Largest eigonvalue not found, using default value 2 for lambda_max",
RuntimeWarning)
lambda_max = tf.constant(2, dtype=tf.float32)
"lambda_max is not provided, using default value of 2. "
"Please use dgl.laplacian_lambda_max to compute the eigenvalues.")
lambda_max = [2] * graph.batch_size
if isinstance(lambda_max, list):
lambda_max = tf.constant(lambda_max, dtype=tf.float32)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment