Unverified Commit db57809d authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[hotfix] Add data type check for kernels. (#2598)

* upd

* upd
parent 460bb42d
......@@ -120,6 +120,11 @@ def _gspmm(gidx, op, reduce_op, u, e):
raise DGLError("We only support gspmm on graph with one edge type")
use_u = op != 'copy_rhs'
use_e = op != 'copy_lhs'
if use_u and use_e:
if F.dtype(u) != F.dtype(e):
raise DGLError("The node features' data type {} doesn't match edge"
" features' data type {}, please convert them to the"
" same type.".format(F.dtype(u), F.dtype(e)))
# deal with scalar features.
expand_u, expand_e = False, False
if use_u:
......@@ -219,6 +224,10 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
raise DGLError("We only support gsddmm on graph with one edge type")
use_lhs = op != 'copy_rhs'
use_rhs = op != 'copy_lhs'
if use_lhs and use_rhs:
if F.dtype(lhs) != F.dtype(rhs):
raise DGLError("The operands data type don't match: {} and {}, please convert them"
" to the same type.".format(F.dtype(lhs), F.dtype(rhs)))
# deal with scalar features.
expand_lhs, expand_rhs = False, False
if use_lhs:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment