bloom.py 17.1 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
5
# Copyright 2023 The vLLM team.
Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only BLOOM model compatible with HuggingFace weights."""
Woosuk Kwon's avatar
Woosuk Kwon committed
20
import math
21
from typing import Iterable, List, Optional, Set, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
24
25

import torch
from torch import nn
from transformers import BloomConfig
26
27
import os
import re
Woosuk Kwon's avatar
Woosuk Kwon committed
28

29
from vllm.attention import Attention, AttentionMetadata
30
from vllm.compilation.decorators import support_torch_compile
31
from vllm.config import CacheConfig, VllmConfig
32
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
33
                              get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.model_executor.layers.activation import get_act_fn
35
36
37
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
38
from vllm.model_executor.layers.logits_processor import LogitsProcessor
39
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
40
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
41
from vllm.model_executor.layers.vocab_parallel_embedding import (
42
    ParallelLMHead, VocabParallelEmbedding)
43
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
44
from vllm.model_executor.sampling_metadata import SamplingMetadata
45
from vllm.sequence import IntermediateTensors
46
47
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
Woosuk Kwon's avatar
Woosuk Kwon committed
48

49
50
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
51
52
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
53

Woosuk Kwon's avatar
Woosuk Kwon committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81

def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
    base = torch.tensor(
        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        num_remaining_heads = min(closest_power_of_2,
                                  total_num_heads - closest_power_of_2)
        extra_powers = torch.arange(start=1,
                                    end=1 + 2 * num_remaining_heads,
                                    step=2,
                                    dtype=torch.int32)
        slopes = torch.cat(
            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
    return slopes


class BloomAttention(nn.Module):

82
83
84
    def __init__(
        self,
        config: BloomConfig,
85
        cache_config: Optional[CacheConfig] = None,
86
        quant_config: Optional[QuantizationConfig] = None,
87
        prefix: str = "",
88
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
91
92
93
94
95
96
97
98
        super().__init__()
        self.hidden_size = config.hidden_size
        self.total_num_heads = config.n_head
        self.head_dim = self.hidden_size // self.total_num_heads
        assert self.head_dim * self.total_num_heads == self.hidden_size

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

99
        self.query_key_value = QKVParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
100
            self.hidden_size,
101
102
            self.head_dim,
            self.total_num_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
103
            bias=True,
104
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
105
106
107
108
109
        )
        self.dense = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
110
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
111
112
113
114
115
116
117
118
119
120
        )

        # Create the alibi slopes and slice them.
        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
        alibi_slopes = _get_alibi_slopes(self.total_num_heads)
        alibi_slopes = alibi_slopes[head_start:head_end].tolist()

        scaling = self.head_dim**-0.5
121
122
123
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scaling,
124
                              alibi_slopes=alibi_slopes,
125
                              cache_config=cache_config,
126
127
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
zhuwenwen's avatar
zhuwenwen committed
128
        
129
130
131
132
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
133
134
135
136
137

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
138
139
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
140
141
142
    ) -> torch.Tensor:
        del position_ids  # Unused.
        qkv, _ = self.query_key_value(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
143
144
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
Woosuk Kwon's avatar
Woosuk Kwon committed
145
        q, k, v = qkv.chunk(chunks=3, dim=-1)
146
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
147
148
149
150
151
152
        output, _ = self.dense(attn_output)
        return output


class BloomMLP(nn.Module):

153
154
155
    def __init__(
        self,
        config: BloomConfig,
156
        quant_config: Optional[QuantizationConfig] = None,
157
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
158
159
        super().__init__()
        hidden_size = config.hidden_size
160
161
162
        self.dense_h_to_4h = ColumnParallelLinear(
            hidden_size,
            4 * hidden_size,
163
            quant_config=quant_config,
164
        )
165
        self.gelu_impl = get_act_fn("gelu")
166
167
168
        self.dense_4h_to_h = RowParallelLinear(
            4 * hidden_size,
            hidden_size,
169
            quant_config=quant_config,
170
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
171
172
173

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.dense_h_to_4h(x)
174
        x = self.gelu_impl(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
175
176
177
178
179
180
        x, _ = self.dense_4h_to_h(x)
        return x


class BloomBlock(nn.Module):

181
182
183
    def __init__(
        self,
        config: BloomConfig,
184
        cache_config: Optional[CacheConfig] = None,
185
        quant_config: Optional[QuantizationConfig] = None,
186
        prefix: str = "",
187
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
188
189
190
191
192
        super().__init__()
        hidden_size = config.hidden_size

        self.input_layernorm = nn.LayerNorm(hidden_size,
                                            eps=config.layer_norm_epsilon)
193
194
195
196
        self.self_attention = BloomAttention(config,
                                             cache_config,
                                             quant_config,
                                             prefix=f"{prefix}.self_attention")
Woosuk Kwon's avatar
Woosuk Kwon committed
197
198
        self.post_attention_layernorm = nn.LayerNorm(
            hidden_size, eps=config.layer_norm_epsilon)
199
        self.mlp = BloomMLP(config, quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
200
201
202
203
204
205
206
        self.apply_residual_connection_post_layernorm = (
            config.apply_residual_connection_post_layernorm)

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
207
208
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    ) -> torch.Tensor:
        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)

        # Layer norm post the self attention.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        # Self attention.
        attention_output = self.self_attention(
            position_ids=position_ids,
            hidden_states=layernorm_output,
            kv_cache=kv_cache,
224
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
        )
        attention_output = attention_output + residual
        layernorm_output = self.post_attention_layernorm(attention_output)

        # Get residual
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = attention_output

        # MLP.
        output = self.mlp(layernorm_output) + residual
        return output


240
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
241
242
class BloomModel(nn.Module):

243
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
244
        super().__init__()
245
246
247
248
249

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
250
251
252
253
        self.embed_dim = config.hidden_size

        # Embedding + LN Embedding
        self.word_embeddings = VocabParallelEmbedding(
254
255
256
            config.vocab_size,
            self.embed_dim,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
257
258
259
260
        self.word_embeddings_layernorm = nn.LayerNorm(
            self.embed_dim, eps=config.layer_norm_epsilon)

        # Transformer blocks
261
262
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
263
264
            lambda prefix: BloomBlock(
                config, cache_config, quant_config, prefix=prefix),
265
            prefix=f"{prefix}.h")
Woosuk Kwon's avatar
Woosuk Kwon committed
266
267
268

        # Final Layer Norm
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
269
270
271
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
Woosuk Kwon's avatar
Woosuk Kwon committed
272

273
274
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.word_embeddings_layernorm(self.word_embeddings(input_ids))
Woosuk Kwon's avatar
Woosuk Kwon committed
275
276
277
278
279

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
280
281
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
282
        intermediate_tensors: Optional[IntermediateTensors],
283
        inputs_embeds: Optional[torch.Tensor] = None,
284
285
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
286
287
288
289
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
290
291
292
293
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
294
295
296
297
            layer = self.h[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
298
                kv_caches[i - self.start_layer],
299
                attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
300
            )
301
302
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Woosuk Kwon committed
303
304
305
306
        hidden_states = self.ln_f(hidden_states)
        return hidden_states


307
class BloomForCausalLM(nn.Module, SupportsPP):
Woosuk Kwon's avatar
Woosuk Kwon committed
308

309
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
310
        super().__init__()
311
312
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
313
        self.config = config
314
        self.quant_config = quant_config
315
316
317
        self.transformer = BloomModel(vllm_config=vllm_config,
                                      prefix=maybe_prefix(
                                          prefix, "transformer"))
318
319
320
321
322
323
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.word_embeddings
        else:
            self.lm_head = ParallelLMHead(self.config.vocab_size,
                                          self.config.hidden_size)

324
        self.logits_processor = LogitsProcessor(config.vocab_size)
zhuwenwen's avatar
zhuwenwen committed
325

Joe Runde's avatar
Joe Runde committed
326
        self.sampler = get_sampler()
327
328
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
329
330
331
332
333
334
335
336
337
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
              
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
Woosuk Kwon's avatar
Woosuk Kwon committed
338

339
340
341
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
342
343
344
345
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
346
347
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
348
        intermediate_tensors: Optional[IntermediateTensors] = None,
349
        inputs_embeds: Optional[torch.Tensor] = None,
350
    ) -> Union[torch.Tensor, IntermediateTensors]:
Woosuk Kwon's avatar
Woosuk Kwon committed
351
        hidden_states = self.transformer(input_ids, positions, kv_caches,
352
353
                                         attn_metadata, intermediate_tensors,
                                         inputs_embeds)
354
355
        return hidden_states

356
357
358
359
360
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
361
        logits = self.logits_processor(self.lm_head, hidden_states,
362
363
364
                                       sampling_metadata)
        return logits

365
366
    def sample(
        self,
367
        logits: torch.Tensor,
368
        sampling_metadata: SamplingMetadata,
369
    ) -> Optional[SamplerOutput]:
370
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
371
372
        return next_tokens

373
374
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
375
        params_dict = dict(self.named_parameters(remove_duplicate=False))
376
        loaded_params: Set[str] = set()
377
        for name, loaded_weight in weights:
378
            if name == "lm_head.weight":
379
380
381
                continue
            if not name.startswith("transformer."):
                name = "transformer." + name
382
383
            if is_pp_missing_parameter(name, self):
                continue
384
            param = params_dict[name]
Woosuk Kwon's avatar
Woosuk Kwon committed
385
386

            if "query_key_value" in name:
387
388
389
                # NOTE: BLOOM's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
Woosuk Kwon's avatar
Woosuk Kwon committed
390
                # Thus, we need weight conversion.
391
                output_dim = getattr(param, "output_dim", None)
Woosuk Kwon's avatar
Woosuk Kwon committed
392
                num_heads = self.config.num_attention_heads
393
394
395
396
397
398
399
400
401
402
403
404
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
405
            loaded_params.add(name)
406
407
408
409
410
411
412
413
414
415
        
        if self.use_llama_nn and self.quant_method is None:
            lay_key_words = [
                "self_attention.query_key_value.weight",
                "self_attention.dense.weight",
                "mlp.dense_h_to_4h.weight",
                "mlp.dense_4h_to_h.weight"
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
416
417
            # lay_qkv_words = ["self_attention.query_key_value.weight"]   
            # qkv_words = "|".join(lay_qkv_words)  
418
            
zhuwenwen's avatar
zhuwenwen committed
419
420
            # lay_qkv_bias_words = ["self_attention.query_key_value.bias"]   
            # qkv_bias_words = "|".join(lay_qkv_bias_words) 
421
            
zhuwenwen's avatar
zhuwenwen committed
422
423
            for layername in loaded_params:
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
424
425
                # if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                #     weight.data = pad_weight(weight.data, 32)
426
427
428
                    
                matches = re.findall(combined_words, layername)
                if matches:   
zhuwenwen's avatar
zhuwenwen committed
429
430
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
431
                    
zhuwenwen's avatar
zhuwenwen committed
432
433
434
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
435
436
437
438
439
440
441
442
                        
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
zhuwenwen's avatar
zhuwenwen committed
443
                    
444
        return loaded_params
445