transformers.py 4.14 KB
Newer Older
yuguo960516's avatar
bloom  
yuguo960516 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# coding=utf-8
# Copyright 2021 The OneFlow Authors. All rights reserved.
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from oneflow import nn

from libai.layers import LayerNorm
from libai.utils import distributed as dist
from projects.BLOOM.modeling.attention import BloomAttention
from projects.BLOOM.modeling.mlp import BloomMLP


class BloomBlock(nn.Module):
    def __init__(
        self,
        hidden_size,
        n_head,
        layer_norm_epsilon,
        hidden_dropout,
        attention_dropout,
        pretraining_tp,
        slow_but_exact,
        init_method,
        output_layer_init_method,
        apply_residual_connection_post_layernorm,
        layer_idx=0,
    ):
        super().__init__()
        hidden_size = hidden_size

        self.input_layernorm = LayerNorm(hidden_size, eps=layer_norm_epsilon, layer_idx=layer_idx)
        self.num_heads = n_head
        self.self_attention = BloomAttention(
            hidden_size=hidden_size,
            n_head=n_head,
            hidden_dropout=hidden_dropout,
            attention_dropout=attention_dropout,
            pretraining_tp=pretraining_tp,
            slow_but_exact=slow_but_exact,
            init_method=init_method,
            output_layer_init_method=output_layer_init_method,
            layer_idx=layer_idx,
        )
        self.post_attention_layernorm = LayerNorm(
            hidden_size, eps=layer_norm_epsilon, layer_idx=layer_idx
        )

        self.mlp = BloomMLP(
            hidden_size,
            pretraining_tp,
            slow_but_exact,
            hidden_dropout,
            init_method,
            output_layer_init_method,
            layer_idx,
        )

        self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
        self.hidden_dropout = hidden_dropout
        self.layer_idx = layer_idx

    def forward(
        self,
        hidden_states,
        alibi,
        attention_mask,
        layer_past=None,
        head_mask=None,
        use_cache: bool = False,
        output_attentions: bool = False,
    ):
        # Change placement for pipeline parallelsim
        hidden_states = hidden_states.to_global(placement=dist.get_layer_placement(self.layer_idx))

        alibi = alibi.to_global(placement=dist.get_layer_placement(self.layer_idx))

        # hidden_states shape: (batch_size, seq_length, hidden_size)
        if attention_mask is not None:
            attention_mask = attention_mask.to_global(
                placement=dist.get_layer_placement(self.layer_idx)
            )

        layernorm_output = self.input_layernorm(hidden_states)

        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        # Self attention.
        attn_outputs = self.self_attention(
            layernorm_output,
            residual,
            layer_past=layer_past,
            attention_mask=attention_mask,
            alibi=alibi,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )

        attention_output = attn_outputs[0]

        outputs = attn_outputs[1:]

        layernorm_output = self.post_attention_layernorm(attention_output)

        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = attention_output

        # MLP.
        output = self.mlp(layernorm_output, residual)

        if use_cache:
            outputs = (output,) + outputs
        else:
            outputs = (output,) + outputs[1:]

        return outputs  # hidden_states, present, attentions