jais.py 14.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
# Adapted from
4
# https://huggingface.co/inceptionai/jais-30b-chat-v3/blob/main/modeling_jais.py
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Copyright 2023 The vLLM team.
# Copyright 2023 the Jais authors and HuggingFace Inc. team.  All rights
# reserved.
# Copyright 2023 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Jais model compatible with HuggingFace weights."""

import math
24
from typing import Iterable, Optional, Set, Tuple, Union
25
26
27
28

import torch
from torch import nn

29
from vllm.attention import Attention
30
from vllm.compilation.decorators import support_torch_compile
31
from vllm.config import CacheConfig, VllmConfig
32
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
33
                              get_tensor_model_parallel_world_size)
34
35
36
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
37
from vllm.model_executor.layers.logits_processor import LogitsProcessor
38
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
39
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
40
from vllm.model_executor.layers.vocab_parallel_embedding import (
41
    ParallelLMHead, VocabParallelEmbedding)
42
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
43
from vllm.model_executor.sampling_metadata import SamplingMetadata
44
from vllm.sequence import IntermediateTensors
45
from vllm.transformers_utils.configs import JAISConfig
46

47
48
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
49
50
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
51

52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78

class SwiGLUActivation(nn.Module):

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        return x1 * nn.functional.silu(x2)


def _get_alibi_slopes(n):

    def get_slopes_power_of_2(n):
        start = 2**(-(2**-(math.log2(n) - 3)))
        ratio = start
        return [start * ratio**i for i in range(n)]

    if math.log2(n).is_integer():
        return get_slopes_power_of_2(n)
    else:
        closest_power_of_2 = 2**math.floor(math.log2(n))
        return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes(
            2 * closest_power_of_2)[0::2][:n - closest_power_of_2])


class JAISAttention(nn.Module):

    def __init__(
        self,
        config: JAISConfig,
79
        cache_config: Optional[CacheConfig] = None,
80
        quant_config: Optional[QuantizationConfig] = None,
81
        prefix: str = "",
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
    ):
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // total_num_heads
        if hasattr(config, "scale_qk_dot_by_d"):
            config.mup_scale_qk_dot_by_d = config.scale_qk_dot_by_d
        self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5
        self.scale = self.head_dim**-self.attn_scale_power

        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            bias=True,
101
            quant_config=quant_config,
102
103
104
105
106
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
107
            quant_config=quant_config,
108
109
110
111
112
113
114
        )

        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
        alibi_slopes = _get_alibi_slopes(total_num_heads)
        alibi_slopes = alibi_slopes[head_start:head_end]
115
116
117
118
119
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
                              alibi_slopes=alibi_slopes,
                              cache_config=cache_config,
120
121
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
122
123
124
125
126
127
128

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
129
        attn_output = self.attn(q, k, v)
130
131
132
133
134
135
136
137
138
139
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class JAISMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: JAISConfig,
140
        quant_config: Optional[QuantizationConfig] = None,
141
142
143
144
145
146
147
148
    ):
        super().__init__()
        hidden_size = config.hidden_size
        self.swiglu = config.activation_function == "swiglu"
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
149
            quant_config=quant_config,
150
151
152
153
154
        )
        self.c_fc2 = (ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
155
            quant_config=quant_config,
156
157
158
159
160
        ) if self.swiglu else None)
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
161
            quant_config=quant_config,
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        )

        self.act = SwiGLUActivation()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        if self.swiglu:
            hidden_states2, _ = self.c_fc2(hidden_states)
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = (self.act(hidden_states, hidden_states2)
                         if self.swiglu else self.act(hidden_states))
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class JAISBlock(nn.Module):

    def __init__(
        self,
        config: JAISConfig,
181
        cache_config: Optional[CacheConfig] = None,
182
        quant_config: Optional[QuantizationConfig] = None,
183
        prefix: str = "",
184
185
186
187
188
189
190
    ):
        super().__init__()
        hidden_size = config.hidden_size
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
191
192
193
194
        self.attn = JAISAttention(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.attn")
195
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
196
        self.mlp = JAISMLP(inner_dim, config, quant_config)
197
198
199
200
201
202
203

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
204
        attn_output = self.attn(hidden_states=hidden_states, )
205
206
207
208
209
210
211
212
213
214
215
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


216
@support_torch_compile
217
218
class JAISModel(nn.Module):

219
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
220
        super().__init__()
221
222
223
224
225

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

226
227
228
229
230
231
232
233
234
235
236
237
238
        self.config = config
        assert not config.add_cross_attention
        assert not config.scale_attn_by_inverse_layer_idx
        assert not config.reorder_and_upcast_attn
        self.embed_dim = config.hidden_size
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
        self.wpe = (nn.Embedding(config.max_position_embeddings,
                                 self.embed_dim)
                    if config.position_embedding_type != "alibi" else None)
        if hasattr(config, "embeddings_scale"):
            self.embeddings_scale = config.embeddings_scale
        else:
            self.embeddings_scale = config.mup_embeddings_scale
239
240
241
242
243

        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
            lambda prefix: JAISBlock(config=config,
                                     cache_config=cache_config,
244
245
                                     quant_config=quant_config,
                                     prefix=prefix),
246
247
248
            prefix=f"{prefix}.h",
        )

249
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
250
251
252
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))
253

254
255
256
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

257
258
259
260
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
261
        intermediate_tensors: Optional[IntermediateTensors] = None,
262
        inputs_embeds: Optional[torch.Tensor] = None,
263
264
    ) -> Union[IntermediateTensors, torch.Tensor]:
        if get_pp_group().is_first_rank:
265
266
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings(input_ids)
267
268
269
270
271
272
273
            if self.wpe is not None:
                position_embeds = self.wpe(position_ids)
                hidden_states = inputs_embeds + position_embeds
            else:
                hidden_states = inputs_embeds
            hidden_states *= torch.tensor(float(self.embeddings_scale),
                                          dtype=hidden_states.dtype)
274
        else:
275
276
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
277

278
279
        for layer in self.h[self.start_layer:self.end_layer]:
            hidden_states = layer(hidden_states)
280
281
282

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
283
284
285
286
287

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


288
class JAISLMHeadModel(nn.Module, SupportsPP):
289

290
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
291
        super().__init__()
292
293
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
294
        self.config = config
295
        self.quant_config = quant_config
296
297
298
        self.transformer = JAISModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(
                                         prefix, "transformer"))
299
300
301
302
303
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.wte
        else:
            self.lm_head = ParallelLMHead(self.config.vocab_size,
                                          self.config.hidden_size)
304
305
306
307
308
309
310
        if hasattr(config, "width_scale"):
            self.output_logits_scale = config.width_scale
        else:
            self.output_logits_scale = (config.mup_output_alpha *
                                        config.mup_width_scale)
        self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size,
                                                scale=self.output_logits_scale)
Joe Runde's avatar
Joe Runde committed
311
        self.sampler = get_sampler()
312
313
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
314

315
316
317
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

318
319
320
321
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
322
        intermediate_tensors: Optional[IntermediateTensors] = None,
323
        inputs_embeds: Optional[torch.Tensor] = None,
324
    ) -> Union[IntermediateTensors, torch.Tensor]:
325
326
        hidden_states = self.transformer(input_ids, positions,
                                         intermediate_tensors, inputs_embeds)
327
328
        return hidden_states

329
330
331
332
333
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
334
        logits = self.logits_processor(self.lm_head, hidden_states,
335
336
337
338
339
340
341
342
343
344
345
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

346
347
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
348
        params_dict = dict(self.named_parameters(remove_duplicate=False))
349
        loaded_params: Set[str] = set()
350
        for name, loaded_weight in weights:
351
352
353
354
355
356
357
358
359
360
361
362
            if "lm_head.weight" in name:
                # GPT-2 ties the weights of the embedding layer and the final
                # linear layer.
                continue
            if ".attn.bias" in name or ".attn.masked_bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
            if "relative_pe" in name:
                continue
            if not name.startswith("transformer."):
                name = "transformer." + name
363
364
365
366

            if is_pp_missing_parameter(name, self):
                continue

367
368
369
370
371
372
373
374
375
376
377
378
            param = params_dict[name]
            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
            # Because of this, we need to transpose the weights.
            # Note(zhuohan): the logic below might break quantized models.
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                loaded_weight = loaded_weight.t()
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
379
            weight_loader(param, loaded_weight)
380
381
            loaded_params.add(name)
        return loaded_params