seed_oss.py 17.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

# Copyright 2025 The Seed team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only SeedOss model compatible with HuggingFace weights."""
25

26
from collections.abc import Iterable
27
from itertools import islice
28
29
30
31
32

import torch
from torch import nn
from transformers import PretrainedConfig as SeedOssConfig

33
34
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
35
36
37
38
39
40
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
41
42
43
44
45
from vllm.model_executor.layers.linear import (
    MergedColumnParallelLinear,
    QKVParallelLinear,
    RowParallelLinear,
)
46
47
48
49
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
50
51
52
    ParallelLMHead,
    VocabParallelEmbedding,
)
53
from vllm.model_executor.model_loader.weight_utils import (
54
55
56
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
57
from vllm.sequence import IntermediateTensors
58
from vllm.transformers_utils.config import set_default_rope_theta
59
60

from .interfaces import SupportsLoRA, SupportsPP
61
62
63
64
65
66
67
68
from .utils import (
    AutoWeightsLoader,
    PPMissingLayer,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
69
70
71
72
73
74
75
76
77
78

logger = init_logger(__name__)


class SeedOssMLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
79
        quant_config: QuantizationConfig | None = None,
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.down_proj",
        )
        if hidden_act != "silu":
98
99
100
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class SeedOssAttention(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        head_dim: int,
117
        rope_parameters: dict,
118
        max_position: int = 4096 * 32,
119
120
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        prefix: str = "",
        attn_type: str = AttentionType.DECODER,
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        self.head_dim = head_dim
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
166
            rope_parameters=rope_parameters,
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            attn_type=attn_type,
            prefix=f"{prefix}.attn",
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


class SeedOssDecoderLayer(nn.Module):
    def __init__(
        self,
        config: SeedOssConfig,
196
197
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
198
199
200
201
        prefix: str = "",
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
202
        set_default_rope_theta(config, default_theta=1000000)
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220

        # By default, SeedOss uses causal attention as it is a
        # decoder-only model.
        # You can override the HF config with `is_causal=False` to enable
        # bidirectional attention, which is used in some embedding models
        if getattr(config, "is_causal", True):
            attn_type = AttentionType.DECODER
        else:
            attn_type = AttentionType.ENCODER_ONLY

        self.self_attn = SeedOssAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            head_dim=config.head_dim,
            cache_config=cache_config,
            quant_config=quant_config,
221
            rope_parameters=config.rope_parameters,
222
223
224
225
226
227
228
229
230
231
            prefix=f"{prefix}.self_attn",
            attn_type=attn_type,
        )
        self.mlp = SeedOssMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )
232
233
234
235
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
236
237
238
239
240

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
241
        residual: torch.Tensor | None,
242
243
244
245
246
247
    ) -> tuple[torch.Tensor, torch.Tensor]:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
248
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
249
250
251
252
253
254
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
255
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
256
257
258
259
260
261
262
263
264
265
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


@support_torch_compile(
    dynamic_arg_dims={
        "input_ids": 0,
        "positions": -1,
        "intermediate_tensors": 0,
        "inputs_embeds": 0,
266
267
    }
)
268
class SeedOssModel(nn.Module):
269
270
271
272
273
274
275
    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
        decoder_layer_type: type[nn.Module] = SeedOssDecoderLayer,
    ):
276
277
278
279
280
281
282
        super().__init__()

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

        # TODO (@robertgshaw2): see if this can be moved out
283
284
285
        if cache_config.sliding_window is not None and hasattr(
            config, "max_window_layers"
        ):
286
287
288
289
290
291
292
            assert config.max_window_layers == config.num_hidden_layers, (
                "Sliding window for some but all layers is not supported. "
                "This model uses sliding window but `max_window_layers` = {} "
                "is less than `num_hidden_layers` = {}. Please open an issue "
                "to discuss this feature.".format(
                    config.max_window_layers,
                    config.num_hidden_layers,
293
294
                )
            )
295
296
297
298
299

        self.config = config
        self.quant_config = quant_config
        self.vocab_size = config.vocab_size

300
301
302
        if get_pp_group().is_first_rank or (
            config.tie_word_embeddings and get_pp_group().is_last_rank
        ):
303
304
305
306
307
308
309
310
311
312
313
314
315
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=f"{prefix}.embed_tokens",
            )
        else:
            self.embed_tokens = PPMissingLayer()

        # Use the provided decoder layer type or default to SeedDecoderLayer
        decoder_layer_type = decoder_layer_type or SeedOssDecoderLayer
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
316
317
318
319
320
321
            lambda prefix: decoder_layer_type(
                config=config,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=prefix,
            ),
322
323
324
            prefix=f"{prefix}.layers",
        )

325
326
327
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
328
329
330
331
332
        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()

333
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
334
335
336
337
338
339
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
340
341
342
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
343
344
345
346
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
347
                hidden_states = self.embed_input_ids(input_ids)
348
349
350
351
352
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
353
        for layer in islice(self.layers, self.start_layer, self.end_layer):
354
355
356
357
358
359
            hidden_states, residual = layer(
                positions,
                hidden_states,
                residual,
            )
        if not get_pp_group().is_last_rank:
360
361
362
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
363
364
365
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

366
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
367
368
369
370
371
372
373
374
375
376
377
378
379
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
380
381
382
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
383
384
                # Loading kv cache quantization scales
                param = params_dict[scale_name]
385
386
387
388
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
389
390
391
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
392
            for param_name, weight_name, shard_id in stacked_params_mapping:
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue
                if is_pp_missing_parameter(name, self):
                    continue
                param = params_dict[name]
416
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class SeedOssForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

        self.config = config

        self.quant_config = quant_config
443
444
445
        self.model = SeedOssModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
446
447
448
449
450

        if get_pp_group().is_last_rank:
            if config.tie_word_embeddings:
                self.lm_head = self.model.embed_tokens
            else:
451
452
453
454
455
456
                self.lm_head = ParallelLMHead(
                    config.vocab_size,
                    config.hidden_size,
                    quant_config=quant_config,
                    prefix=maybe_prefix(prefix, "lm_head"),
                )
457
458
459
460
461
462
        else:
            self.lm_head = PPMissingLayer()

        self.logits_processor = LogitsProcessor(config.vocab_size)

        self.make_empty_intermediate_tensors = (
463
464
            self.model.make_empty_intermediate_tensors
        )
465

466
467
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
468
469
470
471
472

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
473
474
475
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
476
477
478
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
479
480
481
482
483
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
484
    ) -> torch.Tensor | None:
485
        logits = self.logits_processor(self.lm_head, hidden_states)
486
487
        return logits

488
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
489
490
        loader = AutoWeightsLoader(
            self,
491
            skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
492
493
        )
        return loader.load_weights(weights)