"docs/vscode:/vscode.git/clone" did not exist on "9e14887ff8829422af025c80a62a30cc9202bea8"
bloom.py 13.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
6
# Copyright 2023 The vLLM team.
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
20
"""Inference-only BLOOM model compatible with HuggingFace weights."""
21

Woosuk Kwon's avatar
Woosuk Kwon committed
22
import math
23
from collections.abc import Iterable
24
from itertools import islice
25
from typing import Optional, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
29
30

import torch
from torch import nn
from transformers import BloomConfig

31
from vllm.attention import Attention
32
from vllm.compilation.decorators import support_torch_compile
33
from vllm.config import CacheConfig, VllmConfig
34
35
36
37
38
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
Woosuk Kwon's avatar
Woosuk Kwon committed
39
from vllm.model_executor.layers.activation import get_act_fn
40
41
42
43
44
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
45
from vllm.model_executor.layers.logits_processor import LogitsProcessor
46
from vllm.model_executor.layers.quantization import QuantizationConfig
47
from vllm.model_executor.layers.vocab_parallel_embedding import (
48
49
50
    ParallelLMHead,
    VocabParallelEmbedding,
)
51
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
52
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
53

54
from .interfaces import SupportsPP, SupportsQuant
55
56
57
58
59
60
61
from .utils import (
    AutoWeightsLoader,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
62

Woosuk Kwon's avatar
Woosuk Kwon committed
63
64

def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
65
    closest_power_of_2 = 2 ** math.floor(math.log2(total_num_heads))
Woosuk Kwon's avatar
Woosuk Kwon committed
66
    base = torch.tensor(
67
        2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))),
Woosuk Kwon's avatar
Woosuk Kwon committed
68
69
70
71
72
73
74
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
75
            2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))),
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77
            dtype=torch.float32,
        )
78
79
80
81
82
83
84
        num_remaining_heads = min(
            closest_power_of_2, total_num_heads - closest_power_of_2
        )
        extra_powers = torch.arange(
            start=1, end=1 + 2 * num_remaining_heads, step=2, dtype=torch.int32
        )
        slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
87
88
    return slopes


class BloomAttention(nn.Module):
89
90
91
    def __init__(
        self,
        config: BloomConfig,
92
        cache_config: Optional[CacheConfig] = None,
93
        quant_config: Optional[QuantizationConfig] = None,
94
        prefix: str = "",
95
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
98
99
100
101
102
103
104
105
        super().__init__()
        self.hidden_size = config.hidden_size
        self.total_num_heads = config.n_head
        self.head_dim = self.hidden_size // self.total_num_heads
        assert self.head_dim * self.total_num_heads == self.hidden_size

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

106
        self.query_key_value = QKVParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
107
            self.hidden_size,
108
109
            self.head_dim,
            self.total_num_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
110
            bias=True,
111
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
112
113
114
115
116
        )
        self.dense = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
117
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
118
119
120
121
122
123
124
125
126
127
        )

        # Create the alibi slopes and slice them.
        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
        alibi_slopes = _get_alibi_slopes(self.total_num_heads)
        alibi_slopes = alibi_slopes[head_start:head_end].tolist()

        scaling = self.head_dim**-0.5
128
129
130
131
132
133
134
135
136
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            scaling,
            alibi_slopes=alibi_slopes,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
137
138
139
140
141
142
143
144
145

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        del position_ids  # Unused.
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
146
        attn_output = self.attn(q, k, v)
Woosuk Kwon's avatar
Woosuk Kwon committed
147
148
149
150
151
        output, _ = self.dense(attn_output)
        return output


class BloomMLP(nn.Module):
152
153
154
    def __init__(
        self,
        config: BloomConfig,
155
        quant_config: Optional[QuantizationConfig] = None,
156
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
157
158
        super().__init__()
        hidden_size = config.hidden_size
159
160
161
        self.dense_h_to_4h = ColumnParallelLinear(
            hidden_size,
            4 * hidden_size,
162
            quant_config=quant_config,
163
        )
164
        self.gelu_impl = get_act_fn("gelu")
165
166
167
        self.dense_4h_to_h = RowParallelLinear(
            4 * hidden_size,
            hidden_size,
168
            quant_config=quant_config,
169
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
170
171
172

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.dense_h_to_4h(x)
173
        x = self.gelu_impl(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
174
175
176
177
178
        x, _ = self.dense_4h_to_h(x)
        return x


class BloomBlock(nn.Module):
179
180
181
    def __init__(
        self,
        config: BloomConfig,
182
        cache_config: Optional[CacheConfig] = None,
183
        quant_config: Optional[QuantizationConfig] = None,
184
        prefix: str = "",
185
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
186
187
188
        super().__init__()
        hidden_size = config.hidden_size

189
190
191
192
        self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.self_attention = BloomAttention(
            config, cache_config, quant_config, prefix=f"{prefix}.self_attention"
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
193
        self.post_attention_layernorm = nn.LayerNorm(
194
195
            hidden_size, eps=config.layer_norm_epsilon
        )
196
        self.mlp = BloomMLP(config, quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
197
        self.apply_residual_connection_post_layernorm = (
198
199
            config.apply_residual_connection_post_layernorm
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)

        # Layer norm post the self attention.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        # Self attention.
        attention_output = self.self_attention(
            position_ids=position_ids,
            hidden_states=layernorm_output,
        )
        attention_output = attention_output + residual
        layernorm_output = self.post_attention_layernorm(attention_output)

        # Get residual
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = attention_output

        # MLP.
        output = self.mlp(layernorm_output) + residual
        return output


234
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
235
class BloomModel(nn.Module):
236
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
237
        super().__init__()
238
239
240
241

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
242
        self.config = config
243

Woosuk Kwon's avatar
Woosuk Kwon committed
244
245
246
247
        self.embed_dim = config.hidden_size

        # Embedding + LN Embedding
        self.word_embeddings = VocabParallelEmbedding(
248
249
250
            config.vocab_size,
            self.embed_dim,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
251
        self.word_embeddings_layernorm = nn.LayerNorm(
252
253
            self.embed_dim, eps=config.layer_norm_epsilon
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
254
255

        # Transformer blocks
256
257
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
258
            lambda prefix: BloomBlock(
259
260
261
262
                config, cache_config, quant_config, prefix=prefix
            ),
            prefix=f"{prefix}.h",
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
263
264
265

        # Final Layer Norm
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
266
267
268
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.hidden_size
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
269

270
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
271
        return self.word_embeddings(input_ids)
272

Woosuk Kwon's avatar
Woosuk Kwon committed
273
274
275
276
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
277
        intermediate_tensors: Optional[IntermediateTensors],
278
        inputs_embeds: Optional[torch.Tensor] = None,
279
280
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
281
282
283
284
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
285
            hidden_states = self.word_embeddings_layernorm(hidden_states)
286
287
288
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
289
        for layer in islice(self.h, self.start_layer, self.end_layer):
290
            hidden_states = layer(position_ids, hidden_states)
291
292
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Woosuk Kwon committed
293
294
295
        hidden_states = self.ln_f(hidden_states)
        return hidden_states

296
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if is_pp_missing_parameter(name, self):
                continue
            param = params_dict[name]

            if "query_key_value" in name:
                # NOTE: BLOOM's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
                # Thus, we need weight conversion.
                output_dim = getattr(param, "output_dim", None)
                num_heads = self.config.num_attention_heads
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
314
315
316
317
318
                        loaded_weight_shape[:output_dim]
                        + (num_heads, 3, -1)
                        + loaded_weight_shape[output_dim + 1 :]
                    )
                    loaded_weight = loaded_weight.transpose(output_dim, output_dim + 1)
319
320
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

321
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
322
323
324
325
326
            weight_loader(param, loaded_weight)
            loaded_params.add(name)

        return loaded_params

Woosuk Kwon's avatar
Woosuk Kwon committed
327

328
class BloomForCausalLM(nn.Module, SupportsPP, SupportsQuant):
329
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
330
        super().__init__()
331
332
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
333
        self.config = config
334
        self.quant_config = quant_config
335
336
337
        self.transformer = BloomModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
        )
338
339
340
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.word_embeddings
        else:
341
342
343
344
345
            self.lm_head = ParallelLMHead(
                self.config.vocab_size,
                self.config.hidden_size,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
346

347
        self.logits_processor = LogitsProcessor(config.vocab_size)
348
        self.make_empty_intermediate_tensors = (
349
350
            self.transformer.make_empty_intermediate_tensors
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
351

352
353
354
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
355
356
357
358
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
359
        intermediate_tensors: Optional[IntermediateTensors] = None,
360
        inputs_embeds: Optional[torch.Tensor] = None,
361
    ) -> Union[torch.Tensor, IntermediateTensors]:
362
363
364
        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
365
366
        return hidden_states

367
368
369
370
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
371
        logits = self.logits_processor(self.lm_head, hidden_states)
372
373
        return logits

374
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
375
376
377
378
379
380
        loader = AutoWeightsLoader(self, skip_prefixes=["lm_head.weight"])
        weights = _add_transformer_prefix(weights)
        return loader.load_weights(weights)


def _add_transformer_prefix(
381
    weights: Iterable[tuple[str, torch.Tensor]],
382
383
) -> Iterable[tuple[str, torch.Tensor]]:
    for name, tensor in weights:
384
385
        if not name.startswith("transformer."):
            name = "transformer." + name
386
        yield name, tensor