pytorch_bert.py 11.3 KB
Newer Older
1
2
3
4
5
6
7
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Module of the Pytorch BERT model."""

import torch
from transformers import BertModel, BertConfig
8
9
10
11
12
try:
    import transformer_engine.pytorch as te
    from transformer_engine.common.recipe import Format, DelayedScaling
except ImportError:
    te = None
13
14
15
16
17
18
19
20
21
22

from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Precision
from superbench.benchmarks.model_benchmarks.model_base import Optimizer
from superbench.benchmarks.model_benchmarks.pytorch_base import PytorchBase
from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDataset


class BertBenchmarkModel(torch.nn.Module):
    """The BERT model for benchmarking."""
23
    def __init__(self, config, num_classes):
24
25
26
27
        """Constructor.

        Args:
            config (BertConfig): Configurations of BERT model.
28
            num_classes (int): The number of objects for classification.
29
30
31
        """
        super().__init__()
        self._bert = BertModel(config)
32
        self._linear = torch.nn.Linear(config.hidden_size, num_classes)
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

    def forward(self, input):
        """Forward propagation function.

        Args:
            input (torch.LongTensor): Indices of input sequence tokens in the vocabulary,
              shape (batch_size, sequence_length).

        Return:
            result (torch.FloatTensor): Last layer hidden-state of the first token of the sequence
              (classification token) further processed by a Linear layer, shape (batch_size, hidden_size).
        """
        outputs = self._bert(input)
        result = self._linear(outputs[1])
        return result


50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class TeBertBenchmarkModel(torch.nn.Module):
    """BERT model using Transformer Engine."""
    def __init__(self, config, num_classes):
        """Constructor.

        Args:
            config (BertConfig): Configurations of BERT model.
            num_classes (int): The number of objects for classification.
        """
        super().__init__()

        self._embedding = torch.nn.Embedding(config.vocab_size, config.hidden_size)
        # Build BERT using nn.TransformerEncoderLayer or te.TransformerLayer
        # input shape: (seq_len, batch_size, hidden_size)
        encoder_layer = te.TransformerLayer(
            config.hidden_size,
            config.intermediate_size,
            config.num_attention_heads,
            apply_residual_connection_post_layernorm=True,
            output_layernorm=True,
            layer_type='encoder',
        )
        self._encoder_layers = torch.nn.ModuleList([encoder_layer for _ in range(config.num_hidden_layers)])
        # BertPooler used in huggingface transformers
        # https://github.com/huggingface/transformers/blob/accad48e/src/transformers/models/bert/modeling_bert.py#L893
        self._pooler = torch.nn.Sequential(
            torch.nn.Linear(config.hidden_size, config.hidden_size),
            torch.nn.Tanh(),
        )
        self._linear = torch.nn.Linear(config.hidden_size, num_classes)

    def forward(self, input):
        """Forward propagation function.

        Args:
            input (torch.LongTensor): Indices of input sequence tokens in the vocabulary,
              shape (batch_size, sequence_length).

        Return:
            out (torch.FloatTensor): Last layer hidden-state of the first token of the sequence
              (classification token) further processed by a Linear layer, shape (batch_size, hidden_size).
        """
        out = self._embedding(input.movedim(0, -1))
        for layer in self._encoder_layers:
            out = layer(out, attention_mask=None)
        out = self._linear(self._pooler(out.movedim(0, 1)[:, 0]))
        return out


99
100
101
102
103
104
105
106
107
108
109
class PytorchBERT(PytorchBase):
    """The BERT benchmark class."""
    def __init__(self, name, parameters=''):
        """Constructor.

        Args:
            name (str): benchmark name.
            parameters (str): benchmark parameters.
        """
        super().__init__(name, parameters)
        self._config = None
110
111
112
113
114
115
116
117
        self._fp8_recipe = None
        self._supported_precision = [
            Precision.FLOAT32,
            Precision.FLOAT16,
            Precision.FP8_HYBRID,
            Precision.FP8_E4M3,
            Precision.FP8_E5M2,
        ]
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        self._optimizer_type = Optimizer.ADAMW
        self._loss_fn = torch.nn.CrossEntropyLoss()

    def add_parser_arguments(self):
        """Add the BERT-specified arguments.

        BERT model reference: https://huggingface.co/transformers/model_doc/bert.html
        """
        super().add_parser_arguments()

        self._parser.add_argument('--num_classes', type=int, default=100, required=False, help='Num of class.')
        self._parser.add_argument('--hidden_size', type=int, default=1024, required=False, help='Hidden size.')
        self._parser.add_argument(
            '--num_hidden_layers', type=int, default=24, required=False, help='The number of hidden layers.'
        )
        self._parser.add_argument(
            '--num_attention_heads', type=int, default=16, required=False, help='The number of attention heads.'
        )
        self._parser.add_argument(
            '--intermediate_size', type=int, default=4096, required=False, help='Intermediate size.'
        )
        self._parser.add_argument('--seq_len', type=int, default=512, required=False, help='Sequence length.')

    def _generate_dataset(self):
        """Generate dataset for benchmarking according to shape info.

        Return:
            True if dataset is created successfully.
        """
        self._dataset = TorchRandomDataset(
            [self._args.sample_count, self._args.seq_len], self._world_size, dtype=torch.long
        )
        if len(self._dataset) == 0:
            logger.error('Generate random dataset failed - model: {}'.format(self._name))
            return False

        return True

    def _create_model(self, precision):
        """Construct the model for benchmarking.

        Args:
            precision (Precision): precision of model and input data, such as float32, float16.
        """
        self._config = BertConfig(
            hidden_size=self._args.hidden_size,
            num_hidden_layers=self._args.num_hidden_layers,
            num_attention_heads=self._args.num_attention_heads,
            intermediate_size=self._args.intermediate_size
        )

169
170
171
172
173
174
175
176
177
178
179
180
181
182
        enable_fp8 = precision.name.startswith('FP8_')
        if enable_fp8 and te is None:
            logger.error(
                f'Create model with fp8 failed - model: {self._name}, precision: {precision},'
                ' message: Cannot find transformer_engine.'
            )
            return False
        if enable_fp8 and not self._gpu_available:
            logger.error(
                f'Create model with fp8 failed - model: {self._name}, precision: {precision},'
                ' message: FP8 is only supported on GPU.'
            )
            return False

183
        try:
184
185
186
187
188
189
190
191
192
193
            if enable_fp8:
                self._fp8_recipe = DelayedScaling(
                    fp8_format=Format[precision.name.strip('FP8_')],
                    amax_history_len=16,
                    amax_compute_algo='max',
                )
                self._model = TeBertBenchmarkModel(self._config, self._args.num_classes).to(dtype=torch.float16)
            else:
                self._model = BertBenchmarkModel(self._config, self._args.num_classes)
                self._model = self._model.to(dtype=getattr(torch, precision.value))
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
            if self._gpu_available:
                self._model = self._model.cuda()
        except BaseException as e:
            logger.error(
                'Create model with specified precision failed - model: {}, precision: {}, message: {}.'.format(
                    self._name, precision, str(e)
                )
            )
            return False

        self._target = torch.LongTensor(self._args.batch_size).random_(self._args.num_classes)
        if self._gpu_available:
            self._target = self._target.cuda()

        return True

    def _train_step(self, precision):
        """Define the training process.

        Args:
            precision (Precision): precision of model and input data, such as float32, float16.

        Return:
            The step-time list of every training step.
        """
        duration = []
        curr_step = 0
221
        check_frequency = 100
222
223
        while True:
            for idx, sample in enumerate(self._dataloader):
224
                start = self._timer()
225
226
227
                if self._gpu_available:
                    sample = sample.cuda()
                self._optimizer.zero_grad()
228
229
230
231
232
                if self._fp8_recipe is not None:
                    with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe):
                        output = self._model(sample)
                else:
                    output = self._model(sample)
233
234
235
                loss = self._loss_fn(output, self._target)
                loss.backward()
                self._optimizer.step()
236
                end = self._timer()
237
238
239
240
                curr_step += 1
                if curr_step > self._args.num_warmup:
                    # Save the step time of every training/inference step, unit is millisecond.
                    duration.append((end - start) * 1000)
241
                    self._log_step_time(curr_step, precision, duration)
242
                if self._is_finished(curr_step, end, check_frequency):
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
                    return duration

    def _inference_step(self, precision):
        """Define the inference process.

        Args:
            precision (Precision): precision of model and input data,
              such as float32, float16.

        Return:
            The latency list of every inference operation.
        """
        duration = []
        curr_step = 0
        with torch.no_grad():
            self._model.eval()
            while True:
                for idx, sample in enumerate(self._dataloader):
261
                    start = self._timer()
262
263
                    if self._gpu_available:
                        sample = sample.cuda()
264
265
266
267
268
                    if self._fp8_recipe is not None:
                        with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe):
                            self._model(sample)
                    else:
                        self._model(sample)
269
                    end = self._timer()
270
271
272
273
                    curr_step += 1
                    if curr_step > self._args.num_warmup:
                        # Save the step time of every training/inference step, unit is millisecond.
                        duration.append((end - start) * 1000)
274
                        self._log_step_time(curr_step, precision, duration)
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
                    if self._is_finished(curr_step, end):
                        return duration


# Register BERT Large benchmark.
# Reference: https://huggingface.co/transformers/pretrained_models.html
BenchmarkRegistry.register_benchmark(
    'pytorch-bert-large',
    PytorchBERT,
    parameters='--hidden_size=1024 --num_hidden_layers=24 --num_attention_heads=16 --intermediate_size=4096'
)

# Register BERT Base benchmark.
# Reference: https://huggingface.co/transformers/pretrained_models.html
BenchmarkRegistry.register_benchmark(
    'pytorch-bert-base',
    PytorchBERT,
    parameters='--hidden_size=768 --num_hidden_layers=12 --num_attention_heads=12 --intermediate_size=3072'
)