Commit 92761bde authored by 王敏's avatar 王敏
Browse files

[feat]w4a8和w8a8适配deepep低延迟

parent 11b83133
...@@ -175,7 +175,6 @@ if TYPE_CHECKING: ...@@ -175,7 +175,6 @@ if TYPE_CHECKING:
USE_FUSED_SILU_MUL_QUANT: bool = False USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_P2P_ASYNC: bool = False VLLM_P2P_ASYNC: bool = False
VLLM_P2P_BUF_TOKENS: int = 30000 VLLM_P2P_BUF_TOKENS: int = 30000
VLLM_ENABLE_MOE_GROUP_GEMM: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1151,12 +1150,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1151,12 +1150,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# pd separation p2p async buf tokens # pd separation p2p async buf tokens
"VLLM_P2P_BUF_TOKENS": "VLLM_P2P_BUF_TOKENS":
lambda: int(os.getenv("VLLM_P2P_BUF_TOKENS", "30000")), lambda: int(os.getenv("VLLM_P2P_BUF_TOKENS", "30000")),
# pd separation p2p async buf tokens
"VLLM_ENABLE_MOE_GROUP_GEMM":
lambda: (os.environ.get("VLLM_ENABLE_MOE_GROUP_GEMM", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -54,7 +54,7 @@ def get_config_quant_dtype( ...@@ -54,7 +54,7 @@ def get_config_quant_dtype(
) -> Optional[torch.dtype]: ) -> Optional[torch.dtype]:
if use_fp8_w8a8: if use_fp8_w8a8:
return torch.float8_e4m3fn return torch.float8_e4m3fn
elif use_int8_w8a8: elif use_int8_w8a8 or use_int4_w4a8:
return torch.int8 return torch.int8
return None return None
......
...@@ -78,6 +78,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -78,6 +78,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
quant_dtype: Optional[torch.dtype], quant_dtype: Optional[torch.dtype],
per_act_token_quant: bool, per_act_token_quant: bool,
block_shape: Optional[list[int]], block_shape: Optional[list[int]],
expert_num_tokens: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
block_k = block_shape[1] if block_shape is not None else None block_k = block_shape[1] if block_shape is not None else None
...@@ -96,11 +97,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -96,11 +97,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts, max_tokens, hidden_dim = x.size() num_experts, max_tokens, hidden_dim = x.size()
# TODO (varun): Optimization - Use a batched version of quant # TODO (varun): Optimization - Use a batched version of quant
x = x.view((-1, hidden_dim)) if expert_num_tokens is None:
x = x.view((-1, hidden_dim))
x, x_scales = moe_kernel_quantize_input(x, a1_scale, quant_dtype, x, x_scales = moe_kernel_quantize_input(x, a1_scale, quant_dtype,
per_act_token_quant, per_act_token_quant,
block_shape) block_shape, expert_num_tokens)
x = x.view((num_experts, -1, hidden_dim))
if expert_num_tokens is None:
x = x.view((num_experts, -1, hidden_dim))
if quant_dtype is not None: if quant_dtype is not None:
assert x_scales is not None assert x_scales is not None
...@@ -156,7 +160,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -156,7 +160,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_x, expert_x_scale = self._do_quant( expert_x, expert_x_scale = self._do_quant(
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype, expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape) quant_config.per_act_token_quant, quant_config.block_shape, expert_num_tokens)
return (expert_x, expert_x_scale, expert_num_tokens, None, None) return (expert_x, expert_x_scale, expert_num_tokens, None, None)
......
...@@ -29,7 +29,7 @@ from vllm.model_executor.layers.fused_moe.config import ( ...@@ -29,7 +29,7 @@ from vllm.model_executor.layers.fused_moe.config import (
# yapf: enable # yapf: enable
from vllm.model_executor.layers.fused_moe.modular_kernel import ( from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEModularKernel, FusedMoEActivationFormat, FusedMoEModularKernel,
DeepGemmBannedFusedMoEModularKernel, FusedMoEPermuteExpertsUnpermute, DeepGemmDisabledFusedMoEModularKernel, FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize) FusedMoEPrepareAndFinalize)
# from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
# is_rocm_aiter_moe_enabled) # is_rocm_aiter_moe_enabled)
...@@ -192,7 +192,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -192,7 +192,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
experts, experts,
) )
else: else:
self.fused_experts = DeepGemmBannedFusedMoEModularKernel( self.fused_experts = DeepGemmDisabledFusedMoEModularKernel(
prepare_finalize, prepare_finalize,
experts, experts,
) )
...@@ -1539,14 +1539,15 @@ class FusedMoE(torch.nn.Module): ...@@ -1539,14 +1539,15 @@ class FusedMoE(torch.nn.Module):
router_logits: torch.Tensor, router_logits: torch.Tensor,
shared_output: Optional[torch.Tensor] = None): shared_output: Optional[torch.Tensor] = None):
assert self.quant_method is not None assert self.quant_method is not None
if (self.moe_parallel_config.use_pplx_kernels if (self.moe_parallel_config.use_pplx_kernels):
or self.moe_parallel_config.use_deepep_ll_kernels): #or self.moe_parallel_config.use_deepep_ll_kernels):
return self.forward_impl_chunked(hidden_states, router_logits) return self.forward_impl_chunked(hidden_states, router_logits)
do_naive_dispatch_combine: bool = ( do_naive_dispatch_combine: bool = (
self.dp_size > 1 self.dp_size > 1
and self.ep_size > 1 and self.ep_size > 1
and not self.moe_parallel_config.use_deepep_ht_kernels) and envs.VLLM_ALL2ALL_BACKEND == 'naive')
#and not self.moe_parallel_config.use_deepep_ht_kernels)
if do_naive_dispatch_combine: if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits) hidden_states, router_logits)
......
...@@ -762,7 +762,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -762,7 +762,7 @@ class FusedMoEModularKernel(torch.nn.Module):
@final @final
class DeepGemmBannedFusedMoEModularKernel(torch.nn.Module): class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
""" """
This class combines a FusedMoEPrepareAndFinalize instance and This class combines a FusedMoEPrepareAndFinalize instance and
a FusedMoEPermuteExpertsUnpermute to provide an interface that a FusedMoEPermuteExpertsUnpermute to provide an interface that
...@@ -783,12 +783,12 @@ class DeepGemmBannedFusedMoEModularKernel(torch.nn.Module): ...@@ -783,12 +783,12 @@ class DeepGemmBannedFusedMoEModularKernel(torch.nn.Module):
super().__init__() super().__init__()
self.prepare_finalize = prepare_finalize self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts self.fused_experts = fused_experts
assert prepare_finalize.activation_format == \ # assert prepare_finalize.activation_format == \
fused_experts.activation_formats[0], ( # fused_experts.activation_formats[0], (
f"{prepare_finalize.__class__.__name__}." # f"{prepare_finalize.__class__.__name__}."
f"{prepare_finalize.activation_format} == " # f"{prepare_finalize.activation_format} == "
f"{fused_experts.__class__.__name__}." # f"{fused_experts.__class__.__name__}."
f"{fused_experts.activation_formats[0]}") # f"{fused_experts.activation_formats[0]}")
def forward( def forward(
self, self,
...@@ -875,6 +875,7 @@ class DeepGemmBannedFusedMoEModularKernel(torch.nn.Module): ...@@ -875,6 +875,7 @@ class DeepGemmBannedFusedMoEModularKernel(torch.nn.Module):
fused_out = self.fused_experts.apply( fused_out = self.fused_experts.apply(
None, None,
a1,
a1q, a1q,
w1, w1,
w2, w2,
......
...@@ -18,15 +18,16 @@ class TritonOrGroupGemmExperts(mk.CustomizedFusedMoEPermuteExpertsUnpermute): ...@@ -18,15 +18,16 @@ class TritonOrGroupGemmExperts(mk.CustomizedFusedMoEPermuteExpertsUnpermute):
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool = False,
per_act_token_quant: bool = False, per_act_token_quant: bool = False,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
allow_group_gemm: bool = False,
fused_experts = None fused_experts = None
): ):
super().__init__( super().__init__(
FusedMoEQuantConfig.make( FusedMoEQuantConfig.make(
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int4_w4a8=use_int4_w4a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
per_act_token_quant=per_act_token_quant, per_act_token_quant=per_act_token_quant,
...@@ -64,6 +65,7 @@ class TritonOrGroupGemmExperts(mk.CustomizedFusedMoEPermuteExpertsUnpermute): ...@@ -64,6 +65,7 @@ class TritonOrGroupGemmExperts(mk.CustomizedFusedMoEPermuteExpertsUnpermute):
self, self,
output: torch.Tensor, output: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
q_hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
...@@ -103,5 +105,6 @@ class TritonOrGroupGemmExperts(mk.CustomizedFusedMoEPermuteExpertsUnpermute): ...@@ -103,5 +105,6 @@ class TritonOrGroupGemmExperts(mk.CustomizedFusedMoEPermuteExpertsUnpermute):
expert_num_tokens=expert_num_tokens, expert_num_tokens=expert_num_tokens,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
shared_output=shared_output, shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor routed_scaling_factor=routed_scaling_factor,
q_x=q_hidden_states,
) )
...@@ -4,6 +4,9 @@ from math import prod ...@@ -4,6 +4,9 @@ from math import prod
from typing import Optional from typing import Optional
import torch import torch
import triton
import triton.language as tl
from triton.language.extra import libdevice
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
...@@ -48,12 +51,221 @@ def _fp8_quantize( ...@@ -48,12 +51,221 @@ def _fp8_quantize(
return A, A_scale return A, A_scale
# @triton.jit
# def _per_token_quant_int8_kernel_opt(
# x_ptr,
# xq_ptr,
# scale_ptr,
# stride_x,
# stride_xq,
# N,
# T_dim,
# has_tokens_per_expert: tl.constexpr,
# tokens_per_expert_ptr,
# BLOCK: tl.constexpr
# ):
# row_id = tl.program_id(0)
# if has_tokens_per_expert:
# e = row_id // T_dim
# t = row_id % T_dim
# num_valid_tokens_for_e = tl.load(tokens_per_expert_ptr + e)
# if t >= num_valid_tokens_for_e:
# return
# cols = tl.arange(0, BLOCK)
# mask = cols < N
# x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
# other=0.0).to(tl.float32)
# absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
# scale_x = absmax / 127
# x_q = x * (127 / absmax)
# x_q = libdevice.nearbyint(x_q).to(tl.int8)
# tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
# tl.store(scale_ptr + row_id, scale_x)
# def per_token_quant_int8_triton_opt(x: torch.Tensor,
# tokens_per_expert: Optional[torch.Tensor] = None):
# """
# Python wrapper for the Triton kernel.
# """
# if x.dim() != 3:
# raise ValueError(f"Input must be 3D [E, T, H], but got {x.shape}")
# E, T, H = x.shape
# M = E * T
# N = H
# x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
# scales = torch.empty(x.shape[:-1] + (1, ),
# device=x.device,
# dtype=torch.float32)
# BLOCK = triton.next_power_of_2(N)
# num_warps = min(max(BLOCK // 256, 1), 8)
# grid_opt = M
# _per_token_quant_int8_kernel_opt[(grid_opt, )](
# x,
# x_q,
# scales,
# stride_x=x.stride(-2),
# stride_xq=x_q.stride(-2),
# N=N,
# T_dim=T,
# has_tokens_per_expert= tokens_per_expert is not None,
# tokens_per_expert_ptr=tokens_per_expert,
# BLOCK=BLOCK,
# num_warps=num_warps,
# num_stages=1,
# )
# return x_q, scales
@triton.jit
def _per_token_quant_int8_one_kernel_opt(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
T_dim,
has_tokens_per_expert: tl.constexpr,
tokens_per_expert_ptr,
BLOCK: tl.constexpr
):
row_id = tl.program_id(0)
if has_tokens_per_expert:
e = row_id // T_dim
t = row_id % T_dim
num_valid_tokens_for_e = tl.load(tokens_per_expert_ptr + e)
if t >= num_valid_tokens_for_e:
return
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)
@triton.jit
def _per_token_quant_int8_kernel_opt(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
E_dim,
T_dim,
has_tokens_per_expert: tl.constexpr,
tokens_per_expert_ptr,
BLOCK: tl.constexpr
):
token_idx_start = tl.program_id(0)
grid_size = tl.num_programs(0)
num_total_tokens = E_dim * T_dim
for token_idx in range(token_idx_start, num_total_tokens, grid_size):
is_valid_token = True
if has_tokens_per_expert:
e = token_idx // T_dim
t = token_idx % T_dim
num_valid_tokens_for_e = tl.load(tokens_per_expert_ptr + e)
if t >= num_valid_tokens_for_e:
is_valid_token = False
if is_valid_token:
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + token_idx * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + token_idx * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + token_idx, scale_x)
def per_token_quant_int8_triton_opt(x: torch.Tensor,
tokens_per_expert: Optional[torch.Tensor] = None):
if x.dim() != 3:
raise ValueError(f"Input must be 3D [E, T, H], but got {x.shape}")
E, T, H = x.shape
N = H
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(x.shape[:-1] + (1, ),
device=x.device,
dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
num_warps = min(max(BLOCK // 256, 1), 8)
if T >= 4096:
num_warps = 1
num_tokens = E * T
grid_opt = num_tokens
if E == 16 and T >= 1024 :
grid_opt = max(1, num_tokens // (T // 256))
_per_token_quant_int8_kernel_opt[(grid_opt, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
E_dim=E,
T_dim=T,
has_tokens_per_expert=tokens_per_expert is not None,
tokens_per_expert_ptr=tokens_per_expert,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
else:
_per_token_quant_int8_one_kernel_opt[(grid_opt, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
T_dim=T,
has_tokens_per_expert=tokens_per_expert is not None,
tokens_per_expert_ptr=tokens_per_expert,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
def _int8_quantize( def _int8_quantize(
A: torch.Tensor, A: torch.Tensor,
A_scale: Optional[torch.Tensor], A_scale: Optional[torch.Tensor],
per_act_token: bool, per_act_token: bool,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Perform int8 quantization on the inputs. If a block_shape Perform int8 quantization on the inputs. If a block_shape
...@@ -66,7 +278,10 @@ def _int8_quantize( ...@@ -66,7 +278,10 @@ def _int8_quantize(
if block_shape is None: if block_shape is None:
assert per_act_token, \ assert per_act_token, \
"int8 quantization only supports block or channel-wise" "int8 quantization only supports block or channel-wise"
A, A_scale = per_token_quant_int8(A) if expert_num_tokens is None:
A, A_scale = per_token_quant_int8(A)
else:
A, A_scale = per_token_quant_int8_triton_opt(A, expert_num_tokens)
else: else:
assert not per_act_token assert not per_act_token
assert len(block_shape) == 2 assert len(block_shape) == 2
...@@ -83,11 +298,12 @@ def moe_kernel_quantize_input( ...@@ -83,11 +298,12 @@ def moe_kernel_quantize_input(
quant_dtype: Optional[torch.dtype], quant_dtype: Optional[torch.dtype],
per_act_token_quant: bool, per_act_token_quant: bool,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
expert_num_tokens: Optional[torch.Tensor]= None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if quant_dtype == torch.float8_e4m3fn: if quant_dtype == torch.float8_e4m3fn:
return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape) return _fp8_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == torch.int8: elif quant_dtype == torch.int8:
return _int8_quantize(A, A_scale, per_act_token_quant, block_shape) return _int8_quantize(A, A_scale, per_act_token_quant, block_shape, expert_num_tokens)
else: else:
return A, A_scale return A, A_scale
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import enum import enum
from enum import Enum from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional
from math import prod
import torch import torch
from compressed_tensors import CompressionFormat from compressed_tensors import CompressionFormat
...@@ -13,6 +14,8 @@ from compressed_tensors.quantization import (ActivationOrdering, ...@@ -13,6 +14,8 @@ from compressed_tensors.quantization import (ActivationOrdering,
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_ep_group
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase, FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
...@@ -32,10 +35,16 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ...@@ -32,10 +35,16 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
try:
from lightop import m_grouped_w8a8_gemm_nt_masked, fuse_silu_mul_quant_ep
except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -999,11 +1008,27 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -999,11 +1008,27 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
"For INT8 Fused MoE layers, we require channelwise, " "For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales.") "dynamic per token quantization. Found static input scales.")
self.tritonsingleton= W8a8GetCacheJSON() self.tritonsingleton= W8a8GetCacheJSON()
self.fused_experts = self.fused_moe_forward
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.use_deepep = parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
self.num_dispatchers = all2all_manager.world_size
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
if self.use_deepep:
self.N = 2 * intermediate_size_per_partition
self.K = hidden_size
params_dtype = torch.int8 params_dtype = torch.int8
...@@ -1102,8 +1127,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1102,8 +1127,6 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
"EPLB not supported for " "EPLB not supported for "
"`CompressedTensorsW8A8Int8MoEMethod` yet.") "`CompressedTensorsW8A8Int8MoEMethod` yet.")
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
...@@ -1116,28 +1139,204 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1116,28 +1139,204 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
scoring_func=scoring_func, scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate, use_fused_gate=use_fused_gate,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
indices_type=torch.int64 if self.use_deepep else None,)
return fused_experts(
hidden_states=x, return self.fused_experts(
w1=layer.w13_weight, x,
w2=layer.w2_weight, layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True, inplace=True,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
w1_scale=layer.w13_weight_scale, w1_scale=(layer.w13_weight_scale),
w2_scale=layer.w2_weight_scale, w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
use_nn_moe=False, use_nn_moe=use_nn_moe,
shared_output=shared_output, shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor) routed_scaling_factor=routed_scaling_factor,
)
def fused_moe_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
**_ ):
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
hidden_states=x,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
use_nn_moe=False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
def groupgemm_workspace_shapes(self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,):
assert a.dim() == 2
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = a.size(
0) if self.max_num_tokens_per_rank is None else self.max_num_tokens_per_rank
workspace13 = (num_experts, max_num_tokens * num_dispatchers,
max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype)
def w8a8_groupgemm_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
**_ ):
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
local_num_experts = w1.size(0)
E, max_num_tokens, _, _, top_k = mk._moe_problem_size(
q_x, w1, w2, topk_ids)
N, K = self.N, self.K
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.groupgemm_workspace_shapes(
x, q_x, max_num_tokens, N, K, top_k, global_num_experts,
local_num_experts)
workspace13 = torch.empty(prod(workspace13_shape),
device=x.device,
dtype=workspace_dtype)
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
fused_out = _resize_cache(workspace13, fused_out_shape)
# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
expected_m = max_num_tokens // 2
# print("##########################gemm1 workspace1 shape:{} q_x shape:{} " \
# "a1_scale shape:{} w1 shape:{} expert_num_tokens:{} expected_m:{}".format(workspace1.shape,
# q_x.shape,
# a1_scale.shape,
# w1.shape,
# expert_num_tokens,
# expected_m))
m_grouped_w8a8_gemm_nt_masked((q_x, a1_scale),
(w1, w1_scale),
workspace1,
expert_num_tokens,
expected_m,
)
assert expert_num_tokens is not None
a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens)
# print("##########################gemm2 workspace1 shape:{} a2q shape:{} " \
# "a2q_scale shape:{} w2 shape:{} fused_out shape:{} expert_num_tokens:{} expected_m:{}".format(workspace1.shape,
# a2q.shape,
# a2q_scale.shape,
# w2.shape,
# fused_out.shape,
# expert_num_tokens,
# expected_m))
m_grouped_w8a8_gemm_nt_masked((a2q, a2q_scale),
(w2, w2_scale),
fused_out,
expert_num_tokens,
expected_m)
return fused_out
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import (
TritonOrGroupGemmExperts)
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
max_num_tokens_per_rank = (
prepare_finalize.max_num_tokens_per_rank())
assert max_num_tokens_per_rank is not None
self.max_num_tokens_per_rank = max_num_tokens_per_rank
logger.debug(
"TritonOrGroupGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
self.__class__.__name__, max_num_tokens_per_rank,
None, True)
return TritonOrGroupGemmExperts(
use_int8_w8a8=True,
per_act_token_quant=True,
fused_experts=self.w8a8_groupgemm_forward
)
else:
logger.debug(
"TritonOrGroupGemmExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__, None,
False)
return TritonOrGroupGemmExperts(
fused_experts=self.fused_moe_forward
)
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
......
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import os import os
from math import prod
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -7,9 +8,10 @@ from torch.nn.parameter import Parameter ...@@ -7,9 +8,10 @@ from torch.nn.parameter import Parameter
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.distributed import get_tensor_model_parallel_world_size, get_dp_group from vllm.distributed import get_tensor_model_parallel_world_size, get_ep_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase) from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase)
from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig, from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig,
...@@ -17,6 +19,7 @@ from vllm.model_executor.layers.quantization.base_config import (QuantizationCon ...@@ -17,6 +19,7 @@ from vllm.model_executor.layers.quantization.base_config import (QuantizationCon
from vllm.model_executor.layers.quantization.utils.w4a8_utils import w4a8_weight_repack_impl from vllm.model_executor.layers.quantization.utils.w4a8_utils import w4a8_weight_repack_impl
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported) FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.parameter import (ChannelQuantScaleParameter, from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter) ModelWeightParameter)
from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod
...@@ -27,6 +30,7 @@ from vllm.model_executor.layers.fused_moe import ( ...@@ -27,6 +30,7 @@ from vllm.model_executor.layers.fused_moe import (
try: try:
from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin
from lightop import m_grouped_w4a8_gemm_nt_masked, fuse_silu_mul_quant, fuse_silu_mul_quant_ep
except Exception: except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n") print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
...@@ -127,10 +131,6 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig): ...@@ -127,10 +131,6 @@ class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
return [] return []
@property
def weight_block_size(self):
return [128,128]
class SlimQuantW4A8Int8MarlinMoEMethod: class SlimQuantW4A8Int8MarlinMoEMethod:
...@@ -160,14 +160,17 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -160,14 +160,17 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
def __init__(self, quant_config): def __init__(self, quant_config):
self.quant_config = quant_config self.quant_config = quant_config
self.fused_experts = self.w4a8_marlin_forward self.fused_experts = self.w4a8_fused_moe_marlin_forward
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config parallel_config = vllm_config.parallel_config
self.use_deepep = parallel_config.enable_expert_parallel and \ self.use_deepep = parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \ (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
self.enable_moe_group_gemm = parallel_config.enable_expert_parallel and envs.VLLM_ENABLE_MOE_GROUP_GEMM if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
self.num_dispatchers = all2all_manager.world_size
def create_weights( def create_weights(
...@@ -181,6 +184,10 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -181,6 +184,10 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
): ):
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if self.use_deepep:
self.N = 2 * intermediate_size
self.K = hidden_size
# WEIGHTS # WEIGHTS
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
...@@ -233,7 +240,77 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -233,7 +240,77 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
layer.w13_weight = Parameter(w4a8_weight_repack_impl(layer.w13_weight), requires_grad=False) layer.w13_weight = Parameter(w4a8_weight_repack_impl(layer.w13_weight), requires_grad=False)
layer.w2_weight = Parameter(w4a8_weight_repack_impl(layer.w2_weight), requires_grad=False) layer.w2_weight = Parameter(w4a8_weight_repack_impl(layer.w2_weight), requires_grad=False)
def w4a8_marlin_forward(self, def w4a8_fused_moe_marlin_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
**_ ):
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
return fused_experts_impl_w4a8_marlin(
x,
w1,
w2,
topk_ids=topk_ids,
topk_weights=topk_weights,
workspace=workspace,
global_reduce_buffer=global_reduce_buffer,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
def groupgemm_workspace_shapes(self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,):
assert a.dim() == 2
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = a.size(
0) if self.max_num_tokens_per_rank is None else self.max_num_tokens_per_rank
workspace13 = (num_experts, max_num_tokens * num_dispatchers,
max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype)
def w4a8_groupgemm_marlin_forward(self,
x: torch.Tensor, x: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
...@@ -251,35 +328,69 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -251,35 +328,69 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
**_ ): **_ ):
if not self.enable_moe_group_gemm:
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers() import vllm.model_executor.layers.fused_moe.modular_kernel as mk
return fused_experts_impl_w4a8_marlin( local_num_experts = w1.size(0)
x,
w1, E, max_num_tokens, _, _, top_k = mk._moe_problem_size(
w2, q_x, w1, w2, topk_ids)
topk_ids=topk_ids,
topk_weights=topk_weights, N, K = self.N, self.K
workspace=workspace, (workspace13_shape, workspace2_shape, fused_out_shape,
global_reduce_buffer=global_reduce_buffer, workspace_dtype) = self.groupgemm_workspace_shapes(
inplace=True, x, q_x, max_num_tokens, N, K, top_k, global_num_experts,
use_int4_w4a8=True, local_num_experts)
per_channel_quant=True, workspace13 = torch.empty(prod(workspace13_shape),
activation=activation, device=x.device,
expert_map=expert_map, dtype=workspace_dtype)
apply_router_weight_on_input=apply_router_weight_on_input, workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
global_num_experts=global_num_experts, fused_out = _resize_cache(workspace13, fused_out_shape)
w1_scale=w1_scale,
w2_scale=w2_scale, # (from deepgemm docs) : A value hint (which is a value on CPU)
a1_scale=a1_scale, # for the M expectation of each batch, correctly setting this value
a2_scale=a2_scale, # may lead to better performance.
use_nn_moe=use_nn_moe, expected_m = max_num_tokens
shared_output=shared_output, # forward_context = get_forward_context()
routed_scaling_factor=routed_scaling_factor, # expected_m = forward_context.dp_metadata.max_tokens_across_dp_cpu * self.num_dispatchers
)
else: # print("##########################gemm1 workspace1 shape:{} q_x shape:{} " \
# TODO: # "a1_scale shape:{} w1 shape:{} expert_num_tokens:{} expected_m:{}".format(workspace1.shape,
return None # q_x.shape,
# a1_scale.shape,
# w1.shape,
# expert_num_tokens,
# expected_m))
m_grouped_w4a8_gemm_nt_masked((q_x, a1_scale),
(w1, w1_scale),
workspace1,
expert_num_tokens,
expected_m,
)
assert expert_num_tokens is not None
a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens)
# print("##########################gemm2 workspace1 shape:{} a2q shape:{} " \
# "a2q_scale shape:{} w2 shape:{} fused_out shape:{} expert_num_tokens:{} expected_m:{}".format(workspace1.shape,
# a2q.shape,
# a2q_scale.shape,
# w2.shape,
# fused_out.shape,
# expert_num_tokens,
# expected_m))
m_grouped_w4a8_gemm_nt_masked((a2q, a2q_scale),
(w2, w2_scale),
fused_out,
expert_num_tokens,
expected_m)
return fused_out
def apply_mori_ep( def apply_mori_ep(
self, self,
...@@ -347,9 +458,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -347,9 +458,9 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
shared_output: Optional[torch.Tensor] = None, shared_output: Optional[torch.Tensor] = None,
**_ **_
) -> torch.Tensor: ) -> torch.Tensor:
if enable_eplb: # if enable_eplb:
raise NotImplementedError( # raise NotImplementedError(
"EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.") # "EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.")
# Expert selection # Expert selection
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
...@@ -398,22 +509,25 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -398,22 +509,25 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
FusedMoEActivationFormat.BatchedExperts): FusedMoEActivationFormat.BatchedExperts):
max_num_tokens_per_rank = ( max_num_tokens_per_rank = (
prepare_finalize.max_num_tokens_per_rank()) prepare_finalize.max_num_tokens_per_rank())
self.max_num_tokens_per_rank = max_num_tokens_per_rank
assert max_num_tokens_per_rank is not None assert max_num_tokens_per_rank is not None
logger.debug( logger.info(
"BatchedGroupedGemmExperts(%s): " "TritonOrGroupGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s", "max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
self.__class__.__name__, max_num_tokens_per_rank, self.__class__.__name__, max_num_tokens_per_rank,
self.quant_config.weight_block_size, False) None, True)
return None return TritonOrGroupGemmExperts(
use_int4_w4a8=True,
per_act_token_quant=True,
fused_experts=self.w4a8_groupgemm_marlin_forward
)
else: else:
logger.debug( logger.info(
"TritonOrGroupGemmExperts(%s): block_size=%s, per_act_token=%s", "TritonOrGroupGemmExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__, self.quant_config.weight_block_size, self.__class__.__name__, None,
False) False)
return TritonOrGroupGemmExperts( return TritonOrGroupGemmExperts(
use_fp8_w8a8=False, fused_experts=self.w4a8_fused_moe_marlin_forward
block_shape=self.quant_config.weight_block_size,
allow_group_gemm=False,
fused_experts=self.w4a8_marlin_forward
) )
...@@ -151,6 +151,10 @@ class Scheduler(SchedulerInterface): ...@@ -151,6 +151,10 @@ class Scheduler(SchedulerInterface):
self.use_eagle = True self.use_eagle = True
self.num_lookahead_tokens = self.num_spec_tokens self.num_lookahead_tokens = self.num_spec_tokens
self.compilation_config = vllm_config.compilation_config
self.full_cuda_graph = self.compilation_config.full_cuda_graph
self.use_mla = vllm_config.model_config.use_mla
# Create the KV cache manager. # Create the KV cache manager.
self.kv_cache_manager = KVCacheManager( self.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
...@@ -1007,7 +1011,7 @@ class Scheduler(SchedulerInterface): ...@@ -1007,7 +1011,7 @@ class Scheduler(SchedulerInterface):
return scheduler_output return scheduler_output
def schedule(self) -> SchedulerOutput: def schedule(self) -> SchedulerOutput:
if self.num_spec_tokens > 0: if self.full_cuda_graph and self.use_mla and self.num_spec_tokens > 0:
return self.schedule_split_pd() return self.schedule_split_pd()
else: else:
return self.schedule_default() return self.schedule_default()
......
...@@ -1852,6 +1852,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1852,6 +1852,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
time_after_load - time_before_load) time_after_load - time_before_load)
prepare_communication_buffer_for_model(self.model) prepare_communication_buffer_for_model(self.model)
if hasattr(self, "drafter"):
prepare_communication_buffer_for_model(self.drafter.model)
if is_mixture_of_experts( if is_mixture_of_experts(
self.model) and self.parallel_config.enable_eplb: self.model) and self.parallel_config.enable_eplb:
logger.info("EPLB is enabled for model %s.", logger.info("EPLB is enabled for model %s.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment