"vscode:/vscode.git/clone" did not exist on "0313cf854d87a41c84efb69e89a79cd7b5897593"
bloom.py 16.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
6
# Copyright 2023 The vLLM team.
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only BLOOM model compatible with HuggingFace weights."""
21

Woosuk Kwon's avatar
Woosuk Kwon committed
22
import math
23
from collections.abc import Iterable
24
from itertools import islice
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28

import torch
from torch import nn
from transformers import BloomConfig
29
30
import os
import re
Woosuk Kwon's avatar
Woosuk Kwon committed
31

32
from vllm.attention.layer import Attention
33
from vllm.compilation.decorators import support_torch_compile
34
from vllm.config import CacheConfig, VllmConfig
35
36
37
38
39
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
Woosuk Kwon's avatar
Woosuk Kwon committed
40
from vllm.model_executor.layers.activation import get_act_fn
41
42
43
44
45
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
46
from vllm.model_executor.layers.logits_processor import LogitsProcessor
47
from vllm.model_executor.layers.quantization import QuantizationConfig
48
from vllm.model_executor.layers.vocab_parallel_embedding import (
49
50
51
    ParallelLMHead,
    VocabParallelEmbedding,
)
52
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
53
from vllm.sequence import IntermediateTensors
54
55
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
Woosuk Kwon's avatar
Woosuk Kwon committed
56

57
from .interfaces import SupportsPP, SupportsQuant
58
59
60
61
62
63
64
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
65

Woosuk Kwon's avatar
Woosuk Kwon committed
66
67

def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
68
    closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads))
Woosuk Kwon's avatar
Woosuk Kwon committed
69
    base = torch.tensor(
70
        2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
Woosuk Kwon's avatar
Woosuk Kwon committed
71
72
73
74
75
76
77
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
78
            2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
            dtype=torch.float32,
        )
81
82
83
84
85
86
87
        num_remaining_heads = min(
            closest_power_of_2, total_num_heads - closest_power_of_2
        )
        extra_powers = torch.arange(
            start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32
        )
        slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
Woosuk Kwon's avatar
Woosuk Kwon committed
88
89
90
91
    return slopes


class BloomAttention(nn.Module):
92
93
94
    def __init__(
        self,
        config: BloomConfig,
95
96
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
97
        prefix: str = "",
98
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
99
100
101
102
103
104
105
106
107
108
        super().__init__()
        self.hidden_size = config.hidden_size
        self.total_num_heads = config.n_head
        self.head_dim = self.hidden_size // self.total_num_heads
        assert self.head_dim * self.total_num_heads == self.hidden_size

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

109
        self.query_key_value = QKVParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
110
            self.hidden_size,
111
112
            self.head_dim,
            self.total_num_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
113
            bias=True,
114
            quant_config=quant_config,
115
            prefix=f"{prefix}.query_key_value",
Woosuk Kwon's avatar
Woosuk Kwon committed
116
117
118
119
120
        )
        self.dense = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
121
            quant_config=quant_config,
122
            prefix=f"{prefix}.dense",
Woosuk Kwon's avatar
Woosuk Kwon committed
123
124
125
126
127
128
129
130
131
132
        )

        # Create the alibi slopes and slice them.
        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
        alibi_slopes = _get_alibi_slopes(self.total_num_heads)
        alibi_slopes = alibi_slopes[head_start:head_end].tolist()

        scaling = self.head_dim**-0.5
133
134
135
136
137
138
139
140
141
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            scaling,
            alibi_slopes=alibi_slopes,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
142
143
144
145
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
146
147
148
149
150
151
152
153

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        del position_ids  # Unused.
        qkv, _ = self.query_key_value(hidden_states)
zhuwenwen's avatar
zhuwenwen committed
154
155
        # if os.environ.get('FA_PAD') == '1' and self.quant_method is None:
        #     qkv = qkv[...,:-32]
Woosuk Kwon's avatar
Woosuk Kwon committed
156
        q, k, v = qkv.chunk(chunks=3, dim=-1)
157
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
158
159
160
161
162
        output, _ = self.dense(attn_output)
        return output


class BloomMLP(nn.Module):
163
164
165
    def __init__(
        self,
        config: BloomConfig,
166
        quant_config: QuantizationConfig | None = None,
167
        prefix: str = "",
168
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
169
170
        super().__init__()
        hidden_size = config.hidden_size
171
172
173
        self.dense_h_to_4h = ColumnParallelLinear(
            hidden_size,
            4 * hidden_size,
174
            quant_config=quant_config,
175
            prefix=f"{prefix}.dense_h_to_4h",
176
        )
177
        self.gelu_impl = get_act_fn("gelu")
178
179
180
        self.dense_4h_to_h = RowParallelLinear(
            4 * hidden_size,
            hidden_size,
181
            quant_config=quant_config,
182
            prefix=f"{prefix}.dense_4h_to_h",
183
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
184
185
186

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.dense_h_to_4h(x)
187
        x = self.gelu_impl(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
188
189
190
191
192
        x, _ = self.dense_4h_to_h(x)
        return x


class BloomBlock(nn.Module):
193
194
195
    def __init__(
        self,
        config: BloomConfig,
196
197
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
198
        prefix: str = "",
199
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
200
201
202
        super().__init__()
        hidden_size = config.hidden_size

203
204
205
206
        self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.self_attention = BloomAttention(
            config, cache_config, quant_config, prefix=f"{prefix}.self_attention"
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
207
        self.post_attention_layernorm = nn.LayerNorm(
208
209
            hidden_size, eps=config.layer_norm_epsilon
        )
210
        self.mlp = BloomMLP(config, quant_config, prefix=f"{prefix}.mlp")
Woosuk Kwon's avatar
Woosuk Kwon committed
211
        self.apply_residual_connection_post_layernorm = (
212
213
            config.apply_residual_connection_post_layernorm
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)

        # Layer norm post the self attention.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        # Self attention.
        attention_output = self.self_attention(
            position_ids=position_ids,
            hidden_states=layernorm_output,
        )
        attention_output = attention_output + residual
        layernorm_output = self.post_attention_layernorm(attention_output)

        # Get residual
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = attention_output

        # MLP.
        output = self.mlp(layernorm_output) + residual
        return output


248
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
249
class BloomModel(nn.Module):
250
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
251
        super().__init__()
252
253
254
255

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
256
        self.config = config
257

Woosuk Kwon's avatar
Woosuk Kwon committed
258
259
260
261
        self.embed_dim = config.hidden_size

        # Embedding + LN Embedding
        self.word_embeddings = VocabParallelEmbedding(
262
263
264
            config.vocab_size,
            self.embed_dim,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
265
        self.word_embeddings_layernorm = nn.LayerNorm(
266
267
            self.embed_dim, eps=config.layer_norm_epsilon
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
268
269

        # Transformer blocks
270
271
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
272
            lambda prefix: BloomBlock(
273
274
275
276
                config, cache_config, quant_config, prefix=prefix
            ),
            prefix=f"{prefix}.h",
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
277
278
279

        # Final Layer Norm
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
280
281
282
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.hidden_size
        )
zhuwenwen's avatar
zhuwenwen committed
283
284
285
286
287
288
289
290
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
              
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
Woosuk Kwon's avatar
Woosuk Kwon committed
291

292
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
293
        return self.word_embeddings(input_ids)
Woosuk Kwon's avatar
Woosuk Kwon committed
294
295
296

    def forward(
        self,
297
        input_ids: torch.Tensor | None,
Woosuk Kwon's avatar
Woosuk Kwon committed
298
        position_ids: torch.Tensor,
299
300
301
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
302
        if get_pp_group().is_first_rank:
303
304
305
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
306
                hidden_states = self.embed_input_ids(input_ids)
307
            hidden_states = self.word_embeddings_layernorm(hidden_states)
308
309
310
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
311
        for layer in islice(self.h, self.start_layer, self.end_layer):
312
            hidden_states = layer(position_ids, hidden_states)
313
314
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Woosuk Kwon committed
315
316
317
        hidden_states = self.ln_f(hidden_states)
        return hidden_states

318
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
319
        params_dict = dict(self.named_parameters(remove_duplicate=False))
320
        loaded_params: set[str] = set()
321
        for name, loaded_weight in weights:
322
323
            if is_pp_missing_parameter(name, self):
                continue
324
            param = params_dict[name]
Woosuk Kwon's avatar
Woosuk Kwon committed
325
326

            if "query_key_value" in name:
327
328
329
                # NOTE: BLOOM's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
Woosuk Kwon's avatar
Woosuk Kwon committed
330
                # Thus, we need weight conversion.
331
                output_dim = getattr(param, "output_dim", None)
Woosuk Kwon's avatar
Woosuk Kwon committed
332
                num_heads = self.config.num_attention_heads
333
334
335
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
336
337
338
339
340
                        loaded_weight_shape[:output_dim]
                        + (num_heads, 3, -1)
                        + loaded_weight_shape[output_dim + 1 :]
                    )
                    loaded_weight = loaded_weight.transpose(output_dim, output_dim + 1)
341
342
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

343
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
344
            weight_loader(param, loaded_weight)
345
            loaded_params.add(name)
346

347
348
349
350
351
352
353
354
355
        if self.use_llama_nn and self.quant_method is None:
            lay_key_words = [
                "self_attention.query_key_value.weight",
                "self_attention.dense.weight",
                "mlp.dense_h_to_4h.weight",
                "mlp.dense_4h_to_h.weight"
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
356
357
            # lay_qkv_words = ["self_attention.query_key_value.weight"]   
            # qkv_words = "|".join(lay_qkv_words)  
358
            
zhuwenwen's avatar
zhuwenwen committed
359
360
            # lay_qkv_bias_words = ["self_attention.query_key_value.bias"]   
            # qkv_bias_words = "|".join(lay_qkv_bias_words) 
361
            
zhuwenwen's avatar
zhuwenwen committed
362
363
            for layername in loaded_params:
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
364
365
                # if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                #     weight.data = pad_weight(weight.data, 32)
366
367
368
                    
                matches = re.findall(combined_words, layername)
                if matches:   
zhuwenwen's avatar
zhuwenwen committed
369
370
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
371
                    
zhuwenwen's avatar
zhuwenwen committed
372
373
374
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
375
376
377
378
379
380
381
382
                        
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
383
        return loaded_params
384

Woosuk Kwon's avatar
Woosuk Kwon committed
385

386
class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant):
387
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
388
        super().__init__()
389
390
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
391
        self.config = config
392
        self.quant_config = quant_config
393
394
395
        self.transformer = BloomModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
        )
396
397
398
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.word_embeddings
        else:
399
400
401
402
403
            self.lm_head = ParallelLMHead(
                self.config.vocab_size,
                self.config.hidden_size,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
404

405
        self.logits_processor = LogitsProcessor(config.vocab_size)
406
        self.make_empty_intermediate_tensors = (
407
408
            self.transformer.make_empty_intermediate_tensors
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
409

410
411
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.embed_input_ids(input_ids)
412

Woosuk Kwon's avatar
Woosuk Kwon committed
413
414
    def forward(
        self,
415
        input_ids: torch.Tensor | None,
Woosuk Kwon's avatar
Woosuk Kwon committed
416
        positions: torch.Tensor,
417
418
419
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
420
421
422
        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
423
424
        return hidden_states

425
426
427
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
428
    ) -> torch.Tensor | None:
429
        logits = self.logits_processor(self.lm_head, hidden_states)
430
431
        return logits

432
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
433
434
435
436
437
438
        loader = AutoWeightsLoader(self, skip_prefixes=["lm_head.weight"])
        weights = _add_transformer_prefix(weights)
        return loader.load_weights(weights)


def _add_transformer_prefix(
439
    weights: Iterable[tuple[str, torch.Tensor]],
440
441
) -> Iterable[tuple[str, torch.Tensor]]:
    for name, tensor in weights:
442
443
        if not name.startswith("transformer."):
            name = "transformer." + name
444
        yield name, tensor