bloom.py 16.2 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
5
# Copyright 2023 The vLLM team.
Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only BLOOM model compatible with HuggingFace weights."""
Woosuk Kwon's avatar
Woosuk Kwon committed
20
import math
21
from typing import Iterable, Optional, Set, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
24
25

import torch
from torch import nn
from transformers import BloomConfig
26
27
import os
import re
Woosuk Kwon's avatar
Woosuk Kwon committed
28

29
from vllm.attention import Attention
30
from vllm.compilation.decorators import support_torch_compile
31
from vllm.config import CacheConfig, VllmConfig
32
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
33
                              get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.model_executor.layers.activation import get_act_fn
35
36
37
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
38
from vllm.model_executor.layers.logits_processor import LogitsProcessor
39
from vllm.model_executor.layers.quantization import QuantizationConfig
40
from vllm.model_executor.layers.vocab_parallel_embedding import (
41
    ParallelLMHead, VocabParallelEmbedding)
42
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
43
from vllm.model_executor.sampling_metadata import SamplingMetadata
44
from vllm.sequence import IntermediateTensors
45
46
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
Woosuk Kwon's avatar
Woosuk Kwon committed
47

48
from .interfaces import SupportsPP, SupportsQuant, SupportsV0Only
49
from .utils import (is_pp_missing_parameter,
50
51
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
52

Woosuk Kwon's avatar
Woosuk Kwon committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
    base = torch.tensor(
        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        num_remaining_heads = min(closest_power_of_2,
                                  total_num_heads - closest_power_of_2)
        extra_powers = torch.arange(start=1,
                                    end=1 + 2 * num_remaining_heads,
                                    step=2,
                                    dtype=torch.int32)
        slopes = torch.cat(
            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
    return slopes


class BloomAttention(nn.Module):

81
82
83
    def __init__(
        self,
        config: BloomConfig,
84
        cache_config: Optional[CacheConfig] = None,
85
        quant_config: Optional[QuantizationConfig] = None,
86
        prefix: str = "",
87
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
88
89
90
91
92
93
94
95
96
97
        super().__init__()
        self.hidden_size = config.hidden_size
        self.total_num_heads = config.n_head
        self.head_dim = self.hidden_size // self.total_num_heads
        assert self.head_dim * self.total_num_heads == self.hidden_size

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

98
        self.query_key_value = QKVParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
99
            self.hidden_size,
100
101
            self.head_dim,
            self.total_num_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
102
            bias=True,
103
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
104
105
106
107
108
        )
        self.dense = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
109
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
110
111
112
113
114
115
116
117
118
119
        )

        # Create the alibi slopes and slice them.
        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
        alibi_slopes = _get_alibi_slopes(self.total_num_heads)
        alibi_slopes = alibi_slopes[head_start:head_end].tolist()

        scaling = self.head_dim**-0.5
120
121
122
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scaling,
123
                              alibi_slopes=alibi_slopes,
124
                              cache_config=cache_config,
125
126
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
zhuwenwen's avatar
zhuwenwen committed
127
        
128
129
130
131
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
132
133
134
135
136
137
138
139

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        del position_ids  # Unused.
        qkv, _ = self.query_key_value(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
140
141
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
Woosuk Kwon's avatar
Woosuk Kwon committed
142
        q, k, v = qkv.chunk(chunks=3, dim=-1)
143
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
144
145
146
147
148
149
        output, _ = self.dense(attn_output)
        return output


class BloomMLP(nn.Module):

150
151
152
    def __init__(
        self,
        config: BloomConfig,
153
        quant_config: Optional[QuantizationConfig] = None,
154
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
155
156
        super().__init__()
        hidden_size = config.hidden_size
157
158
159
        self.dense_h_to_4h = ColumnParallelLinear(
            hidden_size,
            4 * hidden_size,
160
            quant_config=quant_config,
161
        )
162
        self.gelu_impl = get_act_fn("gelu")
163
164
165
        self.dense_4h_to_h = RowParallelLinear(
            4 * hidden_size,
            hidden_size,
166
            quant_config=quant_config,
167
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
168
169
170

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.dense_h_to_4h(x)
171
        x = self.gelu_impl(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
172
173
174
175
176
177
        x, _ = self.dense_4h_to_h(x)
        return x


class BloomBlock(nn.Module):

178
179
180
    def __init__(
        self,
        config: BloomConfig,
181
        cache_config: Optional[CacheConfig] = None,
182
        quant_config: Optional[QuantizationConfig] = None,
183
        prefix: str = "",
184
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
185
186
187
188
189
        super().__init__()
        hidden_size = config.hidden_size

        self.input_layernorm = nn.LayerNorm(hidden_size,
                                            eps=config.layer_norm_epsilon)
190
191
192
193
        self.self_attention = BloomAttention(config,
                                             cache_config,
                                             quant_config,
                                             prefix=f"{prefix}.self_attention")
Woosuk Kwon's avatar
Woosuk Kwon committed
194
195
        self.post_attention_layernorm = nn.LayerNorm(
            hidden_size, eps=config.layer_norm_epsilon)
196
        self.mlp = BloomMLP(config, quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
        self.apply_residual_connection_post_layernorm = (
            config.apply_residual_connection_post_layernorm)

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)

        # Layer norm post the self attention.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        # Self attention.
        attention_output = self.self_attention(
            position_ids=position_ids,
            hidden_states=layernorm_output,
        )
        attention_output = attention_output + residual
        layernorm_output = self.post_attention_layernorm(attention_output)

        # Get residual
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = attention_output

        # MLP.
        output = self.mlp(layernorm_output) + residual
        return output


233
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
234
235
class BloomModel(nn.Module):

236
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
237
        super().__init__()
238
239
240
241
242

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
243
244
245
246
        self.embed_dim = config.hidden_size

        # Embedding + LN Embedding
        self.word_embeddings = VocabParallelEmbedding(
247
248
249
            config.vocab_size,
            self.embed_dim,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
250
251
252
253
        self.word_embeddings_layernorm = nn.LayerNorm(
            self.embed_dim, eps=config.layer_norm_epsilon)

        # Transformer blocks
254
255
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
256
257
            lambda prefix: BloomBlock(
                config, cache_config, quant_config, prefix=prefix),
258
            prefix=f"{prefix}.h")
Woosuk Kwon's avatar
Woosuk Kwon committed
259
260
261

        # Final Layer Norm
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
262
263
264
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
Woosuk Kwon's avatar
Woosuk Kwon committed
265

266
267
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.word_embeddings_layernorm(self.word_embeddings(input_ids))
Woosuk Kwon's avatar
Woosuk Kwon committed
268
269
270
271
272

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
273
        intermediate_tensors: Optional[IntermediateTensors],
274
        inputs_embeds: Optional[torch.Tensor] = None,
275
276
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
277
278
279
280
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
281
282
283
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
284
285
        for layer in self.h[self.start_layer:self.end_layer]:
            hidden_states = layer(position_ids, hidden_states)
286
287
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Woosuk Kwon committed
288
289
290
291
        hidden_states = self.ln_f(hidden_states)
        return hidden_states


292
class BloomForCausalLM(nn.Module, SupportsPP, SupportsV0Only, SupportsQuant):
Woosuk Kwon's avatar
Woosuk Kwon committed
293

294
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
295
        super().__init__()
296
297
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
298
        self.config = config
299
        self.quant_config = quant_config
300
301
302
        self.transformer = BloomModel(vllm_config=vllm_config,
                                      prefix=maybe_prefix(
                                          prefix, "transformer"))
303
304
305
306
307
308
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.word_embeddings
        else:
            self.lm_head = ParallelLMHead(self.config.vocab_size,
                                          self.config.hidden_size)

309
        self.logits_processor = LogitsProcessor(config.vocab_size)
310
311
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
312
313
314
315
316
317
318
319
320
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
              
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
Woosuk Kwon's avatar
Woosuk Kwon committed
321

322
323
324
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
325
326
327
328
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
329
        intermediate_tensors: Optional[IntermediateTensors] = None,
330
        inputs_embeds: Optional[torch.Tensor] = None,
331
    ) -> Union[torch.Tensor, IntermediateTensors]:
332
333
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
334
335
        return hidden_states

336
337
338
339
340
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
341
        logits = self.logits_processor(self.lm_head, hidden_states,
342
343
344
                                       sampling_metadata)
        return logits

345
346
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
347
        params_dict = dict(self.named_parameters(remove_duplicate=False))
348
        loaded_params: Set[str] = set()
349
        for name, loaded_weight in weights:
350
            if name == "lm_head.weight":
351
352
353
                continue
            if not name.startswith("transformer."):
                name = "transformer." + name
354
355
            if is_pp_missing_parameter(name, self):
                continue
356
            param = params_dict[name]
Woosuk Kwon's avatar
Woosuk Kwon committed
357
358

            if "query_key_value" in name:
359
360
361
                # NOTE: BLOOM's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
Woosuk Kwon's avatar
Woosuk Kwon committed
362
                # Thus, we need weight conversion.
363
                output_dim = getattr(param, "output_dim", None)
Woosuk Kwon's avatar
Woosuk Kwon committed
364
                num_heads = self.config.num_attention_heads
365
366
367
368
369
370
371
372
373
374
375
376
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
377
            loaded_params.add(name)
378
379
380
381
382
383
384
385
386
387
        
        if self.use_llama_nn and self.quant_method is None:
            lay_key_words = [
                "self_attention.query_key_value.weight",
                "self_attention.dense.weight",
                "mlp.dense_h_to_4h.weight",
                "mlp.dense_4h_to_h.weight"
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
388
389
            # lay_qkv_words = ["self_attention.query_key_value.weight"]   
            # qkv_words = "|".join(lay_qkv_words)  
390
            
zhuwenwen's avatar
zhuwenwen committed
391
392
            # lay_qkv_bias_words = ["self_attention.query_key_value.bias"]   
            # qkv_bias_words = "|".join(lay_qkv_bias_words) 
393
            
zhuwenwen's avatar
zhuwenwen committed
394
395
            for layername in loaded_params:
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
396
397
                # if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                #     weight.data = pad_weight(weight.data, 32)
398
399
400
                    
                matches = re.findall(combined_words, layername)
                if matches:   
zhuwenwen's avatar
zhuwenwen committed
401
402
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
403
                    
zhuwenwen's avatar
zhuwenwen committed
404
405
406
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
407
408
409
410
411
412
413
414
                        
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
zhuwenwen's avatar
zhuwenwen committed
415
                    
416
        return loaded_params
417