jais.py 13.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# coding=utf-8
# Adapted from
# https://huggingface.co/core42/jais-30b-chat-v3/blob/main/modeling_jais.py
# Copyright 2023 The vLLM team.
# Copyright 2023 the Jais authors and HuggingFace Inc. team.  All rights
# reserved.
# Copyright 2023 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Jais model compatible with HuggingFace weights."""

import math
23
from typing import Iterable, List, Optional, Tuple, Union
24
25
26
27

import torch
from torch import nn

28
from vllm.attention import Attention, AttentionMetadata
29
from vllm.config import CacheConfig
30
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
31
                              get_tensor_model_parallel_world_size)
32
33
34
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
35
from vllm.model_executor.layers.logits_processor import LogitsProcessor
36
37
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
38
39
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
40
    VocabParallelEmbedding)
41
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
42
from vllm.model_executor.sampling_metadata import SamplingMetadata
43
from vllm.sequence import IntermediateTensors, SamplerOutput
44
from vllm.transformers_utils.configs import JAISConfig
45

46
47
from .utils import is_pp_missing_parameter, make_layers

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

class SwiGLUActivation(nn.Module):

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        return x1 * nn.functional.silu(x2)


def _get_alibi_slopes(n):

    def get_slopes_power_of_2(n):
        start = 2**(-(2**-(math.log2(n) - 3)))
        ratio = start
        return [start * ratio**i for i in range(n)]

    if math.log2(n).is_integer():
        return get_slopes_power_of_2(n)
    else:
        closest_power_of_2 = 2**math.floor(math.log2(n))
        return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes(
            2 * closest_power_of_2)[0::2][:n - closest_power_of_2])


class JAISAttention(nn.Module):

    def __init__(
        self,
        config: JAISConfig,
75
        cache_config: Optional[CacheConfig] = None,
76
        quant_config: Optional[QuantizationConfig] = None,
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    ):
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // total_num_heads
        if hasattr(config, "scale_qk_dot_by_d"):
            config.mup_scale_qk_dot_by_d = config.scale_qk_dot_by_d
        self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5
        self.scale = self.head_dim**-self.attn_scale_power

        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            bias=True,
96
            quant_config=quant_config,
97
98
99
100
101
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
102
            quant_config=quant_config,
103
104
105
106
107
108
109
        )

        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
        alibi_slopes = _get_alibi_slopes(total_num_heads)
        alibi_slopes = alibi_slopes[head_start:head_end]
110
111
112
113
114
115
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
                              alibi_slopes=alibi_slopes,
                              cache_config=cache_config,
                              quant_config=quant_config)
116
117
118
119

    def forward(
        self,
        hidden_states: torch.Tensor,
120
121
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
122
123
124
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
125
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
126
127
128
129
130
131
132
133
134
135
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class JAISMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: JAISConfig,
136
        quant_config: Optional[QuantizationConfig] = None,
137
138
139
140
141
142
143
144
    ):
        super().__init__()
        hidden_size = config.hidden_size
        self.swiglu = config.activation_function == "swiglu"
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
145
            quant_config=quant_config,
146
147
148
149
150
        )
        self.c_fc2 = (ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
151
            quant_config=quant_config,
152
153
154
155
156
        ) if self.swiglu else None)
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
157
            quant_config=quant_config,
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        )

        self.act = SwiGLUActivation()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        if self.swiglu:
            hidden_states2, _ = self.c_fc2(hidden_states)
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = (self.act(hidden_states, hidden_states2)
                         if self.swiglu else self.act(hidden_states))
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class JAISBlock(nn.Module):

    def __init__(
        self,
        config: JAISConfig,
177
        cache_config: Optional[CacheConfig] = None,
178
        quant_config: Optional[QuantizationConfig] = None,
179
180
181
182
183
184
185
    ):
        super().__init__()
        hidden_size = config.hidden_size
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
186
        self.attn = JAISAttention(config, cache_config, quant_config)
187
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
188
        self.mlp = JAISMLP(inner_dim, config, quant_config)
189
190
191
192

    def forward(
        self,
        hidden_states: torch.Tensor,
193
194
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
195
196
197
198
199
200
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
201
            attn_metadata=attn_metadata,
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


class JAISModel(nn.Module):

    def __init__(
        self,
        config: JAISConfig,
219
        cache_config: Optional[CacheConfig] = None,
220
        quant_config: Optional[QuantizationConfig] = None,
221
        prefix: str = "",
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
    ):
        super().__init__()
        self.config = config
        assert not config.add_cross_attention
        assert not config.scale_attn_by_inverse_layer_idx
        assert not config.reorder_and_upcast_attn
        self.embed_dim = config.hidden_size
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
        self.wpe = (nn.Embedding(config.max_position_embeddings,
                                 self.embed_dim)
                    if config.position_embedding_type != "alibi" else None)
        if hasattr(config, "embeddings_scale"):
            self.embeddings_scale = config.embeddings_scale
        else:
            self.embeddings_scale = config.mup_embeddings_scale
237
238
239
240
241
242
243
244
245

        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
            lambda prefix: JAISBlock(config=config,
                                     cache_config=cache_config,
                                     quant_config=quant_config),
            prefix=f"{prefix}.h",
        )

246
247
248
249
250
251
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
252
253
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
254
255
256
257
258
259
260
261
262
263
264
        intermediate_tensors: Optional[IntermediateTensors] = None,
    ) -> Union[IntermediateTensors, torch.Tensor]:
        if get_pp_group().is_first_rank:
            inputs_embeds = self.wte(input_ids)
            if self.wpe is not None:
                position_embeds = self.wpe(position_ids)
                hidden_states = inputs_embeds + position_embeds
            else:
                hidden_states = inputs_embeds
            hidden_states *= torch.tensor(float(self.embeddings_scale),
                                          dtype=hidden_states.dtype)
265
        else:
266
267
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
268

269
        for i in range(self.start_layer, self.end_layer):
270
            layer = self.h[i]
271
272
273
274
275
276
            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata)

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})
277
278
279
280
281
282
283
284
285
286

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class JAISLMHeadModel(nn.Module):

    def __init__(
        self,
        config: JAISConfig,
287
        cache_config: Optional[CacheConfig] = None,
288
        quant_config: Optional[QuantizationConfig] = None,
289
290
291
    ):
        super().__init__()
        self.config = config
292
        self.quant_config = quant_config
293
        self.transformer = JAISModel(config, cache_config, quant_config)
294
        self.lm_head = self.transformer.wte
295
296
297
298
299
300
301
302
303
304
305
306
307
        if hasattr(config, "width_scale"):
            self.output_logits_scale = config.width_scale
        else:
            self.output_logits_scale = (config.mup_output_alpha *
                                        config.mup_width_scale)
        self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size,
                                                scale=self.output_logits_scale)
        self.sampler = Sampler()

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
308
309
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
310
        intermediate_tensors: Optional[IntermediateTensors] = None,
311
    ) -> Union[IntermediateTensors, torch.Tensor]:
312
        hidden_states = self.transformer(input_ids, positions, kv_caches,
313
                                         attn_metadata, intermediate_tensors)
314
315
        return hidden_states

316
317
318
319
320
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
321
        logits = self.logits_processor(self.lm_head, hidden_states,
322
323
324
                                       sampling_metadata)
        return logits

325
326
327
328
329
330
331
332
333
334
    def make_empty_intermediate_tensors(
            self, batch_size: int, dtype: torch.dtype,
            device: torch.device) -> IntermediateTensors:
        return IntermediateTensors({
            "hidden_states":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
        })

335
336
337
338
339
340
341
342
    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

343
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
344
        params_dict = dict(self.named_parameters(remove_duplicate=False))
345
        for name, loaded_weight in weights:
346
347
348
349
350
351
352
353
354
355
356
357
            if "lm_head.weight" in name:
                # GPT-2 ties the weights of the embedding layer and the final
                # linear layer.
                continue
            if ".attn.bias" in name or ".attn.masked_bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
            if "relative_pe" in name:
                continue
            if not name.startswith("transformer."):
                name = "transformer." + name
358
359
360
361

            if is_pp_missing_parameter(name, self):
                continue

362
363
364
365
366
367
368
369
370
371
372
373
            param = params_dict[name]
            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
            # Because of this, we need to transpose the weights.
            # Note(zhuohan): the logic below might break quantized models.
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                loaded_weight = loaded_weight.t()
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
374
            weight_loader(param, loaded_weight)