Commit 3a58da2c authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-fth' into 'v0.9.2-dev'

fuse_moe_fp8接入marlin算子

See merge request dcutoolkit/deeplearing/vllm!399
parents a3fb334b 3c283de3
...@@ -29,6 +29,7 @@ from vllm.utils import round_up ...@@ -29,6 +29,7 @@ from vllm.utils import round_up
try: try:
from lightop import m_grouped_w8a8_gemm_nt_masked, fuse_silu_mul_quant_ep, fuse_silu_mul_quant from lightop import m_grouped_w8a8_gemm_nt_masked, fuse_silu_mul_quant_ep, fuse_silu_mul_quant
from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin
from lmslim.layers.fused_moe.fuse_moe_fp8_marlin import fused_experts_impl_fp8_marlin
from lightop import m_grouped_w8a8_gemm_nt_contig_asm from lightop import m_grouped_w8a8_gemm_nt_contig_asm
except Exception: except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n") print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
...@@ -37,8 +38,27 @@ logger = init_logger(__name__) ...@@ -37,8 +38,27 @@ logger = init_logger(__name__)
__all__ = [ __all__ = [
"CompressedTensorsW8A8Int8MarlinMoEMethod", "CompressedTensorsW8A8Int8MarlinMoEMethod",
"CompressedTensorsW8A8FP8MarlinMoEMethod",
] ]
def fp32_to_fp8_e4m3fn(t: torch.Tensor) -> torch.Tensor:
"""更合理的FP32到Float8_e4m3fn转换,使用最近值而不是简单舍弃尾数"""
# torch.float8_e4m3fn的数值范围约[-448, 448]
fp8_min, fp8_max = -448.0, 448.0
t_clamped = t.clamp(min=fp8_min, max=fp8_max)
# 保证不会下溢到0
# 转换前到float16再转fp8可能提升精度(float8实现本身通常通过float16做rounding)
t_fp16 = t_clamped.to(torch.float16)
return t_fp16.to(torch.float8_e4m3fn)
def w8a8_fp8_nt_kpack2_marlin_weight(w8a8_w, # [size_n, size_k// 2 ]
k_tile=16,
n_tile=16, ):
size_n, size_k = w8a8_w.shape
assert size_n % k_tile == 0 and size_k % n_tile == 0, "k_tile / n_tile 必须能整除对应维度"
w8a8_w = w8a8_w.reshape((size_n // n_tile, n_tile, size_k // k_tile, k_tile))
w8a8_w = w8a8_w.permute((0, 2, 1, 3)).contiguous()
w8a8_w = w8a8_w.reshape((size_n // k_tile, size_k * k_tile))
return w8a8_w
class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase): class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase):
@staticmethod @staticmethod
...@@ -46,17 +66,492 @@ class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase): ...@@ -46,17 +66,492 @@ class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase):
quant_config: "SlimQuantCompressedTensorsMarlinConfig", # type: ignore # noqa E501 quant_config: "SlimQuantCompressedTensorsMarlinConfig", # type: ignore # noqa E501
layer: torch.nn.Module, layer: torch.nn.Module,
) -> "CompressedTensorsMarlinMoEMethod": ) -> "CompressedTensorsMarlinMoEMethod":
# are supported + check if the layer is being ignored. # are supported + check if the layer is being ignored.
weight_quant = quant_config.target_scheme_map["Linear"].get("weights") weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
input_quant = quant_config.target_scheme_map["Linear"].get( input_quant = quant_config.target_scheme_map["Linear"].get(
"input_activations") "input_activations")
if quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): if quant_config._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8FP8MarlinMoEMethod(quant_config)
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MarlinMoEMethod(quant_config) return CompressedTensorsW8A8Int8MarlinMoEMethod(quant_config)
else: else:
raise RuntimeError( raise RuntimeError(
f"Slimquant_marlin does not support the FusedMoe scheme: {weight_quant}, {input_quant}") f"Slimquant_marlin does not support the FusedMoe scheme: {weight_quant}, {input_quant}")
class CompressedTensorsW8A8FP8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsMarlinConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations")
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
if not per_channel:
raise ValueError(
"For FP8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found "
f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales:
raise ValueError(
"For FP8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales.")
self.fused_experts = self.fused_moe_forward
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.dp_size = get_dp_group().world_size
self.ep_size = get_ep_group().world_size
self.use_deepep = self.dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
self.use_deepgemm = False
if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
self.num_dispatchers = all2all_manager.world_size
self.block_shape = [256, 256]
self.use_deepgemm = envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM or envs.VLLM_ALL2ALL_BACKEND == "deepep_auto"
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
if self.use_deepep:
self.N = 2 * intermediate_size_per_partition
self.K = hidden_size
params_dtype = torch.float8_e4m3fn
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
hidden_size,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
assert not self.static_input_scales
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w1_marlin_list = []
for ii in range(layer.w13_weight.shape[0]):
if not self.use_deepgemm:
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
else:
w1_marlin_in = w8a8_fp8_nt_kpack2_marlin_weight(layer.w13_weight[ii])
w1_marlin_list.append(w1_marlin_in.float() if w1_marlin_in.dtype == torch.float8_e4m3fn else w1_marlin_in)
w1_marlin = torch.stack(w1_marlin_list, dim=0)
w1_marlin = fp32_to_fp8_e4m3fn(w1_marlin)
del w1_marlin_list
w2_marlin_list = []
for ii in range(layer.w2_weight.shape[0]):
if not self.use_deepgemm:
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
else:
w2_marlin_in = w8a8_fp8_nt_kpack2_marlin_weight(layer.w2_weight[ii])
w2_marlin_list.append(w2_marlin_in.float() if w2_marlin_in.dtype == torch.float8_e4m3fn else w2_marlin_in)
w2_marlin = torch.stack(w2_marlin_list, dim=0)
w2_marlin = fp32_to_fp8_e4m3fn(w2_marlin)
layer.w13_weight = Parameter(w1_marlin, requires_grad=False)
layer.w2_weight = Parameter(w2_marlin, requires_grad=False)
def masked_groupgemm_workspace_shapes(self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int, ):
assert a.dim() == 2
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = a.size(
0) if self.max_num_tokens_per_rank is None else self.max_num_tokens_per_rank
workspace13 = (num_experts, max_num_tokens * num_dispatchers,
max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype)
def contiguous_groupgemm_workspace_shapes(
self, a: torch.Tensor, aq: torch.Tensor, M: int, N: int, K: int,
topk: int, global_num_experts: int, local_num_experts: int,
expert_num_tokens_cpu: torch.Tensor
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
assert self.block_shape is not None
# We use global_num_experts due to how moe_align_block_size handles
# expert_maps.
block_m = self.block_shape[0]
M_sum = compute_aligned_M(
M, topk, local_num_experts, block_m, expert_num_tokens_cpu
)
assert M_sum % block_m == 0
workspace1 = (M_sum, max(N, K))
workspace2 = (M_sum, max(N // 2, K))
output = (M, K)
return (workspace1, workspace2, output, a.dtype, M_sum)
def w8a8_groupgemm_masked_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
expert_num_tokens_cpu: Optional[torch.Tensor] = None,
**_):
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
local_num_experts = w1.size(0)
E, max_num_tokens, _, _, top_k = mk._moe_problem_size(
q_x, w1, w2, topk_ids)
N, K = self.N, self.K
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.masked_groupgemm_workspace_shapes(
x, q_x, max_num_tokens, N, K, top_k, global_num_experts,
local_num_experts)
workspace13 = torch.empty(prod(workspace13_shape),
device=x.device,
dtype=workspace_dtype)
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
fused_out = _resize_cache(workspace13, fused_out_shape)
# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
# expected_m = max_num_tokens
ori_bs = x.shape[0]
expected_m = ori_bs * self.ep_size
# expected_m = (
# x.shape[0] * self.dp_size * topk_ids.shape[1]
# + global_num_experts
# ) // global_num_experts
m_grouped_w8a8_gemm_nt_masked((q_x, a1_scale),
(w1, w1_scale),
workspace1,
expert_num_tokens,
expected_m,
)
assert expert_num_tokens is not None
a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens)
m_grouped_w8a8_gemm_nt_masked((a2q, a2q_scale),
(w2, w2_scale),
fused_out,
expert_num_tokens,
expected_m)
return fused_out
def w8a8_groupgemm_contiguous_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
expert_num_tokens_cpu: Optional[torch.Tensor] = None,
**_):
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
local_num_experts = w1.size(0)
a1q = q_x
N, K = self.N, self.K
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype, M_sum) = self.contiguous_groupgemm_workspace_shapes(
x, q_x, topk_ids.size(0), N, K, topk_ids.size(1), global_num_experts,
local_num_experts, expert_num_tokens_cpu)
workspace13 = torch.empty(prod(workspace13_shape),
device=x.device,
dtype=workspace_dtype)
workspace2 = torch.empty(prod(workspace2_shape),
device=x.device,
dtype=workspace_dtype)
mm1_out = _resize_cache(workspace13, (M_sum, N))
mm2_out = _resize_cache(workspace2, (M_sum, K))
act_out = _resize_cache(workspace2, (M_sum, N // 2))
quant_out = _resize_cache(
workspace13.view(dtype=a1q.dtype), (M_sum, N // 2)
)
fused_out = _resize_cache(workspace13, fused_out_shape)
a1q_perm = _resize_cache(workspace2.view(dtype=a1q.dtype), (M_sum, K))
a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
aq=a1q,
aq_scale=a1_scale,
topk_ids=topk_ids,
local_num_experts=local_num_experts,
expert_map=expert_map,
block_shape=self.block_shape,
expert_num_tokens=expert_num_tokens,
expert_num_tokens_cpu=expert_num_tokens_cpu,
aq_out=a1q_perm,
M_sum=M_sum
)
m_grouped_w8a8_gemm_nt_contig_asm(
(a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids)
a2q, a2q_scale = fuse_silu_mul_quant(mm1_out)
# a2q, a2q_scale = fuse_silu_mul_quant(input=mm1_out, expert_ids=expert_ids)
m_grouped_w8a8_gemm_nt_contig_asm(
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids)
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)
deepgemm_unpermute_and_reduce(
a=mm2_out,
topk_ids=topk_ids,
topk_weights=topk_weights,
inv_perm=inv_perm,
expert_map=expert_map,
output=fused_out,
)
return fused_out
def fused_moe_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
expert_num_tokens_cpu: Optional[torch.Tensor] = None,
**_):
return fused_experts_impl_fp8_marlin(
hidden_states=x if q_x is None else q_x,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True,
per_channel_quant=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
use_nn_moe=False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
i_q=i_q,
i_s=i_s)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
shared_output: Optional[torch.Tensor] = None,
i_q: Optional[torch.Tensor] = None,
i_s: Optional[torch.Tensor] = None,
**_
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for "
"`CompressedTensorsW8A8FP8MoEMethod` yet.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate,
e_score_correction_bias=e_score_correction_bias,
indices_type=torch.int64 if self.use_deepep else None, )
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
i_q=i_q,
i_s=i_s
)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import (
TritonOrGroupGemmExperts)
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
max_num_tokens_per_rank = (
prepare_finalize.max_num_tokens_per_rank())
assert max_num_tokens_per_rank is not None
self.max_num_tokens_per_rank = max_num_tokens_per_rank
logger.debug(
"TritonOrGroupGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
self.__class__.__name__, max_num_tokens_per_rank,
None, True)
return TritonOrGroupGemmExperts(
use_fp8_w8a8=True,
per_act_token_quant=True,
fused_experts=self.w8a8_groupgemm_masked_forward
)
else:
logger.debug(
"TritonOrGroupGemmExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__, None,
False)
return TritonOrGroupGemmExperts(
use_fp8_w8a8=True if envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM else False,
per_act_token_quant=True if envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM else False,
fused_experts=self.w8a8_groupgemm_contiguous_forward if envs.VLLM_ENABLE_DEEPEP_HT_DEEPGEMM else self.fused_moe_forward
)
class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod): class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod):
def __init__( def __init__(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment