bloom.py 14.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py
5
# Copyright 2023 The vLLM team.
Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2022 HuggingFace Inc. team and BigScience workshop.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Woosuk Kwon's avatar
Woosuk Kwon committed
19
"""Inference-only BLOOM model compatible with HuggingFace weights."""
Woosuk Kwon's avatar
Woosuk Kwon committed
20
import math
21
from typing import Iterable, List, Optional, Set, Tuple, Union
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
24
25
26

import torch
from torch import nn
from transformers import BloomConfig

27
from vllm.attention import Attention, AttentionMetadata
28
from vllm.compilation.decorators import support_torch_compile
29
from vllm.config import CacheConfig, VllmConfig
30
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
31
                              get_tensor_model_parallel_world_size)
Woosuk Kwon's avatar
Woosuk Kwon committed
32
from vllm.model_executor.layers.activation import get_act_fn
33
34
35
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
36
from vllm.model_executor.layers.logits_processor import LogitsProcessor
37
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
38
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
39
from vllm.model_executor.layers.vocab_parallel_embedding import (
40
    ParallelLMHead, VocabParallelEmbedding)
41
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
42
from vllm.model_executor.sampling_metadata import SamplingMetadata
43
from vllm.sequence import IntermediateTensors
Woosuk Kwon's avatar
Woosuk Kwon committed
44

45
46
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
47
48
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
49

Woosuk Kwon's avatar
Woosuk Kwon committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77

def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
    closest_power_of_2 = 2**math.floor(math.log2(total_num_heads))
    base = torch.tensor(
        2**(-(2**-(math.log2(closest_power_of_2) - 3))),
        dtype=torch.float32,
    )
    powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32)
    slopes = torch.pow(base, powers)

    if closest_power_of_2 != total_num_heads:
        extra_base = torch.tensor(
            2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
            dtype=torch.float32,
        )
        num_remaining_heads = min(closest_power_of_2,
                                  total_num_heads - closest_power_of_2)
        extra_powers = torch.arange(start=1,
                                    end=1 + 2 * num_remaining_heads,
                                    step=2,
                                    dtype=torch.int32)
        slopes = torch.cat(
            [slopes, torch.pow(extra_base, extra_powers)], dim=0)
    return slopes


class BloomAttention(nn.Module):

78
79
80
    def __init__(
        self,
        config: BloomConfig,
81
        cache_config: Optional[CacheConfig] = None,
82
        quant_config: Optional[QuantizationConfig] = None,
83
        prefix: str = "",
84
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
87
88
89
90
91
92
93
94
        super().__init__()
        self.hidden_size = config.hidden_size
        self.total_num_heads = config.n_head
        self.head_dim = self.hidden_size // self.total_num_heads
        assert self.head_dim * self.total_num_heads == self.hidden_size

        tp_world_size = get_tensor_model_parallel_world_size()
        assert self.total_num_heads % tp_world_size == 0
        self.num_heads = self.total_num_heads // tp_world_size

95
        self.query_key_value = QKVParallelLinear(
Woosuk Kwon's avatar
Woosuk Kwon committed
96
            self.hidden_size,
97
98
            self.head_dim,
            self.total_num_heads,
Woosuk Kwon's avatar
Woosuk Kwon committed
99
            bias=True,
100
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
101
102
103
104
105
        )
        self.dense = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
106
            quant_config=quant_config,
Woosuk Kwon's avatar
Woosuk Kwon committed
107
108
109
110
111
112
113
114
115
116
        )

        # Create the alibi slopes and slice them.
        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
        alibi_slopes = _get_alibi_slopes(self.total_num_heads)
        alibi_slopes = alibi_slopes[head_start:head_end].tolist()

        scaling = self.head_dim**-0.5
117
118
119
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scaling,
120
                              alibi_slopes=alibi_slopes,
121
                              cache_config=cache_config,
122
123
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
Woosuk Kwon's avatar
Woosuk Kwon committed
124
125
126
127
128

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
129
130
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
131
132
133
134
    ) -> torch.Tensor:
        del position_ids  # Unused.
        qkv, _ = self.query_key_value(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
135
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
136
137
138
139
140
141
        output, _ = self.dense(attn_output)
        return output


class BloomMLP(nn.Module):

142
143
144
    def __init__(
        self,
        config: BloomConfig,
145
        quant_config: Optional[QuantizationConfig] = None,
146
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
147
148
        super().__init__()
        hidden_size = config.hidden_size
149
150
151
        self.dense_h_to_4h = ColumnParallelLinear(
            hidden_size,
            4 * hidden_size,
152
            quant_config=quant_config,
153
        )
154
        self.gelu_impl = get_act_fn("gelu")
155
156
157
        self.dense_4h_to_h = RowParallelLinear(
            4 * hidden_size,
            hidden_size,
158
            quant_config=quant_config,
159
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
160
161
162

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x, _ = self.dense_h_to_4h(x)
163
        x = self.gelu_impl(x)
Woosuk Kwon's avatar
Woosuk Kwon committed
164
165
166
167
168
169
        x, _ = self.dense_4h_to_h(x)
        return x


class BloomBlock(nn.Module):

170
171
172
    def __init__(
        self,
        config: BloomConfig,
173
        cache_config: Optional[CacheConfig] = None,
174
        quant_config: Optional[QuantizationConfig] = None,
175
        prefix: str = "",
176
    ):
Woosuk Kwon's avatar
Woosuk Kwon committed
177
178
179
180
181
        super().__init__()
        hidden_size = config.hidden_size

        self.input_layernorm = nn.LayerNorm(hidden_size,
                                            eps=config.layer_norm_epsilon)
182
183
184
185
        self.self_attention = BloomAttention(config,
                                             cache_config,
                                             quant_config,
                                             prefix=f"{prefix}.self_attention")
Woosuk Kwon's avatar
Woosuk Kwon committed
186
187
        self.post_attention_layernorm = nn.LayerNorm(
            hidden_size, eps=config.layer_norm_epsilon)
188
        self.mlp = BloomMLP(config, quant_config)
Woosuk Kwon's avatar
Woosuk Kwon committed
189
190
191
192
193
194
195
        self.apply_residual_connection_post_layernorm = (
            config.apply_residual_connection_post_layernorm)

    def forward(
        self,
        position_ids: torch.Tensor,
        hidden_states: torch.Tensor,
196
197
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    ) -> torch.Tensor:
        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)

        # Layer norm post the self attention.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        # Self attention.
        attention_output = self.self_attention(
            position_ids=position_ids,
            hidden_states=layernorm_output,
            kv_cache=kv_cache,
213
            attn_metadata=attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
        )
        attention_output = attention_output + residual
        layernorm_output = self.post_attention_layernorm(attention_output)

        # Get residual
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = attention_output

        # MLP.
        output = self.mlp(layernorm_output) + residual
        return output


229
@support_torch_compile
Woosuk Kwon's avatar
Woosuk Kwon committed
230
231
class BloomModel(nn.Module):

232
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
233
        super().__init__()
234
235
236
237
238

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

Woosuk Kwon's avatar
Woosuk Kwon committed
239
240
241
242
        self.embed_dim = config.hidden_size

        # Embedding + LN Embedding
        self.word_embeddings = VocabParallelEmbedding(
243
244
245
            config.vocab_size,
            self.embed_dim,
        )
Woosuk Kwon's avatar
Woosuk Kwon committed
246
247
248
249
        self.word_embeddings_layernorm = nn.LayerNorm(
            self.embed_dim, eps=config.layer_norm_epsilon)

        # Transformer blocks
250
251
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
252
253
            lambda prefix: BloomBlock(
                config, cache_config, quant_config, prefix=prefix),
254
            prefix=f"{prefix}.h")
Woosuk Kwon's avatar
Woosuk Kwon committed
255
256
257

        # Final Layer Norm
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
258
259
260
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.hidden_size))
Woosuk Kwon's avatar
Woosuk Kwon committed
261

262
263
264
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.word_embeddings_layernorm(self.word_embeddings(input_ids))

Woosuk Kwon's avatar
Woosuk Kwon committed
265
266
267
268
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
269
270
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
271
        intermediate_tensors: Optional[IntermediateTensors],
272
        inputs_embeds: Optional[torch.Tensor] = None,
273
274
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
275
276
277
278
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
279
280
281
282
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
        for i in range(self.start_layer, self.end_layer):
Woosuk Kwon's avatar
Woosuk Kwon committed
283
284
285
286
            layer = self.h[i]
            hidden_states = layer(
                position_ids,
                hidden_states,
287
                kv_caches[i - self.start_layer],
288
                attn_metadata,
Woosuk Kwon's avatar
Woosuk Kwon committed
289
            )
290
291
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
Woosuk Kwon's avatar
Woosuk Kwon committed
292
293
294
295
        hidden_states = self.ln_f(hidden_states)
        return hidden_states


296
class BloomForCausalLM(nn.Module, SupportsPP):
Woosuk Kwon's avatar
Woosuk Kwon committed
297

298
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Woosuk Kwon's avatar
Woosuk Kwon committed
299
        super().__init__()
300
301
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
Woosuk Kwon's avatar
Woosuk Kwon committed
302
        self.config = config
303
        self.quant_config = quant_config
304
305
306
        self.transformer = BloomModel(vllm_config=vllm_config,
                                      prefix=maybe_prefix(
                                          prefix, "transformer"))
307
308
309
310
311
312
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.word_embeddings
        else:
            self.lm_head = ParallelLMHead(self.config.vocab_size,
                                          self.config.hidden_size)

313
        self.logits_processor = LogitsProcessor(config.vocab_size)
Joe Runde's avatar
Joe Runde committed
314
        self.sampler = get_sampler()
315
316
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
Woosuk Kwon's avatar
Woosuk Kwon committed
317

318
319
320
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
321
322
323
324
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
325
326
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
327
        intermediate_tensors: Optional[IntermediateTensors] = None,
328
        inputs_embeds: Optional[torch.Tensor] = None,
329
    ) -> Union[torch.Tensor, IntermediateTensors]:
Woosuk Kwon's avatar
Woosuk Kwon committed
330
        hidden_states = self.transformer(input_ids, positions, kv_caches,
331
332
                                         attn_metadata, intermediate_tensors,
                                         inputs_embeds)
333
334
        return hidden_states

335
336
337
338
339
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
340
        logits = self.logits_processor(self.lm_head, hidden_states,
341
342
343
                                       sampling_metadata)
        return logits

344
345
    def sample(
        self,
346
        logits: torch.Tensor,
347
        sampling_metadata: SamplingMetadata,
348
    ) -> Optional[SamplerOutput]:
349
        next_tokens = self.sampler(logits, sampling_metadata)
Woosuk Kwon's avatar
Woosuk Kwon committed
350
351
        return next_tokens

352
353
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
354
        params_dict = dict(self.named_parameters(remove_duplicate=False))
355
        loaded_params: Set[str] = set()
356
        for name, loaded_weight in weights:
357
            if name == "lm_head.weight":
358
359
360
                continue
            if not name.startswith("transformer."):
                name = "transformer." + name
361
362
            if is_pp_missing_parameter(name, self):
                continue
363
            param = params_dict[name]
Woosuk Kwon's avatar
Woosuk Kwon committed
364
365

            if "query_key_value" in name:
366
367
368
                # NOTE: BLOOM's fused QKV's output_dim has the shape of
                # (num_heads * 3 * head_size), while the
                # required shape is (3 * num_heads * head_size).
Woosuk Kwon's avatar
Woosuk Kwon committed
369
                # Thus, we need weight conversion.
370
                output_dim = getattr(param, "output_dim", None)
Woosuk Kwon's avatar
Woosuk Kwon committed
371
                num_heads = self.config.num_attention_heads
372
373
374
375
376
377
378
379
380
381
382
383
                if output_dim is not None:
                    loaded_weight_shape = loaded_weight.shape
                    loaded_weight = loaded_weight.view(
                        loaded_weight_shape[:output_dim] + (num_heads, 3, -1) +
                        loaded_weight_shape[output_dim + 1:])
                    loaded_weight = loaded_weight.transpose(
                        output_dim, output_dim + 1)
                    loaded_weight = loaded_weight.reshape(loaded_weight_shape)

            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)
384
385
            loaded_params.add(name)
        return loaded_params