Unverified Commit 60eadb06 authored by JarvisKevin's avatar JarvisKevin Committed by GitHub
Browse files

[Enhancement] Lower the restrictions of _resize method in BaseMergeCell (#1959)

* Fix the bug met in using nasfpn

Fix the bug met in using nasfpn which is mentioned at https://github.com/open-mmlab/mmdetection/issues/5987

.
Avoid the strong restrictions of _resize function in BaseMergeCell:
1. When Downsampling the feature map, the feature map's shape must be divisible by the target size. We pad zero around feature map before max_pool2d opt to make it always divisible. (line 102 ~ 107)
2. Considering the different downsampling scale of H and W, shape[-2] and shape[-1] are involed in the definition of kernel_size. (line 110)

* Update merge_cells.py

check flake8 & isort

* Update merge_cells.py

* Update merge_cells.py

yapf

* Update mmcv/ops/merge_cells.py

X_pad rename to padding_x
Co-authored-by: default avatarMashiro <57566630+HAOCHENYE@users.noreply.github.com>

* Update merge_cells.py

format the code style after renameing the X_pad to padding_x

* Update test_merge_cells.py

Mainly test the downsampling resize in BaseMergeCell. The smaller target size is set to (14, 7), the classical feature map's size in the last few stages of the backbone, which will product different downsampling scales in different dims.

* Update test_merge_cells.py

add "# Copyright (c) OpenMMLab. All rights reserved."

* Update merge_cells.py

format the variable name

* Update test_merge_cells.py

Testing divisible and indivisible situations simultaneously

* Update mmcv/ops/merge_cells.py

fix the bug when h is indivisible and w is divisible, the pad_w will be padded unreasonable.
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update mmcv/ops/merge_cells.py

fix the bug when w is indivisible and h is divisible, the pad_h will be padded unreasonable.
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* fix undefined error

* Update merge_cells.py

make pad_h, pad_w more readable

* Update test_merge_cells.py

use @pytest.mark.parametrize instead of 'for' methor

* Update merge_cells.py

* Update test_merge_cells.py

isort

* Update merge_cells.py

isort
Co-authored-by: default avatarMashiro <57566630+HAOCHENYE@users.noreply.github.com>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 45fa3e44
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import math
from abc import abstractmethod from abc import abstractmethod
import torch import torch
...@@ -95,8 +96,18 @@ class BaseMergeCell(nn.Module): ...@@ -95,8 +96,18 @@ class BaseMergeCell(nn.Module):
elif x.shape[-2:] < size: elif x.shape[-2:] < size:
return F.interpolate(x, size=size, mode=self.upsample_mode) return F.interpolate(x, size=size, mode=self.upsample_mode)
else: else:
assert x.shape[-2] % size[-2] == 0 and x.shape[-1] % size[-1] == 0 if x.shape[-2] % size[-2] != 0 or x.shape[-1] % size[-1] != 0:
kernel_size = x.shape[-1] // size[-1] h, w = x.shape[-2:]
target_h, target_w = size
pad_h = math.ceil(h / target_h) * target_h - h
pad_w = math.ceil(w / target_w) * target_w - w
pad_l = pad_w // 2
pad_r = pad_w - pad_l
pad_t = pad_h // 2
pad_b = pad_h - pad_t
pad = (pad_l, pad_r, pad_t, pad_b)
x = F.pad(x, pad, mode='constant', value=0.0)
kernel_size = (x.shape[-2] // size[-2], x.shape[-1] // size[-1])
x = F.max_pool2d(x, kernel_size=kernel_size, stride=kernel_size) x = F.max_pool2d(x, kernel_size=kernel_size, stride=kernel_size)
return x return x
......
...@@ -3,6 +3,9 @@ ...@@ -3,6 +3,9 @@
CommandLine: CommandLine:
pytest tests/test_merge_cells.py pytest tests/test_merge_cells.py
""" """
import math
import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -10,33 +13,41 @@ from mmcv.ops.merge_cells import (BaseMergeCell, ConcatCell, GlobalPoolingCell, ...@@ -10,33 +13,41 @@ from mmcv.ops.merge_cells import (BaseMergeCell, ConcatCell, GlobalPoolingCell,
SumCell) SumCell)
def test_sum_cell(): # All size (14, 7) below is to test the situation that
inputs_x = torch.randn([2, 256, 32, 32]) # the input size can't be divisible by the target size.
inputs_y = torch.randn([2, 256, 16, 16]) @pytest.mark.parametrize(
'inputs_x, inputs_y',
[(torch.randn([2, 256, 16, 16]), torch.randn([2, 256, 32, 32])),
(torch.randn([2, 256, 14, 7]), torch.randn([2, 256, 32, 32]))])
def test_sum_cell(inputs_x, inputs_y):
sum_cell = SumCell(256, 256) sum_cell = SumCell(256, 256)
output = sum_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:]) output = sum_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
assert output.size() == inputs_x.size() assert output.size() == inputs_x.size()
output = sum_cell(inputs_x, inputs_y, out_size=inputs_y.shape[-2:]) output = sum_cell(inputs_x, inputs_y, out_size=inputs_y.shape[-2:])
assert output.size() == inputs_y.size() assert output.size() == inputs_y.size()
output = sum_cell(inputs_x, inputs_y) output = sum_cell(inputs_x, inputs_y)
assert output.size() == inputs_x.size() assert output.size() == inputs_y.size()
def test_concat_cell(): @pytest.mark.parametrize(
inputs_x = torch.randn([2, 256, 32, 32]) 'inputs_x, inputs_y',
inputs_y = torch.randn([2, 256, 16, 16]) [(torch.randn([2, 256, 16, 16]), torch.randn([2, 256, 32, 32])),
(torch.randn([2, 256, 14, 7]), torch.randn([2, 256, 32, 32]))])
def test_concat_cell(inputs_x, inputs_y):
concat_cell = ConcatCell(256, 256) concat_cell = ConcatCell(256, 256)
output = concat_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:]) output = concat_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
assert output.size() == inputs_x.size() assert output.size() == inputs_x.size()
output = concat_cell(inputs_x, inputs_y, out_size=inputs_y.shape[-2:]) output = concat_cell(inputs_x, inputs_y, out_size=inputs_y.shape[-2:])
assert output.size() == inputs_y.size() assert output.size() == inputs_y.size()
output = concat_cell(inputs_x, inputs_y) output = concat_cell(inputs_x, inputs_y)
assert output.size() == inputs_x.size() assert output.size() == inputs_y.size()
def test_global_pool_cell(): @pytest.mark.parametrize(
inputs_x = torch.randn([2, 256, 32, 32]) 'inputs_x, inputs_y',
inputs_y = torch.randn([2, 256, 32, 32]) [(torch.randn([2, 256, 16, 16]), torch.randn([2, 256, 32, 32])),
(torch.randn([2, 256, 14, 7]), torch.randn([2, 256, 32, 32]))])
def test_global_pool_cell(inputs_x, inputs_y):
gp_cell = GlobalPoolingCell(with_out_conv=False) gp_cell = GlobalPoolingCell(with_out_conv=False)
gp_cell_out = gp_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:]) gp_cell_out = gp_cell(inputs_x, inputs_y, out_size=inputs_x.shape[-2:])
assert (gp_cell_out.size() == inputs_x.size()) assert (gp_cell_out.size() == inputs_x.size())
...@@ -45,22 +56,40 @@ def test_global_pool_cell(): ...@@ -45,22 +56,40 @@ def test_global_pool_cell():
assert (gp_cell_out.size() == inputs_x.size()) assert (gp_cell_out.size() == inputs_x.size())
def test_resize_methods(): @pytest.mark.parametrize('target_size', [(256, 256), (128, 128), (64, 64),
(14, 7)])
def test_resize_methods(target_size):
inputs_x = torch.randn([2, 256, 128, 128]) inputs_x = torch.randn([2, 256, 128, 128])
target_resize_sizes = [(128, 128), (256, 256)] h, w = inputs_x.shape[-2:]
resize_methods_list = ['nearest', 'bilinear'] target_h, target_w = target_size
if (h <= target_h) or w <= target_w:
rs_mode = 'upsample'
else:
rs_mode = 'downsample'
for method in resize_methods_list: if rs_mode == 'upsample':
merge_cell = BaseMergeCell(upsample_mode=method) upsample_methods_list = ['nearest', 'bilinear']
for target_size in target_resize_sizes: for method in upsample_methods_list:
merge_cell = BaseMergeCell(upsample_mode=method)
merge_cell_out = merge_cell._resize(inputs_x, target_size) merge_cell_out = merge_cell._resize(inputs_x, target_size)
gt_out = F.interpolate(inputs_x, size=target_size, mode=method) gt_out = F.interpolate(inputs_x, size=target_size, mode=method)
assert merge_cell_out.equal(gt_out) assert merge_cell_out.equal(gt_out)
elif rs_mode == 'downsample':
target_size = (64, 64) # resize to a smaller size merge_cell = BaseMergeCell()
merge_cell = BaseMergeCell() merge_cell_out = merge_cell._resize(inputs_x, target_size)
merge_cell_out = merge_cell._resize(inputs_x, target_size) if h % target_h != 0 or w % target_w != 0:
kernel_size = inputs_x.shape[-1] // target_size[-1] pad_h = math.ceil(h / target_h) * target_h - h
gt_out = F.max_pool2d( pad_w = math.ceil(w / target_w) * target_w - w
inputs_x, kernel_size=kernel_size, stride=kernel_size) pad_l = pad_w // 2
assert (merge_cell_out == gt_out).all() pad_r = pad_w - pad_l
pad_t = pad_h // 2
pad_b = pad_h - pad_t
pad = (pad_l, pad_r, pad_t, pad_b)
inputs_x = F.pad(inputs_x, pad, mode='constant', value=0.0)
kernel_size = (inputs_x.shape[-2] // target_h,
inputs_x.shape[-1] // target_w)
gt_out = F.max_pool2d(
inputs_x, kernel_size=kernel_size, stride=kernel_size)
print(merge_cell_out.shape, gt_out.shape)
assert (merge_cell_out == gt_out).all()
assert merge_cell_out.shape[-2:] == target_size
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment