jais.py 14 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
# Adapted from
5
# https://huggingface.co/inceptionai/jais-30b-chat-v3/blob/main/modeling_jais.py
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Copyright 2023 The vLLM team.
# Copyright 2023 the Jais authors and HuggingFace Inc. team.  All rights
# reserved.
# Copyright 2023 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Jais model compatible with HuggingFace weights."""

import math
25
from collections.abc import Iterable
26
from itertools import islice
27
28
29
30

import torch
from torch import nn

31
from vllm.compilation.decorators import support_torch_compile
32
from vllm.config import CacheConfig, VllmConfig
33
34
35
36
37
from vllm.distributed import (
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
)
38
from vllm.model_executor.layers.attention import Attention
39
40
41
42
43
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
44
from vllm.model_executor.layers.logits_processor import LogitsProcessor
45
from vllm.model_executor.layers.quantization import QuantizationConfig
46
from vllm.model_executor.layers.vocab_parallel_embedding import (
47
48
49
    ParallelLMHead,
    VocabParallelEmbedding,
)
50
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
51
from vllm.sequence import IntermediateTensors
52
from vllm.transformers_utils.configs.jais import JAISConfig
53

54
from .interfaces import SupportsPP
55
from .utils import (
56
    AutoWeightsLoader,
57
58
59
60
61
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
62

63
64
65
66
67
68
69
70

class SwiGLUActivation(nn.Module):
    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        return x1 * nn.functional.silu(x2)


def _get_alibi_slopes(n):
    def get_slopes_power_of_2(n):
71
        start = 2 ** (-(2 ** -(math.log2(n) - 3)))
72
73
74
75
76
77
        ratio = start
        return [start * ratio**i for i in range(n)]

    if math.log2(n).is_integer():
        return get_slopes_power_of_2(n)
    else:
78
79
80
81
82
        closest_power_of_2 = 2 ** math.floor(math.log2(n))
        return (
            get_slopes_power_of_2(closest_power_of_2)
            + _get_alibi_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
        )
83
84
85
86
87
88


class JAISAttention(nn.Module):
    def __init__(
        self,
        config: JAISConfig,
89
90
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
91
        prefix: str = "",
92
93
94
95
    ):
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
96
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
97
98
99
100
101
102
103
104
105
106
107
108
109
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // total_num_heads
        if hasattr(config, "scale_qk_dot_by_d"):
            config.mup_scale_qk_dot_by_d = config.scale_qk_dot_by_d
        self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5
        self.scale = self.head_dim**-self.attn_scale_power

        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            bias=True,
110
            quant_config=quant_config,
111
            prefix=f"{prefix}.c_attn",
112
113
114
115
116
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
117
            quant_config=quant_config,
118
            prefix=f"{prefix}.c_proj",
119
120
        )

121
122
123
124
125
126
127
128
        self.use_alibi = config.position_embedding_type == "alibi"
        alibi_slopes = None
        if self.use_alibi:
            tp_rank = get_tensor_model_parallel_rank()
            head_start = tp_rank * self.num_heads
            head_end = (tp_rank + 1) * self.num_heads
            alibi_slopes = _get_alibi_slopes(total_num_heads)
            alibi_slopes = alibi_slopes[head_start:head_end]
129
130
131
132
133
134
135
136
137
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            scale=self.scale,
            alibi_slopes=alibi_slopes,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
138
139
140
141
142
143
144

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
145
        attn_output = self.attn(q, k, v)
146
147
148
149
150
151
152
153
154
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class JAISMLP(nn.Module):
    def __init__(
        self,
        intermediate_size: int,
        config: JAISConfig,
155
        quant_config: QuantizationConfig | None = None,
156
        prefix: str = "",
157
158
159
160
161
162
163
164
    ):
        super().__init__()
        hidden_size = config.hidden_size
        self.swiglu = config.activation_function == "swiglu"
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
165
            quant_config=quant_config,
166
            prefix=f"{prefix}.c_fc",
167
        )
168
169
170
171
172
173
        self.c_fc2 = (
            ColumnParallelLinear(
                hidden_size,
                intermediate_size,
                bias=True,
                quant_config=quant_config,
174
                prefix=f"{prefix}.c_fc2",
175
176
177
178
            )
            if self.swiglu
            else None
        )
179
180
181
182
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
183
            quant_config=quant_config,
184
            prefix=f"{prefix}.c_proj",
185
186
187
188
189
190
191
192
        )

        self.act = SwiGLUActivation()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        if self.swiglu:
            hidden_states2, _ = self.c_fc2(hidden_states)
        hidden_states, _ = self.c_fc(hidden_states)
