Commit 9aadeed6 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch '092dev_DS_V2_w8a8_reduce_Moe_quant' into 'v0.9.2-dev'

[pref]: DS_v2_w8a8模型融掉moe.quant

See merge request dcutoolkit/deeplearing/vllm!268
parents a363d2c3 68972532
...@@ -1394,13 +1394,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1394,13 +1394,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = 1.0) -> None: routed_scaling_factor: Optional[float] = 1.0,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8, activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,use_int4_w4a8, use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,use_int4_w4a8,
per_channel_quant, global_num_experts, expert_map, per_channel_quant, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe, shared_output, routed_scaling_factor) block_shape, use_nn_moe, shared_output, routed_scaling_factor, i_q=i_q, i_s=i_s)
def inplace_fused_experts_fake( def inplace_fused_experts_fake(
...@@ -1428,7 +1430,9 @@ def inplace_fused_experts_fake( ...@@ -1428,7 +1430,9 @@ def inplace_fused_experts_fake(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = 1.0) -> None: routed_scaling_factor: Optional[float] = 1.0,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None) -> None:
pass pass
...@@ -1466,7 +1470,9 @@ def outplace_fused_experts( ...@@ -1466,7 +1470,9 @@ def outplace_fused_experts(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = 1.0) -> torch.Tensor: routed_scaling_factor: Optional[float] = 1.0,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input, False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
...@@ -1500,7 +1506,9 @@ def outplace_fused_experts_fake( ...@@ -1500,7 +1506,9 @@ def outplace_fused_experts_fake(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = 1.0) -> torch.Tensor: routed_scaling_factor: Optional[float] = 1.0,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -1559,7 +1567,9 @@ def fused_experts( ...@@ -1559,7 +1567,9 @@ def fused_experts(
allow_cutlass_block_scaled_grouped_gemm: bool = False, allow_cutlass_block_scaled_grouped_gemm: bool = False,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = 1.0) -> torch.Tensor: routed_scaling_factor: Optional[float] = 1.0,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_) -> torch.Tensor:
# For now, disable DeepGemm for small N (<= 512) until better # For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available. # permute/unpermute ops are available.
N = w1.size(1) N = w1.size(1)
...@@ -1594,6 +1604,7 @@ def fused_experts( ...@@ -1594,6 +1604,7 @@ def fused_experts(
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids) topk_ids=topk_ids)
else: else:
# Fused MoE quantization only 4 DS w8a8 now
return dispatch_fused_experts_func(inplace)( return dispatch_fused_experts_func(inplace)(
hidden_states=hidden_states, hidden_states=hidden_states,
w1=w1, w1=w1,
...@@ -1619,7 +1630,9 @@ def fused_experts( ...@@ -1619,7 +1630,9 @@ def fused_experts(
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
shared_output=shared_output, shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor) routed_scaling_factor=routed_scaling_factor,
i_q=i_q,
i_s=i_s)
def fused_experts_impl( def fused_experts_impl(
...@@ -1649,6 +1662,8 @@ def fused_experts_impl( ...@@ -1649,6 +1662,8 @@ def fused_experts_impl(
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = 1.0, routed_scaling_factor: Optional[float] = 1.0,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens = hidden_states.size(0) num_tokens = hidden_states.size(0)
if use_nn_moe: if use_nn_moe:
...@@ -1695,8 +1710,9 @@ def fused_experts_impl( ...@@ -1695,8 +1710,9 @@ def fused_experts_impl(
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=False, use_nn_moe=False,
shared_output=shared_output, shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor routed_scaling_factor=routed_scaling_factor,
) i_q=i_q,
i_s=i_s)
elif use_int4_w4a8 is True: elif use_int4_w4a8 is True:
return fused_experts_impl_w4a8(hidden_states=hidden_states, return fused_experts_impl_w4a8(hidden_states=hidden_states,
w1=w1, w1=w1,
......
...@@ -8,7 +8,7 @@ import importlib ...@@ -8,7 +8,7 @@ import importlib
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from enum import Enum from enum import Enum
from typing import Callable, Literal, Optional, overload from typing import Callable, Literal, Optional, overload, Tuple, List
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -1435,14 +1435,19 @@ class FusedMoE(torch.nn.Module): ...@@ -1435,14 +1435,19 @@ class FusedMoE(torch.nn.Module):
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
shared_output: Optional[torch.Tensor] = None): shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_
):
# TODO: Once the OOM issue for the TPU backend is resolved, we will # TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op. # switch to using the moe_forward custom op.
if current_platform.is_tpu(): if current_platform.is_tpu():
assert i_q is None and i_s is None, "moe.quant fused not support TPU now"
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits)
else: else:
return torch.ops.vllm.moe_forward(hidden_states, router_logits, return torch.ops.vllm.moe_forward(hidden_states, router_logits,
self.layer_name, shared_output) self.layer_name, shared_output,
i_q, i_s)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor, def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor): full_router_logits: torch.Tensor):
...@@ -1522,7 +1527,9 @@ class FusedMoE(torch.nn.Module): ...@@ -1522,7 +1527,9 @@ class FusedMoE(torch.nn.Module):
def forward_impl(self, hidden_states: torch.Tensor, def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
shared_output: Optional[torch.Tensor] = None): shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_):
assert self.quant_method is not None assert self.quant_method is not None
if (self.moe_parallel_config.use_pplx_kernels if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels): or self.moe_parallel_config.use_deepep_ll_kernels):
...@@ -1559,7 +1566,9 @@ class FusedMoE(torch.nn.Module): ...@@ -1559,7 +1566,9 @@ class FusedMoE(torch.nn.Module):
shared_output=shared_output, shared_output=shared_output,
use_nn_moe=self.use_nn_moe, use_nn_moe=self.use_nn_moe,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate use_fused_gate=self.use_fused_gate,
i_q=i_q,
i_s=i_s
) )
if do_naive_dispatch_combine: if do_naive_dispatch_combine:
...@@ -1630,16 +1639,20 @@ class FusedMoE(torch.nn.Module): ...@@ -1630,16 +1639,20 @@ class FusedMoE(torch.nn.Module):
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, shared_output: Optional[torch.Tensor] = None) -> torch.Tensor: layer_name: str, shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None assert self.quant_method is not None
return self.forward_impl(hidden_states, router_logits, shared_output) return self.forward_impl(hidden_states, router_logits, shared_output, i_q, i_s)
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, shared_output: Optional[torch.Tensor] = None) -> torch.Tensor: layer_name: str, shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
......
...@@ -406,7 +406,9 @@ class ReplicatedLinear(LinearBase): ...@@ -406,7 +406,9 @@ class ReplicatedLinear(LinearBase):
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
quant_args: Optional[list] = None, quant_args: Optional[list] = None,
update_hd: Optional[bool] = True update_hd: Optional[bool] = True
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> Union[torch.Tensor,
tuple[torch.Tensor, Optional[Parameter]],
tuple[torch.Tensor, torch.Tensor, Optional[Parameter], list[torch.Tensor]]]:
if envs.USE_FUSED_RMS_QUANT and (rms_weight is not None or quant_args is not None): if envs.USE_FUSED_RMS_QUANT and (rms_weight is not None or quant_args is not None):
if quant_args is not None: if quant_args is not None:
input_quant_args = quant_args input_quant_args = quant_args
...@@ -601,7 +603,9 @@ class ColumnParallelLinear(LinearBase): ...@@ -601,7 +603,9 @@ class ColumnParallelLinear(LinearBase):
rms_weight: Optional[torch.Tensor] = None, rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True update_hd: Optional[bool] = True
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> Union[torch.Tensor,
tuple[torch.Tensor, Optional[Parameter]],
tuple[torch.Tensor, torch.Tensor, Optional[Parameter]]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None: if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
input_quant_args = None input_quant_args = None
assert rms_weight is not None assert rms_weight is not None
...@@ -680,7 +684,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -680,7 +684,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True, update_hd: Optional[bool] = True,
xqxs: Optional[tuple] = None xqxs: Optional[tuple] = None
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]: ) -> Union[torch.Tensor,
tuple[torch.Tensor, Optional[Parameter]],
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[Parameter]],
]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None: if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
input_quant_args = None input_quant_args = None
assert residual is not None and rms_weight is not None assert residual is not None and rms_weight is not None
...@@ -707,7 +714,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -707,7 +714,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
output_bias = self.bias if self.skip_bias_add else None output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias: if not self.return_bias:
return output return output
return output, new_residual, output_bias return output, new_residual, i_q, _scales, output_bias
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None: elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
......
...@@ -670,7 +670,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase): ...@@ -670,7 +670,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
input_quant_args: Optional[list[torch.Tensor]] = None, input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None): silu_quant_args: Optional[list[torch.Tensor]] = None, **_):
""" """
Use the output of create_weights and the CompressedTensorsScheme Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the associated with the layer to apply the forward pass with the
......
...@@ -1096,6 +1096,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1096,6 +1096,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_
) -> torch.Tensor: ) -> torch.Tensor:
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
...@@ -1137,7 +1139,9 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1137,7 +1139,9 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
use_nn_moe=False, use_nn_moe=False,
shared_output=shared_output, shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor) routed_scaling_factor=routed_scaling_factor,
i_q=i_q,
i_s=i_s)
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import enum import enum
from enum import Enum from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional, List
import torch import torch
from compressed_tensors.quantization import (QuantizationStrategy) from compressed_tensors.quantization import (QuantizationStrategy)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -163,6 +163,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -163,6 +163,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_
) -> torch.Tensor: ) -> torch.Tensor:
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
...@@ -203,5 +205,6 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod) ...@@ -203,5 +205,6 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
use_nn_moe=False, use_nn_moe=False,
shared_output=shared_output, shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor) routed_scaling_factor=routed_scaling_factor,
i_q=i_q,
i_s=i_s)
...@@ -113,7 +113,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -113,7 +113,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor], bias: Optional[torch.Tensor],
input_quant_args: Optional[list[torch.Tensor]] = None, input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None silu_quant_args: Optional[list[torch.Tensor]] = None, **_
) -> torch.Tensor: ) -> torch.Tensor:
# return self.kernel.apply_weights(layer, x, bias) # return self.kernel.apply_weights(layer, x, bias)
......
...@@ -156,7 +156,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -156,7 +156,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
input_quant_args: Optional[list[torch.Tensor]] = None, input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None silu_quant_args: Optional[list[torch.Tensor]] = None, **_
): ):
if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None: if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2 assert len(input_quant_args) == 2
......
...@@ -407,7 +407,7 @@ def apply_int8_linear( ...@@ -407,7 +407,7 @@ def apply_int8_linear(
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
w8a8_strategy:Optional[int]=0, w8a8_strategy:Optional[int]=0,
input_quant_args: Optional[list[torch.Tensor]] = None, input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None silu_quant_args: Optional[list[torch.Tensor]] = None, **_
): ):
# ops.scaled_int8_quant supports both dynamic and static quant. # ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x. # * dynamic, layer.input_scale is None and x_scale computed from x.
......
...@@ -97,16 +97,18 @@ class DeepseekV2MLP(nn.Module): ...@@ -97,16 +97,18 @@ class DeepseekV2MLP(nn.Module):
rms_weight: Optional[torch.Tensor] = None, rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = False, update_hd: Optional[bool] = False,
xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None): xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
gate_up, new_resi, _ = self.gate_up_proj(x, rms_weight, residual, update_hd=update_hd) gate_up, new_resi, i_q, _scales, _ = self.gate_up_proj(x, rms_weight, residual, update_hd=update_hd)
if envs.USE_FUSED_SILU_MUL_QUANT: if envs.USE_FUSED_SILU_MUL_QUANT:
x, _ = self.down_proj(gate_up, use_fused_silu_mul_quant=True) x, _ = self.down_proj(gate_up, use_fused_silu_mul_quant=True)
else: else:
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x) x, _ = self.down_proj(x)
return x, new_resi return x, new_resi, i_q, _scales
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None: elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
gate_up, _ = self.gate_up_proj(x, xqxs=xqxs) gate_up, _ = self.gate_up_proj(x, xqxs=xqxs)
if envs.USE_FUSED_SILU_MUL_QUANT: if envs.USE_FUSED_SILU_MUL_QUANT:
...@@ -210,7 +212,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -210,7 +212,8 @@ class DeepseekV2MoE(nn.Module):
rms_weight: Optional[torch.Tensor] = None, rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> torch.Tensor: ) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None: if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
...@@ -255,9 +258,10 @@ class DeepseekV2MoE(nn.Module): ...@@ -255,9 +258,10 @@ class DeepseekV2MoE(nn.Module):
else: else:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
i_q, i_s = None, None
if self.n_shared_experts is not None: if self.n_shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True) shared_output, new_resi, i_q, i_s = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
else: else:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
...@@ -268,15 +272,18 @@ class DeepseekV2MoE(nn.Module): ...@@ -268,15 +272,18 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
shared_output=shared_output) shared_output=shared_output,
i_q=i_q, i_s=i_s)
else: else:
if hidden_states.dtype != torch.float16: if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts( final_hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor router_logits=router_logits,
i_q=i_q, i_s=i_s) * self.routed_scaling_factor
else: else:
# Fix FP16 overflow # Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
# fp16 mode not fused quant
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
...@@ -298,7 +305,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -298,7 +305,7 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states)) final_hidden_states))
if envs.USE_FUSED_RMS_QUANT: if envs.USE_FUSED_RMS_QUANT:
return final_hidden_states.view(num_tokens, hidden_dim), new_resi return final_hidden_states.view(num_tokens, hidden_dim), new_resi, i_q, i_s
else: else:
return final_hidden_states.view(num_tokens, hidden_dim) return final_hidden_states.view(num_tokens, hidden_dim)
...@@ -614,8 +621,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -614,8 +621,7 @@ class DeepseekV2MLAAttention(nn.Module):
update_input: Optional[bool] = True update_input: Optional[bool] = True
) -> Union[torch.Tensor, ) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None: if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
q_c, new_residual, _, input_quant_args = self.q_a_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False) q_c, new_residual, _, input_quant_args = self.q_a_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False)
...@@ -816,7 +822,10 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -816,7 +822,10 @@ class DeepseekV2DecoderLayer(nn.Module):
# first layer. # first layer.
residual *= 1. / self.routed_scaling_factor residual *= 1. / self.routed_scaling_factor
hidden_states, new_resi = self.mlp(hidden_states, self.post_attention_layernorm.weight.data, residual) hidden_states, new_resi, _i_q, _scales = self.mlp(hidden_states,
rms_weight=self.post_attention_layernorm.weight.data,
residual=residual,
)
if isinstance(self.mlp, if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16: DeepseekV2MLP) and hidden_states.dtype == torch.float16:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment