Commit 38b14e8b authored by wujl5's avatar wujl5 Committed by zhangqha
Browse files

clean code for Deepseek V2.

parent cca00f5c
......@@ -1613,6 +1613,8 @@ def fused_experts(
quant_config: FusedMoEQuantConfig | None = None,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None # TODO:wjl
) -> torch.Tensor:
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
......
......@@ -335,7 +335,7 @@ class FusedRMSNormQuant(nn.Module):
quant_dtype: torch.dtype = torch.int8,
update_input: Optional[bool] = True
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
i_q, i_s = torch.ops.vllm.fused_rmsquant(input=x,
i_q, i_s = torch.ops.vllm.fused_rmsquant_customer_impl(input=x,
weight=self.weight,
epsilon=self.variance_epsilon,
quant_dtype=quant_dtype,
......@@ -383,9 +383,9 @@ def fused_rmsquant_fake(
# customer_lib = Library("customer_", "FRAGMENT")
direct_register_custom_op(
op_name="fused_rmsquant",
op_name="fused_rmsquant_customer_impl",
op_func=fused_rmsquant_impl,
mutates_args=[],
mutates_args=["input", "residual"],
fake_impl=fused_rmsquant_fake,
)
......
......@@ -711,32 +711,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
disable_tp: If true, all weights matrix won't be sharded, this layer
will be treated as a "Replicated" MergedLinear.
"""
def forward(
self,
input_,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor | tuple[torch.Tensor, Parameter | None]:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
if envs.USE_FUSED_RMS_QUANT and iqis is not None:
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args=iqis)
else:
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def __init__(
self,
input_size: int,
......
......@@ -1256,7 +1256,7 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
......
......@@ -307,6 +307,8 @@ class SlimQuantW4A8Int8MoEMethod:
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
......
......@@ -224,6 +224,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
topk_ids: torch.Tensor,
use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
i_q: torch.Tensor | None = None,
i_s: torch.Tensor | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment