Unverified Commit dde5cf5d authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Bug] Fix NN modules crashing with non-FP32 inputs (reopen #4262) (#4829)



* fix fp16 in nn modules

* fix lint

* remove the example
Co-authored-by: default avatarQuan Gan <coin2028@hotmail.com>
parent 9a40d208
......@@ -91,11 +91,13 @@ class APPNPConv(nn.Module):
with graph.local_scope():
if edge_weight is None:
src_norm = th.pow(
graph.out_degrees().float().clamp(min=1), -0.5
graph.out_degrees().to(feat).clamp(min=1), -0.5
)
shp = src_norm.shape + (1,) * (feat.dim() - 1)
src_norm = th.reshape(src_norm, shp).to(feat.device)
dst_norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5)
dst_norm = th.pow(
graph.in_degrees().to(feat).clamp(min=1), -0.5
)
shp = dst_norm.shape + (1,) * (feat.dim() - 1)
dst_norm = th.reshape(dst_norm, shp).to(feat.device)
else:
......
......@@ -285,9 +285,11 @@ class AtomicConv(nn.Module):
number of radial filters, and :math:`T` for the number of types of atomic numbers.
"""
with graph.local_scope():
radial_pooled_values = self.radial_pooling(distances) # (K, E, 1)
radial_pooled_values = self.radial_pooling(distances).to(
feat
) # (K, E, 1)
if self.features_to_use is not None:
feat = (feat == self.features_to_use).float() # (V, T)
feat = (feat == self.features_to_use).to(feat) # (V, T)
graph.ndata["hv"] = feat
graph.edata["he"] = radial_pooled_values.transpose(1, 0).squeeze(
-1
......
......@@ -102,9 +102,8 @@ class ChebConv(nn.Module):
with graph.local_scope():
D_invsqrt = (
th.pow(graph.in_degrees().float().clamp(min=1), -0.5)
th.pow(graph.in_degrees().to(feat).clamp(min=1), -0.5)
.unsqueeze(-1)
.to(feat.device)
)
if lambda_max is None:
......
......@@ -107,7 +107,7 @@ class DenseGraphConv(nn.Module):
The output feature of shape :math:`(N, D_{out})` where :math:`D_{out}`
is size of output feature.
"""
adj = adj.float().to(feat.device)
adj = adj.to(feat)
src_degrees = adj.sum(dim=0).clamp(min=1)
dst_degrees = adj.sum(dim=1).clamp(min=1)
feat_src = feat
......
......@@ -124,7 +124,7 @@ class DenseSAGEConv(nn.Module):
feat_dst = self.feat_drop(feat[1])
else:
feat_src = feat_dst = self.feat_drop(feat)
adj = adj.float().to(feat_src.device)
adj = adj.to(feat_src)
in_degrees = adj.sum(dim=1, keepdim=True)
h_neigh = (adj @ feat_src + feat_dst) / (in_degrees + 1)
rst = self.fc(h_neigh)
......
......@@ -232,7 +232,7 @@ class GCN2Conv(nn.Module):
# normalize to get smoothed representation
if edge_weight is None:
degs = graph.in_degrees().float().clamp(min=1)
degs = graph.in_degrees().to(feat).clamp(min=1)
norm = th.pow(degs, -0.5)
norm = norm.to(feat.device).unsqueeze(1)
else:
......
......@@ -107,8 +107,11 @@ class EdgeWeightNorm(nn.Module):
'This leads to square root of zero or negative values.')
dev = graph.device
graph.srcdata['_src_out_w'] = th.ones((graph.number_of_src_nodes())).float().to(dev)
graph.dstdata['_dst_in_w'] = th.ones((graph.number_of_dst_nodes())).float().to(dev)
dtype = edge_weight.dtype
graph.srcdata['_src_out_w'] = th.ones(
graph.number_of_src_nodes(), dtype=dtype, device=dev)
graph.dstdata['_dst_in_w'] = th.ones(
graph.number_of_dst_nodes(), dtype=dtype, device=dev)
graph.edata['_edge_w'] = edge_weight
if self._norm == 'both':
......@@ -398,7 +401,7 @@ class GraphConv(nn.Module):
# (BarclayII) For RGCN on heterogeneous graphs we need to support GCN on bipartite.
feat_src, feat_dst = expand_as_pair(feat, graph)
if self._norm in ['left', 'both']:
degs = graph.out_degrees().float().clamp(min=1)
degs = graph.out_degrees().to(feat_src).clamp(min=1)
if self._norm == 'both':
norm = th.pow(degs, -0.5)
else:
......@@ -431,7 +434,7 @@ class GraphConv(nn.Module):
rst = th.matmul(rst, weight)
if self._norm in ['right', 'both']:
degs = graph.in_degrees().float().clamp(min=1)
degs = graph.in_degrees().to(feat_dst).clamp(min=1)
if self._norm == 'both':
norm = th.pow(degs, -0.5)
else:
......
......@@ -196,7 +196,7 @@ class SGConv(nn.Module):
else:
if edge_weight is None:
# compute normalization
degs = graph.in_degrees().float().clamp(min=1)
degs = graph.in_degrees().to(feat).clamp(min=1)
norm = th.pow(degs, -0.5)
norm = norm.to(feat.device).unsqueeze(1)
# compute (D^-1 A^k D)^k X
......
......@@ -117,7 +117,7 @@ class TAGConv(nn.Module):
with graph.local_scope():
assert graph.is_homogeneous, "Graph is not homogeneous"
if edge_weight is None:
norm = th.pow(graph.in_degrees().float().clamp(min=1), -0.5)
norm = th.pow(graph.in_degrees().to(feat).clamp(min=1), -0.5)
shp = norm.shape + (1,) * (feat.dim() - 1)
norm = th.reshape(norm, shp).to(feat.device)
......
......@@ -590,7 +590,7 @@ class TWIRLSUnfoldingAndAttention(nn.Module):
Y = X
g.edata["w"] = tc.ones(g.number_of_edges(), 1, device=g.device)
g.ndata["deg"] = g.in_degrees().float()
g.ndata["deg"] = g.in_degrees().to(X)
if self.init_att:
g = self.init_attn(g, Y, self.etas)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment