Unverified Commit 25c9221b authored by Jeremy Goh's avatar Jeremy Goh Committed by GitHub
Browse files

Add check for aggregator_type enum in SAGEConv init (#3691)


Co-authored-by: default avatarQuan (Andy) Gan <coin2028@hotmail.com>
Co-authored-by: default avatarMufei Li <mufeili1996@gmail.com>
parent 9c8c162a
...@@ -6,6 +6,7 @@ from mxnet import nd ...@@ -6,6 +6,7 @@ from mxnet import nd
from mxnet.gluon import nn from mxnet.gluon import nn
from .... import function as fn from .... import function as fn
from ....base import DGLError
from ....utils import expand_as_pair, check_eq_shape from ....utils import expand_as_pair, check_eq_shape
class SAGEConv(nn.Block): class SAGEConv(nn.Block):
...@@ -101,6 +102,12 @@ class SAGEConv(nn.Block): ...@@ -101,6 +102,12 @@ class SAGEConv(nn.Block):
norm=None, norm=None,
activation=None): activation=None):
super(SAGEConv, self).__init__() super(SAGEConv, self).__init__()
valid_aggre_types = {'mean', 'gcn', 'pool', 'lstm'}
if aggregator_type not in valid_aggre_types:
raise DGLError(
'Invalid aggregator_type. Must be one of {}. '
'But got {!r} instead.'.format(valid_aggre_types, aggregator_type)
)
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats self._out_feats = out_feats
......
...@@ -5,6 +5,7 @@ from torch import nn ...@@ -5,6 +5,7 @@ from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from .... import function as fn from .... import function as fn
from ....base import DGLError
from ....utils import expand_as_pair, check_eq_shape, dgl_warning from ....utils import expand_as_pair, check_eq_shape, dgl_warning
...@@ -106,6 +107,12 @@ class SAGEConv(nn.Module): ...@@ -106,6 +107,12 @@ class SAGEConv(nn.Module):
norm=None, norm=None,
activation=None): activation=None):
super(SAGEConv, self).__init__() super(SAGEConv, self).__init__()
valid_aggre_types = {'mean', 'gcn', 'pool', 'lstm'}
if aggregator_type not in valid_aggre_types:
raise DGLError(
'Invalid aggregator_type. Must be one of {}. '
'But got {!r} instead.'.format(valid_aggre_types, aggregator_type)
)
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats self._out_feats = out_feats
......
...@@ -4,6 +4,7 @@ import tensorflow as tf ...@@ -4,6 +4,7 @@ import tensorflow as tf
from tensorflow.keras import layers from tensorflow.keras import layers
from .... import function as fn from .... import function as fn
from ....base import DGLError
from ....utils import expand_as_pair, check_eq_shape from ....utils import expand_as_pair, check_eq_shape
...@@ -100,6 +101,12 @@ class SAGEConv(layers.Layer): ...@@ -100,6 +101,12 @@ class SAGEConv(layers.Layer):
norm=None, norm=None,
activation=None): activation=None):
super(SAGEConv, self).__init__() super(SAGEConv, self).__init__()
valid_aggre_types = {'mean', 'gcn', 'pool', 'lstm'}
if aggregator_type not in valid_aggre_types:
raise DGLError(
'Invalid aggregator_type. Must be one of {}. '
'But got {!r} instead.'.format(valid_aggre_types, aggregator_type)
)
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
self._out_feats = out_feats self._out_feats = out_feats
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment