jais.py 14.9 KB
Newer Older
1
# Adapted from
2
# https://huggingface.co/inceptionai/jais-30b-chat-v3/blob/main/modeling_jais.py
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2023 The vLLM team.
# Copyright 2023 the Jais authors and HuggingFace Inc. team.  All rights
# reserved.
# Copyright 2023 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Jais model compatible with HuggingFace weights."""

import math
22
from typing import Iterable, List, Optional, Set, Tuple, Union
23
24
25
26

import torch
from torch import nn

27
from vllm.attention import Attention, AttentionMetadata
28
from vllm.compilation.decorators import support_torch_compile
29
from vllm.config import CacheConfig, VllmConfig
30
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
31
                              get_tensor_model_parallel_world_size)
32
33
34
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
35
from vllm.model_executor.layers.logits_processor import LogitsProcessor
36
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
37
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
38
from vllm.model_executor.layers.vocab_parallel_embedding import (
39
    ParallelLMHead, VocabParallelEmbedding)
40
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
from vllm.model_executor.sampling_metadata import SamplingMetadata
42
from vllm.sequence import IntermediateTensors
43
from vllm.transformers_utils.configs import JAISConfig
44

45
46
from .interfaces import SupportsPP
from .utils import (is_pp_missing_parameter,
47
48
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
49

50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76

class SwiGLUActivation(nn.Module):

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        return x1 * nn.functional.silu(x2)


def _get_alibi_slopes(n):

    def get_slopes_power_of_2(n):
        start = 2**(-(2**-(math.log2(n) - 3)))
        ratio = start
        return [start * ratio**i for i in range(n)]

    if math.log2(n).is_integer():
        return get_slopes_power_of_2(n)
    else:
        closest_power_of_2 = 2**math.floor(math.log2(n))
        return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes(
            2 * closest_power_of_2)[0::2][:n - closest_power_of_2])


class JAISAttention(nn.Module):

    def __init__(
        self,
        config: JAISConfig,
77
        cache_config: Optional[CacheConfig] = None,
78
        quant_config: Optional[QuantizationConfig] = None,
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    ):
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // total_num_heads
        if hasattr(config, "scale_qk_dot_by_d"):
            config.mup_scale_qk_dot_by_d = config.scale_qk_dot_by_d
        self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5
        self.scale = self.head_dim**-self.attn_scale_power

        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            bias=True,
98
            quant_config=quant_config,
99
100
101
102
103
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
104
            quant_config=quant_config,
105
106
107
108
109
110
111
        )

        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
        alibi_slopes = _get_alibi_slopes(total_num_heads)
        alibi_slopes = alibi_slopes[head_start:head_end]
112
113
114
115
116
117
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
                              alibi_slopes=alibi_slopes,
                              cache_config=cache_config,
                              quant_config=quant_config)
118
119
120
121

    def forward(
        self,
        hidden_states: torch.Tensor,
122
123
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
124
125
126
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
127
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
128
129
130
131
132
133
134
135
136
137
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class JAISMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: JAISConfig,
138
        quant_config: Optional[QuantizationConfig] = None,
139
140
141
142
143
144
145
146
    ):
        super().__init__()
        hidden_size = config.hidden_size
        self.swiglu = config.activation_function == "swiglu"
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
147
            quant_config=quant_config,
148
149
150
151
152
        )
        self.c_fc2 = (ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
153
            quant_config=quant_config,
154
155
156
157
158
        ) if self.swiglu else None)
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
159
            quant_config=quant_config,
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        )

        self.act = SwiGLUActivation()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        if self.swiglu:
            hidden_states2, _ = self.c_fc2(hidden_states)
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = (self.act(hidden_states, hidden_states2)
                         if self.swiglu else self.act(hidden_states))
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class JAISBlock(nn.Module):

    def __init__(
        self,
        config: JAISConfig,
179
        cache_config: Optional[CacheConfig] = None,
180
        quant_config: Optional[QuantizationConfig] = None,
181
182
183
184
185
186
187
    ):
        super().__init__()
        hidden_size = config.hidden_size
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
188
        self.attn = JAISAttention(config, cache_config, quant_config)
189
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
190
        self.mlp = JAISMLP(inner_dim, config, quant_config)
191
192
193
194

    def forward(
        self,
        hidden_states: torch.Tensor,
195
196
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
197
198
199
200
201
202
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
203
            attn_metadata=attn_metadata,
204
205
206
207
208
209
210
211
212
213
214
215
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


216
@support_torch_compile
217
218
class JAISModel(nn.Module):

219
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
220
        super().__init__()
221
222
223
224
225

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

226
227
228
229
230
231
232
233
234
235
236
237
238
        self.config = config
        assert not config.add_cross_attention
        assert not config.scale_attn_by_inverse_layer_idx
        assert not config.reorder_and_upcast_attn
        self.embed_dim = config.hidden_size
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
        self.wpe = (nn.Embedding(config.max_position_embeddings,
                                 self.embed_dim)
                    if config.position_embedding_type != "alibi" else None)
        if hasattr(config, "embeddings_scale"):
            self.embeddings_scale = config.embeddings_scale
        else:
            self.embeddings_scale = config.mup_embeddings_scale
239
240
241
242
243
244
245
246
247

        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
            lambda prefix: JAISBlock(config=config,
                                     cache_config=cache_config,
                                     quant_config=quant_config),
            prefix=f"{prefix}.h",
        )

248
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
249
250
251
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))
252

253
254
255
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.wte(input_ids)

256
257
258
259
    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
260
261
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
262
        intermediate_tensors: Optional[IntermediateTensors] = None,
263
        inputs_embeds: Optional[torch.Tensor] = None,
264
265
    ) -> Union[IntermediateTensors, torch.Tensor]:
        if get_pp_group().is_first_rank:
266
267
            if inputs_embeds is None:
                inputs_embeds = self.get_input_embeddings(input_ids)
268
269
270
271
272
273
274
            if self.wpe is not None:
                position_embeds = self.wpe(position_ids)
                hidden_states = inputs_embeds + position_embeds
            else:
                hidden_states = inputs_embeds
            hidden_states *= torch.tensor(float(self.embeddings_scale),
                                          dtype=hidden_states.dtype)
275
        else:
276
277
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
278

279
        for i in range(self.start_layer, self.end_layer):
280
            layer = self.h[i]
281
282
283
284
285
286
            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata)

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
287
288
289
290
291

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


292
class JAISLMHeadModel(nn.Module, SupportsPP):
293

294
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
295
        super().__init__()
296
297
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
298
        self.config = config
299
        self.quant_config = quant_config
300
301
302
        self.transformer = JAISModel(vllm_config=vllm_config,
                                     prefix=maybe_prefix(
                                         prefix, "transformer"))
303
304
305
306
307
        if self.config.tie_word_embeddings:
            self.lm_head = self.transformer.wte
        else:
            self.lm_head = ParallelLMHead(self.config.vocab_size,
                                          self.config.hidden_size)
308
309
310
311
312
313
314
        if hasattr(config, "width_scale"):
            self.output_logits_scale = config.width_scale
        else:
            self.output_logits_scale = (config.mup_output_alpha *
                                        config.mup_width_scale)
        self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size,
                                                scale=self.output_logits_scale)
Joe Runde's avatar
Joe Runde committed
315
        self.sampler = get_sampler()
316
317
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)
318

319
320
321
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.transformer.get_input_embeddings(input_ids)

322
323
324
325
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
326
327
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
328
        intermediate_tensors: Optional[IntermediateTensors] = None,
329
        inputs_embeds: Optional[torch.Tensor] = None,
330
    ) -> Union[IntermediateTensors, torch.Tensor]:
331
        hidden_states = self.transformer(input_ids, positions, kv_caches,
332
333
                                         attn_metadata, intermediate_tensors,
                                         inputs_embeds)
334
335
        return hidden_states

336
337
338
339
340
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
341
        logits = self.logits_processor(self.lm_head, hidden_states,
342
343
344
345
346
347
348
349
350
351
352
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

353
354
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
355
        params_dict = dict(self.named_parameters(remove_duplicate=False))
356
        loaded_params: Set[str] = set()
357
        for name, loaded_weight in weights:
358
359
360
361
362
363
364
365
366
367
368
369
            if "lm_head.weight" in name:
                # GPT-2 ties the weights of the embedding layer and the final
                # linear layer.
                continue
            if ".attn.bias" in name or ".attn.masked_bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
            if "relative_pe" in name:
                continue
            if not name.startswith("transformer."):
                name = "transformer." + name
370
371
372
373

            if is_pp_missing_parameter(name, self):
                continue

374
375
376
377
378
379
380
381
382
383
384
385
            param = params_dict[name]
            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
            # Because of this, we need to transpose the weights.
            # Note(zhuohan): the logic below might break quantized models.
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                loaded_weight = loaded_weight.t()
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
386
            weight_loader(param, loaded_weight)
387
388
            loaded_params.add(name)
        return loaded_params