pytorch_bert.py 9.78 KB
Newer Older
1
2
3
4
5
6
7
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Module of the Pytorch BERT model."""

import torch
from transformers import BertModel, BertConfig
8
9
10
11
12
try:
    import transformer_engine.pytorch as te
    from transformer_engine.common.recipe import Format, DelayedScaling
except ImportError:
    te = None
13
14
15
16
17
18
19
20
21
22

from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Precision
from superbench.benchmarks.model_benchmarks.model_base import Optimizer
from superbench.benchmarks.model_benchmarks.pytorch_base import PytorchBase
from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDataset


class BertBenchmarkModel(torch.nn.Module):
    """The BERT model for benchmarking."""
23
    def __init__(self, config, num_classes):
24
25
26
27
        """Constructor.

        Args:
            config (BertConfig): Configurations of BERT model.
28
            num_classes (int): The number of objects for classification.
29
30
31
        """
        super().__init__()
        self._bert = BertModel(config)
32
        self._linear = torch.nn.Linear(config.hidden_size, num_classes)
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

    def forward(self, input):
        """Forward propagation function.

        Args:
            input (torch.LongTensor): Indices of input sequence tokens in the vocabulary,
              shape (batch_size, sequence_length).

        Return:
            result (torch.FloatTensor): Last layer hidden-state of the first token of the sequence
              (classification token) further processed by a Linear layer, shape (batch_size, hidden_size).
        """
        outputs = self._bert(input)
        result = self._linear(outputs[1])
        return result


class PytorchBERT(PytorchBase):
    """The BERT benchmark class."""
    def __init__(self, name, parameters=''):
        """Constructor.

        Args:
            name (str): benchmark name.
            parameters (str): benchmark parameters.
        """
        super().__init__(name, parameters)
        self._config = None
61
62
63
64
65
66
67
        self._fp8_recipe = None
        self._supported_precision = [
            Precision.FLOAT32,
            Precision.FLOAT16,
            Precision.FP8_HYBRID,
            Precision.FP8_E4M3,
        ]
68
69
70
71
72
73
        self._optimizer_type = Optimizer.ADAMW
        self._loss_fn = torch.nn.CrossEntropyLoss()

    def add_parser_arguments(self):
        """Add the BERT-specified arguments.

74
        BERT model reference: https://huggingface.co/docs/transformers/model_doc/bert
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        """
        super().add_parser_arguments()

        self._parser.add_argument('--num_classes', type=int, default=100, required=False, help='Num of class.')
        self._parser.add_argument('--hidden_size', type=int, default=1024, required=False, help='Hidden size.')
        self._parser.add_argument(
            '--num_hidden_layers', type=int, default=24, required=False, help='The number of hidden layers.'
        )
        self._parser.add_argument(
            '--num_attention_heads', type=int, default=16, required=False, help='The number of attention heads.'
        )
        self._parser.add_argument(
            '--intermediate_size', type=int, default=4096, required=False, help='Intermediate size.'
        )
        self._parser.add_argument('--seq_len', type=int, default=512, required=False, help='Sequence length.')

    def _generate_dataset(self):
        """Generate dataset for benchmarking according to shape info.

        Return:
            True if dataset is created successfully.
        """
        self._dataset = TorchRandomDataset(
            [self._args.sample_count, self._args.seq_len], self._world_size, dtype=torch.long
        )
        if len(self._dataset) == 0:
            logger.error('Generate random dataset failed - model: {}'.format(self._name))
            return False

        return True

    def _create_model(self, precision):
        """Construct the model for benchmarking.

        Args:
            precision (Precision): precision of model and input data, such as float32, float16.
        """
        self._config = BertConfig(
            hidden_size=self._args.hidden_size,
            num_hidden_layers=self._args.num_hidden_layers,
            num_attention_heads=self._args.num_attention_heads,
            intermediate_size=self._args.intermediate_size
        )

119
120
121
122
123
124
125
126
127
128
129
130
131
132
        enable_fp8 = precision.name.startswith('FP8_')
        if enable_fp8 and te is None:
            logger.error(
                f'Create model with fp8 failed - model: {self._name}, precision: {precision},'
                ' message: Cannot find transformer_engine.'
            )
            return False
        if enable_fp8 and not self._gpu_available:
            logger.error(
                f'Create model with fp8 failed - model: {self._name}, precision: {precision},'
                ' message: FP8 is only supported on GPU.'
            )
            return False

133
        try:
134
            self._model = BertBenchmarkModel(self._config, self._args.num_classes)
135
136
137
138
139
140
            if enable_fp8:
                self._fp8_recipe = DelayedScaling(
                    fp8_format=Format[precision.name.strip('FP8_')],
                    amax_history_len=16,
                    amax_compute_algo='max',
                )
141
                self._to_te_model(self._model.to(dtype=torch.float16))
142
143
            else:
                self._model = self._model.to(dtype=getattr(torch, precision.value))
144
145
146
147
148
149
150
151
152
153
            if self._gpu_available:
                self._model = self._model.cuda()
        except BaseException as e:
            logger.error(
                'Create model with specified precision failed - model: {}, precision: {}, message: {}.'.format(
                    self._name, precision, str(e)
                )
            )
            return False

154
        self._target = self._create_target(self._args.num_classes)
155
156
157
158
159
160
161
162
163
164

        return True

    def _train_step(self, precision):
        """Define the training process.

        Args:
            precision (Precision): precision of model and input data, such as float32, float16.

        Return:
165
            A tuple of (step_times_ms, info) of every training step.
166
167
        """
        duration = []
168
        periodic = {'loss': [], 'act_mean': [], 'step': []}
169
170
171
        curr_step = 0
        while True:
            for idx, sample in enumerate(self._dataloader):
172
                start = self._timer()
173
174
                if self._gpu_available:
                    sample = sample.cuda()
175
176
                if self._args.exclude_copy_time:
                    start = self._timer()
177
                self._optimizer.zero_grad()
178
179
180
181
182
                if self._fp8_recipe is not None:
                    with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe):
                        output = self._model(sample)
                else:
                    output = self._model(sample)
183
184
185
186
187
188
                logits = output
                # Use FP32 logits for loss only when determinism is enabled; otherwise
                # keep logits in their native precision to preserve benchmark semantics.
                enable_determinism = getattr(self._args, 'enable_determinism', False)
                logits_for_loss = logits.float() if enable_determinism else logits
                loss = self._loss_fn(logits_for_loss, self._target)
189
190
                loss.backward()
                self._optimizer.step()
191
                end = self._timer()
192
193
194
                curr_step += 1
                if curr_step > self._args.num_warmup:
                    duration.append((end - start) * 1000)
195
                    self.record_determinism_fingerprint(curr_step, loss, logits, periodic, self._args.check_frequency)
196
                    self._log_step_time(curr_step, precision, duration)
197
198
                if self._is_finished(curr_step, end, self._args.check_frequency):
                    return duration, self._finalize_periodic_logging(periodic)
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215

    def _inference_step(self, precision):
        """Define the inference process.

        Args:
            precision (Precision): precision of model and input data,
              such as float32, float16.

        Return:
            The latency list of every inference operation.
        """
        duration = []
        curr_step = 0
        with torch.no_grad():
            self._model.eval()
            while True:
                for idx, sample in enumerate(self._dataloader):
216
                    start = self._timer()
217
218
                    if self._gpu_available:
                        sample = sample.cuda()
219
220
                    if self._args.exclude_copy_time:
                        start = self._timer()
221
222
223
224
225
                    if self._fp8_recipe is not None:
                        with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe):
                            self._model(sample)
                    else:
                        self._model(sample)
226
                    end = self._timer()
227
228
229
230
                    curr_step += 1
                    if curr_step > self._args.num_warmup:
                        # Save the step time of every training/inference step, unit is millisecond.
                        duration.append((end - start) * 1000)
231
                        self._log_step_time(curr_step, precision, duration)
232
                    if self._is_finished(curr_step, end, self._args.check_frequency):
233
234
235
236
                        return duration


# Register BERT Large benchmark.
237
# Reference: https://huggingface.co/transformers/v3.3.1/pretrained_models.html
238
239
240
241
242
243
244
BenchmarkRegistry.register_benchmark(
    'pytorch-bert-large',
    PytorchBERT,
    parameters='--hidden_size=1024 --num_hidden_layers=24 --num_attention_heads=16 --intermediate_size=4096'
)

# Register BERT Base benchmark.
245
# Reference: https://huggingface.co/transformers/v3.3.1/pretrained_models.html
246
247
248
249
250
BenchmarkRegistry.register_benchmark(
    'pytorch-bert-base',
    PytorchBERT,
    parameters='--hidden_size=768 --num_hidden_layers=12 --num_attention_heads=12 --intermediate_size=3072'
)