Unverified Commit 978ac410 authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[NN] Fix RelGraphConv num_bases error message (#1309)



* Fix RelGraphConv when num_bases is None

* Improve error message in when num_bases is wrong

* Fix integer modulo by zero exception when num_bases is 0
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 4ec8f204
......@@ -72,7 +72,7 @@ class RelGraphConv(nn.Module):
self.num_rels = num_rels
self.regularizer = regularizer
self.num_bases = num_bases
if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases < 0:
if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases <= 0:
self.num_bases = self.num_rels
self.bias = bias
self.activation = activation
......@@ -91,8 +91,11 @@ class RelGraphConv(nn.Module):
# message func
self.message_func = self.basis_message_func
elif regularizer == "bdd":
if in_feat % num_bases != 0 or out_feat % num_bases != 0:
raise ValueError('Feature size must be a multiplier of num_bases.')
if in_feat % self.num_bases != 0 or out_feat % self.num_bases != 0:
raise ValueError(
'Feature size must be a multiplier of num_bases (%d).'
% self.num_bases
)
# add block diagonal weights
self.submat_in = in_feat // self.num_bases
self.submat_out = out_feat // self.num_bases
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment