Unverified Commit 2e0949e2 authored by Beat Buesser's avatar Beat Buesser Committed by GitHub
Browse files

Allow gradient backpropagation through GeneralizedRCNNTransform to inputs (#4327)



* Allow gradient backpropagation through GeneralizedRCNNTransform to inputs
Signed-off-by: default avatarBeat Buesser <beat.buesser@ie.ibm.com>

* Add unit tests for gradient backpropagation to inputs
Signed-off-by: default avatarBeat Buesser <beat.buesser@ie.ibm.com>

* Update torchvision/models/detection/transform.py
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>

* Update _check_input_backprop
Signed-off-by: default avatarBeat Buesser <beat.buesser@ie.ibm.com>

* Account for tests requiring cuda
Signed-off-by: default avatarBeat Buesser <beat.buesser@ie.ibm.com>
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 526a69e8
...@@ -148,6 +148,35 @@ def _check_fx_compatible(model, inputs): ...@@ -148,6 +148,35 @@ def _check_fx_compatible(model, inputs):
torch.testing.assert_close(out, out_fx) torch.testing.assert_close(out, out_fx)
def _check_input_backprop(model, inputs):
if isinstance(inputs, list):
requires_grad = list()
for inp in inputs:
requires_grad.append(inp.requires_grad)
inp.requires_grad_(True)
else:
requires_grad = inputs.requires_grad
inputs.requires_grad_(True)
out = model(inputs)
if isinstance(out, dict):
out["out"].sum().backward()
else:
if isinstance(out[0], dict):
out[0]["scores"].sum().backward()
else:
out[0].sum().backward()
if isinstance(inputs, list):
for i, inp in enumerate(inputs):
assert inputs[i].grad is not None
inp.requires_grad_(requires_grad[i])
else:
assert inputs.grad is not None
inputs.requires_grad_(requires_grad)
# If 'unwrapper' is provided it will be called with the script model outputs # If 'unwrapper' is provided it will be called with the script model outputs
# before they are compared to the eager model outputs. This is useful if the # before they are compared to the eager model outputs. This is useful if the
# model outputs are different between TorchScript / Eager mode # model outputs are different between TorchScript / Eager mode
...@@ -263,6 +292,9 @@ def test_memory_efficient_densenet(model_name): ...@@ -263,6 +292,9 @@ def test_memory_efficient_densenet(model_name):
assert num_params == num_grad assert num_params == num_grad
torch.testing.assert_close(out1, out2, rtol=0.0, atol=1e-5) torch.testing.assert_close(out1, out2, rtol=0.0, atol=1e-5)
_check_input_backprop(model1, x)
_check_input_backprop(model2, x)
@pytest.mark.parametrize('dilate_layer_2', (True, False)) @pytest.mark.parametrize('dilate_layer_2', (True, False))
@pytest.mark.parametrize('dilate_layer_3', (True, False)) @pytest.mark.parametrize('dilate_layer_3', (True, False))
...@@ -312,6 +344,7 @@ def test_inception_v3_eval(): ...@@ -312,6 +344,7 @@ def test_inception_v3_eval():
model = model.eval() model = model.eval()
x = torch.rand(1, 3, 299, 299) x = torch.rand(1, 3, 299, 299)
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None)) _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
_check_input_backprop(model, x)
def test_fasterrcnn_double(): def test_fasterrcnn_double():
...@@ -327,6 +360,7 @@ def test_fasterrcnn_double(): ...@@ -327,6 +360,7 @@ def test_fasterrcnn_double():
assert "boxes" in out[0] assert "boxes" in out[0]
assert "scores" in out[0] assert "scores" in out[0]
assert "labels" in out[0] assert "labels" in out[0]
_check_input_backprop(model, model_input)
def test_googlenet_eval(): def test_googlenet_eval():
...@@ -343,6 +377,7 @@ def test_googlenet_eval(): ...@@ -343,6 +377,7 @@ def test_googlenet_eval():
model = model.eval() model = model.eval()
x = torch.rand(1, 3, 224, 224) x = torch.rand(1, 3, 224, 224)
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None)) _check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
_check_input_backprop(model, x)
@needs_cuda @needs_cuda
...@@ -369,6 +404,8 @@ def test_fasterrcnn_switch_devices(): ...@@ -369,6 +404,8 @@ def test_fasterrcnn_switch_devices():
checkOut(out) checkOut(out)
_check_input_backprop(model, model_input)
# now switch to cpu and make sure it works # now switch to cpu and make sure it works
model.cpu() model.cpu()
x = x.cpu() x = x.cpu()
...@@ -376,6 +413,8 @@ def test_fasterrcnn_switch_devices(): ...@@ -376,6 +413,8 @@ def test_fasterrcnn_switch_devices():
checkOut(out_cpu) checkOut(out_cpu)
_check_input_backprop(model, [x])
def test_generalizedrcnn_transform_repr(): def test_generalizedrcnn_transform_repr():
...@@ -426,6 +465,8 @@ def test_classification_model(model_name, dev): ...@@ -426,6 +465,8 @@ def test_classification_model(model_name, dev):
_assert_expected(out.cpu(), model_name, prec=0.1) _assert_expected(out.cpu(), model_name, prec=0.1)
assert out.shape[-1] == 50 assert out.shape[-1] == 50
_check_input_backprop(model, x)
@pytest.mark.parametrize('model_name', get_available_segmentation_models()) @pytest.mark.parametrize('model_name', get_available_segmentation_models())
@pytest.mark.parametrize('dev', cpu_and_gpu()) @pytest.mark.parametrize('dev', cpu_and_gpu())
...@@ -483,6 +524,8 @@ def test_segmentation_model(model_name, dev): ...@@ -483,6 +524,8 @@ def test_segmentation_model(model_name, dev):
warnings.warn(msg, RuntimeWarning) warnings.warn(msg, RuntimeWarning)
pytest.skip(msg) pytest.skip(msg)
_check_input_backprop(model, x)
@pytest.mark.parametrize('model_name', get_available_detection_models()) @pytest.mark.parametrize('model_name', get_available_detection_models())
@pytest.mark.parametrize('dev', cpu_and_gpu()) @pytest.mark.parametrize('dev', cpu_and_gpu())
...@@ -574,6 +617,8 @@ def test_detection_model(model_name, dev): ...@@ -574,6 +617,8 @@ def test_detection_model(model_name, dev):
warnings.warn(msg, RuntimeWarning) warnings.warn(msg, RuntimeWarning)
pytest.skip(msg) pytest.skip(msg)
_check_input_backprop(model, model_input)
@pytest.mark.parametrize('model_name', get_available_detection_models()) @pytest.mark.parametrize('model_name', get_available_detection_models())
def test_detection_model_validation(model_name): def test_detection_model_validation(model_name):
...@@ -625,6 +670,8 @@ def test_video_model(model_name, dev): ...@@ -625,6 +670,8 @@ def test_video_model(model_name, dev):
out = model(x) out = model(x)
assert out.shape[-1] == 50 assert out.shape[-1] == 50
_check_input_backprop(model, x)
@pytest.mark.skipif(not ('fbgemm' in torch.backends.quantized.supported_engines and @pytest.mark.skipif(not ('fbgemm' in torch.backends.quantized.supported_engines and
'qnnpack' in torch.backends.quantized.supported_engines), 'qnnpack' in torch.backends.quantized.supported_engines),
......
...@@ -214,8 +214,9 @@ class GeneralizedRCNNTransform(nn.Module): ...@@ -214,8 +214,9 @@ class GeneralizedRCNNTransform(nn.Module):
batch_shape = [len(images)] + max_size batch_shape = [len(images)] + max_size
batched_imgs = images[0].new_full(batch_shape, 0) batched_imgs = images[0].new_full(batch_shape, 0)
for img, pad_img in zip(images, batched_imgs): for i in range(batched_imgs.shape[0]):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) img = images[i]
batched_imgs[i, : img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
return batched_imgs return batched_imgs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment