"vllm/vscode:/vscode.git/clone" did not exist on "711aa9d5b6784d50c2f9ab89642566d5cc3fe4ab"
jais.py 13.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
# Adapted from
5
# https://huggingface.co/inceptionai/jais-30b-chat-v3/blob/main/modeling_jais.py
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2023 The vLLM team.
# Copyright 2023 the Jais authors and HuggingFace Inc. team.  All rights
# reserved.
# Copyright 2023 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Jais model compatible with HuggingFace weights."""

import math
25
from collections.abc import Iterable
26
from itertools import islice
27
28
29
30

import torch
from torch import nn

31
from vllm.attention import Attention
32
from vllm.compilation.decorators import support_torch_compile
33
from vllm.config import CacheConfig, VllmConfig
34
35
36
37
38
39
40
41
42
43
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
44
from vllm.model_executor.layers.logits_processor import LogitsProcessor
45
from vllm.model_executor.layers.quantization import QuantizationConfig
46
from vllm.model_executor.layers.vocab_parallel_embedding import (
47
48
49
    ParallelLMHead,
    VocabParallelEmbedding,
)
50
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
51
from vllm.sequence import IntermediateTensors
52
from vllm.transformers_utils.configs import JAISConfig
53

54
from .interfaces import SupportsPP
55
56
57
58
59
60
from .utils import (
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
61

62
63
64
65
66
67
68
69

class SwiGLUActivation(nn.Module):
    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        return x1 * nn.functional.silu(x2)


def _get_alibi_slopes(n):
    def get_slopes_power_of_2(n):
70
        start = 2 ** (-(2 ** -(math.log2(n) - 3)))
71
72
73
74
75
76
        ratio = start
        return [start * ratio**i for i in range(n)]

    if math.log2(n).is_integer():
        return get_slopes_power_of_2(n)
    else:
77
78
79
80
81
        closest_power_of_2 = 2 ** math.floor(math.log2(n))
        return (
            get_slopes_power_of_2(closest_power_of_2)
            + _get_alibi_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
        )
82
83
84
85
86
87


class JAISAttention(nn.Module):
    def __init__(
        self,
        config: JAISConfig,
88
89
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
90
        prefix: str = "",
91
92
93
94
    ):
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
95
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
96
97
98
99
100
101
102
103
104
105
106
107
108
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // total_num_heads
        if hasattr(config, "scale_qk_dot_by_d"):
            config.mup_scale_qk_dot_by_d = config.scale_qk_dot_by_d
        self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5
        self.scale = self.head_dim**-self.attn_scale_power

        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            bias=True,
109
            quant_config=quant_config,
110
111
112
113
114
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
115
            quant_config=quant_config,
116
117
118
119
120
121
122
        )

        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
        alibi_slopes = _get_alibi_slopes(total_num_heads)
        alibi_slopes = alibi_slopes[head_start:head_end]
123
124
125
126
127
128
129
130
131
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            scale=self.scale,
            alibi_slopes=alibi_slopes,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
132
133
134
135
136
137
138

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
139
        attn_output = self.attn(q, k, v)
140
141
142
143
144
145
146
147
148
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class JAISMLP(nn.Module):
    def __init__(
        self,
        intermediate_size: int,
        config: JAISConfig,
149
        quant_config: QuantizationConfig | None = None,
150
151
152
153
154
155
156
157
    ):
        super().__init__()
        hidden_size = config.hidden_size
        self.swiglu = config.activation_function == "swiglu"
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
158
            quant_config=quant_config,
159
        )
160
161
162
163
164
165
166
167
168
169
        self.c_fc2 = (
            ColumnParallelLinear(
                hidden_size,
                intermediate_size,
                bias=True,
                quant_config=quant_config,
            )
            if self.swiglu
            else None
        )
170
171
172
173
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
174
            quant_config=quant_config,
175
176
177
178
179
180
181
182
        )

        self.act = SwiGLUActivation()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        if self.swiglu:
            hidden_states2, _ = self.c_fc2(hidden_states)
        hidden_states, _ = self.c_fc(hidden_states)
183
184
185
186
187
        hidden_states = (
            self.act(hidden_states, hidden_states2)
            if self.swiglu
            else self.act(hidden_states)
        )
188
189
190
191
192
193
194
195
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class JAISBlock(nn.Module):
    def __init__(
        self,
        config: JAISConfig,
196
197
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
198
        prefix: str = "",
199
200
201
    ):
        super().__init__()
        hidden_size = config.hidden_size
202
        inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
203
204

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
205
206
207
        self.attn = JAISAttention(
            config, cache_config, quant_config, prefix=f"{prefix}.attn"
        )
208
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
209
        self.mlp = JAISMLP(inner_dim, config, quant_config)
210
211
212
213
214
215
216

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
217
218
219
        attn_output = self.attn(
            hidden_states=hidden_states,
        )
220
221
222
223
224
225
226
227
228
229
230
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


231
@support_torch_compile
232
class JAISModel(nn.Module):
233
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
234
        super().__init__()
235
236
237
238
239

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

240
241
242
243
244
245
        self.config = config
        assert not config.add_cross_attention
        assert not config.scale_attn_by_inverse_layer_idx
        assert not config.reorder_and_upcast_attn
        self.embed_dim = config.hidden_size
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
246
247
248
249
250
        self.wpe = (
            nn.Embedding(config.max_position_embeddings, self.embed_dim)
            if config.position_embedding_type != "alibi"
            else None
        )
251
252
253
254
        if hasattr(config, "embeddings_scale"):
            self.embeddings_scale = config.embeddings_scale
        else:
            self.embeddings_scale = config.mup_embeddings_scale
255
256
257

        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
258
259
260
261
262
263
            lambda prefix: JAISBlock(
                config=config,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
            ),
264
265
266
            prefix=f"{prefix}.h",
        )

267
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
268
269
270
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.n_embd
        )
271

272
273
274
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

275
276
277
278
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
279
280
281
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> IntermediateTensors | torch.Tensor:
282
        if get_pp_group().is_first_rank:
283
284
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings(input_ids)
285
286
287
288
289
            if self.wpe is not None:
                position_embeds = self.wpe(position_ids)
                hidden_states = inputs_embeds + position_embeds
            else:
                hidden_states = inputs_embeds
290
291
292
            hidden_states *= torch.tensor(
                float(self.embeddings_scale), dtype=hidden_states.dtype
            )
293
        else:
294
295
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
296

297
        for layer in islice(self.h, self.start_layer, self.end_layer):
298
            hidden_states = layer(hidden_states)
299
300
301

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
302
303
304
305
306

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


307
class JAISLMHeadModel(nn.Module, SupportsPP):
308
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
309
        super().__init__()
310
311
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
312
        self.config = config
313
        self.quant_config = quant_config
314
315
316
        self.transformer = JAISModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
        )
317
318
319
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.wte
        else:
320
321
322
323
324
            self.lm_head = ParallelLMHead(
                self.config.vocab_size,
                self.config.hidden_size,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
325
326
327
        if hasattr(config, "width_scale"):
            self.output_logits_scale = config.width_scale
        else:
328
329
330
331
            self.output_logits_scale = config.mup_output_alpha * config.mup_width_scale
        self.logits_processor = LogitsProcessor(
            vocab_size=config.vocab_size, scale=self.output_logits_scale
        )
332
        self.make_empty_intermediate_tensors = (
333
334
            self.transformer.make_empty_intermediate_tensors
        )
335

336
337
338
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

339
340
341
342
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
343
344
345
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> IntermediateTensors | torch.Tensor:
346
347
348
        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
349
350
        return hidden_states

351
352
353
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
354
    ) -> torch.Tensor | None:
355
        logits = self.logits_processor(self.lm_head, hidden_states)
356
357
        return logits

358
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
359
        params_dict = dict(self.named_parameters(remove_duplicate=False))
360
        loaded_params: set[str] = set()
361
        for name, loaded_weight in weights:
362
363
364
365
366
367
368
369
370
371
372
373
            if "lm_head.weight" in name:
                # GPT-2 ties the weights of the embedding layer and the final
                # linear layer.
                continue
            if ".attn.bias" in name or ".attn.masked_bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
            if "relative_pe" in name:
                continue
            if not name.startswith("transformer."):
                name = "transformer." + name
374
375
376
377

            if is_pp_missing_parameter(name, self):
                continue

378
379
380
381
382
383
384
385
386
387
            param = params_dict[name]
            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
            # Because of this, we need to transpose the weights.
            # Note(zhuohan): the logic below might break quantized models.
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                loaded_weight = loaded_weight.t()
388
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
389
            weight_loader(param, loaded_weight)
390
391
            loaded_params.add(name)
        return loaded_params