Unverified Commit 2e0949e2 authored by Beat Buesser's avatar Beat Buesser Committed by GitHub
Browse files

Allow gradient backpropagation through GeneralizedRCNNTransform to inputs (#4327)



* Allow gradient backpropagation through GeneralizedRCNNTransform to inputs
Signed-off-by: default avatarBeat Buesser <beat.buesser@ie.ibm.com>

* Add unit tests for gradient backpropagation to inputs
Signed-off-by: default avatarBeat Buesser <beat.buesser@ie.ibm.com>

* Update torchvision/models/detection/transform.py
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>

* Update _check_input_backprop
Signed-off-by: default avatarBeat Buesser <beat.buesser@ie.ibm.com>

* Account for tests requiring cuda
Signed-off-by: default avatarBeat Buesser <beat.buesser@ie.ibm.com>
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 526a69e8
......@@ -148,6 +148,35 @@ def _check_fx_compatible(model, inputs):
torch.testing.assert_close(out, out_fx)
def _check_input_backprop(model, inputs):
if isinstance(inputs, list):
requires_grad = list()
for inp in inputs:
requires_grad.append(inp.requires_grad)
inp.requires_grad_(True)
else:
requires_grad = inputs.requires_grad
inputs.requires_grad_(True)
out = model(inputs)
if isinstance(out, dict):
out["out"].sum().backward()
else:
if isinstance(out[0], dict):
out[0]["scores"].sum().backward()
else:
out[0].sum().backward()
if isinstance(inputs, list):
for i, inp in enumerate(inputs):
assert inputs[i].grad is not None
inp.requires_grad_(requires_grad[i])
else:
assert inputs.grad is not None
inputs.requires_grad_(requires_grad)
# If 'unwrapper' is provided it will be called with the script model outputs
# before they are compared to the eager model outputs. This is useful if the
# model outputs are different between TorchScript / Eager mode
......@@ -263,6 +292,9 @@ def test_memory_efficient_densenet(model_name):
assert num_params == num_grad
torch.testing.assert_close(out1, out2, rtol=0.0, atol=1e-5)
_check_input_backprop(model1, x)
_check_input_backprop(model2, x)
@pytest.mark.parametrize('dilate_layer_2', (True, False))
@pytest.mark.parametrize('dilate_layer_3', (True, False))
......@@ -312,6 +344,7 @@ def test_inception_v3_eval():
model = model.eval()
x = torch.rand(1, 3, 299, 299)
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
_check_input_backprop(model, x)
def test_fasterrcnn_double():
......@@ -327,6 +360,7 @@ def test_fasterrcnn_double():
assert "boxes" in out[0]
assert "scores" in out[0]
assert "labels" in out[0]
_check_input_backprop(model, model_input)
def test_googlenet_eval():
......@@ -343,6 +377,7 @@ def test_googlenet_eval():
model = model.eval()
x = torch.rand(1, 3, 224, 224)
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
_check_input_backprop(model, x)
@needs_cuda
......@@ -369,6 +404,8 @@ def test_fasterrcnn_switch_devices():
checkOut(out)
_check_input_backprop(model, model_input)
# now switch to cpu and make sure it works
model.cpu()
x = x.cpu()
......@@ -376,6 +413,8 @@ def test_fasterrcnn_switch_devices():
checkOut(out_cpu)
_check_input_backprop(model, [x])
def test_generalizedrcnn_transform_repr():
......@@ -426,6 +465,8 @@ def test_classification_model(model_name, dev):
_assert_expected(out.cpu(), model_name, prec=0.1)
assert out.shape[-1] == 50
_check_input_backprop(model, x)
@pytest.mark.parametrize('model_name', get_available_segmentation_models())
@pytest.mark.parametrize('dev', cpu_and_gpu())
......@@ -483,6 +524,8 @@ def test_segmentation_model(model_name, dev):
warnings.warn(msg, RuntimeWarning)
pytest.skip(msg)
_check_input_backprop(model, x)
@pytest.mark.parametrize('model_name', get_available_detection_models())
@pytest.mark.parametrize('dev', cpu_and_gpu())
......@@ -574,6 +617,8 @@ def test_detection_model(model_name, dev):
warnings.warn(msg, RuntimeWarning)
pytest.skip(msg)
_check_input_backprop(model, model_input)
@pytest.mark.parametrize('model_name', get_available_detection_models())
def test_detection_model_validation(model_name):
......@@ -625,6 +670,8 @@ def test_video_model(model_name, dev):
out = model(x)
assert out.shape[-1] == 50
_check_input_backprop(model, x)
@pytest.mark.skipif(not ('fbgemm' in torch.backends.quantized.supported_engines and
'qnnpack' in torch.backends.quantized.supported_engines),
......
......@@ -214,8 +214,9 @@ class GeneralizedRCNNTransform(nn.Module):
batch_shape = [len(images)] + max_size
batched_imgs = images[0].new_full(batch_shape, 0)
for img, pad_img in zip(images, batched_imgs):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
for i in range(batched_imgs.shape[0]):
img = images[i]
batched_imgs[i, : img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
return batched_imgs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment