Commit 7d4db7e8 authored by yangql's avatar yangql
Browse files

Merge branch 'v0.9.2-dev-ds' into v0.9.2-dev-ds-yql_auto

parents 8943d3db d5b6456a
......@@ -4783,6 +4783,7 @@ class VllmConfig:
mtp_batch_size_capture_list = list(map(lambda x: x * (1 + self.speculative_config.num_lookahead_slots),
batch_size_capture_list))
batch_size_capture_list = sorted(set(batch_size_capture_list + mtp_batch_size_capture_list))
batch_size_capture_list = [i for i in batch_size_capture_list if i == 1 or i % (1 + self.speculative_config.num_lookahead_slots) == 0]
self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)
......
......@@ -140,7 +140,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
# This is the DeepEP default. Stick to it till we can establish
# reasonable defaults based on profiling.
self.num_sms = 24#20
self.num_sms = 30
def get_handle(self, kwargs):
raise NotImplementedError
......
......@@ -298,15 +298,6 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) -> Callable | None:
assert self.handle is not None
# fused_expert_output can have 0 tokens - This happens when none of the
# tokens from the all2all reach this EP rank.
# if fused_expert_output.numel() != 0 and apply_weights_and_reduce:
# fused_expert_output = self._apply_weights_and_reduce(
# num_tokens=topk_ids.size(0),
# fused_expert_output=fused_expert_output,
# topk_weights=topk_weights,
# apply_router_weight_on_input=apply_router_weight_on_input,
# output_dtype=output.dtype)
combined_x, _, event = self.buffer.combine(
# HT combine only supports BF16
x=fused_expert_output,
......
......@@ -74,18 +74,19 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def _do_quant(
self,
x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
a1_dtype: torch.dtype,
quant_dtype: Optional[torch.dtype],
per_act_token_quant: bool,
block_shape: Optional[list[int]],
expert_num_tokens: Optional[torch.Tensor] = None,
quant_config: FusedMoEQuantConfig,
expert_num_tokens: Optional[torch.Tensor]= None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
block_k = block_shape[1] if block_shape is not None else None
if self.use_fp8_dispatch:
block_k = (
quant_config.block_shape[1]
if quant_config.block_shape is not None
else None
)
if block_k == DEEPEP_QUANT_BLOCK_SIZE:
# DeepEP kernels did the quantization for us.
x, x_scales = x
......@@ -102,14 +103,17 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# TODO (varun): Optimization - Use a batched version of quant
if expert_num_tokens is None:
x = x.view((-1, hidden_dim))
x, x_scales = moe_kernel_quantize_input(x, a1_scale, quant_dtype,
per_act_token_quant,
block_shape, expert_num_tokens)
if expert_num_tokens is None:
x, x_scales = moe_kernel_quantize_input(
x,
a1_scale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
expert_num_tokens
)
x = x.view((num_experts, -1, hidden_dim))
if quant_dtype is not None:
if quant_config.quant_dtype is not None:
assert x_scales is not None
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
......@@ -151,7 +155,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1 = a1 * topk_weights.to(a1.dtype)
# Dispatch
expert_x, expert_num_tokens, self.handles, _, hook = self.buffer.low_latency_dispatch(
expert_x, expert_num_tokens, self.handle, _, hook = self.buffer.low_latency_dispatch(
a1,
topk_ids,
self.max_tokens_per_rank,
......@@ -181,7 +185,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1_dtype: torch.dtype,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
expert_x, expert_x_scale = self._do_quant(expert_x, a1_dtype, quant_config)
expert_x, expert_x_scale = self._do_quant(expert_x, a1_scale, a1_dtype, quant_config, expert_num_tokens)
expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None
......
......@@ -12,7 +12,7 @@ import torch
import vllm.envs as envs
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.utils import cdiv, async_tensor_h2d
from vllm.utils import cdiv
#
# This file defines a set of base classes used to make MoE kernels more modular.
......@@ -112,9 +112,6 @@ class ExpertTokensMetadata:
def make_from_list(
expert_num_tokens_list: list[int], device: str
) -> "ExpertTokensMetadata":
# expert_num_tokens_cpu = torch.tensor(
# expert_num_tokens_list, device="cpu", dtype=torch.int32
# )
expert_num_tokens_cpu = torch.tensor(
expert_num_tokens_list, device="cpu", dtype=torch.int32, pin_memory=True
)
......@@ -813,21 +810,6 @@ class FusedMoEModularKernel(torch.nn.Module):
return output
_aux_stream: torch.cuda.Stream | None = None
def aux_stream() -> torch.cuda.Stream | None:
"""
Ensures aux_stream is initialized only once
"""
global _aux_stream
# TODO: validate this works properly on ROCm platform.
if _aux_stream is None:
_aux_stream = torch.cuda.Stream()
return _aux_stream
@final
class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
"""
......@@ -853,10 +835,6 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
self.fused_experts = fused_experts
self.shared_experts = shared_experts
if self.shared_experts is not None:
self.shared_experts_stream = aux_stream()
self.shared_experts_overlap_event = torch.cuda.Event()
# assert prepare_finalize.activation_format == \
# fused_experts.activation_formats[0], (
# f"{prepare_finalize.__class__.__name__}."
......@@ -933,18 +911,6 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
if global_num_experts == -1:
global_num_experts = local_num_experts
# (a1q, a1q_scale, expert_num_tokens, _expert_topk_ids,
# _expert_topk_weights) = self.prepare_finalize.prepare(
# a1,
# a1_scale,
# a2_scale,
# topk_weights,
# topk_ids,
# global_num_experts,
# expert_map,
# apply_router_weight_on_input,
# self.fused_experts.quant_config,
# )
prepare_ret = self.prepare_finalize.prepare_async(
a1,
a1_scale,
......
......@@ -812,14 +812,8 @@ def deepgemm_moe_permute(
aq_scale_out = torch.empty(
(M_sum, aq_scale.shape[-1]), device=device, dtype=torch.float32
#(M_sum, H // block_k), device=device, dtype=torch.float32
)
# maybe_has_empty_blocks = expert_num_tokens_cpu is None
# expert_ids_init = torch.zeros# if maybe_has_empty_blocks else torch.empty
# expert_ids = expert_ids_init((M_sum), device=device, dtype=torch.int32)
expert_ids = torch.full(
(M_sum,), -1, dtype=torch.int32, device=device
)
......
......@@ -93,8 +93,7 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_auto")
#self.use_deepep_ll = self.use_deepep and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency"
self.use_deepgemm = False
if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
......@@ -334,6 +333,12 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
mm1_out = _resize_cache(workspace13, (M_sum, N))
mm2_out = _resize_cache(workspace2, (M_sum, K))
# act_out = _resize_cache(workspace2.view(dtype=torch.int8), (M_sum, N // 2))
# act_out = _resize_cache(
# workspace13.view(dtype=torch.int8), (M_sum, N // 2)
# )
fused_out = _resize_cache(workspace13, fused_out_shape)
a1q_perm = _resize_cache(workspace2.view(dtype=a1q.dtype), (M_sum, K))
......@@ -351,18 +356,12 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
M_sum=M_sum
)
# if expert_map is not None:
# # DeepGemm (Grouped Contiguous) kernel needs a valid B index
# # for all rows of A. To that effect, simply compute with
# # the 0th weight matrix.
# # Note that this relies on the fact that corresponding topk
# # weights would be 0 during weight multiplication.
# expert_ids = torch.where(expert_ids == -1, 0, expert_ids)
m_grouped_w8a8_gemm_nt_contig_asm(
(a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids)
#a2q, a2q_scale = fuse_silu_mul_quant(mm1_out, expert_ids=expert_ids)
a2q, a2q_scale = fuse_silu_mul_quant(mm1_out)
#a2q, a2q_scale = fuse_silu_mul_quant(input=mm1_out, output=act_out, expert_ids=expert_ids)
m_grouped_w8a8_gemm_nt_contig_asm(
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment