Commit 335525a7 authored by rusty1s's avatar rusty1s
Browse files

fix nan values

parent 68f4609c
...@@ -97,6 +97,7 @@ public: ...@@ -97,6 +97,7 @@ public:
auto dim = ctx->saved_data["dim"].toInt(); auto dim = ctx->saved_data["dim"].toInt();
auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList()); auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList());
auto grad_in = torch::gather(grad_out * out, dim, index, false).div_(src); auto grad_in = torch::gather(grad_out * out, dim, index, false).div_(src);
grad_in.masked_fill_(grad_in.isnan(), 0);
return {grad_in, Variable(), Variable(), Variable(), Variable()}; return {grad_in, Variable(), Variable(), Variable(), Variable()};
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment