pytorch_llama.py 10.5 KB
Newer Older
pdr's avatar
pdr committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Module of the Pytorch Llama2 model."""

import torch
from transformers import LlamaModel, LlamaConfig
try:
    import transformer_engine.pytorch as te
    from transformer_engine.common.recipe import Format, DelayedScaling
except ImportError:
    te = None

from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Precision
from superbench.benchmarks.model_benchmarks.model_base import Optimizer
from superbench.benchmarks.model_benchmarks.pytorch_base import PytorchBase
from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDataset


class LlamaBenchmarkModel(torch.nn.Module):
    """The Llama model for benchmarking."""
    def __init__(self, config, num_classes):
        """Constructor.

        Args:
            config (LlamaConfig): Configurations of Llama model.
            num_classes (int): The number of objects for classification.
        """
        super().__init__()
        self._llama = LlamaModel(config)
        self._linear = torch.nn.Linear(config.hidden_size, num_classes)

    def forward(self, input):
        """Forward propagation function.

        Args:
            input (torch.LongTensor): Indices of input sequence tokens in the vocabulary,
              shape (batch_size, sequence_length).

        Return:
            result (torch.FloatTensor): Last layer hidden-state of the first token of the sequence
              (classification token) further processed by a Linear layer, shape (batch_size, hidden_size).
        """
        outputs = self._llama(input)
        result = self._linear(outputs[0])
        return result


class PytorchLlama(PytorchBase):
    """The Llama benchmark class."""
    def __init__(self, name, parameters=''):
        """Constructor.

        Args:
            name (str): benchmark name.
            parameters (str): benchmark parameters.
        """
        super().__init__(name, parameters)
        self._config = None
        self._fp8_recipe = None
        self._supported_precision = [
            Precision.FLOAT32,
            Precision.FLOAT16,
            Precision.FP8_HYBRID,
            Precision.FP8_E4M3,
        ]
        self._optimizer_type = Optimizer.ADAMW
        self._loss_fn = torch.nn.CrossEntropyLoss()

    def add_parser_arguments(self):
        """Add the Llama-specified arguments.

        Llama2 model reference: https://huggingface.co/docs/transformers/model_doc/llama2
        """
        super().add_parser_arguments()

        self._parser.add_argument('--num_classes', type=int, default=100, required=False, help='Num of class.')
        self._parser.add_argument('--hidden_size', type=int, default=1280, required=False, help='Hidden size.')
        self._parser.add_argument(
            '--num_hidden_layers', type=int, default=36, required=False, help='The number of hidden layers.'
        )
        self._parser.add_argument(
            '--num_attention_heads', type=int, default=20, required=False, help='The number of attention heads.'
        )
        self._parser.add_argument(
            '--intermediate_size',
            type=int,
            default=11008,
            required=False,
            help='Dimension of the MLP representations.'
        )
        self._parser.add_argument('--seq_len', type=int, default=512, required=False, help='Sequence length.')
        self._parser.add_argument(
            '--num_key_value_heads',
            type=int,
            default=None,
            required=False,
            help='The number of key_value heads that should be used to implement Grouped Query Attention.'
        )

    def _generate_dataset(self):
        """Generate dataset for benchmarking according to shape info.

        Return:
            True if dataset is created successfully.
        """
        self._dataset = TorchRandomDataset(
            [self._args.sample_count, self._args.seq_len], self._world_size, dtype=torch.long
        )
        if len(self._dataset) == 0:
            logger.error('Generate random dataset failed - model: {}'.format(self._name))
            return False

        return True

    def _create_model(self, precision):
        """Construct the model for benchmarking.

        Args:
            precision (Precision): precision of model and input data, such as float32, float16.
        """
        self._config = LlamaConfig(
            hidden_size=self._args.hidden_size,
            num_hidden_layers=self._args.num_hidden_layers,
            num_attention_heads=self._args.num_attention_heads,
            num_key_value_heads=self._args.num_key_value_heads,
            intermediate_size=self._args.intermediate_size,
            max_position_embeddings=4096,    # Maximum sequence length that llama2 supports
            rms_norm_eps=1e-05,    # Llama2 default for epsilon used by the rms normalization layers
        )

        enable_fp8 = precision.name.startswith('FP8_')
        if enable_fp8 and te is None:
            logger.error(
                f'Create model with fp8 failed - model: {self._name}, precision: {precision},'
                ' message: Cannot find transformer_engine.'
            )
            return False
        if enable_fp8 and not self._gpu_available:
            logger.error(
                f'Create model with fp8 failed - model: {self._name}, precision: {precision},'
                ' message: FP8 is only supported on GPU.'
            )
            return False

        try:
            self._model = LlamaBenchmarkModel(self._config, self._args.num_classes)
            if enable_fp8:
                self._fp8_recipe = DelayedScaling(
                    fp8_format=Format[precision.name.strip('FP8_')],
                    amax_history_len=16,
                    amax_compute_algo='max',
                )
                self._to_te_model(self._model.to(dtype=torch.float16))
            else:
                self._model = self._model.to(dtype=getattr(torch, precision.value))
            if self._gpu_available:
                self._model = self._model.cuda()
        except BaseException as e:
            logger.error(
                'Create model with specified precision failed - model: {}, precision: {}, message: {}.'.format(
                    self._name, precision, str(e)
                )
            )
            return False

168
        self._target = self._create_target(self._args.num_classes)
pdr's avatar
pdr committed
169
170
171
172
173
174
175
176
177
178

        return True

    def _train_step(self, precision):
        """Define the training process.

        Args:
            precision (Precision): precision of model and input data, such as float32, float16.

        Return:
179
            A tuple of (step_times_ms, info) of every training step.
pdr's avatar
pdr committed
180
181
        """
        duration = []
182
        periodic = {'loss': [], 'act_mean': [], 'step': []}
pdr's avatar
pdr committed
183
184
185
186
187
188
        curr_step = 0
        while True:
            for idx, sample in enumerate(self._dataloader):
                start = self._timer()
                if self._gpu_available:
                    sample = sample.cuda()
189
190
                if self._args.exclude_copy_time:
                    start = self._timer()
pdr's avatar
pdr committed
191
192
193
194
195
196
                self._optimizer.zero_grad()
                if self._fp8_recipe is not None:
                    with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe):
                        output = self._model(sample)
                else:
                    output = self._model(sample)
197
198
199
200
201
202
                logits = output[range(self._args.batch_size), -1]
                # Use FP32 logits for loss only when determinism is enabled; otherwise
                # keep logits in their native precision to preserve benchmark semantics.
                enable_determinism = getattr(self._args, 'enable_determinism', False)
                logits_for_loss = logits.float() if enable_determinism else logits
                loss = self._loss_fn(logits_for_loss, self._target)
pdr's avatar
pdr committed
203
204
205
206
207
208
                loss.backward()
                self._optimizer.step()
                end = self._timer()
                curr_step += 1
                if curr_step > self._args.num_warmup:
                    duration.append((end - start) * 1000)
209
                    self.record_determinism_fingerprint(curr_step, loss, logits, periodic, self._args.check_frequency)
pdr's avatar
pdr committed
210
                    self._log_step_time(curr_step, precision, duration)
211
212
                if self._is_finished(curr_step, end, self._args.check_frequency):
                    return duration, self._finalize_periodic_logging(periodic)
pdr's avatar
pdr committed
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232

    def _inference_step(self, precision):
        """Define the inference process.

        Args:
            precision (Precision): precision of model and input data,
              such as float32, float16.

        Return:
            The latency list of every inference operation.
        """
        duration = []
        curr_step = 0
        with torch.no_grad():
            self._model.eval()
            while True:
                for idx, sample in enumerate(self._dataloader):
                    start = self._timer()
                    if self._gpu_available:
                        sample = sample.cuda()
233
234
                    if self._args.exclude_copy_time:
                        start = self._timer()
pdr's avatar
pdr committed
235
236
237
238
239
240
241
242
243
244
                    if self._fp8_recipe is not None:
                        with te.fp8_autocast(enabled=True, fp8_recipe=self._fp8_recipe):
                            self._model(sample)
                    else:
                        self._model(sample)
                    end = self._timer()
                    curr_step += 1
                    if curr_step > self._args.num_warmup:
                        duration.append((end - start) * 1000)
                        self._log_step_time(curr_step, precision, duration)
245
                    if self._is_finished(curr_step, end, self._args.check_frequency):
pdr's avatar
pdr committed
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
                        return duration


# Register Llama2 benchmark with 7b parameters.
BenchmarkRegistry.register_benchmark(
    'pytorch-llama2-7b',
    PytorchLlama,
    parameters='--hidden_size=4096 --num_hidden_layers=32 --num_attention_heads=32 --num_key_value_heads=32 \
        --intermediate_size=11008'
)

# Register Llama2 benchmark with 13b parameters.
BenchmarkRegistry.register_benchmark(
    'pytorch-llama2-13b',
    PytorchLlama,
    parameters='--hidden_size=5120 --num_hidden_layers=40 --num_attention_heads=40 --num_key_value_heads=40 \
        --intermediate_size=13824'
)

# Register Llama2 benchmark with 70b parameters.
BenchmarkRegistry.register_benchmark(
    'pytorch-llama2-70b',
    PytorchLlama,
    parameters='--hidden_size=8192 --num_hidden_layers=80 --num_attention_heads=64 --num_key_value_heads=8 \
        --intermediate_size=28672'
)