test_oneshot_utils.py 4.01 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import math
from typing import Union

import pytest
import torch
import pytorch_lightning
from pytorch_lightning import LightningModule, Trainer
from torch.utils.data import DataLoader, Dataset

pytestmark = pytest.mark.skipif(pytorch_lightning.__version__ < '1.0', reason='Incompatible APIs')


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log('train_loss', loss)
        return {'loss': loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log('valid_loss', loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log('test_loss', loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)



def test_concat_loader():
    from nni.retiarii.oneshot.pytorch.dataloader import ConcatLoader

    loaders = {
        'a': DataLoader(range(10), batch_size=4),
        'b': DataLoader(range(20), batch_size=5),
    }
    dataloader = ConcatLoader(loaders)
    assert len(dataloader) == 7
    for i, (data, label) in enumerate(dataloader):
        if i < 3:
            assert len(data) <= 4
            assert label == 'a'
        else:
            assert len(data) <= 5
            assert label == 'b'


def test_concat_loader_nested():
    from nni.retiarii.oneshot.pytorch.dataloader import ConcatLoader

    loaders = {
        'a': [DataLoader(range(10), batch_size=4), DataLoader(range(20), batch_size=6)],
        'b': DataLoader(range(20), batch_size=5),
    }
    dataloader = ConcatLoader(loaders)
    assert len(dataloader) == 7
    for i, (data, label) in enumerate(dataloader):
        if i < 3:
            assert isinstance(data, list) and len(data) == 2
            assert label == 'a'
        else:
            assert label == 'b'


@pytest.mark.parametrize('replace_sampler_ddp', [False, True])
@pytest.mark.parametrize('is_min_size_mode', [True])
@pytest.mark.parametrize('num_devices', ['auto', 1, 3, 10])
def test_concat_loader_with_ddp(
    replace_sampler_ddp: bool, is_min_size_mode: bool, num_devices: Union[int, str]
):
    """Inspired by tests/trainer/test_supporters.py in lightning."""
    from nni.retiarii.oneshot.pytorch.dataloader import ConcatLoader

    mode = 'min_size' if is_min_size_mode else 'max_size_cycle'
    dim = 3
    n1 = 8
    n2 = 6
    n3 = 9
    dataloader = ConcatLoader({
        'a': {
            'a1': DataLoader(RandomDataset(dim, n1), batch_size=1),
            'a2': DataLoader(RandomDataset(dim, n2), batch_size=1),
        },
        'b': DataLoader(RandomDataset(dim, n3), batch_size=1),
    }, mode=mode)
    expected_length_before_ddp = n3 + (min(n1, n2) if is_min_size_mode else max(n1, n2))
    print(len(dataloader))
    assert len(dataloader) == expected_length_before_ddp
    model = BoringModel()
    trainer = Trainer(
        strategy='ddp',
113
        accelerator='cpu',
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
        devices=num_devices,
        replace_sampler_ddp=replace_sampler_ddp,
    )
    trainer._data_connector.attach_data(
        model=model, train_dataloaders=dataloader, val_dataloaders=None, datamodule=None
    )
    expected_length_after_ddp = (
        math.ceil(n3 / trainer.num_devices) + \
            math.ceil((min(n1, n2) if is_min_size_mode else max(n1, n2)) / trainer.num_devices)
        if replace_sampler_ddp
        else expected_length_before_ddp
    )
    print('Num devices =', trainer.num_devices)
    trainer.reset_train_dataloader(model=model)
    assert trainer.train_dataloader is not None
    assert trainer.train_dataloader.mode == mode
    
    assert trainer.num_training_batches == expected_length_after_ddp