Commit d5538a81 authored by 王敏's avatar 王敏
Browse files

[Feature]w4a8适配低延迟模式

parent aef3c487
......@@ -38,7 +38,7 @@ from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.import_utils import has_deep_gemm
from vllm.model_executor.layers.activation import SiluAndMul
from lightop import fuse_silu_mul_quant_ep
from lightop import fuse_silu_mul_quant_ep, m_grouped_w4a8_gemm_nt_masked
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
if has_deep_gemm():
from deepgemm import m_grouped_w8a8_gemm_nt_masked
......@@ -649,6 +649,21 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# act_out = self.act_fn(workspace1)
# a2q, a2q_scale = per_token_quant_int8(act_out)
# moe_grouped_gemm(a2q, w2, a2q_scale, self.w2_scale, expert_num_tokens, output)
elif self.quant_config.use_int4_w4a8:
m_grouped_w4a8_gemm_nt_masked((a1q, a1q_scale),
(w1, self.w1_scale),
workspace1,
expert_num_tokens,
expected_m,
)
a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens)
m_grouped_w4a8_gemm_nt_masked((a2q, a2q_scale),
(w2, self.w2_scale),
output,
expert_num_tokens,
expected_m)
else:
raise ValueError(f"Unsupported dtype {self.quant_config.quant_dtype}")
......@@ -579,6 +579,28 @@ def int8_w8a8_moe_quant_config(
block_shape=block_shape,
)
def int8_w4a8_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
per_act_token_quant: bool = False,
block_shape: list[int] | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for int8 activations and int8 weights.
"""
return FusedMoEQuantConfig.make(
torch.int8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=False,
block_shape=block_shape,
)
def gptq_marlin_moe_quant_config(
w1_scale: torch.Tensor,
......
from typing import Any, Callable, Dict, List, Optional
import os
import torch
from torch.nn.parameter import Parameter
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import get_current_vllm_config
from vllm.model_executor.utils import set_weight_attrs
from vllm.distributed import get_tensor_model_parallel_world_size
from torch.nn.parameter import Parameter
from vllm.distributed import get_tensor_model_parallel_world_size, get_ep_group, get_dp_group
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase)
from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig,
QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.w4a8_utils import w4a8_weight_repack_impl
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
FusedMoeWeightScaleSupported,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
FusedMoEActivationFormat)
from vllm.model_executor.layers.fused_moe.config import (FusedMoEQuantConfig,
int8_w4a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel)
from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod
......@@ -23,6 +29,8 @@ try:
except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
logger = init_logger(__name__)
class MarlinMoeWorkspace:
"""
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
......@@ -148,10 +156,31 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
self.quant_config = quant_config
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
self.moe_mk: Optional[FusedMoEModularKernel] = None
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.dp_size = get_dp_group().world_size
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
self.num_dispatchers = all2all_manager.world_size
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) :
return None
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return int8_w4a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=True,
block_shape=[256, 256] if self.use_deepep else None,
)
def create_weights(
self,
......@@ -162,7 +191,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
params_dtype: torch.dtype,
**extra_weight_attrs,
):
tp_size = get_tensor_model_parallel_world_size()
if self.use_deepep:
self.N = 2 * intermediate_size
self.K = hidden_size
# WEIGHTS
w13_weight = torch.nn.Parameter(
......@@ -251,3 +282,39 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts,
)
if (
prepare_finalize.activation_format
== FusedMoEActivationFormat.BatchedExperts
):
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__)
return BatchedDeepGemmExperts(
moe_config=self.moe,
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
N=self.N,
K=self.K
)
else:
logger.debug("DeepGemmExperts(%s)", self.__class__.__name__)
return DeepGemmExperts(moe_config=self.moe,
quant_config=self.moe_quant_config,
N=self.N,
K=self.K)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment