utils.py 2.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union

import torch
import torch.nn.functional as F


def multiview_img_stack_batch(
        tensor_list: List[torch.Tensor],
        pad_size_divisor: int = 1,
        pad_value: Union[int, float] = 0) -> torch.Tensor:
    """
    Compared to the stack_batch in mmengine.model.utils,
    multiview_img_stack_batch further handle the multiview images.
15
    see diff of padded_sizes[:, :-2] = 0 vs padded_sizes[:, 0] = 0 in line 47
16
17
18
19
20
21
22
23
24
25
    Stack multiple tensors to form a batch and pad the tensor to the max
    shape use the right bottom padding mode in these images. If
    ``pad_size_divisor > 0``, add padding to ensure the shape of each dim is
    divisible by ``pad_size_divisor``.

    Args:
        tensor_list (List[Tensor]): A list of tensors with the same dim.
        pad_size_divisor (int): If ``pad_size_divisor > 0``, add padding
            to ensure the shape of each dim is divisible by
            ``pad_size_divisor``. This depends on the model, and many
26
27
            models need to be divisible by 32. Defaults to 1.
        pad_value (int or float): The padding value. Defaults to 0.
28
29
30
31
32
33

    Returns:
        Tensor: The n dim tensor.
    """
    assert isinstance(
        tensor_list,
34
        list), f'Expected input type to be list, but got {type(tensor_list)}'
35
36
37
38
    assert tensor_list, '`tensor_list` could not be an empty list'
    assert len({
        tensor.ndim
        for tensor in tensor_list
39
    }) == 1, ('Expected the dimensions of all tensors must be the same, '
40
41
42
43
44
45
46
47
48
              f'but got {[tensor.ndim for tensor in tensor_list]}')

    dim = tensor_list[0].dim()
    num_img = len(tensor_list)
    all_sizes: torch.Tensor = torch.Tensor(
        [tensor.shape for tensor in tensor_list])
    max_sizes = torch.ceil(
        torch.max(all_sizes, dim=0)[0] / pad_size_divisor) * pad_size_divisor
    padded_sizes = max_sizes - all_sizes
49
    # The first dim normally means channel, which should not be padded.
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
    padded_sizes[:, :-2] = 0
    if padded_sizes.sum() == 0:
        return torch.stack(tensor_list)
    # `pad` is the second arguments of `F.pad`. If pad is (1, 2, 3, 4),
    # it means that padding the last dim with 1(left) 2(right), padding the
    # penultimate dim to 3(top) 4(bottom). The order of `pad` is opposite of
    # the `padded_sizes`. Therefore, the `padded_sizes` needs to be reversed,
    # and only odd index of pad should be assigned to keep padding "right" and
    # "bottom".
    pad = torch.zeros(num_img, 2 * dim, dtype=torch.int)
    pad[:, 1::2] = padded_sizes[:, range(dim - 1, -1, -1)]
    batch_tensor = []
    for idx, tensor in enumerate(tensor_list):
        batch_tensor.append(
            F.pad(tensor, tuple(pad[idx].tolist()), value=pad_value))
    return torch.stack(batch_tensor)