deepseek_nextn.py 15.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Inference-only DeepSeek NextN Speculative Decoding."""
16
import logging
17
18
19
20
21
22
from typing import Iterable, Optional, Tuple

import torch
from torch import nn
from transformers import PretrainedConfig

23
from sglang.srt.distributed import get_tensor_model_parallel_world_size
24
25
26
27
28
29
30
31
32
33
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ReplicatedLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_utils import (
    block_quant_to_tensor_quant,
    normalize_e4m3fn_to_e4m3fnuz,
)
34
35
36
from sglang.srt.layers.quantization.int8_utils import (
    block_dequant as int8_block_dequant,
)
37
38
39
40
41
42
43
44
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
45
from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda, is_hip
46

47
_is_hip = is_hip()
Yineng Zhang's avatar
Yineng Zhang committed
48
49
50
51
_is_cuda = is_cuda()

if _is_cuda:
    from sgl_kernel import awq_dequantize
Yineng Zhang's avatar
Yineng Zhang committed
52
else:
Lianmin Zheng's avatar
Lianmin Zheng committed
53
    from vllm._custom_ops import awq_dequantize
54
55


56
57
58
logger = logging.getLogger(__name__)


59
60
61
62
63
class DeepseekModelNextN(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
64
        prefix: str = "",
65
66
67
68
69
70
71
72
    ) -> None:
        super().__init__()
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
            enable_tp=not global_server_args_dict["enable_dp_attention"],
73
            prefix=add_prefix("embed_tokens", prefix),
74
75
76
77
78
79
80
81
        )

        self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)

        self.decoder = DeepseekV2DecoderLayer(
82
83
84
85
86
            config,
            0,
            quant_config=quant_config,
            is_nextn=True,
            prefix=add_prefix("decoder", prefix),
87
88
89
90
91
92
93
94
95
96
97
98
        )

        self.shared_head = nn.Module()
        self.shared_head.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        forward_batch: ForwardBatch,
        input_embeds: torch.Tensor = None,
    ) -> torch.Tensor:
99
100
101
        zero_allocator = BumpAllocator(
            buffer_size=2,
            dtype=torch.float32,
102
103
104
            device=(
                input_embeds.device if input_embeds is not None else input_ids.device
            ),
105
106
        )

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        if input_embeds is None:
            hidden_states = self.embed_tokens(input_ids)
        else:
            hidden_states = input_embeds

        hidden_states = self.eh_proj(
            torch.cat(
                (
                    self.enorm(hidden_states),
                    self.hnorm(forward_batch.spec_info.hidden_states),
                ),
                dim=-1,
            )
        )

        residual = None
        hidden_states, residual = self.decoder(
124
            positions, hidden_states, forward_batch, residual, zero_allocator
125
126
127
128
129
130
131
132
133
134
135
136
137
        )

        if not forward_batch.forward_mode.is_idle():
            hidden_states, _ = self.shared_head.norm(hidden_states, residual)
        return hidden_states


class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
138
        prefix: str = "",
139
140
141
    ) -> None:
        nn.Module.__init__(self)
        self.config = config
142
        self.tp_size = get_tensor_model_parallel_world_size()
143
        self.quant_config = quant_config
144
        self.determine_n_share_experts_fusion("DeepseekV3ForCausalLMNextN")
145

146
147
148
        self.model = DeepseekModelNextN(
            config, quant_config, prefix=add_prefix("model", prefix)
        )
149
150

        if global_server_args_dict["enable_dp_attention"]:
151
            self.lm_head = ReplicatedLinear(
152
153
154
                config.hidden_size,
                config.vocab_size,
                bias=False,
155
                prefix=add_prefix("model.shared_head.head", prefix),
156
157
158
            )
            self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
        else:
159
            self.lm_head = ParallelLMHead(
160
161
162
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
163
                prefix=add_prefix("model.shared_head.head", prefix),
164
165
166
167
168
169
170
171
172
173
174
175
            )
            self.logits_processor = LogitsProcessor(config)

    @torch.no_grad()
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, forward_batch)
        return self.logits_processor(
176
            input_ids, hidden_states, self.lm_head, forward_batch
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
        )

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        if hasattr(self.config, "num_nextn_predict_layers"):
            num_nextn_layers = self.config.num_nextn_predict_layers
            assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
            assert num_nextn_layers == self.config.num_hidden_layers
        else:
            raise ValueError("num_nextn_predict_layers is not in the config")

        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
        if self.n_share_experts_fusion > 0:
            logger.info(
                f"Cloning {self.n_share_experts_fusion} "
                "replicas of the shared expert into MoE for DeepseekV3ForCausalLMNextN"
            )
            weights_list = list(weights)
            weights_dict = dict(weights_list)
            if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
                suffix_list = [
                    "down_proj.weight",
                    "down_proj.weight_scale",
                    "gate_proj.weight",
                    "gate_proj.weight_scale",
                    "up_proj.weight",
                    "up_proj.weight_scale",
                ]
            else:
                suffix_list = [
                    "down_proj.weight",
                    "down_proj.weight_scale_inv",
                    "gate_proj.weight",
                    "gate_proj.weight_scale_inv",
                    "up_proj.weight",
                    "up_proj.weight_scale_inv",
                ]
            names_to_remove = []
            for num_repeat in range(self.n_share_experts_fusion):
                for suffix in suffix_list:
                    shared_expert_weight_name = (
                        f"model.layers.0.mlp.shared_experts.{suffix}"
                    )
                    weights_list.append(
                        (
                            f"model.layers.0."
                            f"mlp.experts."
                            f"{self.config.n_routed_experts + num_repeat}"
                            f".{suffix}",
                            weights_dict[shared_expert_weight_name],
                        )
                    )
                    names_to_remove += [shared_expert_weight_name]
            weights = [w for w in weights_list if w[0] not in names_to_remove]
234
235
236
237
238
239
240
241

        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
        expert_params_mapping = MoEImpl.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
242
            num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        )

        nextn_layer_prefix = "model.layers.0"
        nextn_spec_weight_names = [
            "shared_head.norm",
            "eh_proj",
            "enorm",
            "hnorm",
        ]

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if not name.startswith(nextn_layer_prefix):
                continue
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271

            # Use shared head and embed weights from target model
            if "shared_head.head" in name or "embed_tokens" in name:
                continue

            is_decoder = True
            # For nextn specific weights
            for weight_name in nextn_spec_weight_names:
                if weight_name in name:
                    name = name.replace(nextn_layer_prefix, "model")
                    is_decoder = False
                    break
            # For decoder layer weights
            if is_decoder:
                name = name.replace(nextn_layer_prefix, "model.decoder")
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321

            if "rotary_emb.inv_freq" in name:
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if ("mlp.experts." in name) and name not in params_dict:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(
                        param,
                        loaded_weight,
                        name,
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue

                    param = params_dict[name]
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
                    weight_loader(param, loaded_weight)

Baizhou Zhang's avatar
Baizhou Zhang committed
322
323
324
325
326
327
328
329
330
        self_attn = self.model.decoder.self_attn
        if hasattr(self_attn.kv_b_proj, "qweight"):
            # AWQ compatible
            if _is_cuda:
                w = awq_dequantize(
                    self_attn.kv_b_proj.qweight,
                    self_attn.kv_b_proj.scales,
                    self_attn.kv_b_proj.qzeros,
                ).T
331
            else:
Baizhou Zhang's avatar
Baizhou Zhang committed
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
                w = awq_dequantize(
                    self_attn.kv_b_proj.qweight,
                    self_attn.kv_b_proj.scales,
                    self_attn.kv_b_proj.qzeros,
                    0,
                    0,
                    0,
                ).T
        else:
            w = self_attn.kv_b_proj.weight
        # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
        # This may affect the accuracy of fp8 model.
        if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
            torch.float8_e4m3fn,
            torch.float8_e4m3fnuz,
        ):
            weight_block_size = self.quant_config.weight_block_size
            if weight_block_size is not None:
                assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                if _is_hip:
                    weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                        weight=w,
                        weight_scale=self_attn.kv_b_proj.weight_scale_inv,
                        input_scale=None,
                    )
                else:
                    weight = w
                    weight_scale = self_attn.kv_b_proj.weight_scale_inv

                w, scale = block_quant_to_tensor_quant(
                    weight, weight_scale, weight_block_size
                )
                self_attn.w_scale = scale
        if w.dtype == torch.int8:
            if hasattr(self.quant_config, "weight_block_size"):
                # block-wise int8 need it
368
369
370
                weight_block_size = self.quant_config.weight_block_size
                if weight_block_size is not None:
                    assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
Baizhou Zhang's avatar
Baizhou Zhang committed
371
372
373
                    weight = w
                    weight_scale = self_attn.kv_b_proj.weight_scale_inv
                    w = int8_block_dequant(weight, weight_scale, weight_block_size).to(
374
375
                        torch.bfloat16
                    )
Baizhou Zhang's avatar
Baizhou Zhang committed
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
            else:
                # channel-wise int8 need it
                assert hasattr(self_attn.kv_b_proj, "weight_scale")
                w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
                    torch.bfloat16
                )
        w_kc, w_vc = w.unflatten(
            0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
        ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
        self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
        self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
        if hasattr(self_attn.kv_b_proj, "weight_scale") and self_attn.w_scale is None:
            self_attn.w_scale = self_attn.kv_b_proj.weight_scale
            if _is_hip:
                self_attn.w_scale *= 2.0
391
392
393


EntryClass = [DeepseekV3ForCausalLMNextN]