pytorch_gpt2.py 7.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Module of the Pytorch GPT2 model."""

import torch
from transformers import GPT2Model, GPT2Config

from superbench.common.utils import logger
from superbench.benchmarks import BenchmarkRegistry, Precision
from superbench.benchmarks.model_benchmarks.model_base import Optimizer
from superbench.benchmarks.model_benchmarks.pytorch_base import PytorchBase
from superbench.benchmarks.model_benchmarks.random_dataset import TorchRandomDataset


class GPT2BenchmarkModel(torch.nn.Module):
    """The GPT2 model for benchmarking."""
18
    def __init__(self, config, num_classes):
19
20
21
22
        """Constructor.

        Args:
            config (GPT2Config): Configurations of GPT2 model.
23
            num_classes (int): The number of objects for classification.
24
25
26
        """
        super().__init__()
        self._bert = GPT2Model(config)
27
        self._linear = torch.nn.Linear(config.hidden_size, num_classes)
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

    def forward(self, input):
        """Forward propagation function.

        Args:
            input (torch.LongTensor): Indices of input sequence tokens in the vocabulary,
              shape (batch_size, sequence_length).

        Return:
            result (torch.FloatTensor): Last layer hidden-state of the first token of the sequence
              (classification token) further processed by a Linear layer, shape (batch_size, hidden_size).
        """
        outputs = self._bert(input)
        result = self._linear(outputs[0])
        return result


class PytorchGPT2(PytorchBase):
    """The GPT2 benchmark class."""
    def __init__(self, name, parameters=''):
        """Constructor.

        Args:
            name (str): benchmark name.
            parameters (str): benchmark parameters.
        """
        super().__init__(name, parameters)
        self._config = None
        self._supported_precision = [Precision.FLOAT32, Precision.FLOAT16]
        self._optimizer_type = Optimizer.ADAMW
        self._loss_fn = torch.nn.CrossEntropyLoss()

    def add_parser_arguments(self):
        """Add the GPT2-specified arguments.

        GPT2 model reference: https://huggingface.co/transformers/model_doc/gpt2.html
        """
        super().add_parser_arguments()

        self._parser.add_argument('--num_classes', type=int, default=100, required=False, help='Num of class.')
        self._parser.add_argument('--hidden_size', type=int, default=1280, required=False, help='Hidden size.')
        self._parser.add_argument(
            '--num_hidden_layers', type=int, default=36, required=False, help='The number of hidden layers.'
        )
        self._parser.add_argument(
            '--num_attention_heads', type=int, default=20, required=False, help='The number of attention heads.'
        )
        self._parser.add_argument('--seq_len', type=int, default=512, required=False, help='Sequence length.')

    def _generate_dataset(self):
        """Generate dataset for benchmarking according to shape info.

        Return:
            True if dataset is created successfully.
        """
        self._dataset = TorchRandomDataset(
            [self._args.sample_count, self._args.seq_len], self._world_size, dtype=torch.long
        )
        if len(self._dataset) == 0:
            logger.error('Generate random dataset failed - model: {}'.format(self._name))
            return False

        return True

    def _create_model(self, precision):
        """Construct the model for benchmarking.

        Args:
            precision (Precision): precision of model and input data, such as float32, float16.
        """
        self._config = GPT2Config(
            n_embd=self._args.hidden_size, n_layer=self._args.num_hidden_layers, n_head=self._args.num_attention_heads
        )

        try:
            self._model = GPT2BenchmarkModel(self._config, self._args.num_classes)
            self._model = self._model.to(dtype=getattr(torch, precision.value))
            if self._gpu_available:
                self._model = self._model.cuda()
        except BaseException as e:
            logger.error(
                'Create model with specified precision failed - model: {}, precision: {}, message: {}.'.format(
                    self._name, precision, str(e)
                )
            )
            return False

        self._target = torch.LongTensor(self._args.batch_size).random_(self._args.num_classes)
        if self._gpu_available:
            self._target = self._target.cuda()

        return True

    def _train_step(self, precision):
        """Define the training process.

        Args:
            precision (Precision): precision of model and input data, such as float32, float16.

        Return:
            The step-time list of every training step.
        """
        duration = []
        curr_step = 0
132
        check_frequency = 100
133
134
        while True:
            for idx, sample in enumerate(self._dataloader):
135
                start = self._timer()
136
137
138
139
140
141
142
                if self._gpu_available:
                    sample = sample.cuda()
                self._optimizer.zero_grad()
                output = self._model(sample)
                loss = self._loss_fn(output[range(self._args.batch_size), -1], self._target)
                loss.backward()
                self._optimizer.step()
143
                end = self._timer()
144
145
146
147
                curr_step += 1
                if curr_step > self._args.num_warmup:
                    # Save the step time of every training/inference step, unit is millisecond.
                    duration.append((end - start) * 1000)
148
                if self._is_finished(curr_step, end, check_frequency):
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
                    return duration

    def _inference_step(self, precision):
        """Define the inference process.

        Args:
            precision (Precision): precision of model and input data,
              such as float32, float16.

        Return:
            The latency list of every inference operation.
        """
        duration = []
        curr_step = 0
        with torch.no_grad():
            self._model.eval()
            while True:
                for idx, sample in enumerate(self._dataloader):
167
                    start = self._timer()
168
169
170
                    if self._gpu_available:
                        sample = sample.cuda()
                    self._model(sample)
171
                    end = self._timer()
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
                    curr_step += 1
                    if curr_step > self._args.num_warmup:
                        # Save the step time of every training/inference step, unit is millisecond.
                        duration.append((end - start) * 1000)
                    if self._is_finished(curr_step, end):
                        return duration


# Register GPT2 benchmark with 117M parameters.
# Reference: https://huggingface.co/transformers/pretrained_models.html
BenchmarkRegistry.register_benchmark(
    'pytorch-gpt2-small', PytorchGPT2, parameters='--hidden_size=768 --num_hidden_layers=12 --num_attention_heads=12'
)

# Register GPT2 benchmark with 345M parameters.
# Reference: https://huggingface.co/transformers/pretrained_models.html
BenchmarkRegistry.register_benchmark(
    'pytorch-gpt2-medium', PytorchGPT2, parameters='--hidden_size=1024 --num_hidden_layers=24 --num_attention_heads=16'
)

# Register GPT2 benchmark with 774M parameters.
# Reference: https://huggingface.co/transformers/pretrained_models.html
BenchmarkRegistry.register_benchmark(
    'pytorch-gpt2-large', PytorchGPT2, parameters='--hidden_size=1280 --num_hidden_layers=36 --num_attention_heads=20'
)

# Register GPT2 benchmark with 1558M parameters.
# Reference: https://huggingface.co/transformers/pretrained_models.html
BenchmarkRegistry.register_benchmark(
    'pytorch-gpt2-xl', PytorchGPT2, parameters='--hidden_size=1600 --num_hidden_layers=48 --num_attention_heads=25'
)