Commit 38b14e8b authored by wujl5's avatar wujl5 Committed by zhangqha
Browse files

clean code for Deepseek V2.

parent cca00f5c
...@@ -1613,6 +1613,8 @@ def fused_experts( ...@@ -1613,6 +1613,8 @@ def fused_experts(
quant_config: FusedMoEQuantConfig | None = None, quant_config: FusedMoEQuantConfig | None = None,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False, use_fused_gate: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None # TODO:wjl
) -> torch.Tensor: ) -> torch.Tensor:
if quant_config is None: if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
......
...@@ -335,7 +335,7 @@ class FusedRMSNormQuant(nn.Module): ...@@ -335,7 +335,7 @@ class FusedRMSNormQuant(nn.Module):
quant_dtype: torch.dtype = torch.int8, quant_dtype: torch.dtype = torch.int8,
update_input: Optional[bool] = True update_input: Optional[bool] = True
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
i_q, i_s = torch.ops.vllm.fused_rmsquant(input=x, i_q, i_s = torch.ops.vllm.fused_rmsquant_customer_impl(input=x,
weight=self.weight, weight=self.weight,
epsilon=self.variance_epsilon, epsilon=self.variance_epsilon,
quant_dtype=quant_dtype, quant_dtype=quant_dtype,
...@@ -383,9 +383,9 @@ def fused_rmsquant_fake( ...@@ -383,9 +383,9 @@ def fused_rmsquant_fake(
# customer_lib = Library("customer_", "FRAGMENT") # customer_lib = Library("customer_", "FRAGMENT")
direct_register_custom_op( direct_register_custom_op(
op_name="fused_rmsquant", op_name="fused_rmsquant_customer_impl",
op_func=fused_rmsquant_impl, op_func=fused_rmsquant_impl,
mutates_args=[], mutates_args=["input", "residual"],
fake_impl=fused_rmsquant_fake, fake_impl=fused_rmsquant_fake,
) )
......
...@@ -711,32 +711,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -711,32 +711,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
disable_tp: If true, all weights matrix won't be sharded, this layer disable_tp: If true, all weights matrix won't be sharded, this layer
will be treated as a "Replicated" MergedLinear. will be treated as a "Replicated" MergedLinear.
""" """
def forward(
self,
input_,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
else:
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def __init__( def __init__(
self, self,
input_size: int, input_size: int,
......
...@@ -1256,7 +1256,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1256,7 +1256,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
i_q: torch.Tensor | None = None, i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None, i_s: torch.Tensor | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
......
...@@ -307,6 +307,8 @@ class SlimQuantW4A8Int8MoEMethod: ...@@ -307,6 +307,8 @@ class SlimQuantW4A8Int8MoEMethod:
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False, use_fused_gate: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts( return fused_experts(
......
...@@ -224,6 +224,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -224,6 +224,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False, use_fused_gate: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers() workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment