Commit 68972532 authored by wujl5's avatar wujl5 Committed by zhuwenwen
Browse files

[pref]: DS_v2_w8a8模型融掉moe.quant

parent 7ff48a6c
......@@ -1394,13 +1394,15 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = 1.0) -> None:
routed_scaling_factor: Optional[float] = 1.0,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,use_int4_w4a8,
per_channel_quant, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape, use_nn_moe, shared_output, routed_scaling_factor)
block_shape, use_nn_moe, shared_output, routed_scaling_factor, i_q=i_q, i_s=i_s)
def inplace_fused_experts_fake(
......@@ -1428,7 +1430,9 @@ def inplace_fused_experts_fake(
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = 1.0) -> None:
routed_scaling_factor: Optional[float] = 1.0,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None) -> None:
pass
......@@ -1466,7 +1470,9 @@ def outplace_fused_experts(
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = 1.0) -> torch.Tensor:
routed_scaling_factor: Optional[float] = 1.0,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16,
......@@ -1500,7 +1506,9 @@ def outplace_fused_experts_fake(
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = 1.0) -> torch.Tensor:
routed_scaling_factor: Optional[float] = 1.0,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.empty_like(hidden_states)
......@@ -1559,7 +1567,9 @@ def fused_experts(
allow_cutlass_block_scaled_grouped_gemm: bool = False,
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = 1.0) -> torch.Tensor:
routed_scaling_factor: Optional[float] = 1.0,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_) -> torch.Tensor:
# For now, disable DeepGemm for small N (<= 512) until better
# permute/unpermute ops are available.
N = w1.size(1)
......@@ -1594,6 +1604,7 @@ def fused_experts(
topk_weights=topk_weights,
topk_ids=topk_ids)
else:
# Fused MoE quantization only 4 DS w8a8 now
return dispatch_fused_experts_func(inplace)(
hidden_states=hidden_states,
w1=w1,
......@@ -1619,7 +1630,9 @@ def fused_experts(
block_shape=block_shape,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
routed_scaling_factor=routed_scaling_factor,
i_q=i_q,
i_s=i_s)
def fused_experts_impl(
......@@ -1649,6 +1662,8 @@ def fused_experts_impl(
use_nn_moe: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
routed_scaling_factor: Optional[float] = 1.0,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_
) -> torch.Tensor:
num_tokens = hidden_states.size(0)
if use_nn_moe:
......@@ -1695,8 +1710,9 @@ def fused_experts_impl(
block_shape=block_shape,
use_nn_moe=False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor
)
routed_scaling_factor=routed_scaling_factor,
i_q=i_q,
i_s=i_s)
elif use_int4_w4a8 is True:
return fused_experts_impl_w4a8(hidden_states=hidden_states,
w1=w1,
......
......@@ -8,7 +8,7 @@ import importlib
from abc import abstractmethod
from collections.abc import Iterable
from enum import Enum
from typing import Callable, Literal, Optional, overload
from typing import Callable, Literal, Optional, overload, Tuple, List
import torch
import torch.nn.functional as F
......@@ -1435,14 +1435,19 @@ class FusedMoE(torch.nn.Module):
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_output: Optional[torch.Tensor] = None):
shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_
):
# TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op.
if current_platform.is_tpu():
assert i_q is None and i_s is None, "moe.quant fused not support TPU now"
return self.forward_impl(hidden_states, router_logits)
else:
return torch.ops.vllm.moe_forward(hidden_states, router_logits,
self.layer_name, shared_output)
self.layer_name, shared_output,
i_q, i_s)
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor):
......@@ -1522,7 +1527,9 @@ class FusedMoE(torch.nn.Module):
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor,
shared_output: Optional[torch.Tensor] = None):
shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_):
assert self.quant_method is not None
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels):
......@@ -1559,7 +1566,9 @@ class FusedMoE(torch.nn.Module):
shared_output=shared_output,
use_nn_moe=self.use_nn_moe,
routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate
use_fused_gate=self.use_fused_gate,
i_q=i_q,
i_s=i_s
)
if do_naive_dispatch_combine:
......@@ -1630,16 +1639,20 @@ class FusedMoE(torch.nn.Module):
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, shared_output: Optional[torch.Tensor] = None) -> torch.Tensor:
layer_name: str, shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None
return self.forward_impl(hidden_states, router_logits, shared_output)
return self.forward_impl(hidden_states, router_logits, shared_output, i_q, i_s)
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str, shared_output: Optional[torch.Tensor] = None) -> torch.Tensor:
layer_name: str, shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.empty_like(hidden_states)
......
......@@ -406,7 +406,9 @@ class ReplicatedLinear(LinearBase):
residual: Optional[torch.Tensor] = None,
quant_args: Optional[list] = None,
update_hd: Optional[bool] = True
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
) -> Union[torch.Tensor,
tuple[torch.Tensor, Optional[Parameter]],
tuple[torch.Tensor, torch.Tensor, Optional[Parameter], list[torch.Tensor]]]:
if envs.USE_FUSED_RMS_QUANT and (rms_weight is not None or quant_args is not None):
if quant_args is not None:
input_quant_args = quant_args
......@@ -601,7 +603,9 @@ class ColumnParallelLinear(LinearBase):
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
) -> Union[torch.Tensor,
tuple[torch.Tensor, Optional[Parameter]],
tuple[torch.Tensor, torch.Tensor, Optional[Parameter]]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
input_quant_args = None
assert rms_weight is not None
......@@ -680,7 +684,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True,
xqxs: Optional[tuple] = None
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
) -> Union[torch.Tensor,
tuple[torch.Tensor, Optional[Parameter]],
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[Parameter]],
]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
input_quant_args = None
assert residual is not None and rms_weight is not None
......@@ -707,7 +714,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, new_residual, output_bias
return output, new_residual, i_q, _scales, output_bias
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
bias = self.bias if not self.skip_bias_add else None
......
......@@ -670,7 +670,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None):
silu_quant_args: Optional[list[torch.Tensor]] = None, **_):
"""
Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the
......
......@@ -1096,6 +1096,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
......@@ -1137,7 +1139,9 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
routed_scaling_factor=routed_scaling_factor,
i_q=i_q,
i_s=i_s)
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
......
......@@ -3,7 +3,7 @@
import enum
from enum import Enum
from typing import Callable, Optional
from typing import Callable, Optional, List
import torch
from compressed_tensors.quantization import (QuantizationStrategy)
from vllm.logger import init_logger
......@@ -163,6 +163,8 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None, **_
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
......@@ -203,5 +205,6 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
routed_scaling_factor=routed_scaling_factor,
i_q=i_q,
i_s=i_s)
......@@ -113,7 +113,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor],
input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None
silu_quant_args: Optional[list[torch.Tensor]] = None, **_
) -> torch.Tensor:
# return self.kernel.apply_weights(layer, x, bias)
......
......@@ -156,7 +156,7 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None
silu_quant_args: Optional[list[torch.Tensor]] = None, **_
):
if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2
......
......@@ -407,7 +407,7 @@ def apply_int8_linear(
bias: Optional[torch.Tensor] = None,
w8a8_strategy:Optional[int]=0,
input_quant_args: Optional[list[torch.Tensor]] = None,
silu_quant_args: Optional[list[torch.Tensor]] = None
silu_quant_args: Optional[list[torch.Tensor]] = None, **_
):
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
......
......@@ -97,16 +97,18 @@ class DeepseekV2MLP(nn.Module):
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = False,
xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None):
xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
if envs.USE_FUSED_RMS_QUANT:
gate_up, new_resi, _ = self.gate_up_proj(x, rms_weight, residual, update_hd=update_hd)
gate_up, new_resi, i_q, _scales, _ = self.gate_up_proj(x, rms_weight, residual, update_hd=update_hd)
if envs.USE_FUSED_SILU_MUL_QUANT:
x, _ = self.down_proj(gate_up, use_fused_silu_mul_quant=True)
else:
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x, new_resi
return x, new_resi, i_q, _scales
elif envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
gate_up, _ = self.gate_up_proj(x, xqxs=xqxs)
if envs.USE_FUSED_SILU_MUL_QUANT:
......@@ -210,7 +212,8 @@ class DeepseekV2MoE(nn.Module):
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
xqxs: Optional[tuple[torch.Tensor, torch.Tensor]] = None
) -> torch.Tensor:
) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
if envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and xqxs is not None:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
......@@ -255,9 +258,10 @@ class DeepseekV2MoE(nn.Module):
else:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
i_q, i_s = None, None
if self.n_shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
shared_output, new_resi, i_q, i_s = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
else:
shared_output = self.shared_experts(hidden_states)
......@@ -268,15 +272,18 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
shared_output=shared_output)
shared_output=shared_output,
i_q=i_q, i_s=i_s)
else:
if hidden_states.dtype != torch.float16:
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits) * self.routed_scaling_factor
router_logits=router_logits,
i_q=i_q, i_s=i_s) * self.routed_scaling_factor
else:
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
# fp16 mode not fused quant
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
......@@ -298,7 +305,7 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states))
if envs.USE_FUSED_RMS_QUANT:
return final_hidden_states.view(num_tokens, hidden_dim), new_resi
return final_hidden_states.view(num_tokens, hidden_dim), new_resi, i_q, i_s
else:
return final_hidden_states.view(num_tokens, hidden_dim)
......@@ -614,8 +621,7 @@ class DeepseekV2MLAAttention(nn.Module):
update_input: Optional[bool] = True
) -> Union[torch.Tensor,
Tuple[torch.Tensor, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
]:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
if self.q_lora_rank is not None:
q_c, new_residual, _, input_quant_args = self.q_a_proj(hidden_states, rms_weight=rms_weight, residual=residual, update_hd=False)
......@@ -816,7 +822,10 @@ class DeepseekV2DecoderLayer(nn.Module):
# first layer.
residual *= 1. / self.routed_scaling_factor
hidden_states, new_resi = self.mlp(hidden_states, self.post_attention_layernorm.weight.data, residual)
hidden_states, new_resi, _i_q, _scales = self.mlp(hidden_states,
rms_weight=self.post_attention_layernorm.weight.data,
residual=residual,
)
if isinstance(self.mlp,
DeepseekV2MLP) and hidden_states.dtype == torch.float16:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment