bloom.py 15.8 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
# coding=utf-8
2
3
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
4
# Copyright 2023 The vLLM team.
Woosuk Kwon's avatar
Woosuk Kwon committed
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
18
"""Inference-only BLOOM model compatible with HuggingFace weights."""
Woosuk Kwon's avatar
Woosuk Kwon committed
19
import math
20
from typing import Iterable, List, Optional, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22
23
24

import torch
from torch import nn
from transformers import BloomConfig
25
26
import os
import re
Woosuk Kwon's avatar
Woosuk Kwon committed
27

28
from vllm.attention import Attention, AttentionMetadata
29
from vllm.compilation.decorators import support_torch_compile
30
from vllm.config import CacheConfig
31
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
32
                              get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
33
from vllm.model_executor.layers.activation import get_act_fn
34
35
36
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
37
from vllm.model_executor.layers.logits_processor import LogitsProcessor
38
from vllm.model_executor.layers.quantization import QuantizationConfig
39
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
40
from vllm.model_executor.layers.vocab_parallel_embedding import (
41
    ParallelLMHead, VocabParallelEmbedding)
42
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
43
from vllm.model_executor.sampling_metadata import SamplingMetadata
44
from vllm.sequence import IntermediateTensors
45
46
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
Woosuk Kwon's avatar
Woosuk Kwon committed
47

48
49
50
51
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
                    make_empty_intermediate_tensors_factory, make_layers)

Woosuk Kwon's avatar
Woosuk Kwon committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79

def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
    base = torch.tensor(
        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        num_remaining_heads = min(closest_power_of_2,
                                  total_num_heads - closest_power_of_2)
        extra_powers = torch.arange(start=1,
                                    end=1 + 2 * num_remaining_heads,
                                    step=2,
                                    dtype=torch.int32)
        slopes = torch.cat(
            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
    return slopes


class BloomAttention(nn.Module):

80
81
82
    def __init__(
        self,
        config: BloomConfig,
83
        cache_config: Optional[CacheConfig] = None,
84
        quant_config: Optional[QuantizationConfig] = None,
85
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
86
87
88
89
90
91
92
93
94
95
        super().__init__()
        self.hidden_size = config.hidden_size
        self.total_num_heads = config.n_head
        self.head_dim = self.hidden_size // self.total_num_heads
        assert self.head_dim * self.total_num_heads == self.hidden_size

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

96
        self.query_key_value = QKVParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
97
            self.hidden_size,
98
99
            self.head_dim,
            self.total_num_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
100
            bias=True,
101
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
102
103
104
105
106
        )
        self.dense = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
107
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
110
111
112
113
114
115
116
117
        )

        # Create the alibi slopes and slice them.
        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
        alibi_slopes = _get_alibi_slopes(self.total_num_heads)
        alibi_slopes = alibi_slopes[head_start:head_end].tolist()

        scaling = self.head_dim**-0.5
118
119
120
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scaling,
121
                              alibi_slopes=alibi_slopes,
122
123
                              cache_config=cache_config,
                              quant_config=quant_config)
124
125
126
127
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
128
129
130
131
132

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
133
134
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
135
136
137
138
    ) -> torch.Tensor:
        del position_ids  # Unused.
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
139
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
140
141
142
143
144
145
        output, _ = self.dense(attn_output)
        return output


class BloomMLP(nn.Module):

146
147
148
    def __init__(
        self,
        config: BloomConfig,
149
        quant_config: Optional[QuantizationConfig] = None,
150
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
151
152
        super().__init__()
        hidden_size = config.hidden_size
153
154
155
        self.dense_h_to_4h = ColumnParallelLinear(
            hidden_size,
            4 * hidden_size,
156
            quant_config=quant_config,
157
        )
158
        self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size)
159
160
161
        self.dense_4h_to_h = RowParallelLinear(
            4 * hidden_size,
            hidden_size,
162
            quant_config=quant_config,
163
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
164
165
166

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.dense_h_to_4h(x)
167
        x = self.gelu_impl(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
168
169
170
171
172
173
        x, _ = self.dense_4h_to_h(x)
        return x


class BloomBlock(nn.Module):

174
175
176
    def __init__(
        self,
        config: BloomConfig,
177
        cache_config: Optional[CacheConfig] = None,
178
        quant_config: Optional[QuantizationConfig] = None,
179
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
180
181
182
183
184
        super().__init__()
        hidden_size = config.hidden_size

        self.input_layernorm = nn.LayerNorm(hidden_size,
                                            eps=config.layer_norm_epsilon)
185
186
        self.self_attention = BloomAttention(config, cache_config,
                                             quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
187
188
        self.post_attention_layernorm = nn.LayerNorm(
            hidden_size, eps=config.layer_norm_epsilon)
189
        self.mlp = BloomMLP(config, quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
190
191
192
193
194
195
196
        self.apply_residual_connection_post_layernorm = (
            config.apply_residual_connection_post_layernorm)

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
197
198
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    ) -> torch.Tensor:
        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)

        # Layer norm post the self attention.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        # Self attention.
        attention_output = self.self_attention(
            position_ids=position_ids,
            hidden_states=layernorm_output,
            kv_cache=kv_cache,
214
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        )
        attention_output = attention_output + residual
        layernorm_output = self.post_attention_layernorm(attention_output)

        # Get residual
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = attention_output

        # MLP.
        output = self.mlp(layernorm_output) + residual
        return output


230
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
231
232
class BloomModel(nn.Module):

233
234
235
    def __init__(
        self,
        config: BloomConfig,
236
        cache_config: Optional[CacheConfig] = None,
237
        quant_config: Optional[QuantizationConfig] = None,
238
        prefix: str = "",
239
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
240
241
242
243
244
        super().__init__()
        self.embed_dim = config.hidden_size

        # Embedding + LN Embedding
        self.word_embeddings = VocabParallelEmbedding(
245
246
247
            config.vocab_size,
            self.embed_dim,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
248
249
250
251
        self.word_embeddings_layernorm = nn.LayerNorm(
            self.embed_dim, eps=config.layer_norm_epsilon)

        # Transformer blocks
252
253
254
255
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
            lambda prefix: BloomBlock(config, cache_config, quant_config),
            prefix=f"{prefix}.h")
Woosuk Kwon's avatar
Woosuk Kwon committed
256
257
258

        # Final Layer Norm
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
259
260
261
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
Woosuk Kwon's avatar
Woosuk Kwon committed
262
263
264
265
266

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
267
268
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
269
270
271
272
273
274
275
276
277
        intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            hidden_states = self.word_embeddings(input_ids)
            hidden_states = self.word_embeddings_layernorm(hidden_states)
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
278
279
280
281
            layer = self.h[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
282
                kv_caches[i - self.start_layer],
283
                attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
284
            )
285
286
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Woosuk Kwon committed
287
288
289
290
        hidden_states = self.ln_f(hidden_states)
        return hidden_states


291
class BloomForCausalLM(nn.Module, SupportsPP):
Woosuk Kwon's avatar
Woosuk Kwon committed
292

293
294
295
    def __init__(
        self,
        config: BloomConfig,
296
        cache_config: Optional[CacheConfig] = None,
297
        quant_config: Optional[QuantizationConfig] = None,
298
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
299
300
        super().__init__()
        self.config = config
301
        self.quant_config = quant_config
302
        self.transformer = BloomModel(config, cache_config, quant_config)
303
304
305
306
307
308
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.word_embeddings
        else:
            self.lm_head = ParallelLMHead(self.config.vocab_size,
                                          self.config.hidden_size)

309
310
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
311
312
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
313
314
315
316
317
318
319
320
321
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
              
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
Woosuk Kwon's avatar
Woosuk Kwon committed
322
323
324
325
326

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
327
328
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
329
        intermediate_tensors: Optional[IntermediateTensors] = None,
330
    ) -> Union[torch.Tensor, IntermediateTensors]:
Woosuk Kwon's avatar
Woosuk Kwon committed
331
        hidden_states = self.transformer(input_ids, positions, kv_caches,
332
                                         attn_metadata, intermediate_tensors)
333
334
        return hidden_states

335
336
337
338
339
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
340
        logits = self.logits_processor(self.lm_head, hidden_states,
341
342
343
                                       sampling_metadata)
        return logits

344
345
    def sample(
        self,
346
        logits: torch.Tensor,
347
        sampling_metadata: SamplingMetadata,
348
    ) -> Optional[SamplerOutput]:
349
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
350
351
        return next_tokens

352
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
353
        params_dict = dict(self.named_parameters(remove_duplicate=False))
354
        for name, loaded_weight in weights:
355
            if name == "lm_head.weight":
356
357
358
                continue
            if not name.startswith("transformer."):
                name = "transformer." + name
359
360
            if is_pp_missing_parameter(name, self):
                continue
361
            param = params_dict[name]
Woosuk Kwon's avatar
Woosuk Kwon committed
362
363

            if "query_key_value" in name:
364
365
366
                # NOTE: BLOOM's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
Woosuk Kwon's avatar
Woosuk Kwon committed
367
                # Thus, we need weight conversion.
368
                output_dim = getattr(param, "output_dim", None)
Woosuk Kwon's avatar
Woosuk Kwon committed
369
                num_heads = self.config.num_attention_heads
370
371
372
373
374
375
376
377
378
379
380
381
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
        
        if self.use_llama_nn and self.quant_method is None:
            lay_key_words = [
                "self_attention.query_key_value.weight",
                "self_attention.dense.weight",
                "mlp.dense_h_to_4h.weight",
                "mlp.dense_4h_to_h.weight"
            ]
            combined_words = "|".join(lay_key_words)
            
            lay_qkv_words = ["self_attention.query_key_value.weight"]   
            qkv_words = "|".join(lay_qkv_words)  
            
            lay_qkv_bias_words = ["self_attention.query_key_value.bias"]   
            qkv_bias_words = "|".join(lay_qkv_bias_words) 
            
            for layername, weight in params_dict.items():
                if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                    weight.data = pad_weight(weight.data, 32)
                    
                matches = re.findall(combined_words, layername)
                if matches:   
                    if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                        weight.data = pad_weight(weight.data, 32)  
                    
                    if self.use_fa_pad and (re.findall(qkv_words, layername)):
                        if not gemm_bank_conf(weight.data.shape[0]):
                            weight.data = pad_weight(weight.data, 32)
                        
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)