193
194
195
196
197
        hidden_states = (
            self.act(hidden_states, hidden_states2)
            if self.swiglu
            else self.act(hidden_states)
        )
198
199
200
201
202
203
204
205
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class JAISBlock(nn.Module):
    def __init__(
        self,
        config: JAISConfig,
206
207
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
208
        prefix: str = "",
209
210
211
    ):
        super().__init__()
        hidden_size = config.hidden_size
212
        inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
213
214

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
215
216
217
        self.attn = JAISAttention(
            config, cache_config, quant_config, prefix=f"{prefix}.attn"
        )
218
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
219
        self.mlp = JAISMLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
220
221
222
223
224
225
226

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
227
228
229
        attn_output = self.attn(
            hidden_states=hidden_states,
        )
230
231
232
233
234
235
236
237
238
239
240
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


241
@support_torch_compile
242
class JAISModel(nn.Module):
243
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
244
        super().__init__()
245
246
247
248
249

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

250
251
252
253
254
        self.config = config
        assert not config.scale_attn_by_inverse_layer_idx
        assert not config.reorder_and_upcast_attn
        self.embed_dim = config.hidden_size
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
255
256
257
258
259
        self.wpe = (
            nn.Embedding(config.max_position_embeddings, self.embed_dim)
            if config.position_embedding_type != "alibi"
            else None
        )
260
261
262
263
        if hasattr(config, "embeddings_scale"):
            self.embeddings_scale = config.embeddings_scale
        else:
            self.embeddings_scale = config.mup_embeddings_scale
264
265
266

        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
267
268
269
270
271
272
            lambda prefix: JAISBlock(
                config=config,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
            ),
273
274
275
            prefix=f"{prefix}.h",
        )

276
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
277
278
279
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states"], config.n_embd
        )
280

281
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
282
283
        return self.wte(input_ids)

284
285
    def forward(
        self,
286
        input_ids: torch.Tensor | None,
287
        position_ids: torch.Tensor,
288
289
290
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> IntermediateTensors | torch.Tensor:
291
        if get_pp_group().is_first_rank:
292
            if inputs_embeds is None:
293
                inputs_embeds = self.embed_input_ids(input_ids)
294
295
296
297
298
            if self.wpe is not None:
                position_embeds = self.wpe(position_ids)
                hidden_states = inputs_embeds + position_embeds
            else:
                hidden_states = inputs_embeds
299
300
301
            hidden_states *= torch.tensor(
                float(self.embeddings_scale), dtype=hidden_states.dtype
            )
302
        else:
303
304
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
305

306
        for layer in islice(self.h, self.start_layer, self.end_layer):
307
            hidden_states = layer(hidden_states)
308
309
310

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
311
312
313
314

        hidden_states = self.ln_f(hidden_states)
        return hidden_states

315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if ".attn.bias" in name or ".attn.masked_bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
            if "relative_pe" in name:
                continue

            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
            # Because of this, we need to transpose the weights.
            # Note(zhuohan): the logic below might break quantized models.
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                loaded_weight = loaded_weight.t()
            weight_loader = getattr(param, "weight_loader", default_weight_loader)
            weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

344

345
class JAISLMHeadModel(nn.Module, SupportsPP):
346
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
347
        super().__init__()
348
349
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
350
        self.config = config
351
        self.quant_config = quant_config
352
353
354
        self.transformer = JAISModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
        )
355
356
357
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.wte
        else:
358
359
360
361
362
            self.lm_head = ParallelLMHead(
                self.config.vocab_size,
                self.config.hidden_size,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
363
364
365
        if hasattr(config, "width_scale"):
            self.output_logits_scale = config.width_scale
        else:
366
367
368
369
            self.output_logits_scale = config.mup_output_alpha * config.mup_width_scale
        self.logits_processor = LogitsProcessor(
            vocab_size=config.vocab_size, scale=self.output_logits_scale
        )
370
        self.make_empty_intermediate_tensors = (
371
372
            self.transformer.make_empty_intermediate_tensors
        )
373

374
375
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.embed_input_ids(input_ids)
376

377
378
    def forward(
        self,
379
        input_ids: torch.Tensor | None,
380
        positions: torch.Tensor,
381
382
383
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> IntermediateTensors | torch.Tensor:
384
385
386
        hidden_states = self.transformer(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
387
388
        return hidden_states

389
390
391
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
392
    ) -> torch.Tensor | None:
393
        logits = self.logits_processor(self.lm_head, hidden_states)
394
395
        return logits

396
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
397
398
399
400
401
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
        )
        return loader.load_weights(weights)