Commit ba6f2101 authored by chenhw5's avatar chenhw5 Committed by zhangzbb
Browse files

[FEATURE] GLM5 FP8 EP适配

parent e7dee10f
......@@ -157,10 +157,7 @@ def maybe_make_prepare_finalize(
# Note: We may want to use FP8 dispatch just to reduce
# data movement.
use_fp8_dispatch = (
quant_config.quant_dtype == current_platform.fp8_dtype()
and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE
)
use_fp8_dispatch = quant_config.quant_dtype == current_platform.fp8_dtype()
use_int8_dispatch = quant_config.quant_dtype == torch.int8
......
......@@ -38,10 +38,10 @@ from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.import_utils import has_deep_gemm
from vllm.model_executor.layers.activation import SiluAndMul
from lightop import fuse_silu_mul_quant_ep
from lightop import fuse_silu_mul_quant_ep, fuse_silu_mul_fp8_quant_ep
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
if has_deep_gemm():
from deepgemm import m_grouped_w8a8_gemm_nt_masked
from deepgemm import m_grouped_w8a8_gemm_nt_masked, m_grouped_fp8_gemm_nt_masked
else:
from lightop import m_grouped_w8a8_gemm_nt_masked
......@@ -452,8 +452,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
num_dispatchers=num_dispatchers,
)
if quant_config.use_fp8_w8a8:
assert self.block_shape == get_mk_alignment_for_contiguous_layout()
#if quant_config.use_fp8_w8a8:
#assert self.block_shape == get_mk_alignment_for_contiguous_layout()
self.N = N
self.K = K
......@@ -606,7 +606,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expected_m = self.get_expected_m()
if self.quant_config.use_fp8_w8a16 or self.quant_config.use_fp8_w8a8:
fp8_m_grouped_gemm_nt_masked(
m_grouped_fp8_gemm_nt_masked(
(a1q, a1q_scale),
(w1, self.w1_scale),
workspace1,
......@@ -614,14 +614,13 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
expected_m,
)
quant_scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
a2q, a2q_scale = persistent_masked_m_silu_mul_quant(
workspace1,
expert_num_tokens,
quant_scale_fmt=quant_scale_fmt,
a2q, a2q_scale = fuse_silu_mul_fp8_quant_ep(
input=workspace1,
fp8type=0,
tokens_per_expert=expert_num_tokens,
)
fp8_m_grouped_gemm_nt_masked(
m_grouped_fp8_gemm_nt_masked(
(a2q, a2q_scale),
(w2, self.w2_scale),
output,
......
......@@ -87,14 +87,14 @@ def _quant_flags_to_group_shape(
"""
a_shape: GroupShape | None
w_shape: GroupShape | None
if block_shape is not None and quant_dtype!=torch.int8:
if block_shape is not None and quant_dtype!=torch.int8 and quant_dtype!=current_platform.fp8_dtype():
assert not per_act_token_quant
assert not per_out_ch_quant
# TODO(bnell): this is not quite right for activations since first
# dim should be 1.
a_shape = GroupShape(row=block_shape[0], col=block_shape[1])
w_shape = GroupShape(row=block_shape[0], col=block_shape[1])
elif block_shape is not None and quant_dtype == torch.int8:
elif block_shape is not None and (quant_dtype == torch.int8 or quant_dtype == current_platform.fp8_dtype()):
a_shape = GroupShape(row=block_shape[0], col=block_shape[1])
w_shape = GroupShape(row=block_shape[0], col=block_shape[1])
else:
......@@ -518,7 +518,7 @@ class FusedMoEQuantConfig:
weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
),
)
if quant_dtype != torch.int8:
if quant_dtype != torch.int8 and quant_dtype != current_platform.fp8_dtype():
assert quant_config.per_act_token_quant == per_act_token_quant
assert quant_config.per_out_ch_quant == per_out_ch_quant
assert quant_config.block_shape == block_shape
......
......@@ -22,7 +22,7 @@ from vllm.v1.worker.ubatching import (
dbo_enabled,
dbo_maybe_run_recv_hook,
)
from vllm.platforms import current_platform
logger = init_logger(__name__)
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
......@@ -179,7 +179,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
if quant_config.block_shape is not None
else None
)
if block_k == DEEPEP_QUANT_BLOCK_SIZE:
if block_k == DEEPEP_QUANT_BLOCK_SIZE or (isinstance(x, tuple) and x[0].dtype == current_platform.fp8_dtype()):
# DeepEP kernels did the quantization for us.
x, x_scales = x
return x, x_scales
......
......@@ -18,7 +18,7 @@ from vllm.model_executor.layers.fused_moe import (
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.quantization.utils.w8a8_utils import(
get_w8a8_int8_marlin_weights, w8a8_nt_kpack2_marlin_weight, weight8bit_nt_kpack2_marlin1)
from vllm.model_executor.layers.fused_moe.config import (FusedMoEQuantConfig, int8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe.config import (FusedMoEQuantConfig, int8_w8a8_moe_quant_config, fp8_w8a8_moe_quant_config)
from vllm.model_executor.layers.fused_moe import (
FusedMoE,
FusedMoEMethodBase,
......@@ -120,14 +120,38 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
"dynamic per token quantization. Found static input scales.")
self.fused_experts = self.fused_moe_forward
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.dp_size = get_dp_group().world_size
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
self.num_dispatchers = all2all_manager.world_size
def get_fused_moe_quant_config(
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
return None
return fp8_w8a8_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
per_act_token_quant=True,
per_out_ch_quant=False,
block_shape=[256, 256] if self.use_deepep else None,
)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
if self.use_deepep:
self.N = 2 * intermediate_size_per_partition
self.K = hidden_size
params_dtype = torch.float8_e4m3fn
# WEIGHTS
......@@ -200,7 +224,10 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
else:
w1_marlin_list = []
for ii in range(layer.w13_weight.shape[0]):
if not self.use_deepep:
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
else:
w1_marlin_in = weight8bit_nt_kpack2_marlin1(layer.w13_weight[ii])
w1_marlin_list.append(w1_marlin_in.float() if w1_marlin_in.dtype == torch.float8_e4m3fn else w1_marlin_in)
w1_marlin = torch.stack(w1_marlin_list, dim=0)
w1_marlin = fp32_to_fp8_e4m3fn(w1_marlin)
......@@ -208,7 +235,10 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
del w1_marlin_list
w2_marlin_list = []
for ii in range(layer.w2_weight.shape[0]):
if not self.use_deepep:
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
else:
w2_marlin_in = weight8bit_nt_kpack2_marlin1(layer.w2_weight[ii])
w2_marlin_list.append(w2_marlin_in.float() if w2_marlin_in.dtype == torch.float8_e4m3fn else w2_marlin_in)
w2_marlin = torch.stack(w2_marlin_list, dim=0)
w2_marlin = fp32_to_fp8_e4m3fn(w2_marlin)
......@@ -328,6 +358,42 @@ class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
routed_scaling_factor=routed_scaling_factor,
shared_output=shared_output, )
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
DeepGemmExperts,
)
if (
prepare_finalize.activation_format
== FusedMoEActivationFormat.BatchedExperts
):
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__)
return BatchedDeepGemmExperts(
moe_config=self.moe,
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
N=self.N,
K=self.K
)
else:
logger.debug("DeepGemmExperts(%s)", self.__class__.__name__)
return DeepGemmExperts(moe_config=self.moe,
quant_config=self.moe_quant_config,
N=self.N,
K=self.K)
class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
def __init__(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment