test_pytorch_base.py 8.52 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Tests for BenchmarkRegistry module."""

import time
import numbers

import torch

from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Precision, Platform, BenchmarkContext, ReturnCode
from superbench.benchmarks.model_benchmarks.model_base import Optimizer, DistributedImpl, DistributedBackend
from superbench.benchmarks.model_benchmarks.pytorch_base import PytorchBase
from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDataset


class MNISTModel(torch.nn.Module):
    """The MNIST model for benchmarking."""
    def __init__(self):
        """Constructor."""
        super().__init__()
        self.conv1 = torch.nn.Conv2d(1, 32, 3, 1)
        self.conv2 = torch.nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = torch.nn.Dropout(0.25)
        self.dropout2 = torch.nn.Dropout(0.5)
        self.fc1 = torch.nn.Linear(9216, 128)
        self.fc2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        """Forward propagation function.

        Args:
            x (torch.Tensor): Image tensor.

        Return:
            output (torch.Tensor): Tensor of the log_softmax result.
        """
        x = self.conv1(x)
        x = torch.nn.functional.relu(x)
        x = self.conv2(x)
        x = torch.nn.functional.relu(x)
        x = torch.nn.functional.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = torch.nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = torch.nn.functional.log_softmax(x, dim=1)
        return output


class PytorchMNIST(PytorchBase):
    """The MNIST benchmark class."""
    def __init__(self, name, parameters=''):
        """Constructor.

        Args:
            name (str): benchmark name.
            parameters (str): benchmark parameters.
        """
        super().__init__(name, parameters)
        self._supported_precision = [Precision.FLOAT32]
        self._optimizer_type = Optimizer.ADAMW
        self._loss_fn = torch.nn.functional.nll_loss

    def _generate_dataset(self):
        """Generate dataset for benchmarking according to shape info.

        Return:
            True if dataset is created successfully.
        """
        samples_count = (self._args.batch_size * (self._args.num_warmup + self._args.num_steps))
        self._dataset = TorchRandomDataset([samples_count, 1, 28, 28], self._world_size, dtype=torch.float32)
        if len(self._dataset) == 0:
            logger.error('Generate random dataset failed - model: {}'.format(self._name))
            return False

        return True

    def _create_model(self, precision):
        """Construct the model for benchmarking.

        Args:
            precision (Precision): precision of model and input data, such as float32, float16.
        """
        try:
            self._model = MNISTModel()
            self._model = self._model.to(dtype=getattr(torch, precision.value))
            if self._gpu_available:
                self._model = self._model.cuda()
        except BaseException as e:
            logger.error(
                'Create model with specified precision failed - model: {}, precision: {}, message: {}.'.format(
                    self._name, precision, str(e)
                )
            )
            return False

        self._target = torch.LongTensor(self._args.batch_size).random_(10)
        if self._gpu_available:
            self._target = self._target.cuda()

        return True

    def _train_step(self, precision):
        """Define the training process.

        Args:
            precision (Precision): precision of model and input data, such as float32, float16.

        Return:
            The step-time list of every training step.
        """
        duration = []
        for idx, sample in enumerate(self._dataloader):
            sample = sample.to(dtype=getattr(torch, precision.value))
            start = time.time()
            if self._gpu_available:
                sample = sample.cuda()
            self._optimizer.zero_grad()
            output = self._model(sample)
            loss = self._loss_fn(output, self._target)
            loss.backward()
            self._optimizer.step()
            end = time.time()
            if idx % 10 == 0:
                logger.info(
                    'Train step [{}/{} ({:.0f}%)]'.format(
                        idx, len(self._dataloader), 100. * idx / len(self._dataloader)
                    )
                )
            if idx >= self._args.num_warmup:
                duration.append((end - start) * 1000)

        return duration

    def _inference_step(self, precision):
        """Define the inference process.

        Args:
            precision (Precision): precision of model and input data,
              such as float32, float16.

        Return:
            The latency list of every inference operation.
        """
        duration = []
        with torch.no_grad():
            self._model.eval()
            for idx, sample in enumerate(self._dataloader):
                sample = sample.to(dtype=getattr(torch, precision.value))
                start = time.time()
                if self._gpu_available:
                    sample = sample.cuda()
                self._model(sample)
                if self._gpu_available:
                    torch.cuda.synchronize()
                end = time.time()
                if idx % 10 == 0:
                    logger.info(
                        'Inference step [{}/{} ({:.0f}%)]'.format(
                            idx, len(self._dataloader), 100. * idx / len(self._dataloader)
                        )
                    )
                if idx >= self._args.num_warmup:
                    duration.append((end - start) * 1000)
        return duration


def test_pytorch_base():
    """Test PytorchBase class."""
    # Register BERT Base benchmark.
    BenchmarkRegistry.register_benchmark('pytorch-mnist', PytorchMNIST)

    # Launch benchmark for testing.
    context = BenchmarkContext(
        'pytorch-mnist',
        Platform.CPU,
        parameters='--batch_size=32 --num_warmup=8 --num_steps=64 --model_action train inference --no_gpu'
    )

    assert (BenchmarkRegistry.check_parameters(context))

    if BenchmarkRegistry.check_parameters(context):
        benchmark = BenchmarkRegistry.launch_benchmark(context)

        assert (benchmark.name == 'pytorch-mnist')
        assert (benchmark.return_code == ReturnCode.SUCCESS)

        # Test results.
        for metric in [
            'steptime_train_float32', 'steptime_inference_float32', 'throughput_train_float32',
            'throughput_inference_float32'
        ]:
            assert (len(benchmark.raw_data[metric]) == 1)
            assert (len(benchmark.raw_data[metric][0]) == 64)
            assert (len(benchmark.result[metric]) == 1)
            assert (isinstance(benchmark.result[metric][0], numbers.Number))

        # Test _cal_params_count().
        assert (benchmark._cal_params_count() == 1199882)

        # Test _judge_gpu_availability().
        assert (benchmark._gpu_available is False)

        # Test _init_distributed_setting().
        assert (benchmark._args.distributed_impl is None)
        assert (benchmark._args.distributed_backend is None)
        assert (benchmark._init_distributed_setting() is True)
        benchmark._args.distributed_impl = DistributedImpl.DDP
        benchmark._args.distributed_backend = DistributedBackend.NCCL
        assert (benchmark._init_distributed_setting() is False)
        benchmark._args.distributed_impl = DistributedImpl.MIRRORED
        assert (benchmark._init_distributed_setting() is False)

        # Test _init_dataloader().
        benchmark._args.distributed_impl = None
        assert (benchmark._init_dataloader() is True)
        benchmark._args.distributed_impl = DistributedImpl.DDP
        assert (benchmark._init_dataloader() is False)
        benchmark._args.distributed_impl = DistributedImpl.MIRRORED
        assert (benchmark._init_dataloader() is False)

        # Test _create_optimizer().
        assert (isinstance(benchmark._optimizer, torch.optim.AdamW))
        benchmark._optimizer_type = Optimizer.ADAM
        assert (benchmark._create_optimizer() is True)
        assert (isinstance(benchmark._optimizer, torch.optim.Adam))
        benchmark._optimizer_type = Optimizer.SGD
        assert (benchmark._create_optimizer() is True)
        assert (isinstance(benchmark._optimizer, torch.optim.SGD))
        benchmark._optimizer_type = None
        assert (benchmark._create_optimizer() is False)