Commit f687d53c authored by 王敏's avatar 王敏
Browse files

[fix]解决moe_fused_gate编译错误,去掉mla中mtp部分的修改

parent e0ba5f60
......@@ -852,8 +852,8 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
set(VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp"
"csrc/moe/moe_align_sum_kernels.cu"
"csrc/moe/topk_softmax_kernels.cu")
# "csrc/moe/moe_fused_gate.cu"
"csrc/moe/topk_softmax_kernels.cu"
"csrc/moe/moe_fused_gate.cu")
if(VLLM_GPU_LANG STREQUAL "CUDA")
......
......@@ -69,7 +69,7 @@ __device__ inline bool cmp_eq(const T& a, const T& b) {
}
// Fixed constants common to both dynamic and static template versions:
//static constexpr int WARP_SIZE = 32;
static constexpr int SIZE_WARP = 32;
static constexpr int WARPS_PER_CTA = 6;
static constexpr int MAX_VPT = 32; // maximum VPT we support, > params.VPT = num_expert / num_expert_group
......@@ -331,10 +331,10 @@ __global__ void moe_fused_gate_kernel(
// Macro to compute compile-time constants and launch the kernel.
#define LAUNCH_MOE_GATE_CONFIG(T, EXPERTS, EXPERT_GROUP) \
do { \
constexpr int VPT = (EXPERTS) / (EXPERT_GROUP); \
int VPT = (EXPERTS) / (EXPERT_GROUP); \
/* If EXPERT_GROUP > WARP_SIZE, fall back to 1 row per warp */ \
constexpr int ROWS_PER_WARP = ((EXPERT_GROUP) <= WARP_SIZE) ? (WARP_SIZE / (EXPERT_GROUP)) : 1; \
constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; \
int ROWS_PER_WARP = ((EXPERT_GROUP) <= SIZE_WARP) ? (SIZE_WARP / (EXPERT_GROUP)) : 1; \
int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; \
moe_fused_gate_kernel<T, VPT, (EXPERTS), (EXPERT_GROUP), ROWS_PER_WARP, ROWS_PER_CTA, WARPS_PER_CTA> \
<<<num_blocks, block_dim, 0, stream>>>( \
input.data_ptr(), \
......@@ -379,7 +379,7 @@ __global__ void moe_fused_gate_kernel_dynamic(
params.VPT = num_experts / num_expert_group; // e.g., for deepseek v3, this is 256 / 8 = 32
params.THREADS_PER_ROW = num_expert_group; // fixed as num_expert_group, e.g., for deepseek v3, this is 8
params.WARPS_PER_CTA = WARPS_PER_CTA; // fixed as 6
params.ROWS_PER_WARP = std::max<int64_t>(1, WARP_SIZE / num_expert_group); // WARP_SIZE is fixed as 32
params.ROWS_PER_WARP = std::max<int64_t>(1, SIZE_WARP / num_expert_group); // WARP_SIZE is fixed as 32
params.ROWS_PER_CTA = params.WARPS_PER_CTA * params.ROWS_PER_WARP;
moe_fused_gate_impl<T>(
......@@ -413,11 +413,11 @@ std::vector<at::Tensor> moe_fused_gate(
auto indices = torch::empty({num_rows, topk}, options.dtype(torch::kInt32));
// Compute grid dimensions based on runtime value for num_expert_group.
int64_t rows_per_warp = std::max<int64_t>(1, WARP_SIZE / num_expert_group);
int64_t rows_per_warp = std::max<int64_t>(1, SIZE_WARP / num_expert_group);
int64_t num_warps = (num_rows + rows_per_warp - 1) / rows_per_warp;
int64_t num_blocks = (num_warps + WARPS_PER_CTA - 1) / WARPS_PER_CTA;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
dim3 block_dim(WARP_SIZE, WARPS_PER_CTA);
dim3 block_dim(SIZE_WARP, WARPS_PER_CTA);
// Check 1: Ensure that num_experts is a power of 2.
TORCH_CHECK((num_experts & (num_experts - 1)) == 0, "num_experts must be a power of 2, but got ", num_experts);
......
......@@ -36,11 +36,11 @@ void shuffle_rows(const torch::Tensor& input_tensor,
const torch::Tensor& dst2src_map,
torch::Tensor& output_tensor);
// std::vector<torch::Tensor> moe_fused_gate(
// torch::Tensor& input,
// torch::Tensor& bias,
// int64_t num_expert_group,
// int64_t topk_group,
// int64_t topk,
// int64_t n_share_experts_fusion,
// double routed_scaling_factor);
std::vector<torch::Tensor> moe_fused_gate(
torch::Tensor& input,
torch::Tensor& bias,
int64_t num_expert_group,
int64_t topk_group,
int64_t topk,
int64_t n_share_experts_fusion,
double routed_scaling_factor);
......@@ -25,11 +25,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor! num_tokens_post_pad) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
// m.def(
// "moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
// "n_share_experts_fusion, float routed_scaling_factor) -> "
// "(Tensor[])");
// m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
m.def(
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk, int "
"n_share_experts_fusion, float routed_scaling_factor) -> "
"(Tensor[])");
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
#ifndef USE_ROCM
......
......@@ -2329,51 +2329,51 @@ def flash_mla_with_kvcache(
# def moe_fused_gate(
# input_tensor,
# bias,
# num_expert_group,
# topk_group,
# topk,
# n_share_experts_fusion=0,
# routed_scaling_factor=0,
# ):
# # This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# # it split group of expert into num_expert_group, and use top2 expert weight sum in each group
# # as the group weight to select exerpt groups and then select topk experts within the selected groups
# # the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# # and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now.
# # for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# # n_share_experts_fusion: if > 0, the last expert will be replaced with a round-robin shared expert
# # routed_scaling_factor: if > 0, the last expert will be scaled by this factor
# return torch.ops._moe_C.moe_fused_gate(
# input_tensor,
# bias,
# num_expert_group,
# topk_group,
# topk,
# n_share_experts_fusion,
# routed_scaling_factor,
# )
# if hasattr(torch.ops._moe_C, "moe_fused_gate"):
# @register_fake("_moe_C::moe_fused_gate")
# def moe_fused_gate_fake(
# input_tensor: torch.Tensor,
# bias: torch.Tensor,
# num_expert_group: int,
# topk_group: int,
# topk: int,
# n_share_experts_fusion: int,
# routed_scaling_factor: int,
# ):
# return torch.empty((input_tensor.size(0), topk),
# dtype=input_tensor.dtype,
# device=input_tensor.device), \
# torch.empty((input_tensor.size(0), topk),
# dtype=input_tensor.dtype,
# device=input_tensor.device)
def moe_fused_gate(
input_tensor,
bias,
num_expert_group,
topk_group,
topk,
n_share_experts_fusion=0,
routed_scaling_factor=0,
):
# This fused kernel function is used to select topk expert in a hierarchical 2-layer fashion
# it split group of expert into num_expert_group, and use top2 expert weight sum in each group
# as the group weight to select exerpt groups and then select topk experts within the selected groups
# the #experts is decided by the input tensor shape and we currently only support power of 2 #experts
# and #experts should be divisible by num_expert_group. #expert/num_expert_group <= 32 is limitted for now.
# for non-supported case, we suggestion to use the biased_grouped_topk func in sglang.srt.layers.moe.topk
# n_share_experts_fusion: if > 0, the last expert will be replaced with a round-robin shared expert
# routed_scaling_factor: if > 0, the last expert will be scaled by this factor
return torch.ops._moe_C.moe_fused_gate(
input_tensor,
bias,
num_expert_group,
topk_group,
topk,
n_share_experts_fusion,
routed_scaling_factor,
)
if hasattr(torch.ops._moe_C, "moe_fused_gate"):
@register_fake("_moe_C::moe_fused_gate")
def moe_fused_gate_fake(
input_tensor: torch.Tensor,
bias: torch.Tensor,
num_expert_group: int,
topk_group: int,
topk: int,
n_share_experts_fusion: int,
routed_scaling_factor: int,
):
return torch.empty((input_tensor.size(0), topk),
dtype=input_tensor.dtype,
device=input_tensor.device), \
torch.empty((input_tensor.size(0), topk),
dtype=input_tensor.dtype,
device=input_tensor.device)
def sm100_cutlass_mla_decode(out: torch.Tensor, lse: torch.Tensor,
......
......@@ -51,10 +51,6 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
......@@ -65,8 +61,7 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
config.hidden_size,
bias=False)
# self.is_v32 = hasattr(config, "index_topk")
self.is_v32 = False
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
topk_tokens = config.index_topk
topk_indices_buffer = torch.empty(
......@@ -88,8 +83,6 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds[positions == 0] = 0
......@@ -121,7 +114,10 @@ class DeepSeekMultiTokenPredictor(nn.Module):
for idx in range(self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers)
})
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.logits_processor = LogitsProcessor(config.vocab_size)
def forward(
......@@ -132,6 +128,8 @@ class DeepSeekMultiTokenPredictor(nn.Module):
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = (spec_step_idx % self.num_mtp_layers)
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids,
......@@ -325,13 +323,20 @@ class DeepSeekMTP(nn.Module, SupportsPP):
spec_layer_weight_names = [
"embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
]
shared_weight_names = ["embed_tokens"]
spec_layer_weight = False
shared_weight = False
for weight_name in spec_layer_weight_names:
if weight_name in name:
spec_layer_weight = True
if weight_name in shared_weight_names:
shared_weight = True
break
if not spec_layer_weight:
# treat rest weights as weights for transformer layer block
name = name.replace(f"model.layers.{spec_layer}.",
f"model.layers.{spec_layer}.mtp_block.")
elif shared_weight:
# treat shared weights as top level weights
name = name.replace(f"model.layers.{spec_layer}.", "model.")
return name
......@@ -558,14 +558,6 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
device=device,
)
self.block_table = block_table
self.use_spec_decode = False
# support for cudagraph spec docoding
self.spec_decode_block_table_tensor = None
self.spec_decode_seq_lens = None
def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
qo_indptr = prefill.query_start_loc
......@@ -659,31 +651,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
Currently, only decode is supported for full cudagraphs with MLA.
"""
m = common_attn_metadata
# assert m.num_reqs <= (m.num_actual_tokens *
# self.reorder_batch_threshold), \
# "MLA only supports decode-only full CUDAGraph capture. " \
# "Make sure all cudagraph capture sizes <= max_num_seq."
# assert m.max_query_len <= self.reorder_batch_threshold # decode only
self.use_spec_decode = m.num_speculative_tokens > 0
# support for cudagraph spec docoding
if self.use_spec_decode:
for i in range(m.num_reqs):
self.num_scheduled_tokens_np[i] = m.num_actual_tokens // m.num_reqs
if self.spec_decode_block_table_tensor is None:
max_num_reqs = m.seq_lens.shape[0]
block_table_tensor = self.block_table.get_device_tensor()
tokens_per_seq = 1+m.num_speculative_tokens
self.spec_decode_block_table_tensor = torch.zeros((block_table_tensor.shape[0]*tokens_per_seq,
block_table_tensor.shape[1]),
dtype=block_table_tensor.dtype,
device=m.seq_lens.device)
self.spec_decode_seq_lens = torch.zeros(max_num_reqs * tokens_per_seq,
dtype=m.seq_lens.dtype,
device=m.seq_lens.device)
assert m.num_reqs <= (m.num_actual_tokens *
self.reorder_batch_threshold), \
"MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
assert m.max_query_len <= self.reorder_batch_threshold # decode only
return self.build(0, m)
def build(self,
......@@ -699,15 +673,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.device
block_table = self.block_table
block_table_tensor = common_attn_metadata.block_table_tensor
slot_mapping = common_attn_metadata.slot_mapping
if slot_mapping is None:
block_table.slot_mapping[:num_tokens].copy_(
block_table.slot_mapping_cpu[:num_tokens],
non_blocking=True)
block_table.slot_mapping[num_tokens:].fill_(-1)
slot_mapping = block_table.slot_mapping[:num_tokens]
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
......@@ -873,57 +840,8 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
- prefill_query_start_loc[:-1]
prefill_metadata.cudnn_workspace = self.cudnn_workspace
# TODO @ wangming
decode_metadata = None
# if num_decodes > 0:
# if self.use_spec_decode and not common_attn_metadata.spec_layer_decoding:
# query_lens = self.num_scheduled_tokens_np[:num_decodes]
# cu_num_blocks = np.cumsum(query_lens)
# virtual_batches = cu_num_blocks[-1]
# block_offsets = np.repeat(cu_num_blocks - query_lens, query_lens)
# arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets
# rarange = np.repeat(query_lens, query_lens) - arange - 1
# repeats = torch.from_numpy(query_lens).pin_memory().to(
# block_table_tensor.device, non_blocking=True).contiguous()
# decode_block_table_tensor = torch.repeat_interleave(
# block_table_tensor[:self._num_decodes, ...],
# repeats, dim=0).contiguous()
# decode_seq_lens = torch.repeat_interleave(seq_lens[:self._num_decodes], repeats, dim=0).contiguous()
# seq_lens_minus = torch.from_numpy(rarange).to(torch.int32).pin_memory().to(
# seq_lens.device, non_blocking=True).contiguous()
# decode_seq_lens = decode_seq_lens - seq_lens_minus
# if self.spec_decode_block_table_tensor is not None:
# self.spec_decode_block_table_tensor[:self._num_decode_tokens].copy_(decode_block_table_tensor)
# self.spec_decode_seq_lens[:self._num_decode_tokens].copy_(decode_seq_lens)
# decode_metadata = self._build_decode(
# block_table_tensor=self.spec_decode_block_table_tensor[:self._num_decode_tokens, ...],
# seq_lens=self.spec_decode_seq_lens[:self._num_decode_tokens],
# )
# else:
# decode_metadata = self._build_decode(
# block_table_tensor=decode_block_table_tensor,
# seq_lens=decode_seq_lens,
# )
# else:
# self._num_decode_tokens = num_decodes
# if self.use_spec_decode and self.spec_decode_block_table_tensor is not None:
# self.spec_decode_block_table_tensor[:self._num_decode_tokens].copy_(block_table_tensor[:self._num_decode_tokens, ...])
# self.spec_decode_seq_lens[:self._num_decode_tokens].copy_(seq_lens[:self._num_decode_tokens])
# decode_metadata = self._build_decode(
# block_table_tensor=self.spec_decode_block_table_tensor[:self._num_decode_tokens, ...],
# seq_lens=self.spec_decode_seq_lens[:self._num_decode_tokens],
# )
# else:
# decode_metadata = self._build_decode(
# block_table_tensor=block_table_tensor[:self._num_decode_tokens, ...],
# seq_lens=seq_lens[:self._num_decode_tokens],
# )
if num_decodes > 0:
decode_metadata = self._build_decode(
block_table_tensor=block_table_tensor[:num_decodes, ...],
seq_lens_cpu=seq_lens_cpu[:num_decodes],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment