Commit 3c14f46c authored by Tomas Simon's avatar Tomas Simon Committed by Facebook GitHub Bot
Browse files

Fix edge grad assert mismatch

Summary:
* The sizes for vi in the edge_grad_estimator_fwd assert were not updated after D68534639 expanded the dimension to 3
* This updates the size in the assert and adds an explicit call to edge_grad_estimator_fwd (a no-op) in the autograd implementation to make sure the sizes are checked

Reviewed By: HapeMask, phg1024

Differential Revision: D72433642

fbshipit-source-id: 49dd82e0a07fe174c2157b362eedf464984d386d
parent 85f58cf1
Pipeline #2864 canceled with stages
......@@ -63,9 +63,9 @@ torch::Tensor edge_grad_estimator_fwd(
index_img.layout() == torch::kStrided,
"edge_grad_estimator(): expected all inputs to have torch.strided layout");
TORCH_CHECK(
(v_pix.dim() == 3) && (v_pix_img.dim() == 4) && (vi.dim() == 2) && (img.dim() == 4) &&
(v_pix.dim() == 3) && (v_pix_img.dim() == 4) && (vi.dim() == 3) && (img.dim() == 4) &&
(index_img.dim() == 3),
"edge_grad_estimator(): expected v_pix.ndim == 3, v_pix_img.ndim == 4, vi.ndim == 2, img.ndim == 4, index_img.ndim == 3, "
"edge_grad_estimator(): expected v_pix.ndim == 3, v_pix_img.ndim == 4, vi.ndim == 3, img.ndim == 4, index_img.ndim == 3, "
"but got v_pix with sizes ",
v_pix.sizes(),
" and v_pix_img with sizes ",
......@@ -89,14 +89,14 @@ torch::Tensor edge_grad_estimator_fwd(
" and index_img with sizes ",
index_img.sizes());
TORCH_CHECK(
v_pix.size(2) == 3 && v_pix_img.size(1) == 3 && vi.size(1) == 3,
"edge_grad_estimator(): expected third dim of v_pix to be of size 3, and second dim of vi to be of size 3, but got ",
v_pix.size(2) == 3 && v_pix_img.size(1) == 3 && vi.size(2) == 3,
"edge_grad_estimator(): expected third dim of v_pix to be of size 3, and third dim of vi to be of size 3, but got ",
v_pix.size(2),
" in the third dim of v_pix, and ",
v_pix_img.size(1),
" in the second dim of v_pix_img, and ",
vi.size(1),
" in the second dim of vi");
vi.size(2),
" in the third dim of vi");
TORCH_CHECK(
v_pix_img.size(3) == img.size(3) && v_pix_img.size(3) == index_img.size(2) &&
v_pix_img.size(2) == img.size(2) && v_pix_img.size(2) == index_img.size(1),
......@@ -120,6 +120,8 @@ class EdgeGradEstimatorFunction : public torch::autograd::Function<EdgeGradEstim
const torch::Tensor& vi,
const torch::Tensor& img,
const torch::Tensor& index_img) {
// Call edge_grad_estimator_fwd to check the input sizes
edge_grad_estimator_fwd(v_pix, v_pix_img, vi, img, index_img);
ctx->set_materialize_grads(false);
ctx->save_for_backward({v_pix, img, index_img, vi});
ctx->saved_data["v_pix_img_requires_grad"] = v_pix_img.requires_grad();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment