Commit 92761bde authored by 王敏's avatar 王敏
Browse files

[feat]w4a8和w8a8适配deepep低延迟

parent 11b83133
......@@ -175,7 +175,6 @@ if TYPE_CHECKING:
USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_P2P_ASYNC: bool = False
VLLM_P2P_BUF_TOKENS: int = 30000
VLLM_ENABLE_MOE_GROUP_GEMM: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1151,12 +1150,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# pd separation p2p async buf tokens
"VLLM_P2P_BUF_TOKENS":
lambda: int(os.getenv("VLLM_P2P_BUF_TOKENS", "30000")),
# pd separation p2p async buf tokens
"VLLM_ENABLE_MOE_GROUP_GEMM":
lambda: (os.environ.get("VLLM_ENABLE_MOE_GROUP_GEMM", "False").lower() in
("true", "1")),
lambda: int(os.getenv("VLLM_P2P_BUF_TOKENS", "30000")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -54,7 +54,7 @@ def get_config_quant_dtype(
) -> Optional[torch.dtype]:
if use_fp8_w8a8:
return torch.float8_e4m3fn
elif use_int8_w8a8:
elif use_int8_w8a8 or use_int4_w4a8:
return torch.int8
return None
......
......@@ -78,6 +78,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
quant_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
block_shape: Optional[list[int]],
expert_num_tokens: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
block_k = block_shape[1] if block_shape is not None else None
......@@ -96,11 +97,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts, max_tokens, hidden_dim = x.size()
# TODO (varun): Optimization - Use a batched version of quant
x = x.view((-1, hidden_dim))
if expert_num_tokens is None:
x = x.view((-1, hidden_dim))
x, x_scales = moe_kernel_quantize_input(x, a1_scale, quant_dtype,
per_act_token_quant,
block_shape)
x = x.view((num_experts, -1, hidden_dim))
block_shape, expert_num_tokens)
if expert_num_tokens is None:
x = x.view((num_experts, -1, hidden_dim))
if quant_dtype is not None:
assert x_scales is not None
......@@ -156,7 +160,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_x, expert_x_scale = self._do_quant(
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape)
quant_config.per_act_token_quant, quant_config.block_shape, expert_num_tokens)
return (expert_x, expert_x_scale, expert_num_tokens, None, None)
......
......@@ -29,7 +29,7 @@ from vllm.model_executor.layers.fused_moe.config import (
# yapf: enable
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEModularKernel,
DeepGemmBannedFusedMoEModularKernel, FusedMoEPermuteExpertsUnpermute,
DeepGemmDisabledFusedMoEModularKernel, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize)
# from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
# is_rocm_aiter_moe_enabled)
......@@ -192,7 +192,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
experts,
)
else:
self.fused_experts = DeepGemmBannedFusedMoEModularKernel(
self.fused_experts = DeepGemmDisabledFusedMoEModularKernel(
prepare_finalize,
experts,
)
......@@ -1539,14 +1539,15 @@ class FusedMoE(torch.nn.Module):
router_logits: torch.Tensor,
shared_output: Optional[torch.Tensor] = None):
assert self.quant_method is not None
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels):
if (self.moe_parallel_config.use_pplx_kernels):
#or self.moe_parallel_config.use_deepep_ll_kernels):
return self.forward_impl_chunked(hidden_states, router_logits)
do_naive_dispatch_combine: bool = (
self.dp_size > 1
and self.ep_size > 1
and not self.moe_parallel_config.use_deepep_ht_kernels)
and envs.VLLM_ALL2ALL_BACKEND == 'naive')
#and not self.moe_parallel_config.use_deepep_ht_kernels)
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits)
......
......@@ -762,7 +762,7 @@ class FusedMoEModularKernel(torch.nn.Module):
@final
class DeepGemmBannedFusedMoEModularKernel(torch.nn.Module):
class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
"""
This class combines a FusedMoEPrepareAndFinalize instance and
a FusedMoEPermuteExpertsUnpermute to provide an interface that
......@@ -783,12 +783,12 @@ class DeepGemmBannedFusedMoEModularKernel(torch.nn.Module):
super().__init__()
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
assert prepare_finalize.activation_format == \
fused_experts.activation_formats[0], (
f"{prepare_finalize.__class__.__name__}."
f"{prepare_finalize.activation_format} == "
f"{fused_experts.__class__.__name__}."
f"{fused_experts.activation_formats[0]}")
# assert prepare_finalize.activation_format == \
# fused_experts.activation_formats[0], (
# f"{prepare_finalize.__class__.__name__}."
# f"{prepare_finalize.activation_format} == "
# f"{fused_experts.__class__.__name__}."
# f"{fused_experts.activation_formats[0]}")
def forward(
self,
......@@ -875,6 +875,7 @@ class DeepGemmBannedFusedMoEModularKernel(torch.nn.Module):
fused_out = self.fused_experts.apply(
None,
a1,
a1q,
w1,
w2,
......
......@@ -18,15 +18,16 @@ class TritonOrGroupGemmExperts(mk.CustomizedFusedMoEPermuteExpertsUnpermute):
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_int4_w4a8: bool = False,
per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None,
allow_group_gemm: bool = False,
fused_experts = None
):
super().__init__(
FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int4_w4a8=use_int4_w4a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
per_act_token_quant=per_act_token_quant,
......@@ -64,6 +65,7 @@ class TritonOrGroupGemmExperts(mk.CustomizedFusedMoEPermuteExpertsUnpermute):
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
q_hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
......@@ -103,5 +105,6 @@ class TritonOrGroupGemmExperts(mk.CustomizedFusedMoEPermuteExpertsUnpermute):
expert_num_tokens=expert_num_tokens,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor
routed_scaling_factor=routed_scaling_factor,
q_x=q_hidden_states,
)
......@@ -4,6 +4,9 @@ from math import prod
from typing import Optional
import torch
import triton
import triton.language as tl
from triton.language.extra import libdevice
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
......@@ -48,12 +51,221 @@ def _fp8_quantize(
return A, A_scale
# @triton.jit
# def _per_token_quant_int8_kernel_opt(
# x_ptr,
# xq_ptr,
# scale_ptr,
# stride_x,
# stride_xq,
# N,
# T_dim,
# has_tokens_per_expert: tl.constexpr,
# tokens_per_expert_ptr,
# BLOCK: tl.constexpr
# ):
# row_id = tl.program_id(0)
# if has_tokens_per_expert:
# e = row_id // T_dim
# t = row_id % T_dim
# num_valid_tokens_for_e = tl.load(tokens_per_expert_ptr + e)
# if t >= num_valid_tokens_for_e:
# return
# cols = tl.arange(0, BLOCK)
# mask = cols < N
# x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
# other=0.0).to(tl.float32)
# absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
# scale_x = absmax / 127
# x_q = x * (127 / absmax)
# x_q = libdevice.nearbyint(x_q).to(tl.int8)
# tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
# tl.store(scale_ptr + row_id, scale_x)
# def per_token_quant_int8_triton_opt(x: torch.Tensor,
# tokens_per_expert: Optional[torch.Tensor] = None):
# """
# Python wrapper for the Triton kernel.
# """
# if x.dim() != 3:
# raise ValueError(f"Input must be 3D [E, T, H], but got {x.shape}")
# E, T, H = x.shape
# M = E * T
# N = H
# x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
# scales = torch.empty(x.shape[:-1] + (1, ),
# device=x.device,
# dtype=torch.float32)
# BLOCK = triton.next_power_of_2(N)
# num_warps = min(max(BLOCK // 256, 1), 8)
# grid_opt = M
# _per_token_quant_int8_kernel_opt[(grid_opt, )](
# x,
# x_q,
# scales,
# stride_x=x.stride(-2),
# stride_xq=x_q.stride(-2),
# N=N,
# T_dim=T,
# has_tokens_per_expert= tokens_per_expert is not None,
# tokens_per_expert_ptr=tokens_per_expert,
# BLOCK=BLOCK,
# num_warps=num_warps,
# num_stages=1,
# )
# return x_q, scales
@triton.jit
def _per_token_quant_int8_one_kernel_opt(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
T_dim,
has_tokens_per_expert: tl.constexpr,
tokens_per_expert_ptr,
BLOCK: tl.constexpr
):
row_id = tl.program_id(0)
if has_tokens_per_expert:
e = row_id // T_dim
t = row_id % T_dim
num_valid_tokens_for_e = tl.load(tokens_per_expert_ptr + e)
if t >= num_valid_tokens_for_e:
return
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)
@triton.jit
def _per_token_quant_int8_kernel_opt(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
E_dim,
T_dim,
has_tokens_per_expert: tl.constexpr,
tokens_per_expert_ptr,
BLOCK: tl.constexpr
):
token_idx_start = tl.program_id(0)
grid_size = tl.num_programs(0)
num_total_tokens = E_dim * T_dim
for token_idx in range(token_idx_start, num_total_tokens, grid_size):
is_valid_token = True
if has_tokens_per_expert:
e = token_idx // T_dim
t = token_idx % T_dim
num_valid_tokens_for_e = tl.load(tokens_per_expert_ptr + e)
if t >= num_valid_tokens_for_e:
is_valid_token = False
if is_valid_token:
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + token_idx * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + token_idx * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + token_idx, scale_x)
def per_token_quant_int8_triton_opt(x: torch.Tensor,
tokens_per_expert: Optional[torch.Tensor] = None):
if x.dim() != 3:
raise ValueError(f"Input must be 3D [E, T, H], but got {x.shape}")
E, T, H = x.shape
N = H
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(x.shape[:-1] + (1, ),
device=x.device,
dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
num_warps = min(max(BLOCK // 256, 1), 8)
if T >= 4096:
num_warps = 1
num_tokens = E * T
grid_opt = num_tokens
if E == 16 and T >= 1024 :
grid_opt = max(1, num_tokens // (T // 256))
_per_token_quant_int8_kernel_opt[(grid_opt, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
E_dim=E,
T_dim=T,
has_tokens_per_expert=tokens_per_expert is not None,
tokens_per_expert_ptr=tokens_per_expert,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
else:
_per_token_quant_int8_one_kernel_opt[(grid_opt, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
T_dim=T,
has_tokens_per_expert=tokens_per_expert is not None,
tokens_per_expert_ptr=tokens_per_expert,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
def _int8_quantize(
A: torch.Tensor,
A_scale: Optional[torch.Tensor],
per_act_token: bool,
block_shape: Optional[list[int]] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Perform int8 quantization on the inputs. If a block_shape
......@@ -66,7 +278,10 @@ def _int8_quantize(
if block_shape is None:
assert per_act_token, \
"int8 quantization only supports block or channel-wise"
A, A_scale = per_token_quant_int8(A)
if expert_num_tokens is None:
A, A_scale = per_token_quant_int8(A)
else:
A, A_scale = per_token_quant_int8_triton_opt(A, expert_num_tokens)
else:
assert not per_act_token
assert len(block_shape) == 2
......@@ -83,11 +298,12 @@ def moe_kernel_quantize_input(
quant_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
expert_num_tokens: Optional[torch.Tensor]= None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if quant_dtype == torch.float8_e4m3fn:
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.int8:
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape)
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape, expert_num_tokens)
else:
return A, A_scale
......
......@@ -4,6 +4,7 @@
import enum
from enum import Enum
from typing import Callable, Optional
from math import prod
import torch
from compressed_tensors import CompressionFormat
......@@ -13,6 +14,8 @@ from compressed_tensors.quantization import (ActivationOrdering,
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_ep_group
from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
......@@ -32,10 +35,16 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import W8a8GetCacheJSON
try:
from lightop import m_grouped_w8a8_gemm_nt_masked, fuse_silu_mul_quant_ep
except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
logger = init_logger(__name__)
......@@ -999,11 +1008,27 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales.")
self.tritonsingleton= W8a8GetCacheJSON()
self.fused_experts = self.fused_moe_forward
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.use_deepep = parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
self.num_dispatchers = all2all_manager.world_size
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
if self.use_deepep:
self.N = 2 * intermediate_size_per_partition
self.K = hidden_size
params_dtype = torch.int8
......@@ -1102,8 +1127,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
"EPLB not supported for "
"`CompressedTensorsW8A8Int8MoEMethod` yet.")
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
......@@ -1116,28 +1139,204 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
e_score_correction_bias=e_score_correction_bias,
indices_type=torch.int64 if self.use_deepep else None,)
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
routed_scaling_factor=routed_scaling_factor,
)
def fused_moe_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
**_ ):
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
hidden_states=x,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
use_nn_moe=False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
def groupgemm_workspace_shapes(self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,):
assert a.dim() == 2
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = a.size(
0) if self.max_num_tokens_per_rank is None else self.max_num_tokens_per_rank
workspace13 = (num_experts, max_num_tokens * num_dispatchers,
max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype)
def w8a8_groupgemm_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
**_ ):
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
local_num_experts = w1.size(0)
E, max_num_tokens, _, _, top_k = mk._moe_problem_size(
q_x, w1, w2, topk_ids)
N, K = self.N, self.K
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.groupgemm_workspace_shapes(
x, q_x, max_num_tokens, N, K, top_k, global_num_experts,
local_num_experts)
workspace13 = torch.empty(prod(workspace13_shape),
device=x.device,
dtype=workspace_dtype)
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
fused_out = _resize_cache(workspace13, fused_out_shape)
# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
expected_m = max_num_tokens // 2
# print("##########################gemm1 workspace1 shape:{} q_x shape:{} " \
# "a1_scale shape:{} w1 shape:{} expert_num_tokens:{} expected_m:{}".format(workspace1.shape,
# q_x.shape,
# a1_scale.shape,
# w1.shape,
# expert_num_tokens,
# expected_m))
m_grouped_w8a8_gemm_nt_masked((q_x, a1_scale),
(w1, w1_scale),
workspace1,
expert_num_tokens,
expected_m,
)
assert expert_num_tokens is not None
a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens)
# print("##########################gemm2 workspace1 shape:{} a2q shape:{} " \
# "a2q_scale shape:{} w2 shape:{} fused_out shape:{} expert_num_tokens:{} expected_m:{}".format(workspace1.shape,
# a2q.shape,
# a2q_scale.shape,
# w2.shape,
# fused_out.shape,
# expert_num_tokens,
# expected_m))
m_grouped_w8a8_gemm_nt_masked((a2q, a2q_scale),
(w2, w2_scale),
fused_out,
expert_num_tokens,
expected_m)
return fused_out
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import (
TritonOrGroupGemmExperts)
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
max_num_tokens_per_rank = (
prepare_finalize.max_num_tokens_per_rank())
assert max_num_tokens_per_rank is not None
self.max_num_tokens_per_rank = max_num_tokens_per_rank
logger.debug(
"TritonOrGroupGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
self.__class__.__name__, max_num_tokens_per_rank,
None, True)
return TritonOrGroupGemmExperts(
use_int8_w8a8=True,
per_act_token_quant=True,
fused_experts=self.w8a8_groupgemm_forward
)
else:
logger.debug(
"TritonOrGroupGemmExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__, None,
False)
return TritonOrGroupGemmExperts(
fused_experts=self.fused_moe_forward
)
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
......
from typing import Any, Callable, Dict, List, Optional
import os
from math import prod
import torch
from torch.nn.parameter import Parameter
......@@ -7,9 +8,10 @@ from torch.nn.parameter import Parameter
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.model_executor.utils import set_weight_attrs
from vllm.distributed import get_tensor_model_parallel_world_size, get_dp_group
from vllm.distributed import get_tensor_model_parallel_world_size, get_ep_group
from vllm.logger import init_logger
from vllm.config import get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase)
from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig,
......@@ -17,6 +19,7 @@ from vllm.model_executor.layers.quantization.base_config import (QuantizationCon
from vllm.model_executor.layers.quantization.utils.w4a8_utils import w4a8_weight_repack_impl
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter)
from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod
......@@ -27,6 +30,7 @@ from vllm.model_executor.layers.fused_moe import (
try:
from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin
from lightop import m_grouped_w4a8_gemm_nt_masked, fuse_silu_mul_quant, fuse_silu_mul_quant_ep
except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
......@@ -127,10 +131,6 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
def get_scaled_act_names(self) -> List[str]:
return []
@property
def weight_block_size(self):
return [128,128]
class SlimQuantW4A8Int8MarlinMoEMethod:
......@@ -160,14 +160,17 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
def __init__(self, quant_config):
self.quant_config = quant_config
self.fused_experts = self.w4a8_marlin_forward
self.fused_experts = self.w4a8_fused_moe_marlin_forward
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.use_deepep = parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
self.enable_moe_group_gemm = parallel_config.enable_expert_parallel and envs.VLLM_ENABLE_MOE_GROUP_GEMM
if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
self.num_dispatchers = all2all_manager.world_size
def create_weights(
......@@ -181,6 +184,10 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
):
tp_size = get_tensor_model_parallel_world_size()
if self.use_deepep:
self.N = 2 * intermediate_size
self.K = hidden_size
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
......@@ -233,7 +240,77 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
layer.w13_weight = Parameter(w4a8_weight_repack_impl(layer.w13_weight), requires_grad=False)
layer.w2_weight = Parameter(w4a8_weight_repack_impl(layer.w2_weight), requires_grad=False)
def w4a8_marlin_forward(self,
def w4a8_fused_moe_marlin_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
**_ ):
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
return fused_experts_impl_w4a8_marlin(
x,
w1,
w2,
topk_ids=topk_ids,
topk_weights=topk_weights,
workspace=workspace,
global_reduce_buffer=global_reduce_buffer,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
def groupgemm_workspace_shapes(self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,):
assert a.dim() == 2
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = a.size(
0) if self.max_num_tokens_per_rank is None else self.max_num_tokens_per_rank
workspace13 = (num_experts, max_num_tokens * num_dispatchers,
max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype)
def w4a8_groupgemm_marlin_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
......@@ -251,35 +328,69 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
**_ ):
if not self.enable_moe_group_gemm:
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
return fused_experts_impl_w4a8_marlin(
x,
w1,
w2,
topk_ids=topk_ids,
topk_weights=topk_weights,
workspace=workspace,
global_reduce_buffer=global_reduce_buffer,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
else:
# TODO:
return None
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
local_num_experts = w1.size(0)
E, max_num_tokens, _, _, top_k = mk._moe_problem_size(
q_x, w1, w2, topk_ids)
N, K = self.N, self.K
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.groupgemm_workspace_shapes(
x, q_x, max_num_tokens, N, K, top_k, global_num_experts,
local_num_experts)
workspace13 = torch.empty(prod(workspace13_shape),
device=x.device,
dtype=workspace_dtype)
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
fused_out = _resize_cache(workspace13, fused_out_shape)
# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
expected_m = max_num_tokens
# forward_context = get_forward_context()
# expected_m = forward_context.dp_metadata.max_tokens_across_dp_cpu * self.num_dispatchers
# print("##########################gemm1 workspace1 shape:{} q_x shape:{} " \
# "a1_scale shape:{} w1 shape:{} expert_num_tokens:{} expected_m:{}".format(workspace1.shape,
# q_x.shape,
# a1_scale.shape,
# w1.shape,
# expert_num_tokens,
# expected_m))
m_grouped_w4a8_gemm_nt_masked((q_x, a1_scale),
(w1, w1_scale),
workspace1,
expert_num_tokens,
expected_m,
)
assert expert_num_tokens is not None
a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens)
# print("##########################gemm2 workspace1 shape:{} a2q shape:{} " \
# "a2q_scale shape:{} w2 shape:{} fused_out shape:{} expert_num_tokens:{} expected_m:{}".format(workspace1.shape,
# a2q.shape,
# a2q_scale.shape,
# w2.shape,
# fused_out.shape,
# expert_num_tokens,
# expected_m))
m_grouped_w4a8_gemm_nt_masked((a2q, a2q_scale),
(w2, w2_scale),
fused_out,
expert_num_tokens,
expected_m)
return fused_out
def apply_mori_ep(
self,
......@@ -347,9 +458,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
shared_output: Optional[torch.Tensor] = None,
**_
) -> torch.Tensor:
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.")
# if enable_eplb:
# raise NotImplementedError(
# "EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.")
# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
......@@ -398,22 +509,25 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
FusedMoEActivationFormat.BatchedExperts):
max_num_tokens_per_rank = (
prepare_finalize.max_num_tokens_per_rank())
self.max_num_tokens_per_rank = max_num_tokens_per_rank
assert max_num_tokens_per_rank is not None
logger.debug(
"BatchedGroupedGemmExperts(%s): "
logger.info(
"TritonOrGroupGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
self.__class__.__name__, max_num_tokens_per_rank,
self.quant_config.weight_block_size, False)
return None
None, True)
return TritonOrGroupGemmExperts(
use_int4_w4a8=True,
per_act_token_quant=True,
fused_experts=self.w4a8_groupgemm_marlin_forward
)
else:
logger.debug(
logger.info(
"TritonOrGroupGemmExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__, self.quant_config.weight_block_size,
self.__class__.__name__, None,
False)
return TritonOrGroupGemmExperts(
use_fp8_w8a8=False,
block_shape=self.quant_config.weight_block_size,
allow_group_gemm=False,
fused_experts=self.w4a8_marlin_forward
fused_experts=self.w4a8_fused_moe_marlin_forward
)
......@@ -151,6 +151,10 @@ class Scheduler(SchedulerInterface):
self.use_eagle = True
self.num_lookahead_tokens = self.num_spec_tokens
self.compilation_config = vllm_config.compilation_config
self.full_cuda_graph = self.compilation_config.full_cuda_graph
self.use_mla = vllm_config.model_config.use_mla
# Create the KV cache manager.
self.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config,
......@@ -1007,7 +1011,7 @@ class Scheduler(SchedulerInterface):
return scheduler_output
def schedule(self) -> SchedulerOutput:
if self.num_spec_tokens > 0:
if self.full_cuda_graph and self.use_mla and self.num_spec_tokens > 0:
return self.schedule_split_pd()
else:
return self.schedule_default()
......
......@@ -1852,6 +1852,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
time_after_load - time_before_load)
prepare_communication_buffer_for_model(self.model)
if hasattr(self, "drafter"):
prepare_communication_buffer_for_model(self.drafter.model)
if is_mixture_of_experts(
self.model) and self.parallel_config.enable_eplb:
logger.info("EPLB is enabled for model %s.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment