pytorch_lstm.py 7.99 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Module of the Pytorch LSTM model."""

import torch

from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Precision
from superbench.benchmarks.model_benchmarks.model_base import Optimizer
from superbench.benchmarks.model_benchmarks.pytorch_base import PytorchBase
from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDataset


class LSTMBenchmarkModel(torch.nn.Module):
    """The LSTM model for benchmarking."""
    def __init__(self, input_size, hidden_size, num_layers, bidirectional, num_classes):
        """Constructor.

        Args:
            input_size (int): The number of expected features in the input.
            hidden_size (int):  The number of features in the hidden state.
            num_layers  (int): The number of recurrent layers.
            bidirectional (bool): If True, becomes a bidirectional LSTM.
            num_classes (int): The number of objects for classification.
        """
        super().__init__()
        self._lstm = torch.nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional)
29
        self._linear = torch.nn.Linear(hidden_size * (2 if bidirectional else 1), num_classes)
30
31
32
33
34
35

    def forward(self, input):
        """Forward propagation function.

        Args:
            input (torch.FloatTensor): Tensor containing the features of the input sequence,
36
              shape (batch_size, sequence_length, input_size).
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

        Return:
            result (torch.FloatTensor): The output features from the last layer of the LSTM
              further processed by a Linear layer, shape (batch_size, num_classes).
        """
        self._lstm.flatten_parameters()
        outputs = self._lstm(input)
        result = self._linear(outputs[0][:, -1, :])
        return result


class PytorchLSTM(PytorchBase):
    """The LSTM benchmark class."""
    def __init__(self, name, parameters=''):
        """Constructor.

        Args:
            name (str): benchmark name.
            parameters (str): benchmark parameters.
        """
        super().__init__(name, parameters)
        self._config = None
        self._supported_precision = [Precision.FLOAT32, Precision.FLOAT16]
        self._optimizer_type = Optimizer.SGD
        self._loss_fn = torch.nn.CrossEntropyLoss()

    def add_parser_arguments(self):
        """Add the LSTM-specified arguments.

        LSTM model reference: https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
        """
        super().add_parser_arguments()

        self._parser.add_argument(
            '--num_classes', type=int, default=100, required=False, help='The number of objects for classification.'
        )
        self._parser.add_argument(
            '--input_size', type=int, default=256, required=False, help='The number of expected features in the input.'
        )
        self._parser.add_argument(
            '--hidden_size', type=int, default=1024, required=False, help='The number of features in the hidden state.'
        )
        self._parser.add_argument(
            '--num_layers', type=int, default=8, required=False, help='The number of recurrent layers.'
        )

        self._parser.add_argument('--bidirectional', action='store_true', default=False, help='Bidirectional LSTM.')
        self._parser.add_argument('--seq_len', type=int, default=512, required=False, help='Sequence length.')

    def _generate_dataset(self):
        """Generate dataset for benchmarking according to shape info.

        Return:
            True if dataset is created successfully.
        """
        self._dataset = TorchRandomDataset(
93
            [self._args.sample_count, self._args.seq_len, self._args.input_size], self._world_size, dtype=torch.float32
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        )
        if len(self._dataset) == 0:
            logger.error('Generate random dataset failed - model: {}'.format(self._name))
            return False

        return True

    def _create_model(self, precision):
        """Construct the model for benchmarking.

        Args:
            precision (Precision): precision of model and input data, such as float32, float16.
        """
        try:
            self._model = LSTMBenchmarkModel(
                self._args.input_size, self._args.hidden_size, self._args.num_layers, self._args.bidirectional,
                self._args.num_classes
            )
            self._model = self._model.to(dtype=getattr(torch, precision.value))
            if self._gpu_available:
                self._model = self._model.cuda()
        except BaseException as e:
            logger.error(
                'Create model with specified precision failed - model: {}, precision: {}, message: {}.'.format(
                    self._name, precision, str(e)
                )
            )
            return False

123
        self._target = self._create_target(self._args.num_classes)
124
125
126
127
128
129
130
131
132
133

        return True

    def _train_step(self, precision):
        """Define the training process.

        Args:
            precision (Precision): precision of model and input data, such as float32, float16.

        Return:
134
            A tuple of (step_times_ms, info) of every training step.
135
136
        """
        duration = []
137
        periodic = {'loss': [], 'act_mean': [], 'step': []}
138
139
140
141
        curr_step = 0
        while True:
            for idx, sample in enumerate(self._dataloader):
                sample = sample.to(dtype=getattr(torch, precision.value))
142
                start = self._timer()
143
144
                if self._gpu_available:
                    sample = sample.cuda()
145
146
                if self._args.exclude_copy_time:
                    start = self._timer()
147
148
                self._optimizer.zero_grad()
                output = self._model(sample)
149
150
151
                enable_determinism = getattr(self._args, 'enable_determinism', False)
                logits_for_loss = output.float() if enable_determinism else output
                loss = self._loss_fn(logits_for_loss, self._target)
152
153
                loss.backward()
                self._optimizer.step()
154
                end = self._timer()
155
156
157
                curr_step += 1
                if curr_step > self._args.num_warmup:
                    duration.append((end - start) * 1000)
158
                    self.record_determinism_fingerprint(curr_step, loss, output, periodic, self._args.check_frequency)
159
                    self._log_step_time(curr_step, precision, duration)
160
161
                if self._is_finished(curr_step, end, self._args.check_frequency):
                    return duration, self._finalize_periodic_logging(periodic)
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179

    def _inference_step(self, precision):
        """Define the inference process.

        Args:
            precision (Precision): precision of model and input data,
              such as float32, float16.

        Return:
            The latency list of every inference operation.
        """
        duration = []
        curr_step = 0
        with torch.no_grad():
            self._model.eval()
            while True:
                for idx, sample in enumerate(self._dataloader):
                    sample = sample.to(dtype=getattr(torch, precision.value))
180
                    start = self._timer()
181
182
                    if self._gpu_available:
                        sample = sample.cuda()
183
184
                    if self._args.exclude_copy_time:
                        start = self._timer()
185
                    self._model(sample)
186
                    end = self._timer()
187
188
189
190
                    curr_step += 1
                    if curr_step > self._args.num_warmup:
                        # Save the step time of every training/inference step, unit is millisecond.
                        duration.append((end - start) * 1000)
191
                        self._log_step_time(curr_step, precision, duration)
192
                    if self._is_finished(curr_step, end, self._args.check_frequency):
193
194
195
196
197
198
199
                        return duration


# Register LSTM benchmark.
BenchmarkRegistry.register_benchmark(
    'pytorch-lstm', PytorchLSTM, parameters='--input_size=256 --hidden_size=1024 --num_layers=8'
)