jais.py 12.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# coding=utf-8
# Adapted from
# https://huggingface.co/core42/jais-30b-chat-v3/blob/main/modeling_jais.py
# Copyright 2023 The vLLM team.
# Copyright 2023 the Jais authors and HuggingFace Inc. team.  All rights
# reserved.
# Copyright 2023 Cerebras Systems.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Jais model compatible with HuggingFace weights."""

import math
23
from typing import Iterable, List, Optional, Tuple
24
25
26
27

import torch
from torch import nn

28
from vllm.attention import Attention, AttentionMetadata
29
from vllm.config import CacheConfig
30
31
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
32
33
34
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
35
from vllm.model_executor.layers.logits_processor import LogitsProcessor
36
37
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
38
39
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
40
    VocabParallelEmbedding)
41
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
42
from vllm.model_executor.sampling_metadata import SamplingMetadata
43
44
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import JAISConfig
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72


class SwiGLUActivation(nn.Module):

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        return x1 * nn.functional.silu(x2)


def _get_alibi_slopes(n):

    def get_slopes_power_of_2(n):
        start = 2**(-(2**-(math.log2(n) - 3)))
        ratio = start
        return [start * ratio**i for i in range(n)]

    if math.log2(n).is_integer():
        return get_slopes_power_of_2(n)
    else:
        closest_power_of_2 = 2**math.floor(math.log2(n))
        return (get_slopes_power_of_2(closest_power_of_2) + _get_alibi_slopes(
            2 * closest_power_of_2)[0::2][:n - closest_power_of_2])


class JAISAttention(nn.Module):

    def __init__(
        self,
        config: JAISConfig,
73
        cache_config: Optional[CacheConfig] = None,
74
        quant_config: Optional[QuantizationConfig] = None,
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    ):
        super().__init__()
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // total_num_heads
        if hasattr(config, "scale_qk_dot_by_d"):
            config.mup_scale_qk_dot_by_d = config.scale_qk_dot_by_d
        self.attn_scale_power = 1.0 if config.mup_scale_qk_dot_by_d else 0.5
        self.scale = self.head_dim**-self.attn_scale_power

        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            bias=True,
94
            quant_config=quant_config,
95
96
97
98
99
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
100
            quant_config=quant_config,
101
102
103
104
105
106
107
108
109
110
111
112
        )

        tp_rank = get_tensor_model_parallel_rank()
        head_start = tp_rank * self.num_heads
        head_end = (tp_rank + 1) * self.num_heads
        alibi_slopes = _get_alibi_slopes(total_num_heads)
        alibi_slopes = alibi_slopes[head_start:head_end]
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            scale=self.scale,
            alibi_slopes=alibi_slopes,
113
            cache_config=cache_config,
114
115
116
117
118
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
119
120
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
121
122
123
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
124
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
125
126
127
128
129
130
131
132
133
134
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class JAISMLP(nn.Module):

    def __init__(
        self,
        intermediate_size: int,
        config: JAISConfig,
135
        quant_config: Optional[QuantizationConfig] = None,
136
137
138
139
140
141
142
143
    ):
        super().__init__()
        hidden_size = config.hidden_size
        self.swiglu = config.activation_function == "swiglu"
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
144
            quant_config=quant_config,
145
146
147
148
149
        )
        self.c_fc2 = (ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
150
            quant_config=quant_config,
151
152
153
154
155
        ) if self.swiglu else None)
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
156
            quant_config=quant_config,
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
        )

        self.act = SwiGLUActivation()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        if self.swiglu:
            hidden_states2, _ = self.c_fc2(hidden_states)
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = (self.act(hidden_states, hidden_states2)
                         if self.swiglu else self.act(hidden_states))
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class JAISBlock(nn.Module):

    def __init__(
        self,
        config: JAISConfig,
176
        cache_config: Optional[CacheConfig] = None,
177
        quant_config: Optional[QuantizationConfig] = None,
178
179
180
181
182
183
184
    ):
        super().__init__()
        hidden_size = config.hidden_size
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
185
        self.attn = JAISAttention(config, cache_config, quant_config)
186
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
187
        self.mlp = JAISMLP(inner_dim, config, quant_config)
188
189
190
191

    def forward(
        self,
        hidden_states: torch.Tensor,
192
193
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
194
195
196
197
198
199
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
200
            attn_metadata=attn_metadata,
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


class JAISModel(nn.Module):

    def __init__(
        self,
        config: JAISConfig,
218
        cache_config: Optional[CacheConfig] = None,
219
        quant_config: Optional[QuantizationConfig] = None,
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
    ):
        super().__init__()
        self.config = config
        assert not config.add_cross_attention
        assert not config.scale_attn_by_inverse_layer_idx
        assert not config.reorder_and_upcast_attn
        self.embed_dim = config.hidden_size
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
        self.wpe = (nn.Embedding(config.max_position_embeddings,
                                 self.embed_dim)
                    if config.position_embedding_type != "alibi" else None)
        if hasattr(config, "embeddings_scale"):
            self.embeddings_scale = config.embeddings_scale
        else:
            self.embeddings_scale = config.mup_embeddings_scale
        self.h = nn.ModuleList([
236
            JAISBlock(config, cache_config, quant_config)
237
238
239
240
241
242
243
244
            for _ in range(config.num_hidden_layers)
        ])
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
245
246
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
247
248
249
250
251
252
253
254
255
256
257
258
    ) -> torch.Tensor:
        inputs_embeds = self.wte(input_ids)
        if self.wpe is not None:
            position_embeds = self.wpe(position_ids)
            hidden_states = inputs_embeds + position_embeds
        else:
            hidden_states = inputs_embeds
        hidden_states *= torch.tensor(float(self.embeddings_scale),
                                      dtype=hidden_states.dtype)

        for i in range(len(self.h)):
            layer = self.h[i]
259
            hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
260
261
262
263
264
265
266
267
268
269

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class JAISLMHeadModel(nn.Module):

    def __init__(
        self,
        config: JAISConfig,
270
        cache_config: Optional[CacheConfig] = None,
271
        quant_config: Optional[QuantizationConfig] = None,
272
273
274
    ):
        super().__init__()
        self.config = config
275
        self.quant_config = quant_config
276
        self.transformer = JAISModel(config, cache_config, quant_config)
277
278
279
280
281
282
283
284
285
286
287
288
289
290
        self.lm_head_weight = self.transformer.wte.weight
        if hasattr(config, "width_scale"):
            self.output_logits_scale = config.width_scale
        else:
            self.output_logits_scale = (config.mup_output_alpha *
                                        config.mup_width_scale)
        self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size,
                                                scale=self.output_logits_scale)
        self.sampler = Sampler()

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
291
292
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
293
294
    ) -> torch.Tensor:
        hidden_states = self.transformer(input_ids, positions, kv_caches,
295
                                         attn_metadata)
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head_weight, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

312
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
313
        params_dict = dict(self.named_parameters(remove_duplicate=False))
314
        for name, loaded_weight in weights:
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
            if "lm_head.weight" in name:
                # GPT-2 ties the weights of the embedding layer and the final
                # linear layer.
                continue
            if ".attn.bias" in name or ".attn.masked_bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
            if "relative_pe" in name:
                continue
            if not name.startswith("transformer."):
                name = "transformer." + name
            param = params_dict[name]
            # The HF's GPT-2 implementation uses Conv1D instead of Linear.
            # Because of this, we need to transpose the weights.
            # Note(zhuohan): the logic below might break quantized models.
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                loaded_weight = loaded_weight.t()
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
339
            weight_loader(param, loaded_weight)