bloom_model.py 14.9 KB
Newer Older
yuguo960516's avatar
bloom  
yuguo960516 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import oneflow as flow
from oneflow import nn

from libai.config import configurable
from libai.inference.generator.generation_utils import Generator
from libai.layers import Embedding, LayerNorm, LMLogits
from libai.models.utils import init_method_normal, scaled_init_method_normal
from libai.utils import distributed as dist
from projects.BLOOM.modeling.mask import _expand_mask, _make_causal_mask, build_alibi_tensor
from projects.BLOOM.modeling.transformers import BloomBlock


class BloomModel(nn.Module):
    @configurable
    def __init__(
        self,
        vocab_size,
        hidden_size,
        hidden_layers,
        n_head,
        padding_idx,
        pretraining_tp=1,
        slow_but_exact=False,
        initializer_range=0.02,
        apply_residual_connection_post_layernorm=False,
        hidden_dropout=0,
        attention_dropout=0,
        amp_enabled=False,
        layer_norm_epsilon=1e-12,
        cfg=None,
    ):
        super().__init__()
        self.cfg = cfg
        self.embed_dim = hidden_size
        self.num_heads = n_head
        self.hidden_layers = hidden_layers
        init_method = init_method_normal(initializer_range)
        scaled_init_method = scaled_init_method_normal(initializer_range, hidden_layers)

        self.word_embeddings = Embedding(
            vocab_size,
            self.embed_dim,
            padding_idx=padding_idx,
            init_method=init_method,
            amp_enabled=amp_enabled,
            layer_idx=0,
        )

        self.word_embeddings_layernorm = LayerNorm(
            self.embed_dim, eps=layer_norm_epsilon, layer_idx=0
        )

        self.h = flow.nn.ModuleList(
            [
                BloomBlock(
                    hidden_size=hidden_size,
                    n_head=n_head,
                    layer_norm_epsilon=layer_norm_epsilon,
                    hidden_dropout=hidden_dropout,
                    attention_dropout=attention_dropout,
                    pretraining_tp=pretraining_tp,
                    slow_but_exact=slow_but_exact,
                    init_method=init_method,
                    output_layer_init_method=scaled_init_method,
                    apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm,  # noqa
                    layer_idx=i,
                )
                for i in range(hidden_layers)
            ]
        )

        # Final Layer Norm
        self.ln_f = LayerNorm(self.embed_dim, eps=layer_norm_epsilon, layer_idx=hidden_layers - 1)

    @classmethod
    def from_config(cls, cfg):
        return {
            "vocab_size": cfg.vocab_size,
            "hidden_size": cfg.hidden_size,
            "hidden_layers": cfg.hidden_layers,
            "n_head": cfg.n_head,
            "padding_idx": cfg.padding_idx,
            "pretraining_tp": cfg.pretraining_tp,
            "slow_but_exact": cfg.slow_but_exact,
            "apply_residual_connection_post_layernorm": cfg.apply_residual_connection_post_layernorm,  # noqa
            "hidden_dropout": cfg.hidden_dropout,
            "attention_dropout": cfg.attention_dropout,
            "amp_enabled": cfg.amp_enabled,
            "layer_norm_epsilon": cfg.layer_norm_epsilon,
            "cfg": cfg,
        }

    def _prepare_attn_mask(
        self,
        attention_mask,
        input_shape,
        past_key_values_length,
    ):
        combined_attention_mask = None
        _, src_length = input_shape

        if src_length > 1:
            combined_attention_mask = _make_causal_mask(
                input_shape, past_key_values_length=past_key_values_length
            )

        # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
        expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
        combined_attention_mask = (
            expanded_attn_mask
            if combined_attention_mask is None
            else expanded_attn_mask | combined_attention_mask
        )

        return combined_attention_mask

    def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
        """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
        if head_mask.dim() == 1:
            head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
            head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
        elif head_mask.dim() == 2:
            head_mask = (
                head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
            )  # We can specify head_mask for each layer
        assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
        head_mask = head_mask.to(dtype=self.dtype)  # switch to float if need + fp16 compatibility
        return head_mask

    def get_head_mask(self, head_mask, num_hidden_layers, is_attention_chunked=False):
        """
        Prepare the head mask if needed.

        Args:
            head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`,
            *optional*):
                The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for
                discard).
            num_hidden_layers (`int`):
                The number of hidden layers in the model.
            is_attention_chunked: (`bool`, *optional*, defaults to `False`):
                Whether or not the attentions scores are computed by chunks or not.

        Returns:
            `torch.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x
            seq_length]` or list with `[None]` for each layer.
        """
        if head_mask is not None:
            head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
            if is_attention_chunked is True:
                head_mask = head_mask.unsqueeze(-1)
        else:
            head_mask = [None] * num_hidden_layers

        return head_mask

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        past_key_values=None,
        head_mask=None,
        inputs_embeds=None,
        use_cache=None,
        output_attentions=None,
    ):
        input_ids = (
            input_ids.to_global(placement=dist.get_layer_placement(0))
            if input_ids is not None
            else input_ids
        )
        attention_mask = (
            attention_mask.to_global(placement=dist.get_layer_placement(0))
            if attention_mask is not None
            else attention_mask
        )
        head_mask = (
            head_mask.to_global(placement=dist.get_layer_placement(0))
            if head_mask is not None
            else head_mask
        )
        inputs_embeds = (
            inputs_embeds.to_global(placement=dist.get_layer_placement(0))
            if inputs_embeds is not None
            else inputs_embeds
        )

        if input_ids is not None:
            batch_size, seq_length = input_ids.shape

        if past_key_values is None:
            past_key_values = tuple([None] * len(self.h))

        head_mask = self.get_head_mask(head_mask, self.hidden_layers)

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)

        hidden_states = self.word_embeddings_layernorm(inputs_embeds)
        presents = () if use_cache else None

        seq_length_with_past = seq_length
        past_key_values_length = 0
        if past_key_values[0] is not None:
            past_key_values_length = past_key_values[0][0].shape[2]
            seq_length_with_past = seq_length_with_past + past_key_values_length
        if attention_mask is None:
            attention_mask = flow.ones(
                (batch_size, seq_length_with_past),
                sbp=dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]),
                placement=dist.get_layer_placement(0),
            )

        alibi = build_alibi_tensor(attention_mask, self.num_heads, hidden_states.dtype)

        causal_mask = self._prepare_attn_mask(
            attention_mask,
            input_shape=(batch_size, seq_length),
            past_key_values_length=past_key_values_length,
        )

        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
            outputs = block(
                hidden_states,
                layer_past=layer_past,
                attention_mask=causal_mask,
                head_mask=head_mask[i],
                use_cache=use_cache,
                output_attentions=output_attentions,
                alibi=alibi,
            )

            hidden_states = outputs[0]
            if use_cache is True:
                presents = presents + (outputs[1],)

        hidden_states = self.ln_f(hidden_states)

        return {"last_hidden_state": hidden_states, "past_key_values": presents}


class BloomForCausalLM(nn.Module, Generator):
    @configurable
    def __init__(
        self,
        vocab_size,
        hidden_size,
        hidden_layers,
        n_head,
        padding_idx,
        pretraining_tp=1,
        slow_but_exact=False,
        initializer_range=0.02,
        apply_residual_connection_post_layernorm=False,
        hidden_dropout=0,
        attention_dropout=0,
        amp_enabled=False,
        layer_norm_epsilon=1e-12,
        cfg=None,
    ):
        super().__init__()
        self.cfg = cfg
        self.transformer = BloomModel(
            vocab_size=vocab_size,
            hidden_size=hidden_size,
            hidden_layers=hidden_layers,
            n_head=n_head,
            padding_idx=padding_idx,
            pretraining_tp=pretraining_tp,
            slow_but_exact=slow_but_exact,
            initializer_range=initializer_range,
            apply_residual_connection_post_layernorm=apply_residual_connection_post_layernorm,
            hidden_dropout=hidden_dropout,
            attention_dropout=attention_dropout,
            amp_enabled=amp_enabled,
            layer_norm_epsilon=layer_norm_epsilon,
            cfg=cfg,
        )
        self.lm_head = LMLogits(vocab_size, bias=False)

    @classmethod
    def from_config(cls, cfg):
        return {
            "vocab_size": cfg.vocab_size,
            "hidden_size": cfg.hidden_size,
            "hidden_layers": cfg.hidden_layers,
            "n_head": cfg.n_head,
            "padding_idx": cfg.padding_idx,
            "pretraining_tp": cfg.pretraining_tp,
            "slow_but_exact": cfg.slow_but_exact,
            "apply_residual_connection_post_layernorm": cfg.apply_residual_connection_post_layernorm,  # noqa
            "hidden_dropout": cfg.hidden_dropout,
            "attention_dropout": cfg.attention_dropout,
            "amp_enabled": cfg.amp_enabled,
            "layer_norm_epsilon": cfg.layer_norm_epsilon,
            "cfg": cfg,
        }

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        past_key_values=None,
        head_mask=None,
        inputs_embeds=None,
        use_cache=None,
        output_attentions=None,
    ):

        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        hidden_states = transformer_outputs["last_hidden_state"]

        lm_logits = self.lm_head(hidden_states, self.transformer.word_embeddings.weight)

        return {
            "logits": lm_logits,
            "past_key_values": transformer_outputs["past_key_values"],
            "hidden_states": transformer_outputs["last_hidden_state"],
            # "attentions": transformer_outputs.attentions,
        }

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        inputs_embeds=None,
        **kwargs,
    ) -> dict:
        # only last token for input_ids if past is not None
        if past_key_values:
            input_ids = input_ids[:, -1].unsqueeze(-1)

            if past_key_values[0][0].shape[0] == input_ids.shape[0]:
                past_key_values = self._convert_to_bloom_cache(past_key_values)

        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
            }
        )
        return model_inputs

    def _reorder_cache(self, past, beam_idx):
        standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))

        device_to_beam_idx = {
            past_state.device: beam_idx.to(past_state.device)
            for layer_past in past
            for past_state in layer_past
        }
        reordered_past = tuple(
            (
                layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
                layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
            )
            for layer_past in standardized_past
        )
        return self._convert_to_bloom_cache(reordered_past)

    def _convert_to_standard_cache(
        past_key_value,
        batch_size,
    ):
        """
        Standardizes the format of the cache so as to match most implementations,
        i.e. to tuple(tuple([batch_size, num_heads, ...]))
        """
        batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
        num_heads = batch_size_times_num_heads // batch_size
        return tuple(
            (
                layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
                layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
            )
            for layer_past in past_key_value
        )

    def _convert_to_bloom_cache(past_key_value):
        """
        Converts the cache to the format expected by Bloom,
        i.e. to tuple(tuple([batch_size * num_heads, ...]))
        """
        batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
        batch_size_times_num_heads = batch_size * num_heads
        return tuple(
            (
                layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
                layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
            )
            for layer_past in past_key_value
        )