Unverified Commit 25c9221b authored by Jeremy Goh's avatar Jeremy Goh Committed by GitHub
Browse files

Add check for aggregator_type enum in SAGEConv init (#3691)


Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 9c8c162a
......@@ -6,6 +6,7 @@ from mxnet import nd
from mxnet.gluon import nn
from .... import function as fn
from ....base import DGLError
from ....utils import expand_as_pair, check_eq_shape
class SAGEConv(nn.Block):
......@@ -101,6 +102,12 @@ class SAGEConv(nn.Block):
norm=None,
activation=None):
super(SAGEConv, self).__init__()
valid_aggre_types = {'mean', 'gcn', 'pool', 'lstm'}
if aggregator_type not in valid_aggre_types:
raise DGLError(
'Invalid aggregator_type. Must be one of {}. '
'But got {!r} instead.'.format(valid_aggre_types, aggregator_type)
)
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
......
......@@ -5,6 +5,7 @@ from torch import nn
from torch.nn import functional as F
from .... import function as fn
from ....base import DGLError
from ....utils import expand_as_pair, check_eq_shape, dgl_warning
......@@ -106,6 +107,12 @@ class SAGEConv(nn.Module):
norm=None,
activation=None):
super(SAGEConv, self).__init__()
valid_aggre_types = {'mean', 'gcn', 'pool', 'lstm'}
if aggregator_type not in valid_aggre_types:
raise DGLError(
'Invalid aggregator_type. Must be one of {}. '
'But got {!r} instead.'.format(valid_aggre_types, aggregator_type)
)
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
......
......@@ -4,6 +4,7 @@ import tensorflow as tf
from tensorflow.keras import layers
from .... import function as fn
from ....base import DGLError
from ....utils import expand_as_pair, check_eq_shape
......@@ -100,6 +101,12 @@ class SAGEConv(layers.Layer):
norm=None,
activation=None):
super(SAGEConv, self).__init__()
valid_aggre_types = {'mean', 'gcn', 'pool', 'lstm'}
if aggregator_type not in valid_aggre_types:
raise DGLError(
'Invalid aggregator_type. Must be one of {}. '
'But got {!r} instead.'.format(valid_aggre_types, aggregator_type)
)
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment