Commit 4e4db0b4 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.14.1-dev' of http://10.16.6.30/dcutoolkit/deeplearing/vllm into v0.14.1-dev

parents f6653ed9 de21e4d1
...@@ -900,6 +900,7 @@ class ModelConfig: ...@@ -900,6 +900,7 @@ class ModelConfig:
"mxfp4", "mxfp4",
"cpu_awq", "cpu_awq",
"slimquant_w4a8_marlin", "slimquant_w4a8_marlin",
"slimquant_marlin",
"slimquant_compressed_tensors_marlin", "slimquant_compressed_tensors_marlin",
] ]
quantization_methods = [ quantization_methods = [
......
...@@ -344,7 +344,7 @@ class SpeculativeConfig: ...@@ -344,7 +344,7 @@ class SpeculativeConfig:
tokenizer_revision=self.target_model_config.tokenizer_revision, tokenizer_revision=self.target_model_config.tokenizer_revision,
spec_target_max_model_len=self.target_model_config.max_model_len, spec_target_max_model_len=self.target_model_config.max_model_len,
quantization=self.quantization, quantization=self.quantization,
enforce_eager=True if envs.VLLM_SPEC_DECODE_EAGER else self.target_model_config.enforce_eager, enforce_eager=self.target_model_config.enforce_eager,
max_logprobs=self.target_model_config.max_logprobs, max_logprobs=self.target_model_config.max_logprobs,
hf_overrides=SpeculativeConfig.hf_config_override, hf_overrides=SpeculativeConfig.hf_config_override,
config_format=self.target_model_config.config_format, config_format=self.target_model_config.config_format,
......
...@@ -259,7 +259,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): ...@@ -259,7 +259,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
# This is the DeepEP default. Stick to it till we can establish # This is the DeepEP default. Stick to it till we can establish
# reasonable defaults based on profiling. # reasonable defaults based on profiling.
self.num_sms = 20 self.num_sms = 30
def get_handle(self, kwargs): def get_handle(self, kwargs):
raise NotImplementedError raise NotImplementedError
...@@ -292,16 +292,21 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): ...@@ -292,16 +292,21 @@ class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase):
def _make_all2all_kwargs(self) -> dict[Any, Any]: def _make_all2all_kwargs(self) -> dict[Any, Any]:
# Defaults for internode and intranode are taken from DeepEP tests. # Defaults for internode and intranode are taken from DeepEP tests.
num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024 #num_nvl_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
num_nvl_bytes = int(2e9/2)#1024 * 1024 * 1024
num_rdma_bytes = None num_rdma_bytes = None
num_qps_per_rank = None num_qps_per_rank = None
if self.internode and not envs.VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE: if self.internode and not envs.VLLM_DEEPEP_HIGH_THROUGHPUT_FORCE_INTRA_NODE:
num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024 # num_rdma_bytes = envs.VLLM_DEEPEP_BUFFER_SIZE_MB * 1024 * 1024
num_qps_per_rank = self.num_sms // 2 # num_qps_per_rank = self.num_sms // 2
num_rdma_bytes = int(1e9/2) #1024 * 1024 * 1024
num_qps_per_rank = 30 #self.num_sms // 2
self.num_sms = 30
else: else:
num_rdma_bytes = 0 num_rdma_bytes = 0
num_qps_per_rank = 1 num_qps_per_rank = 1
self.num_sms = 60
assert num_rdma_bytes is not None assert num_rdma_bytes is not None
assert num_qps_per_rank is not None assert num_qps_per_rank is not None
......
...@@ -279,6 +279,7 @@ if TYPE_CHECKING: ...@@ -279,6 +279,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_MOE_SUM: bool = False VLLM_USE_LIGHTOP_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False
VLLM_USE_PD_SPLIT: bool = False VLLM_USE_PD_SPLIT: bool = False
VLLM_USE_PP_SYNC: bool = False VLLM_USE_PP_SYNC: bool = False
VLLM_USE_PIECEWISE: bool = False VLLM_USE_PIECEWISE: bool = False
...@@ -288,6 +289,7 @@ if TYPE_CHECKING: ...@@ -288,6 +289,7 @@ if TYPE_CHECKING:
VLLM_USE_TOPK_RENORM: bool = False VLLM_USE_TOPK_RENORM: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_W8A8_BACKEND: int = 3 VLLM_W8A8_BACKEND: int = 3
VLLM_REJECT_SAMPLE_OPT: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1814,6 +1816,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1814,6 +1816,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# blaslt: 3 (default) # blaslt: 3 (default)
# rocblas: others # rocblas: others
"VLLM_W8A8_BACKEND": lambda: int(os.getenv("VLLM_W8A8_BACKEND", "1")), "VLLM_W8A8_BACKEND": lambda: int(os.getenv("VLLM_W8A8_BACKEND", "1")),
# vllm will use optimized reject sample
"VLLM_REJECT_SAMPLE_OPT":
lambda: (os.getenv('VLLM_REJECT_SAMPLE_OPT', 'False').lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -160,6 +160,8 @@ def maybe_make_prepare_finalize( ...@@ -160,6 +160,8 @@ def maybe_make_prepare_finalize(
and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE
) )
use_int8_dispatch = quant_config.quant_dtype == torch.int8
prepare_finalize = DeepEPLLPrepareAndFinalize( prepare_finalize = DeepEPLLPrepareAndFinalize(
handle, handle,
max_tokens_per_rank=moe.max_num_tokens, max_tokens_per_rank=moe.max_num_tokens,
...@@ -168,6 +170,7 @@ def maybe_make_prepare_finalize( ...@@ -168,6 +170,7 @@ def maybe_make_prepare_finalize(
global_to_physical=global_to_physical, global_to_physical=global_to_physical,
physical_to_global=physical_to_global, physical_to_global=physical_to_global,
local_expert_global_ids=local_expert_global_ids, local_expert_global_ids=local_expert_global_ids,
use_int8_dispatch=use_int8_dispatch,
) )
return prepare_finalize return prepare_finalize
...@@ -22,6 +22,16 @@ from vllm.utils.deep_gemm import ( ...@@ -22,6 +22,16 @@ from vllm.utils.deep_gemm import (
) )
from vllm.utils.math_utils import cdiv, round_up from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.import_utils import has_deep_gemm
from lightop import fuse_silu_mul_quant_ep
if has_deep_gemm():
from deep_gemm import m_grouped_w8a8_gemm_nt_masked
else:
from lightop import m_grouped_w8a8_gemm_nt_masked
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -256,6 +266,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -256,6 +266,8 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
max_num_tokens: int, max_num_tokens: int,
num_dispatchers: int, num_dispatchers: int,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
N: int = -1,
K: int = -1,
): ):
""" """
max_num_tokens: Maximum number of tokens from a DP Rank max_num_tokens: Maximum number of tokens from a DP Rank
...@@ -263,11 +275,17 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -263,11 +275,17 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
quant_config: Quantization configuration quant_config: Quantization configuration
""" """
super().__init__(quant_config) super().__init__(quant_config)
assert self.block_shape == get_mk_alignment_for_contiguous_layout()
assert self.quant_config.use_fp8_w8a8 if quant_config.use_fp8_w8a8:
assert self.block_shape == get_mk_alignment_for_contiguous_layout()
#assert self.quant_config.use_fp8_w8a8
self.max_num_tokens = max_num_tokens self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers self.num_dispatchers = num_dispatchers
self.N = N
self.K = K
@property @property
def activation_formats( def activation_formats(
self, self,
...@@ -373,12 +391,15 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -373,12 +391,15 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
a1q = hidden_states a1q = hidden_states
_, N, K = w1.size() _, N, K = w1.size()
assert w2.size(1) == K #assert w2.size(1) == K
E, max_num_tokens, N, K, _ = self.moe_problem_size( E, max_num_tokens, N, K, _ = self.moe_problem_size(
hidden_states, w1, w2, topk_ids hidden_states, w1, w2, topk_ids
) )
if self.N > 0:
N = self.N
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N)) workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
expected_m = self.estimate_expected_m( expected_m = self.estimate_expected_m(
...@@ -387,25 +408,44 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -387,25 +408,44 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
topk=topk_ids.size(-1), topk=topk_ids.size(-1),
) )
fp8_m_grouped_gemm_nt_masked( if self.quant_config.use_fp8_w8a16 or self.quant_config.use_fp8_w8a8:
(a1q, a1q_scale), fp8_m_grouped_gemm_nt_masked(
(w1, self.w1_scale), (a1q, a1q_scale),
workspace1, (w1, self.w1_scale),
expert_num_tokens, workspace1,
expected_m, expert_num_tokens,
) expected_m,
)
quant_scale_fmt = DeepGemmQuantScaleFMT.from_oracle() quant_scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
a2q, a2q_scale = persistent_masked_m_silu_mul_quant( a2q, a2q_scale = persistent_masked_m_silu_mul_quant(
workspace1, workspace1,
expert_num_tokens, expert_num_tokens,
quant_scale_fmt=quant_scale_fmt, quant_scale_fmt=quant_scale_fmt,
) )
fp8_m_grouped_gemm_nt_masked( fp8_m_grouped_gemm_nt_masked(
(a2q, a2q_scale), (a2q, a2q_scale),
(w2, self.w2_scale), (w2, self.w2_scale),
output, output,
expert_num_tokens, expert_num_tokens,
expected_m, expected_m,
) )
elif self.quant_config.use_int8_w8a8:
m_grouped_w8a8_gemm_nt_masked((a1q, a1q_scale),
(w1, self.w1_scale),
workspace1,
expert_num_tokens,
expected_m,
)
assert expert_num_tokens is not None
a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens)
m_grouped_w8a8_gemm_nt_masked((a2q, a2q_scale),
(w2, self.w2_scale),
output,
expert_num_tokens,
expected_m)
else:
raise ValueError(f"Unsupported dtype {self.quant_config.quant_dtype}")
...@@ -87,7 +87,7 @@ def _quant_flags_to_group_shape( ...@@ -87,7 +87,7 @@ def _quant_flags_to_group_shape(
""" """
a_shape: GroupShape | None a_shape: GroupShape | None
w_shape: GroupShape | None w_shape: GroupShape | None
if block_shape is not None: if block_shape is not None and quant_dtype!=torch.int8:
assert not per_act_token_quant assert not per_act_token_quant
assert not per_out_ch_quant assert not per_out_ch_quant
# TODO(bnell): this is not quite right for activations since first # TODO(bnell): this is not quite right for activations since first
...@@ -207,10 +207,10 @@ class FusedMoEQuantConfig: ...@@ -207,10 +207,10 @@ class FusedMoEQuantConfig:
_w1: FusedMoEQuantDesc _w1: FusedMoEQuantDesc
_w2: FusedMoEQuantDesc _w2: FusedMoEQuantDesc
def __post_init__(self): # def __post_init__(self):
assert not self.per_act_token_quant or self.block_shape is None, ( # assert not self.per_act_token_quant or self.block_shape is None, (
"illegal quantization" # "illegal quantization"
) # )
# #
# Convenience accessors for various properties. # Convenience accessors for various properties.
...@@ -242,6 +242,9 @@ class FusedMoEQuantConfig: ...@@ -242,6 +242,9 @@ class FusedMoEQuantConfig:
@property @property
def block_shape(self) -> list[int] | None: def block_shape(self) -> list[int] | None:
if self.use_int8_w8a8:
return [256, 256]
if ( if (
self._a1.shape is not None self._a1.shape is not None
and self._a1.shape != GroupShape.PER_TENSOR and self._a1.shape != GroupShape.PER_TENSOR
...@@ -565,7 +568,7 @@ def int8_w8a8_moe_quant_config( ...@@ -565,7 +568,7 @@ def int8_w8a8_moe_quant_config(
a2_scale=a2_scale, a2_scale=a2_scale,
per_act_token_quant=per_act_token_quant, per_act_token_quant=per_act_token_quant,
per_out_ch_quant=False, per_out_ch_quant=False,
block_shape=None, block_shape=[256, 256],
) )
......
...@@ -33,6 +33,12 @@ from vllm.utils.deep_gemm import ( ...@@ -33,6 +33,12 @@ from vllm.utils.deep_gemm import (
) )
from vllm.utils.import_utils import has_deep_gemm from vllm.utils.import_utils import has_deep_gemm
from lightop import fuse_silu_mul_quant
if has_deep_gemm():
from deep_gemm import m_grouped_i8_gemm_nt_contiguous
else:
from lightop import m_grouped_w8a8_gemm_nt_contig_asm as m_grouped_i8_gemm_nt_contiguous
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -109,12 +115,19 @@ def _valid_deep_gemm( ...@@ -109,12 +115,19 @@ def _valid_deep_gemm(
class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config: FusedMoEQuantConfig): def __init__(self, quant_config: FusedMoEQuantConfig,
N: int = -1,
K: int = -1,):
super().__init__(quant_config) super().__init__(quant_config)
assert quant_config.block_shape == get_mk_alignment_for_contiguous_layout()
assert quant_config.quant_dtype == torch.float8_e4m3fn if quant_config.use_fp8_w8a8 or quant_config.use_fp8_w8a16:
assert not quant_config.per_act_token_quant assert quant_config.block_shape == get_mk_alignment_for_contiguous_layout()
assert not quant_config.per_out_ch_quant assert quant_config.quant_dtype == torch.float8_e4m3fn
assert not quant_config.per_act_token_quant
assert not quant_config.per_out_ch_quant
self.N = N
self.K = K
@property @property
def activation_formats( def activation_formats(
...@@ -230,19 +243,24 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -230,19 +243,24 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = local_num_experts global_num_experts = local_num_experts
assert w2.size(1) == K #assert w2.size(1) == K
if self.N > 0:
N = self.N
K = self.K
use_fp8 = self.quant_config.use_fp8_w8a16 or self.quant_config.use_fp8_w8a8
M_sum = compute_aligned_M( M_sum = compute_aligned_M(
M=topk_ids.size(0), M=topk_ids.size(0),
num_topk=topk_ids.size(1), num_topk=topk_ids.size(1),
local_num_experts=local_num_experts, local_num_experts=local_num_experts,
alignment=get_mk_alignment_for_contiguous_layout()[0], alignment=get_mk_alignment_for_contiguous_layout()[0] if use_fp8 else self.block_shape[0],
expert_tokens_meta=expert_tokens_meta, expert_tokens_meta=expert_tokens_meta,
) )
a1q_perm = _resize_cache( a1q_perm = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, K) workspace13.view(dtype=torch.float8_e4m3fn if use_fp8 else a1q.dtype), (M_sum, K)
) )
a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute( a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
aq=a1q, aq=a1q,
aq_scale=a1q_scale, aq_scale=a1q_scale,
...@@ -255,22 +273,37 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -255,22 +273,37 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
assert a1q.size(0) == M_sum assert a1q.size(0) == M_sum
mm1_out = _resize_cache(workspace2, (M_sum, N)) mm1_out = _resize_cache(workspace2, (M_sum, N))
m_grouped_fp8_gemm_nt_contiguous(
(a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids
)
activation_out_dim = self.adjust_N_for_activation(N, activation) if use_fp8:
quant_out = _resize_cache( m_grouped_fp8_gemm_nt_contiguous(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, activation_out_dim) (a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids
) )
a2q, a2q_scale = self._act_mul_quant(
input=mm1_out.view(-1, N), output=quant_out, activation=activation activation_out_dim = self.adjust_N_for_activation(N, activation)
) quant_out = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, activation_out_dim)
)
a2q, a2q_scale = self._act_mul_quant(
input=mm1_out.view(-1, N), output=quant_out, activation=activation
)
mm2_out = _resize_cache(workspace2, (M_sum, K))
m_grouped_fp8_gemm_nt_contiguous(
(a2q, a2q_scale), (w2, self.w2_scale), mm2_out, expert_ids
)
elif self.quant_config.use_int8_w8a8:
m_grouped_i8_gemm_nt_contiguous(
(a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids)
a2q, a2q_scale = fuse_silu_mul_quant(mm1_out)
#a2q, a2q_scale = fuse_silu_mul_quant(input=mm1_out, expert_ids=expert_ids)
mm2_out = _resize_cache(workspace2, (M_sum, K))
m_grouped_i8_gemm_nt_contiguous(
(a2q, a2q_scale), (w2, self.w2_scale), mm2_out, expert_ids)
else:
raise ValueError(f"Unsupported dtype {self.quant_config.quant_dtype}")
mm2_out = _resize_cache(workspace2, (M_sum, K))
m_grouped_fp8_gemm_nt_contiguous(
(a2q, a2q_scale), (w2, self.w2_scale), mm2_out, expert_ids
)
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights) topk_weights = torch.ones_like(topk_weights)
......
...@@ -13,6 +13,8 @@ from vllm.triton_utils import tl, triton ...@@ -13,6 +13,8 @@ from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import get_mk_alignment_for_contiguous_layout from vllm.utils.deep_gemm import get_mk_alignment_for_contiguous_layout
from vllm.utils.math_utils import round_up from vllm.utils.math_utils import round_up
from lightop import op
def expert_num_tokens_round_up_and_sum( def expert_num_tokens_round_up_and_sum(
expert_num_tokens: torch.Tensor, alignment: int expert_num_tokens: torch.Tensor, alignment: int
...@@ -57,6 +59,12 @@ def round_up_128(x: int) -> int: ...@@ -57,6 +59,12 @@ def round_up_128(x: int) -> int:
return ((x + y - 1) // y) * y return ((x + y - 1) // y) * y
@triton.jit
def round_up_256(x: int) -> int:
y = 256
return ((x + y - 1) // y) * y
@triton.jit @triton.jit
def _fwd_kernel_ep_scatter_1( def _fwd_kernel_ep_scatter_1(
num_recv_tokens_per_expert, num_recv_tokens_per_expert,
...@@ -74,26 +82,27 @@ def _fwd_kernel_ep_scatter_1( ...@@ -74,26 +82,27 @@ def _fwd_kernel_ep_scatter_1(
mask=offset_cumsum < num_experts, mask=offset_cumsum < num_experts,
other=0, other=0,
) )
tokens_per_expert = round_up_128(tokens_per_expert) #tokens_per_expert = round_up_128(tokens_per_expert)
tokens_per_expert = round_up_256(tokens_per_expert)
cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert
#if cur_expert == 0:
tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts) tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts)
tl.debug_barrier()
#cur_expert_start = cumsum[cur_expert]
cur_expert_start = tl.load(expert_start_loc + cur_expert) cur_expert_start = tl.load(expert_start_loc + cur_expert)
cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert) cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert)
m_indices_start_ptr = m_indices + cur_expert_start m_indices_start_ptr = m_indices + cur_expert_start
off_expert = tl.arange(0, BLOCK_E) off_expert = tl.arange(0, BLOCK_E)
# any rows in the per-expert aligned region that do not correspond to
# real tokens are left untouched here and should remain initialized to
# -1 so DeepGEMM can skip them
for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4): for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4):
offs = start_m + off_expert
mask = offs < cur_expert_token_num
tl.store( tl.store(
m_indices_start_ptr + offs, m_indices_start_ptr + start_m + off_expert,
cur_expert, cur_expert,
mask=mask, mask=start_m + off_expert < cur_expert_token_num
) )
...@@ -133,26 +142,32 @@ def _fwd_kernel_ep_scatter_2( ...@@ -133,26 +142,32 @@ def _fwd_kernel_ep_scatter_2(
offset_in = tl.arange(0, HIDDEN_SIZE_PAD) offset_in = tl.arange(0, HIDDEN_SIZE_PAD)
mask = offset_in < HIDDEN_SIZE mask = offset_in < HIDDEN_SIZE
offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD) index_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
mask_s = offset_in_s < SCALE_HIDDEN_SIZE mask_s = index_in_s < SCALE_HIDDEN_SIZE
for token_id in range(start_token_id, total_token_num, grid_num): for token_id_int32 in range(start_token_id, total_token_num, grid_num):
token_id = token_id_int32.to(tl.int64)
to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask) to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
to_copy_s = tl.load( to_copy_s = tl.load(
recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s recv_x_scale
+ token_id * recv_x_scale_stride0
+ index_in_s * recv_x_scale_stride1,
mask=mask_s,
) )
for topk_index in tl.range(0, topk_num, 1, num_stages=4): for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4):
topk_index = topk_idx_int32.to(tl.int64)
expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index) expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
if HAS_EXPERT_MAP: if HAS_EXPERT_MAP:
expert_id = apply_expert_map(expert_id, expert_map) expert_id = apply_expert_map(expert_id, expert_map)
if expert_id >= 0: if expert_id >= 0:
dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1) dest_token_index_int32 = tl.atomic_add(expert_start_loc + expert_id, 1)
dest_token_index = dest_token_index_int32.to(tl.int64)
tl.store( tl.store(
output_index + token_id * output_index_stride0 + topk_index, output_index + token_id * output_index_stride0 + topk_index,
dest_token_index, dest_token_index_int32,
) )
output_tensor_ptr = ( output_tensor_ptr = (
output_tensor + dest_token_index * output_tensor_stride0 output_tensor + dest_token_index * output_tensor_stride0
...@@ -161,7 +176,11 @@ def _fwd_kernel_ep_scatter_2( ...@@ -161,7 +176,11 @@ def _fwd_kernel_ep_scatter_2(
output_tensor_scale + dest_token_index * output_tensor_scale_stride0 output_tensor_scale + dest_token_index * output_tensor_scale_stride0
) )
tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask) tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask)
tl.store(output_tensor_scale_ptr + offset_in_s, to_copy_s, mask=mask_s) tl.store(
output_tensor_scale_ptr + index_in_s * output_tensor_scale_stride1,
to_copy_s,
mask=mask_s,
)
@torch.no_grad() @torch.no_grad()
...@@ -177,58 +196,71 @@ def ep_scatter( ...@@ -177,58 +196,71 @@ def ep_scatter(
m_indices: torch.Tensor, m_indices: torch.Tensor,
output_index: torch.Tensor, output_index: torch.Tensor,
): ):
BLOCK_E = 128 # token num of per expert is aligned to 128 # BLOCK_E = 128 # token num of per expert is aligned to 128
BLOCK_D = 128 # block size of quantization # BLOCK_D = 128 # block size of quantization
BLOCK_E = 256 # token num of per expert is aligned to 256
num_warps = 8 num_warps = 8
num_experts = num_recv_tokens_per_expert.shape[0] num_experts = num_recv_tokens_per_expert.shape[0]
hidden_size = recv_x.shape[1] hidden_size = recv_x.shape[1]
scale_hidden_size = recv_x_scale.shape[-1]
# grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts) # grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts)
grid = num_experts grid = num_experts
assert m_indices.shape[0] % BLOCK_E == 0 assert m_indices.shape[0] % BLOCK_E == 0
_fwd_kernel_ep_scatter_1[(grid,)]( if hasattr(op, "ep_scatter"):
num_recv_tokens_per_expert, op.ep_scatter(
expert_start_loc, recv_x, recv_x_scale,
m_indices, recv_topk, expert_map,
num_experts=num_experts, num_recv_tokens_per_expert,
num_warps=num_warps, output_tensor, output_tensor_scale, m_indices, output_index,
BLOCK_E=BLOCK_E, num_experts, BLOCK_E
BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts), )
) else:
_fwd_kernel_ep_scatter_1[(grid,)](
num_recv_tokens_per_expert,
expert_start_loc,
m_indices,
num_experts=num_experts,
num_warps=num_warps,
BLOCK_E=BLOCK_E,
BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts),
)
grid = min(recv_topk.shape[0], 1024 * 8) grid = min(recv_topk.shape[0], 1024 * 8)
_fwd_kernel_ep_scatter_2[(grid,)]( _fwd_kernel_ep_scatter_2[(grid,)](
recv_topk.shape[0], recv_topk.shape[0],
expert_start_loc, expert_start_loc,
recv_x, recv_x,
recv_x.stride(0), recv_x.stride(0),
recv_x.stride(1), recv_x.stride(1),
recv_x_scale, recv_x_scale,
recv_x_scale.stride(0), recv_x_scale.stride(0),
recv_x_scale.stride(1), recv_x_scale.stride(1),
recv_topk, recv_topk,
recv_topk.stride(0), recv_topk.stride(0),
recv_topk.stride(1), recv_topk.stride(1),
output_tensor, output_tensor,
output_tensor.stride(0), output_tensor.stride(0),
output_tensor.stride(1), output_tensor.stride(1),
output_tensor_scale, output_tensor_scale,
output_tensor_scale.stride(0), output_tensor_scale.stride(0),
output_tensor_scale.stride(1), output_tensor_scale.stride(1),
output_index, output_index,
output_index.stride(0), output_index.stride(0),
output_index.stride(1), output_index.stride(1),
topk_num=recv_topk.shape[1], topk_num=recv_topk.shape[1],
expert_map=expert_map, expert_map=expert_map,
HAS_EXPERT_MAP=expert_map is not None, HAS_EXPERT_MAP=expert_map is not None,
num_warps=num_warps, num_warps=num_warps,
HIDDEN_SIZE=hidden_size, HIDDEN_SIZE=hidden_size,
HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size), HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size),
SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D, # SCALE_HIDDEN_SIZE=hidden_size // BLOCK_D,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D), # SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size // BLOCK_D),
) SCALE_HIDDEN_SIZE=scale_hidden_size,
SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size),
)
return return
...@@ -255,25 +287,34 @@ def _fwd_kernel_ep_gather( ...@@ -255,25 +287,34 @@ def _fwd_kernel_ep_gather(
HAS_EXPERT_MAP: tl.constexpr, HAS_EXPERT_MAP: tl.constexpr,
BLOCK_D: tl.constexpr, BLOCK_D: tl.constexpr,
): ):
cur_block = tl.program_id(0) cur_block_int32 = tl.program_id(0)
start_cur_token = tl.program_id(1) cur_block = cur_block_int32.to(tl.int64)
start_cur_token_int32 = tl.program_id(1)
grid_num = tl.num_programs(1) grid_num = tl.num_programs(1)
for cur_token in range(start_cur_token, total_token_num, grid_num): for cur_token_int32 in range(start_cur_token_int32, total_token_num, grid_num):
cur_token = cur_token_int32.to(tl.int64)
off_d = tl.arange(0, BLOCK_D) off_d = tl.arange(0, BLOCK_D)
accumulator = tl.zeros([BLOCK_D], dtype=tl.float32) accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
for topk_index in range(0, topk_num):
for topk_index_int32 in range(0, topk_num):
topk_index = topk_index_int32.to(tl.int64)
expert_id = tl.load( expert_id = tl.load(
recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
) )
if HAS_EXPERT_MAP: if HAS_EXPERT_MAP:
expert_id = apply_expert_map(expert_id, expert_map) expert_id = apply_expert_map(expert_id, expert_map)
if expert_id >= 0: if expert_id >= 0:
source_token_index = tl.load( source_token_index_int32 = tl.load(
input_index + cur_token * input_index_stride0 + topk_index input_index + cur_token * input_index_stride0 + topk_index
) )
source_token_index = source_token_index_int32.to(tl.int64)
acc_weight = tl.load( acc_weight = tl.load(
recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
) )
...@@ -350,7 +391,8 @@ def deepgemm_moe_permute( ...@@ -350,7 +391,8 @@ def deepgemm_moe_permute(
H = aq.size(1) H = aq.size(1)
device = aq.device device = aq.device
block_m, block_k = get_mk_alignment_for_contiguous_layout() #block_m, block_k = get_mk_alignment_for_contiguous_layout()
block_m = 256
M_sum = compute_aligned_M( M_sum = compute_aligned_M(
M=topk_ids.size(0), M=topk_ids.size(0),
...@@ -368,8 +410,11 @@ def deepgemm_moe_permute( ...@@ -368,8 +410,11 @@ def deepgemm_moe_permute(
if aq_out is None: if aq_out is None:
aq_out = torch.empty((M_sum, H), device=device, dtype=aq.dtype) aq_out = torch.empty((M_sum, H), device=device, dtype=aq.dtype)
# aq_scale_out = torch.empty(
# (M_sum, H // block_k), device=device, dtype=torch.float32
# )
aq_scale_out = torch.empty( aq_scale_out = torch.empty(
(M_sum, H // block_k), device=device, dtype=torch.float32 (M_sum, aq_scale.shape[-1]), device=device, dtype=torch.float32
) )
# DeepGEMM uses negative values in m_indices (here expert_ids) to mark # DeepGEMM uses negative values in m_indices (here expert_ids) to mark
......
...@@ -225,7 +225,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -225,7 +225,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# DeepEP kernels only support dispatching block-quantized # DeepEP kernels only support dispatching block-quantized
# activation scales. # activation scales.
# Dispatch in bfloat16 and quantize afterwards # Dispatch in bfloat16 and quantize afterwards
if not quant_config.is_block_quantized: if not quant_config.is_block_quantized and not quant_config.is_per_act_token:
# Quantize after dispatch. # Quantize after dispatch.
expert_x_scale = None expert_x_scale = None
if expert_x.numel() != 0: if expert_x.numel() != 0:
...@@ -266,7 +266,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -266,7 +266,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) )
a1 = a1 * topk_weights.to(a1.dtype) a1 = a1 * topk_weights.to(a1.dtype)
if quant_config.is_block_quantized: if quant_config.is_block_quantized or quant_config.is_per_act_token:
# Quant and Dispatch # Quant and Dispatch
a1q, a1q_scale = moe_kernel_quantize_input( a1q, a1q_scale = moe_kernel_quantize_input(
a1, a1,
...@@ -345,6 +345,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -345,6 +345,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
f"Expected fused_expert_output bfloat16, got {fused_expert_output.dtype}" f"Expected fused_expert_output bfloat16, got {fused_expert_output.dtype}"
) )
previous_event = dbo_get_previous_event(self.buffer.capture) previous_event = dbo_get_previous_event(self.buffer.capture)
torch.cuda.synchronize()
print(f"####################combine x shape:{fused_expert_output.shape} x dtype:{fused_expert_output.dtype}, config:{self._get_combine_config()}, do_async:{do_async}")
combined_x, _, event = self.buffer.combine( combined_x, _, event = self.buffer.combine(
# HT combine only supports BF16 # HT combine only supports BF16
x=fused_expert_output, x=fused_expert_output,
...@@ -356,6 +359,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -356,6 +359,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
allocate_on_comm_stream=False, allocate_on_comm_stream=False,
) )
torch.cuda.synchronize()
print(f"################combine end")
dbo_switch_to_compute() dbo_switch_to_compute()
if do_async: if do_async:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable from collections.abc import Callable
from typing import Callable, Optional
import deep_ep import deep_ep
import torch import torch
...@@ -91,12 +92,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -91,12 +92,14 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
global_to_physical: torch.Tensor | None = None, global_to_physical: torch.Tensor | None = None,
physical_to_global: torch.Tensor | None = None, physical_to_global: torch.Tensor | None = None,
local_expert_global_ids: torch.Tensor | None = None, local_expert_global_ids: torch.Tensor | None = None,
use_int8_dispatch: bool = False
): ):
super().__init__() super().__init__()
self.buffer = buffer self.buffer = buffer
self.max_tokens_per_rank = max_tokens_per_rank self.max_tokens_per_rank = max_tokens_per_rank
self.use_fp8_dispatch = use_fp8_dispatch self.use_fp8_dispatch = use_fp8_dispatch
self.use_int8_dispatch = use_int8_dispatch
# The dispatch function returns a handle that the combine function # The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the # requires. We store the handle here so it is available to the
# combine function. # combine function.
...@@ -168,6 +171,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -168,6 +171,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor], x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
a1_dtype: torch.dtype, a1_dtype: torch.dtype,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
expert_num_tokens: Optional[torch.Tensor]= None,
) -> tuple[torch.Tensor, torch.Tensor | None]: ) -> tuple[torch.Tensor, torch.Tensor | None]:
if self.use_fp8_dispatch: if self.use_fp8_dispatch:
block_k = ( block_k = (
...@@ -183,6 +187,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -183,6 +187,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# Dequant to get back the tokens in the datatype we dispatched in. # Dequant to get back the tokens in the datatype we dispatched in.
x_fp8, x_scales = x x_fp8, x_scales = x
x = dequant_fp8(x_fp8, x_scales).to(dtype=a1_dtype) x = dequant_fp8(x_fp8, x_scales).to(dtype=a1_dtype)
elif self.use_int8_dispatch:
x, x_scales = x
return x, x_scales
assert isinstance(x, (torch.Tensor, tuple)) assert isinstance(x, (torch.Tensor, tuple))
q_dtype = quant_config.quant_dtype q_dtype = quant_config.quant_dtype
...@@ -214,7 +221,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -214,7 +221,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts, max_tokens, hidden_dim = x.size() num_experts, max_tokens, hidden_dim = x.size()
# TODO (varun): Optimization - Use a batched version of quant # TODO (varun): Optimization - Use a batched version of quant
x = x.view((-1, hidden_dim)) if expert_num_tokens is None:
x = x.view((-1, hidden_dim))
x, x_scales = moe_kernel_quantize_input( x, x_scales = moe_kernel_quantize_input(
x, x,
quant_config.a1_scale, quant_config.a1_scale,
...@@ -294,7 +302,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -294,7 +302,8 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
dispatch_topk_ids, dispatch_topk_ids,
self.max_tokens_per_rank, self.max_tokens_per_rank,
num_experts, num_experts,
use_fp8=self.use_fp8_dispatch, use_fp8=self.use_fp8_dispatch or self.use_int8_dispatch,
use_int8=self.use_int8_dispatch,
round_scale=self.use_ue8m0_dispatch, round_scale=self.use_ue8m0_dispatch,
use_ue8m0=self.use_ue8m0_dispatch, use_ue8m0=self.use_ue8m0_dispatch,
**(dict(use_nvfp4=True) if use_nvfp4 else dict()), **(dict(use_nvfp4=True) if use_nvfp4 else dict()),
...@@ -327,7 +336,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -327,7 +336,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1_dtype: torch.dtype, a1_dtype: torch.dtype,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType: ) -> mk.PrepareResultType:
expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype, quant_config) expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype, quant_config, expert_num_tokens)
expert_tokens_meta = mk.ExpertTokensMetadata( expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None
......
...@@ -54,6 +54,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -54,6 +54,8 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
old_quant_method.select_gemm_impl(prepare_finalize, moe_layer), old_quant_method.select_gemm_impl(prepare_finalize, moe_layer),
shared_experts, shared_experts,
moe_parallel_config=moe_layer.moe_parallel_config, moe_parallel_config=moe_layer.moe_parallel_config,
N=old_quant_method.N if hasattr(old_quant_method, "N") else -1,
K=old_quant_method.K if hasattr(old_quant_method, "K") else -1,
), ),
) )
...@@ -95,6 +97,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): ...@@ -95,6 +97,7 @@ class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
router: FusedMoERouter, router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
use_nn_moe: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = router.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
......
...@@ -306,9 +306,8 @@ class FusedMoERouterImpl(FusedMoERouter): ...@@ -306,9 +306,8 @@ class FusedMoERouterImpl(FusedMoERouter):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
use_fused_gate: bool | None = False,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
return self.layer._select_experts(hidden_states, router_logits, use_fused_gate) return self.layer._select_experts(hidden_states, router_logits)
# --8<-- [start:fused_moe] # --8<-- [start:fused_moe]
...@@ -1594,7 +1593,6 @@ class FusedMoE(CustomOp): ...@@ -1594,7 +1593,6 @@ class FusedMoE(CustomOp):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
use_fused_gate: bool | None = False,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Route the input hidden states to the top-k experts based on the Route the input hidden states to the top-k experts based on the
...@@ -1657,7 +1655,7 @@ class FusedMoE(CustomOp): ...@@ -1657,7 +1655,7 @@ class FusedMoE(CustomOp):
# DeepSeekv2 uses grouped_top_k # DeepSeekv2 uses grouped_top_k
elif self.use_grouped_topk and valid_grouping(): elif self.use_grouped_topk and valid_grouping():
assert self._grouped_topk_impl is not None assert self._grouped_topk_impl is not None
if use_fused_gate: if self.use_fused_gate:
# if envs.VLLM_USE_LIGHTOP: # if envs.VLLM_USE_LIGHTOP:
if False: if False:
topk_weights, topk_ids = op.moe_fused_gate( topk_weights, topk_ids = op.moe_fused_gate(
...@@ -1672,10 +1670,10 @@ class FusedMoE(CustomOp): ...@@ -1672,10 +1670,10 @@ class FusedMoE(CustomOp):
else: else:
topk_weights, topk_ids = ops.moe_fused_gate( topk_weights, topk_ids = ops.moe_fused_gate(
router_logits, router_logits,
e_score_correction_bias=self.e_score_correction_bias, self.e_score_correction_bias,
num_expert_group=self.num_expert_group, self.num_expert_group,
topk_group=self.topk_group, self.topk_group,
topk=self.top_k, self.top_k,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
n_share_experts_fusion=0, n_share_experts_fusion=0,
) )
...@@ -1881,35 +1879,35 @@ class FusedMoE(CustomOp): ...@@ -1881,35 +1879,35 @@ class FusedMoE(CustomOp):
staged_hidden_states.copy_(hidden_states, non_blocking=True) staged_hidden_states.copy_(hidden_states, non_blocking=True)
staged_router_logits.copy_(router_logits, non_blocking=True) staged_router_logits.copy_(router_logits, non_blocking=True)
zero_expert_result = None # zero_expert_result = None
if self.zero_expert_num > 0 and self.zero_expert_type is not None: # if self.zero_expert_num > 0 and self.zero_expert_type is not None:
topk_weights, topk_ids = FusedMoE.select_experts( # topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=staged_hidden_states, # hidden_states=staged_hidden_states,
router_logits=staged_router_logits, # router_logits=staged_router_logits,
use_grouped_topk=self.use_grouped_topk, # use_grouped_topk=self.use_grouped_topk,
top_k=self.top_k, # top_k=self.top_k,
renormalize=self.renormalize, # renormalize=self.renormalize,
topk_group=self.topk_group, # topk_group=self.topk_group,
num_expert_group=self.num_expert_group, # num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function, # custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func, # scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor, # routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias, # e_score_correction_bias=self.e_score_correction_bias,
indices_type=self.quant_method.topk_indices_dtype, # indices_type=self.quant_method.topk_indices_dtype,
enable_eplb=self.enable_eplb, # enable_eplb=self.enable_eplb,
expert_map=self.expert_map, # expert_map=self.expert_map,
expert_load_view=self.expert_load_view, # expert_load_view=self.expert_load_view,
logical_to_physical_map=self.logical_to_physical_map, # logical_to_physical_map=self.logical_to_physical_map,
logical_replica_count=self.logical_replica_count) # logical_replica_count=self.logical_replica_count)
# Compute zero_expert_result # # Compute zero_expert_result
zero_expert_result = zero_experts_compute_triton( # zero_expert_result = zero_experts_compute_triton(
expert_indices=topk_ids, # expert_indices=topk_ids,
expert_scales=topk_weights, # expert_scales=topk_weights,
num_experts=self.global_num_experts, # num_experts=self.global_num_experts,
zero_expert_type=self.zero_expert_type, # zero_expert_type=self.zero_expert_type,
hidden_states=staged_hidden_states, # hidden_states=staged_hidden_states,
) # )
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
...@@ -2093,7 +2091,6 @@ class FusedMoE(CustomOp): ...@@ -2093,7 +2091,6 @@ class FusedMoE(CustomOp):
else hidden_states, else hidden_states,
router_logits=router_logits, router_logits=router_logits,
use_nn_moe=self.use_nn_moe, use_nn_moe=self.use_nn_moe,
use_fused_gate=self.use_fused_gate,
) )
if has_separate_shared_experts: if has_separate_shared_experts:
......
...@@ -676,6 +676,21 @@ def _slice_scales( ...@@ -676,6 +676,21 @@ def _slice_scales(
return None return None
_alt_stream: torch.cuda.Stream | None = None
def alt_stream() -> torch.cuda.Stream | None:
"""
Ensures aux_stream is initialized only once
"""
global _alt_stream
# TODO: validate this works properly on ROCm platform.
if _alt_stream is None:
_alt_stream = torch.cuda.Stream()
return _alt_stream
@final @final
class FusedMoEModularKernel(torch.nn.Module): class FusedMoEModularKernel(torch.nn.Module):
""" """
...@@ -696,6 +711,8 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -696,6 +711,8 @@ class FusedMoEModularKernel(torch.nn.Module):
fused_experts: FusedMoEPermuteExpertsUnpermute, fused_experts: FusedMoEPermuteExpertsUnpermute,
shared_experts: torch.nn.Module | None = None, shared_experts: torch.nn.Module | None = None,
moe_parallel_config: FusedMoEParallelConfig | None = None, moe_parallel_config: FusedMoEParallelConfig | None = None,
N: int = -1,
K: int = -1,
): ):
super().__init__() super().__init__()
self.prepare_finalize = prepare_finalize self.prepare_finalize = prepare_finalize
...@@ -722,6 +739,12 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -722,6 +739,12 @@ class FusedMoEModularKernel(torch.nn.Module):
f"{fused_experts.__class__.__name__}." f"{fused_experts.__class__.__name__}."
f"{fused_experts.activation_formats[0]}" f"{fused_experts.activation_formats[0]}"
) )
self.N = N
self.K = K
if self.shared_experts is not None:
self.alt_stream = alt_stream()
self.alt_event = torch.cuda.Event()
def _post_init_setup(self): def _post_init_setup(self):
""" """
...@@ -1020,11 +1043,13 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1020,11 +1043,13 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map: torch.Tensor | None, expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
expert_tokens_meta: ExpertTokensMetadata | None, expert_tokens_meta: ExpertTokensMetadata | None,
use_nn_moe: bool | None = False,
) -> torch.Tensor: ) -> torch.Tensor:
_, M_full, N, K, top_k = self.fused_experts.moe_problem_size( _, M_full, N, K, top_k = self.fused_experts.moe_problem_size(
a1q, w1, w2, topk_ids a1q, w1, w2, topk_ids
) )
if self.N > 0:
N = self.N
K = self.K
num_chunks, CHUNK_SIZE = self._chunk_info(M_full) num_chunks, CHUNK_SIZE = self._chunk_info(M_full)
...@@ -1096,7 +1121,6 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1096,7 +1121,6 @@ class FusedMoEModularKernel(torch.nn.Module):
workspace2=workspace2, workspace2=workspace2,
expert_tokens_meta=c_expert_tokens_meta, expert_tokens_meta=c_expert_tokens_meta,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_nn_moe=use_nn_moe,
) )
return fused_out return fused_out
...@@ -1130,37 +1154,46 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1130,37 +1154,46 @@ class FusedMoEModularKernel(torch.nn.Module):
if self.shared_experts is not None: if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
else: else:
finalize_ret = self.prepare_finalize.finalize_async( self.alt_event.record()
output,
fused_out,
topk_weights,
topk_ids,
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
)
if self.shared_experts is not None: if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
# TODO(lucas): refactor this in the alternative schedules followup current_stream = torch.cuda.current_stream()
# currently unpack if we have hook + receiver pair or just with torch.cuda.stream(self.alt_stream):
# receiver (see finalize_async docstring) self.alt_stream.wait_event(self.alt_event)
hook, receiver = (
finalize_ret finalize_ret = self.prepare_finalize.finalize_async(
if isinstance(finalize_ret, tuple) output,
else (None, finalize_ret) fused_out,
) topk_weights,
topk_ids,
apply_router_weight_on_input,
self.fused_experts.finalize_weight_and_reduce_impl(),
)
if hook is not None: # TODO(lucas): refactor this in the alternative schedules followup
if dbo_enabled(): # currently unpack if we have hook + receiver pair or just
# If DBO is being used, register the hook with the ubatch # receiver (see finalize_async docstring)
# context and call it in dbo_maybe_run_recv_hook instead of hook, receiver = (
# passing it to the receiver. finalize_ret
dbo_register_recv_hook(hook) if isinstance(finalize_ret, tuple)
dbo_yield() else (None, finalize_ret)
else: )
hook()
if hook is not None:
if dbo_enabled():
# If DBO is being used, register the hook with the ubatch
# context and call it in dbo_maybe_run_recv_hook instead of
# passing it to the receiver.
dbo_register_recv_hook(hook)
dbo_yield()
else:
hook()
receiver()
receiver() self.alt_event.record()
current_stream.wait_event(self.alt_event)
if self.shared_experts is None: if self.shared_experts is None:
return output return output
...@@ -1180,7 +1213,6 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1180,7 +1213,6 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_nn_moe: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets This function computes a Mixture of Experts (MoE) layer using two sets
...@@ -1242,7 +1274,6 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1242,7 +1274,6 @@ class FusedMoEModularKernel(torch.nn.Module):
expert_map=expert_map, expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
expert_tokens_meta=expert_tokens_meta, expert_tokens_meta=expert_tokens_meta,
use_nn_moe=use_nn_moe,
) )
return self._finalize( return self._finalize(
......
...@@ -316,7 +316,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -316,7 +316,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward( return self.forward(
router=router, router=router,
...@@ -324,7 +323,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -324,7 +323,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x=x, x=x,
router_logits=router_logits, router_logits=router_logits,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
use_fused_gate=use_fused_gate,
) )
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
...@@ -343,12 +341,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -343,12 +341,10 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = router.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_fused_gate=use_fused_gate,
) )
result = self.kernel( result = self.kernel(
hidden_states=x, hidden_states=x,
......
...@@ -2,9 +2,11 @@ ...@@ -2,9 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools import functools
from math import prod from math import prod
from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from triton.language.extra import libdevice
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
...@@ -157,11 +159,147 @@ def _fp8_quantize( ...@@ -157,11 +159,147 @@ def _fp8_quantize(
return A, A_scale return A, A_scale
@triton.jit
def _per_token_quant_int8_one_kernel_opt(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
T_dim,
has_tokens_per_expert: tl.constexpr,
tokens_per_expert_ptr,
BLOCK: tl.constexpr
):
row_id = tl.program_id(0)
if has_tokens_per_expert:
e = row_id // T_dim
t = row_id % T_dim
num_valid_tokens_for_e = tl.load(tokens_per_expert_ptr + e)
if t >= num_valid_tokens_for_e:
return
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)
@triton.jit
def _per_token_quant_int8_kernel_opt(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
E_dim,
T_dim,
has_tokens_per_expert: tl.constexpr,
tokens_per_expert_ptr,
BLOCK: tl.constexpr
):
token_idx_start = tl.program_id(0)
grid_size = tl.num_programs(0)
num_total_tokens = E_dim * T_dim
for token_idx in range(token_idx_start, num_total_tokens, grid_size):
is_valid_token = True
if has_tokens_per_expert:
e = token_idx // T_dim
t = token_idx % T_dim
num_valid_tokens_for_e = tl.load(tokens_per_expert_ptr + e)
if t >= num_valid_tokens_for_e:
is_valid_token = False
if is_valid_token:
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + token_idx * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + token_idx * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + token_idx, scale_x)
def per_token_quant_int8_triton_opt(x: torch.Tensor,
tokens_per_expert: Optional[torch.Tensor] = None):
if x.dim() != 3:
raise ValueError(f"Input must be 3D [E, T, H], but got {x.shape}")
E, T, H = x.shape
N = H
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(x.shape[:-1] + (1, ),
device=x.device,
dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
num_warps = min(max(BLOCK // 256, 1), 8)
if T >= 4096:
num_warps = 1
num_tokens = E * T
grid_opt = num_tokens
if E == 16 and T >= 1024 :
grid_opt = max(1, num_tokens // (T // 256))
_per_token_quant_int8_kernel_opt[(grid_opt, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
E_dim=E,
T_dim=T,
has_tokens_per_expert=tokens_per_expert is not None,
tokens_per_expert_ptr=tokens_per_expert,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
else:
_per_token_quant_int8_one_kernel_opt[(grid_opt, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
T_dim=T,
has_tokens_per_expert=tokens_per_expert is not None,
tokens_per_expert_ptr=tokens_per_expert,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
def _int8_quantize( def _int8_quantize(
A: torch.Tensor, A: torch.Tensor,
A_scale: torch.Tensor | None, A_scale: torch.Tensor | None,
per_act_token: bool, per_act_token: bool,
block_shape: list[int] | None = None, block_shape: list[int] | None = None,
expert_num_tokens: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Perform int8 quantization on the inputs. If a block_shape Perform int8 quantization on the inputs. If a block_shape
...@@ -171,9 +309,12 @@ def _int8_quantize( ...@@ -171,9 +309,12 @@ def _int8_quantize(
# If weights are per-channel (per_channel_quant=True), then # If weights are per-channel (per_channel_quant=True), then
# activations apply per-token quantization. Otherwise, assume # activations apply per-token quantization. Otherwise, assume
# activation tensor-wise fp8/int8 quantization, dynamic or static # activation tensor-wise fp8/int8 quantization, dynamic or static
if block_shape is None: if block_shape is None or per_act_token:
assert per_act_token, "int8 quantization only supports block or channel-wise" assert per_act_token, "int8 quantization only supports block or channel-wise"
A, A_scale = per_token_quant_int8(A) if expert_num_tokens is None:
A, A_scale = per_token_quant_int8(A)
else:
A, A_scale = per_token_quant_int8_triton_opt(A, expert_num_tokens)
else: else:
assert not per_act_token assert not per_act_token
assert len(block_shape) == 2 assert len(block_shape) == 2
......
...@@ -42,6 +42,7 @@ QuantizationMethods = Literal[ ...@@ -42,6 +42,7 @@ QuantizationMethods = Literal[
"blockwise_int8", "blockwise_int8",
"slimquant_w4a8", "slimquant_w4a8",
"slimquant_w4a8_marlin", "slimquant_w4a8_marlin",
"slimquant_marlin",
"slimquant_compressed_tensors_marlin", "slimquant_compressed_tensors_marlin",
] ]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
...@@ -192,6 +193,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -192,6 +193,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"blockwise_int8": BlockInt8Config, "blockwise_int8": BlockInt8Config,
"slimquant_w4a8":SlimQuantW4A8Int8Config, "slimquant_w4a8":SlimQuantW4A8Int8Config,
"slimquant_w4a8_marlin":SlimQuantW4A8Int8MarlinConfig, "slimquant_w4a8_marlin":SlimQuantW4A8Int8MarlinConfig,
"slimquant_marlin":SlimQuantCompressedTensorsMarlinConfig,
"slimquant_compressed_tensors_marlin":SlimQuantCompressedTensorsMarlinConfig, "slimquant_compressed_tensors_marlin":SlimQuantCompressedTensorsMarlinConfig,
} }
# Update the `method_to_config` with customized quantization methods. # Update the `method_to_config` with customized quantization methods.
......
...@@ -782,15 +782,13 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -782,15 +782,13 @@ class AWQMarlinMoEMethod(FusedMoEMethodBase):
router: FusedMoERouter, router: FusedMoERouter,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert layer.activation == "silu", "Only SiLU activation is supported." assert layer.activation == "silu", "Only SiLU activation is supported."
topk_weights, topk_ids = router.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_fused_gate=use_fused_gate,
) )
return fused_marlin_moe( return fused_marlin_moe(
......
...@@ -471,7 +471,6 @@ class BlockInt8MoEMethod: ...@@ -471,7 +471,6 @@ class BlockInt8MoEMethod:
enable_eplb: bool = False, enable_eplb: bool = False,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
**_ **_
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -491,7 +490,6 @@ class BlockInt8MoEMethod: ...@@ -491,7 +490,6 @@ class BlockInt8MoEMethod:
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate
) )
# Expert fusion with INT8 quantization # Expert fusion with INT8 quantization
......
...@@ -979,12 +979,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -979,12 +979,10 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
use_nn_moe: bool | None = False, use_nn_moe: bool | None = False,
use_fused_gate: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = router.select_experts( topk_weights, topk_ids = router.select_experts(
hidden_states=x, hidden_states=x,
router_logits=router_logits, router_logits=router_logits,
use_fused_gate=use_fused_gate,
) )
assert self.kernel is not None assert self.kernel is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment