collate.py 2.75 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
3
4
5
6
import collections

import torch
import torch.nn.functional as F
from torch.utils.data.dataloader import default_collate

Kai Chen's avatar
Kai Chen committed
7
from ..utils import DataContainer
Kai Chen's avatar
Kai Chen committed
8
9
10
11
12
13
14
15

# https://github.com/pytorch/pytorch/issues/973
import resource
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))


def collate(batch, samples_per_gpu=1):
Kai Chen's avatar
Kai Chen committed
16
17
18
19
20
21
22
23
24
    """Puts each data field into a tensor/DataContainer with outer dimension
    batch size.

    Extend default_collate to add support for :type:`~mmdet.DataContainer`.
    There are 3 cases for data containers.
    1. cpu_only = True, e.g., meta data
    2. cpu_only = False, stack = True, e.g., images tensors
    3. cpu_only = False, stack = False, e.g., gt bboxes
    """
Kai Chen's avatar
Kai Chen committed
25
26
27
28
29
30
31

    if not isinstance(batch, collections.Sequence):
        raise TypeError("{} is not supported.".format(batch.dtype))

    if isinstance(batch[0], DataContainer):
        assert len(batch) % samples_per_gpu == 0
        stacked = []
Kai Chen's avatar
Kai Chen committed
32
33
34
35
36
37
38
        if batch[0].cpu_only:
            for i in range(0, len(batch), samples_per_gpu):
                stacked.append(
                    [sample.data for sample in batch[i:i + samples_per_gpu]])
            return DataContainer(
                stacked, batch[0].stack, batch[0].padding_value, cpu_only=True)
        elif batch[0].stack:
Kai Chen's avatar
Kai Chen committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
            for i in range(0, len(batch), samples_per_gpu):
                assert isinstance(batch[i].data, torch.Tensor)
                # TODO: handle tensors other than 3d
                assert batch[i].dim() == 3
                c, h, w = batch[0].size()
                for sample in batch[i:i + samples_per_gpu]:
                    assert c == sample.size(0)
                    h = max(h, sample.size(1))
                    w = max(w, sample.size(2))
                padded_samples = [
                    F.pad(
                        sample.data,
                        (0, w - sample.size(2), 0, h - sample.size(1)),
                        value=sample.padding_value)
                    for sample in batch[i:i + samples_per_gpu]
                ]
                stacked.append(default_collate(padded_samples))
        else:
            for i in range(0, len(batch), samples_per_gpu):
                stacked.append(
                    [sample.data for sample in batch[i:i + samples_per_gpu]])
        return DataContainer(stacked, batch[0].stack, batch[0].padding_value)
    elif isinstance(batch[0], collections.Sequence):
        transposed = zip(*batch)
        return [collate(samples, samples_per_gpu) for samples in transposed]
    elif isinstance(batch[0], collections.Mapping):
        return {
            key: collate([d[key] for d in batch], samples_per_gpu)
            for key in batch[0]
        }
    else:
        return default_collate(batch)