Commit d5538a81 authored by 王敏's avatar 王敏
Browse files

[Feature]w4a8适配低延迟模式

parent aef3c487
...@@ -38,7 +38,7 @@ from vllm.utils.math_utils import cdiv, round_up ...@@ -38,7 +38,7 @@ from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.import_utils import has_deep_gemm from vllm.utils.import_utils import has_deep_gemm
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from lightop import fuse_silu_mul_quant_ep from lightop import fuse_silu_mul_quant_ep, m_grouped_w4a8_gemm_nt_masked
from lmslim.layers.gemm.int8_utils import per_token_quant_int8 from lmslim.layers.gemm.int8_utils import per_token_quant_int8
if has_deep_gemm(): if has_deep_gemm():
from deepgemm import m_grouped_w8a8_gemm_nt_masked from deepgemm import m_grouped_w8a8_gemm_nt_masked
...@@ -649,6 +649,21 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -649,6 +649,21 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# act_out = self.act_fn(workspace1) # act_out = self.act_fn(workspace1)
# a2q, a2q_scale = per_token_quant_int8(act_out) # a2q, a2q_scale = per_token_quant_int8(act_out)
# moe_grouped_gemm(a2q, w2, a2q_scale, self.w2_scale, expert_num_tokens, output) # moe_grouped_gemm(a2q, w2, a2q_scale, self.w2_scale, expert_num_tokens, output)
elif self.quant_config.use_int4_w4a8:
m_grouped_w4a8_gemm_nt_masked((a1q, a1q_scale),
(w1, self.w1_scale),
workspace1,
expert_num_tokens,
expected_m,
)
a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens)
m_grouped_w4a8_gemm_nt_masked((a2q, a2q_scale),
(w2, self.w2_scale),
output,
expert_num_tokens,
expected_m)
else: else:
raise ValueError(f"Unsupported dtype {self.quant_config.quant_dtype}") raise ValueError(f"Unsupported dtype {self.quant_config.quant_dtype}")
...@@ -579,6 +579,28 @@ def int8_w8a8_moe_quant_config( ...@@ -579,6 +579,28 @@ def int8_w8a8_moe_quant_config(
block_shape=block_shape, block_shape=block_shape,
) )
def int8_w4a8_moe_quant_config(
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
per_act_token_quant: bool = False,
block_shape: list[int] | None = None,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for int8 activations and int8 weights.
"""
return FusedMoEQuantConfig.make(
torch.int8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
per_act_token_quant=per_act_token_quant,
per_out_ch_quant=False,
block_shape=block_shape,
)
def gptq_marlin_moe_quant_config( def gptq_marlin_moe_quant_config(
w1_scale: torch.Tensor, w1_scale: torch.Tensor,
......
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import os import os
import torch import torch
from torch.nn.parameter import Parameter
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.config import get_current_vllm_config
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size, get_ep_group, get_dp_group
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase) from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase)
from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig, from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig,
QuantizeMethodBase) QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.w4a8_utils import w4a8_weight_repack_impl from vllm.model_executor.layers.quantization.utils.w4a8_utils import w4a8_weight_repack_impl
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported) FusedMoeWeightScaleSupported,
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
FusedMoEActivationFormat)
from vllm.model_executor.layers.fused_moe.config import (FusedMoEQuantConfig,
int8_w4a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel) FusedMoEModularKernel)
from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod
...@@ -23,6 +29,8 @@ try: ...@@ -23,6 +29,8 @@ try:
except Exception: except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n") print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
logger = init_logger(__name__)
class MarlinMoeWorkspace: class MarlinMoeWorkspace:
""" """
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE. Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
...@@ -148,10 +156,31 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -148,10 +156,31 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
self.quant_config = quant_config self.quant_config = quant_config
self.moe_quant_config: Optional[FusedMoEQuantConfig] = None self.moe_quant_config: Optional[FusedMoEQuantConfig] = None
self.moe_mk: Optional[FusedMoEModularKernel] = None self.moe_mk: Optional[FusedMoEModularKernel] = None
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.dp_size = get_dp_group().world_size
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
self.num_dispatchers = all2all_manager.world_size
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module) : self, layer: torch.nn.Module
return None ) -> FusedMoEQuantConfig | None:
return int8_w4a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=True,
block_shape=[256, 256] if self.use_deepep else None,
)
def create_weights( def create_weights(
self, self,
...@@ -162,7 +191,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -162,7 +191,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
): ):
tp_size = get_tensor_model_parallel_world_size() if self.use_deepep:
self.N = 2 * intermediate_size
self.K = hidden_size
# WEIGHTS # WEIGHTS
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
...@@ -251,3 +282,39 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -251,3 +282,39 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
shared_output=shared_output, shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
) )
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts,
)
if (
prepare_finalize.activation_format
== FusedMoEActivationFormat.BatchedExperts
):
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__)
return BatchedDeepGemmExperts(
moe_config=self.moe,
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
N=self.N,
K=self.K
)
else:
logger.debug("DeepGemmExperts(%s)", self.__class__.__name__)
return DeepGemmExperts(moe_config=self.moe,
quant_config=self.moe_quant_config,
N=self.N,
K=self.K)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment