bloom.py 16.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
6
# Copyright 2023 The vLLM team.
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only BLOOM model compatible with HuggingFace weights."""
Woosuk Kwon's avatar
Woosuk Kwon committed
21
import math
22
from collections.abc import Iterable
23
from itertools import islice
24
from typing import Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28

import torch
from torch import nn
from transformers import BloomConfig
29
30
import os
import re
Woosuk Kwon's avatar
Woosuk Kwon committed
31

32
from vllm.attention import Attention
33
from vllm.compilation.decorators import support_torch_compile
34
from vllm.config import CacheConfig, VllmConfig
35
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
36
                              get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
37
from vllm.model_executor.layers.activation import get_act_fn
38
39
40
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
from vllm.model_executor.layers.quantization import QuantizationConfig
43
from vllm.model_executor.layers.vocab_parallel_embedding import (
44
    ParallelLMHead, VocabParallelEmbedding)
45
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
46
from vllm.model_executor.sampling_metadata import SamplingMetadata
47
from vllm.sequence import IntermediateTensors
48
49
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
Woosuk Kwon's avatar
Woosuk Kwon committed
50

51
from .interfaces import SupportsPP, SupportsQuant
52
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
53
54
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
55

Woosuk Kwon's avatar
Woosuk Kwon committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83

def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
    base = torch.tensor(
        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        num_remaining_heads = min(closest_power_of_2,
                                  total_num_heads - closest_power_of_2)
        extra_powers = torch.arange(start=1,
                                    end=1 + 2 * num_remaining_heads,
                                    step=2,
                                    dtype=torch.int32)
        slopes = torch.cat(
            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
    return slopes


class BloomAttention(nn.Module):

84
85
86
    def __init__(
        self,
        config: BloomConfig,
87
        cache_config: Optional[CacheConfig] = None,
88
        quant_config: Optional[QuantizationConfig] = None,
89
        prefix: str = "",
90
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
91
92
93
94
95
96
97
98
99
100
        super().__init__()
        self.hidden_size = config.hidden_size
        self.total_num_heads = config.n_head
        self.head_dim = self.hidden_size // self.total_num_heads
        assert self.head_dim * self.total_num_heads == self.hidden_size

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

101
        self.query_key_value = QKVParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
102
            self.hidden_size,
103
104
            self.head_dim,
            self.total_num_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
105
            bias=True,
106
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
107
108
109
110
111
        )
        self.dense = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
112
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
113
114
115
116
117
118
119
120
121
122
        )

        # Create the alibi slopes and slice them.
        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
        alibi_slopes = _get_alibi_slopes(self.total_num_heads)
        alibi_slopes = alibi_slopes[head_start:head_end].tolist()

        scaling = self.head_dim**-0.5
123
124
125
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scaling,
126
                              alibi_slopes=alibi_slopes,
127
                              cache_config=cache_config,
128
129
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
zhuwenwen's avatar
zhuwenwen committed
130
        
131
132
133
134
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
135
136
137
138
139
140
141
142

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        del position_ids  # Unused.
        qkv, _ = self.query_key_value(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
143
144
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
Woosuk Kwon's avatar
Woosuk Kwon committed
145
        q, k, v = qkv.chunk(chunks=3, dim=-1)
146
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
147
148
149
150
151
152
        output, _ = self.dense(attn_output)
        return output


class BloomMLP(nn.Module):

153
154
155
    def __init__(
        self,
        config: BloomConfig,
156
        quant_config: Optional[QuantizationConfig] = None,
157
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
158
159
        super().__init__()
        hidden_size = config.hidden_size
160
161
162
        self.dense_h_to_4h = ColumnParallelLinear(
            hidden_size,
            4 * hidden_size,
163
            quant_config=quant_config,
164
        )
165
        self.gelu_impl = get_act_fn("gelu")
166
167
168
        self.dense_4h_to_h = RowParallelLinear(
            4 * hidden_size,
            hidden_size,
169
            quant_config=quant_config,
170
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
171
172
173

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.dense_h_to_4h(x)
174
        x = self.gelu_impl(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
175
176
177
178
179
180
        x, _ = self.dense_4h_to_h(x)
        return x


class BloomBlock(nn.Module):

181
182
183
    def __init__(
        self,
        config: BloomConfig,
184
        cache_config: Optional[CacheConfig] = None,
185
        quant_config: Optional[QuantizationConfig] = None,
186
        prefix: str = "",
187
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
188
189
190
191
192
        super().__init__()
        hidden_size = config.hidden_size

        self.input_layernorm = nn.LayerNorm(hidden_size,
                                            eps=config.layer_norm_epsilon)
193
194
195
196
        self.self_attention = BloomAttention(config,
                                             cache_config,
                                             quant_config,
                                             prefix=f"{prefix}.self_attention")
Woosuk Kwon's avatar
Woosuk Kwon committed
197
198
        self.post_attention_layernorm = nn.LayerNorm(
            hidden_size, eps=config.layer_norm_epsilon)
199
        self.mlp = BloomMLP(config, quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
        self.apply_residual_connection_post_layernorm = (
            config.apply_residual_connection_post_layernorm)

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)

        # Layer norm post the self attention.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        # Self attention.
        attention_output = self.self_attention(
            position_ids=position_ids,
            hidden_states=layernorm_output,
        )
        attention_output = attention_output + residual
        layernorm_output = self.post_attention_layernorm(attention_output)

        # Get residual
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = attention_output

        # MLP.
        output = self.mlp(layernorm_output) + residual
        return output


236
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
237
238
class BloomModel(nn.Module):

239
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
240
        super().__init__()
241
242
243
244

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
245
        self.config = config
246

Woosuk Kwon's avatar
Woosuk Kwon committed
247
248
249
250
        self.embed_dim = config.hidden_size

        # Embedding + LN Embedding
        self.word_embeddings = VocabParallelEmbedding(
251
252
253
            config.vocab_size,
            self.embed_dim,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
254
255
256
257
        self.word_embeddings_layernorm = nn.LayerNorm(
            self.embed_dim, eps=config.layer_norm_epsilon)

        # Transformer blocks
258
259
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
260
261
            lambda prefix: BloomBlock(
                config, cache_config, quant_config, prefix=prefix),
262
            prefix=f"{prefix}.h")
Woosuk Kwon's avatar
Woosuk Kwon committed
263
264
265

        # Final Layer Norm
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
266
267
268
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
zhuwenwen's avatar
zhuwenwen committed
269
270
271
272
273
274
275
276
277
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
              
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
Woosuk Kwon's avatar
Woosuk Kwon committed
278

279
280
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.word_embeddings_layernorm(self.word_embeddings(input_ids))
Woosuk Kwon's avatar
Woosuk Kwon committed
281
282
283
284
285

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
286
        intermediate_tensors: Optional[IntermediateTensors],
287
        inputs_embeds: Optional[torch.Tensor] = None,
288
289
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
290
291
292
293
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
294
295
296
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
297
        for layer in islice(self.h, self.start_layer, self.end_layer):
298
            hidden_states = layer(position_ids, hidden_states)
299
300
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Woosuk Kwon committed
301
302
303
        hidden_states = self.ln_f(hidden_states)
        return hidden_states

304
305
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
306
        params_dict = dict(self.named_parameters(remove_duplicate=False))
307
        loaded_params: set[str] = set()
308
        for name, loaded_weight in weights:
309
310
            if is_pp_missing_parameter(name, self):
                continue
311
            param = params_dict[name]
Woosuk Kwon's avatar
Woosuk Kwon committed
312
313

            if "query_key_value" in name:
314
315
316
                # NOTE: BLOOM's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
Woosuk Kwon's avatar
Woosuk Kwon committed
317
                # Thus, we need weight conversion.
318
                output_dim = getattr(param, "output_dim", None)
Woosuk Kwon's avatar
Woosuk Kwon committed
319
                num_heads = self.config.num_attention_heads
320
321
322
323
324
325
326
327
328
329
330
331
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
332
            loaded_params.add(name)
333

334
335
336
337
338
339
340
341
342
        if self.use_llama_nn and self.quant_method is None:
            lay_key_words = [
                "self_attention.query_key_value.weight",
                "self_attention.dense.weight",
                "mlp.dense_h_to_4h.weight",
                "mlp.dense_4h_to_h.weight"
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
343
344
            # lay_qkv_words = ["self_attention.query_key_value.weight"]   
            # qkv_words = "|".join(lay_qkv_words)  
345
            
zhuwenwen's avatar
zhuwenwen committed
346
347
            # lay_qkv_bias_words = ["self_attention.query_key_value.bias"]   
            # qkv_bias_words = "|".join(lay_qkv_bias_words) 
348
            
zhuwenwen's avatar
zhuwenwen committed
349
350
            for layername in loaded_params:
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
351
352
                # if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                #     weight.data = pad_weight(weight.data, 32)
353
354
355
                    
                matches = re.findall(combined_words, layername)
                if matches:   
zhuwenwen's avatar
zhuwenwen committed
356
357
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
358
                    
zhuwenwen's avatar
zhuwenwen committed
359
360
361
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
362
363
364
365
366
367
368
369
                        
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
370
        return loaded_params
371

Woosuk Kwon's avatar
Woosuk Kwon committed
372

373
class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant):
Woosuk Kwon's avatar
Woosuk Kwon committed
374

375
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
376
        super().__init__()
377
378
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
379
        self.config = config
380
        self.quant_config = quant_config
381
382
383
        self.transformer = BloomModel(vllm_config=vllm_config,
                                      prefix=maybe_prefix(
                                          prefix, "transformer"))
384
385
386
387
388
389
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.word_embeddings
        else:
            self.lm_head = ParallelLMHead(self.config.vocab_size,
                                          self.config.hidden_size)

390
        self.logits_processor = LogitsProcessor(config.vocab_size)
391
392
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
zhuwenwen's avatar
zhuwenwen committed
393
        
Woosuk Kwon's avatar
Woosuk Kwon committed
394

395
396
397
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
398
399
400
401
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
402
        intermediate_tensors: Optional[IntermediateTensors] = None,
403
        inputs_embeds: Optional[torch.Tensor] = None,
404
    ) -> Union[torch.Tensor, IntermediateTensors]:
405
406
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
407
408
        return hidden_states

409
410
411
412
413
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
414
        logits = self.logits_processor(self.lm_head, hidden_states,
415
416
417
                                       sampling_metadata)
        return logits

418
419
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
420
421
422
423
424
425
426
427
428
429
430
431
        loader = AutoWeightsLoader(self, skip_prefixes=["lm_head.weight"])
        weights = _add_transformer_prefix(weights)
        return loader.load_weights(weights)


def _add_transformer_prefix(
    weights: Iterable[tuple[str, torch.Tensor]]
) -> Iterable[tuple[str, torch.Tensor]]:
    for name, tensor in weights:
        if not name.startswith('transformer.'):
            name = 'transformer.' + name
        yield name, tensor