import logging import torch import triton from sglang.srt.utils import ceil_div, is_cuda logger = logging.getLogger(__name__) _is_cuda = is_cuda() if _is_cuda: from sglang.srt.layers.quantization.fp8_kernel import ( sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8, ) import triton.language as tl @triton.jit def deepep_permute_triton_kernel( input_ptr, gateup_input_ptr, src2dst_ptr, topk_ids_ptr, a1_scales_ptr, topk, hidden_size, BLOCK_SIZE: tl.constexpr, ): OutDtype = gateup_input_ptr.dtype.element_ty src_idx = tl.program_id(0) src2dst_ptr = src2dst_ptr + src_idx * topk topk_ids_ptr = topk_ids_ptr + src_idx * topk src_ptr = input_ptr + src_idx * hidden_size for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): offset = start_offset + tl.arange(0, BLOCK_SIZE) mask = offset < hidden_size in_data = tl.load(src_ptr + offset, mask=mask).to(OutDtype) for idx in range(topk): dst_idx = tl.load(src2dst_ptr + idx) if dst_idx >= 0: dst_ptr = gateup_input_ptr + dst_idx * hidden_size tl.store(dst_ptr + offset, in_data, mask=mask) @triton.jit def deepep_post_reorder_triton_kernel( down_output_ptr, output_ptr, src2dst_ptr, topk_ids_ptr, topk_weights_ptr, topk, hidden_size, BLOCK_SIZE: tl.constexpr, ): InDtype = down_output_ptr.dtype.element_ty src_idx = tl.program_id(0) src2dst_ptr = src2dst_ptr + src_idx * topk topk_ids_ptr = topk_ids_ptr + src_idx * topk topk_weights_ptr = topk_weights_ptr + src_idx * topk store_ptr = output_ptr + src_idx * hidden_size for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): offset = start_offset + tl.arange(0, BLOCK_SIZE) mask = offset < hidden_size sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype) for idx in range(topk): dst_idx = tl.load(src2dst_ptr + idx) if dst_idx >= 0: weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) load_ptr = down_output_ptr + dst_idx * hidden_size in_data = tl.load(load_ptr + offset, mask=mask) sum_vec += in_data * weigh_scale tl.store(store_ptr + offset, sum_vec, mask=mask) @triton.jit def compute_src2dst_triton_kernel( reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr ): pid = tl.program_id(axis=0) dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = dst_id < num_toks src_id = tl.load(reorder_ids + dst_id, mask=mask) tl.store(src2dst + src_id, dst_id, mask=mask) @triton.jit def deepep_compute_src2dst_triton_kernel( reorder_ids, src2dst, num_toks, num_minus_one, BLOCK_SIZE: tl.constexpr ): pid = tl.program_id(axis=0) dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = dst_id < num_toks src_id = tl.load(reorder_ids + dst_id, mask=mask) num_invalid = tl.load(num_minus_one) tl.store(src2dst + src_id, dst_id - num_invalid, mask=mask) def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int): reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) seg_indptr = torch.empty(num_experts + 1, device=topk_ids.device, dtype=torch.int64) src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int64) # Find offset expert_ids = torch.arange( num_experts + 1, device=topk_ids.device, dtype=reorder_topk_ids.dtype ) torch.searchsorted(reorder_topk_ids, expert_ids, out=seg_indptr) num_minus_one = seg_indptr[0] seg_indptr = seg_indptr - num_minus_one BLOCK_SIZE = 512 grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),) deepep_compute_src2dst_triton_kernel[grid]( reorder_ids, src2dst, topk_ids.numel(), num_minus_one, BLOCK_SIZE ) reorder_topk_ids = reorder_topk_ids[num_minus_one:] return reorder_topk_ids, src2dst, seg_indptr @triton.jit def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks): expert_id_minus_1 = tl.program_id(0) - 1 low = 0 high = num_toks - 1 target_location = -1 while low <= high: mid = (low + high) // 2 if tl.load(reorder_topk_ids + mid) > expert_id_minus_1: high = mid - 1 else: low = mid + 1 target_location = mid tl.store(seg_indptr + expert_id_minus_1 + 1, target_location + 1) def run_moe_ep_preproess(topk_ids: torch.Tensor, num_local_experts: int): reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) seg_indptr = torch.zeros( num_local_experts + 1, device=topk_ids.device, dtype=torch.int64 ) src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32) compute_seg_indptr_triton_kernel[(num_local_experts,)]( reorder_topk_ids, seg_indptr, topk_ids.numel() ) BLOCK_SIZE = 512 grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),) compute_src2dst_triton_kernel[grid]( reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE ) return reorder_topk_ids, src2dst, seg_indptr @triton.jit def pre_reorder_triton_kernel_for_cutlass_moe( input_ptr, gateup_input_ptr, src2dst_ptr, topk_ids_ptr, a1_scales_ptr, num_local_experts, topk, hidden_size, BLOCK_SIZE: tl.constexpr, ): OutDtype = gateup_input_ptr.dtype.element_ty src_idx_int32 = tl.program_id(0) src_idx = src_idx_int32.to(tl.int64) src2dst_ptr = src2dst_ptr + src_idx * topk topk_ids_ptr = topk_ids_ptr + src_idx * topk src_ptr = input_ptr + src_idx * hidden_size vec = tl.arange(0, BLOCK_SIZE) for idx in range(topk): expert_id = tl.load(topk_ids_ptr + idx) if expert_id != num_local_experts: if a1_scales_ptr is not None: scale = 1.0 / tl.load(a1_scales_ptr) else: scale = 1.0 dst_idx_int32 = tl.load(src2dst_ptr + idx) dst_idx = dst_idx_int32.to(tl.int64) dst_ptr = gateup_input_ptr + dst_idx * hidden_size for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): offset = start_offset + vec mask = offset < hidden_size in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32) out_data = (in_data * scale).to(OutDtype) tl.store(dst_ptr + offset, out_data, mask=mask) # copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py @triton.jit def _silu_and_mul_post_quant_kernel( input_ptr, stride_input_0, stride_input_1, stride_input_2, output_ptr, stride_output_0, stride_output_1, stride_output_2, output_scale_ptr, stride_output_scale_0, stride_output_scale_1, stride_output_scale_2, masked_m_ptr, size_n, fp8_max, fp8_min, BLOCK_N: tl.constexpr, NUM_STAGE: tl.constexpr, SCALE_UE8M0: tl.constexpr, ): expert_id = tl.program_id(2) token_id = tl.program_id(1) hidden_dim_block_index = tl.program_id(0) block_num_per_expert = tl.num_programs(1) token_num_cur_expert = tl.load(masked_m_ptr + expert_id) stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64) stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64) stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64) stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64) offs_in_d = hidden_dim_block_index * BLOCK_N + tl.arange(0, BLOCK_N) input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d output_scale_offs = ( output_scale_ptr + expert_id * stride_output_scale_0 + hidden_dim_block_index * stride_output_scale_2 ) for token_index in tl.range( token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE ): gate = tl.load( input_ptr_offs + token_index * stride_input_1, mask=offs_in_d < size_n, other=0.0, ).to(tl.float32) up = tl.load( input_ptr_offs + token_index * stride_input_1 + size_n, mask=offs_in_d < size_n, other=0.0, ) gate = gate / (1 + tl.exp(-gate)) gate = gate.to(input_ptr.dtype.element_ty) gate_up = up * gate _absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10) output_s = _absmax / fp8_max if SCALE_UE8M0: output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s)))) output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to( output_ptr.dtype.element_ty ) tl.store( output_ptr_offs + token_index * stride_output_1, output_q, mask=offs_in_d < size_n, ) tl.store( output_scale_offs + token_index * stride_output_scale_1, output_s, ) def silu_and_mul_masked_post_quant_fwd( input: torch.Tensor, output: torch.Tensor, output_scale: torch.Tensor, quant_group_size: int, masked_m: torch.Tensor, scale_ue8m0: bool = False, ): """ input shape [expert_num, token_num_padded, hidden_dim] output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8 output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32 quant_group_size int, masked_m shape [expert_num], """ assert input.is_contiguous() assert output.dtype == torch.float8_e4m3fn assert output.is_contiguous() assert len(input.shape) == 3 assert input.shape[0] == masked_m.shape[0] assert input.shape[-1] % 2 == 0 size_n = input.shape[-1] // 2 assert size_n % quant_group_size == 0 expert_num = len(masked_m) if expert_num < 4: BLOCK_NUM_PER_EXPERT = 64 else: BLOCK_NUM_PER_EXPERT = 32 BLOCK_N = quant_group_size num_warps = 1 NUM_STAGES = 6 hidden_dim_split_block_num = triton.cdiv(size_n, BLOCK_N) assert BLOCK_N % quant_group_size == 0 grid = ( hidden_dim_split_block_num, BLOCK_NUM_PER_EXPERT, expert_num, ) finfo = torch.finfo(torch.float8_e4m3fn) fp8_max = finfo.max fp8_min = -fp8_max _silu_and_mul_post_quant_kernel[grid]( input, *input.stride(), output, *output.stride(), output_scale, *output_scale.stride(), masked_m, size_n, fp8_max, fp8_min, BLOCK_N=BLOCK_N, NUM_STAGE=NUM_STAGES, num_warps=num_warps, SCALE_UE8M0=scale_ue8m0, ) return @triton.jit def post_reorder_triton_kernel_for_cutlass_moe( down_output_ptr, output_ptr, src2dst_ptr, topk_ids_ptr, topk_weights_ptr, topk, num_local_experts, hidden_size, BLOCK_SIZE: tl.constexpr, ): InDtype = down_output_ptr.dtype.element_ty src_idx_int32 = tl.program_id(0) src_idx = src_idx_int32.to(tl.int64) src2dst_ptr = src2dst_ptr + src_idx * topk topk_ids_ptr = topk_ids_ptr + src_idx * topk topk_weights_ptr = topk_weights_ptr + src_idx * topk store_ptr = output_ptr + src_idx * hidden_size vec = tl.arange(0, BLOCK_SIZE) for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): offset = start_offset + vec mask = offset < hidden_size sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype) for idx in range(topk): expert_id = tl.load(topk_ids_ptr + idx) if expert_id != num_local_experts: dst_idx_int32 = tl.load(src2dst_ptr + idx) dst_idx = dst_idx_int32.to(tl.int64) weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) load_ptr = down_output_ptr + dst_idx * hidden_size in_data = tl.load(load_ptr + offset, mask=mask) sum_vec += in_data * weigh_scale tl.store(store_ptr + offset, sum_vec, mask=mask) @triton.jit def post_reorder_triton_kernel( down_output_ptr, output_ptr, src2dst_ptr, topk_ids_ptr, topk_weights_ptr, topk, hidden_size, BLOCK_SIZE: tl.constexpr, ): InDtype = down_output_ptr.dtype.element_ty src_idx_int32 = tl.program_id(0) src_idx = src_idx_int32.to(tl.int64) src2dst_ptr = src2dst_ptr + src_idx * topk topk_ids_ptr = topk_ids_ptr + src_idx * topk topk_weights_ptr = topk_weights_ptr + src_idx * topk store_ptr = output_ptr + src_idx * hidden_size vec = tl.arange(0, BLOCK_SIZE) for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): offset = start_offset + vec mask = offset < hidden_size sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype) for idx in range(topk): expert_id = tl.load(topk_ids_ptr + idx) if expert_id > 0: dst_idx_int32 = tl.load(src2dst_ptr + idx) dst_idx = dst_idx_int32.to(tl.int64) weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) load_ptr = down_output_ptr + dst_idx * hidden_size in_data = tl.load(load_ptr + offset, mask=mask) sum_vec += in_data * weigh_scale tl.store(store_ptr + offset, sum_vec, mask=mask) @triton.jit def _fwd_kernel_ep_scatter_1( num_recv_tokens_per_expert, expert_start_loc, m_indices, num_experts: tl.constexpr, BLOCK_E: tl.constexpr, BLOCK_EXPERT_NUM: tl.constexpr, ): cur_expert = tl.program_id(0) offset_cumsum = tl.arange(0, BLOCK_EXPERT_NUM) tokens_per_expert = tl.load( num_recv_tokens_per_expert + offset_cumsum, mask=offset_cumsum < num_experts, other=0, ) cumsum = tl.cumsum(tokens_per_expert) - tokens_per_expert tl.store(expert_start_loc + offset_cumsum, cumsum, mask=offset_cumsum < num_experts) cur_expert_start = tl.load(expert_start_loc + cur_expert) cur_expert_token_num = tl.load(num_recv_tokens_per_expert + cur_expert) m_indices_start_ptr = m_indices + cur_expert_start off_expert = tl.arange(0, BLOCK_E) for start_m in tl.range(0, cur_expert_token_num, BLOCK_E, num_stages=4): tl.store( m_indices_start_ptr + start_m + off_expert, cur_expert, ) @triton.jit def _fwd_kernel_ep_scatter_2( total_token_num, expert_start_loc, recv_x, recv_x_stride0, recv_x_stride1, recv_x_scale, recv_x_scale_stride0, recv_x_scale_stride1, recv_topk, recv_topk_stride0, recv_topk_stride1, output_tensor, output_tensor_stride0, output_tensor_stride1, output_tensor_scale, output_tensor_scale_stride0, output_tensor_scale_stride1, output_index, output_index_stride0, output_index_stride1, topk_num: tl.constexpr, HIDDEN_SIZE: tl.constexpr, HIDDEN_SIZE_PAD: tl.constexpr, SCALE_HIDDEN_SIZE: tl.constexpr, SCALE_HIDDEN_SIZE_PAD: tl.constexpr, ): start_token_id = tl.program_id(0) grid_num = tl.num_programs(0) offset_in = tl.arange(0, HIDDEN_SIZE_PAD) mask = offset_in < HIDDEN_SIZE index_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD) mask_s = index_in_s < SCALE_HIDDEN_SIZE for token_id_int32 in range(start_token_id, total_token_num, grid_num): token_id = token_id_int32.to(tl.int64) to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask) to_copy_s = tl.load( recv_x_scale + token_id * recv_x_scale_stride0 + index_in_s * recv_x_scale_stride1, mask=mask_s, ) for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4): topk_index = topk_idx_int32.to(tl.int64) expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index) if expert_id >= 0: dest_token_index_int32 = tl.atomic_add(expert_start_loc + expert_id, 1) dest_token_index = dest_token_index_int32.to(tl.int64) tl.store( output_index + token_id * output_index_stride0 + topk_index, dest_token_index_int32, ) output_tensor_ptr = ( output_tensor + dest_token_index * output_tensor_stride0 ) output_tensor_scale_ptr = ( output_tensor_scale + dest_token_index * output_tensor_scale_stride0 ) tl.store(output_tensor_ptr + offset_in, to_copy, mask=mask) tl.store( output_tensor_scale_ptr + index_in_s * output_tensor_scale_stride1, to_copy_s, mask=mask_s, ) # copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/deepep_scatter_gather.py @torch.no_grad() def ep_scatter( recv_x: torch.Tensor, recv_x_scale: torch.Tensor, recv_topk: torch.Tensor, num_recv_tokens_per_expert: torch.Tensor, expert_start_loc: torch.Tensor, output_tensor: torch.Tensor, output_tensor_scale: torch.Tensor, m_indices: torch.Tensor, output_index: torch.Tensor, scale_ue8m0: bool = False, ): BLOCK_E = 128 # token num of per expert is aligned to 128 BLOCK_D = 128 # block size of quantization num_warps = 8 num_experts = num_recv_tokens_per_expert.shape[0] hidden_size = recv_x.shape[1] # grid = (triton.cdiv(hidden_size, BLOCK_D), num_experts) grid = num_experts scale_hidden_size = hidden_size // BLOCK_D if scale_ue8m0: # ue8m0 scales are packed here (4 scales per int32), # hence the effective size of this dimension is divided by 4. scale_hidden_size = ceil_div(scale_hidden_size, 4) assert m_indices.shape[0] % BLOCK_E == 0 assert ( recv_x_scale.dtype == output_tensor_scale.dtype ), f"recv_x_scale.dtype: {recv_x_scale.dtype}, output_tensor_scale.dtype: {output_tensor_scale.dtype}" assert recv_x_scale.shape[1] == output_tensor_scale.shape[1] == scale_hidden_size _fwd_kernel_ep_scatter_1[(grid,)]( num_recv_tokens_per_expert, expert_start_loc, m_indices, num_experts=num_experts, num_warps=num_warps, BLOCK_E=BLOCK_E, BLOCK_EXPERT_NUM=triton.next_power_of_2(num_experts), ) grid = min(recv_topk.shape[0], 1024 * 8) _fwd_kernel_ep_scatter_2[(grid,)]( recv_topk.shape[0], expert_start_loc, recv_x, recv_x.stride(0), recv_x.stride(1), recv_x_scale, recv_x_scale.stride(0), recv_x_scale.stride(1), recv_topk, recv_topk.stride(0), recv_topk.stride(1), output_tensor, output_tensor.stride(0), output_tensor.stride(1), output_tensor_scale, output_tensor_scale.stride(0), output_tensor_scale.stride(1), output_index, output_index.stride(0), output_index.stride(1), topk_num=recv_topk.shape[1], num_warps=num_warps, HIDDEN_SIZE=hidden_size, HIDDEN_SIZE_PAD=triton.next_power_of_2(hidden_size), SCALE_HIDDEN_SIZE=scale_hidden_size, SCALE_HIDDEN_SIZE_PAD=triton.next_power_of_2(scale_hidden_size), ) return @triton.jit def _fwd_kernel_ep_gather( total_token_num, input_tensor, input_tensor_stride0, input_tensor_stride1, recv_topk_ids, recv_topk_ids_stride0, recv_topk_ids_stride1, recv_topk_weight, recv_topk_weight_stride0, recv_topk_weight_stride1, input_index, input_index_stride0, input_index_stride1, output_tensor, output_tensor_stride0, output_tensor_stride1, topk_num: tl.constexpr, BLOCK_D: tl.constexpr, ): cur_block_int32 = tl.program_id(0) cur_block = cur_block_int32.to(tl.int64) start_cur_token_int32 = tl.program_id(1) grid_num = tl.num_programs(1) for cur_token_int32 in range(start_cur_token_int32, total_token_num, grid_num): cur_token = cur_token_int32.to(tl.int64) off_d = tl.arange(0, BLOCK_D) accumulator = tl.zeros([BLOCK_D], dtype=tl.float32) for topk_index_int32 in range(0, topk_num): topk_index = topk_index_int32.to(tl.int64) expert_id = tl.load( recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index ) if expert_id >= 0: source_token_index_int32 = tl.load( input_index + cur_token * input_index_stride0 + topk_index ) source_token_index = source_token_index_int32.to(tl.int64) acc_weight = tl.load( recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index ) tmp = tl.load( input_tensor + source_token_index * input_tensor_stride0 + cur_block * BLOCK_D + off_d ) accumulator += tmp.to(tl.float32) * acc_weight tl.store( output_tensor + cur_token * output_tensor_stride0 + cur_block * BLOCK_D + off_d, accumulator.to(output_tensor.dtype.element_ty), ) @torch.no_grad() def ep_gather( input_tensor: torch.Tensor, recv_topk_ids: torch.Tensor, recv_topk_weight: torch.Tensor, input_index: torch.Tensor, output_tensor: torch.Tensor, ): num_warps = 2 num_tokens = output_tensor.shape[0] hidden_size = input_tensor.shape[1] BLOCK_D = 128 if hidden_size % 1024 != 0 else 1024 # block size of quantization assert hidden_size % BLOCK_D == 0 grid = (triton.cdiv(hidden_size, BLOCK_D), min(num_tokens, 1024)) _fwd_kernel_ep_gather[grid]( num_tokens, input_tensor, input_tensor.stride(0), input_tensor.stride(1), recv_topk_ids, recv_topk_ids.stride(0), recv_topk_ids.stride(1), recv_topk_weight, recv_topk_weight.stride(0), recv_topk_weight.stride(1), input_index, input_index.stride(0), input_index.stride(1), output_tensor, output_tensor.stride(0), output_tensor.stride(1), topk_num=recv_topk_ids.shape[1], num_warps=num_warps, BLOCK_D=BLOCK_D, ) return # copy from # https://github.com/deepseek-ai/DeepGEMM/blob/bd2a77552886b98c205af12f8d7d2d61247c4b27/deep_gemm/jit_kernels/utils.py#L58 def get_tma_aligned_size(x: int, element_size: int) -> int: """ Global memory address of TMA must be 16-byte aligned. Since we use column-major layout for the LHS scaling tensor, the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes. Arguments: x: original M-axis shape of the LHS scaling tensor. element_size: element size of the LHS scaling tensor. Returns: M-axis shape of the LHS scaling tensor after padding. """ tma_alignment_bytes = 16 assert tma_alignment_bytes % element_size == 0 alignment = tma_alignment_bytes // element_size return ceil_div(x, alignment) * alignment @triton.jit def _tma_align_input_scale_kernel( input_scale_ptr, output_ptr, m, k_div_block_size, input_scale_stride_m, input_scale_stride_k, output_stride_m, output_stride_k, BLOCK_SIZE_K: tl.constexpr, ): pid_m = tl.program_id(axis=0) grid_m = tl.num_programs(0) k_offsets = tl.arange(0, BLOCK_SIZE_K) for m_base in range(pid_m, m, grid_m): input_offset = ( input_scale_ptr + m_base * input_scale_stride_m + k_offsets * input_scale_stride_k ) input_data = tl.load(input_offset, mask=k_offsets < k_div_block_size) output_offset = ( output_ptr + k_offsets * output_stride_k + m_base * output_stride_m ) tl.store(output_offset, input_data, mask=k_offsets < k_div_block_size) # copy from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py def tma_align_input_scale(input_scale: torch.Tensor): assert input_scale.dim() == 2 m, k_div_block_size = input_scale.shape padd_m = get_tma_aligned_size(m, input_scale.element_size()) output = torch.empty( (k_div_block_size, padd_m), dtype=input_scale.dtype, device=input_scale.device ) grid_m = min(m, 8192) BLOCK_SIZE_K = triton.next_power_of_2(k_div_block_size) _tma_align_input_scale_kernel[(grid_m,)]( input_scale_ptr=input_scale, output_ptr=output, m=m, k_div_block_size=k_div_block_size, input_scale_stride_m=input_scale.stride(0), input_scale_stride_k=input_scale.stride(1), output_stride_m=output.stride(1), # Note: these are swapped output_stride_k=output.stride(0), # for column-major BLOCK_SIZE_K=BLOCK_SIZE_K, ) return output.t()[:m] @triton.jit def compute_masked_m_triton_kernel(seg_indptr, masked_m): expert_id = tl.program_id(0) start = tl.load(seg_indptr + expert_id) end = tl.load(seg_indptr + expert_id + 1) tl.store(masked_m + expert_id, (end - start)) @triton.jit def deepgemm_compute_src2dst_triton_kernel( topk_ids, reorder_ids, seg_indptr, src2dst, m_max, num_toks, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = dst_id < num_toks src_id = tl.load(reorder_ids + dst_id, mask=mask) expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks)) expert_dst_start = tl.load(seg_indptr + expert_id, mask=(expert_id >= 0)) expert_dst_offset = dst_id - expert_dst_start dst_id = expert_id * m_max + expert_dst_offset tl.store(src2dst + src_id, dst_id, mask=mask) @triton.jit def fill_gateup_input_triton_kernel( input_ptr, scale_ptr, gateup_input_ptr, gateup_input_scale_ptr, src2dst_ptr, topk_ids_ptr, topk, hidden_size, scale_size, BLOCK_SIZE: tl.constexpr, ): src_idx_int32 = tl.program_id(0) src_idx = src_idx_int32.to(tl.int64) src2dst_ptr = src2dst_ptr + src_idx * topk topk_ids_ptr = topk_ids_ptr + src_idx * topk src_ptr = input_ptr + src_idx * hidden_size scale_src_ptr = scale_ptr + src_idx * scale_size vec = tl.arange(0, BLOCK_SIZE) for idx in range(topk): expert_id = tl.load(topk_ids_ptr + idx) if expert_id >= 0: dst_idx_int32 = tl.load(src2dst_ptr + idx) dst_idx = dst_idx_int32.to(tl.int64) dst_ptr = gateup_input_ptr + dst_idx * hidden_size for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): offset = start_offset + vec mask = offset < hidden_size in_data = tl.load(src_ptr + offset, mask=mask) tl.store(dst_ptr + offset, in_data, mask=mask) scale_dst_ptr = gateup_input_scale_ptr + dst_idx * scale_size for start_offset in tl.range(0, scale_size, BLOCK_SIZE): offset = start_offset + vec mask = offset < scale_size in_scale = tl.load(scale_src_ptr + offset, mask=mask) tl.store(scale_dst_ptr + offset, in_scale, mask=mask) def moe_ep_deepgemm_preprocess( topk_ids: torch.Tensor, num_local_experts: int, hidden_states: torch.Tensor, top_k: int, block_shape, output_dtype: torch.dtype = torch.float8_e4m3fn, ): reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) seg_indptr = torch.zeros( num_local_experts + 1, device=topk_ids.device, dtype=torch.int64 ) src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32) masked_m = torch.empty(num_local_experts, device=topk_ids.device, dtype=torch.int32) compute_seg_indptr_triton_kernel[(num_local_experts + 1,)]( reorder_topk_ids, seg_indptr, topk_ids.numel() ) grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),) compute_masked_m_triton_kernel[(num_local_experts,)](seg_indptr, masked_m) # For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165 m_max = (hidden_states.size(0) // 256 + 1) * 256 expected_m = (topk_ids.numel() - 1) // num_local_experts + 1 gateup_input = torch.empty( (num_local_experts, m_max, hidden_states.size(1)), device=hidden_states.device, dtype=output_dtype, ) deepgemm_compute_src2dst_triton_kernel[grid]( topk_ids, reorder_ids, seg_indptr, src2dst, m_max, topk_ids.numel(), BLOCK_SIZE=256, ) if block_shape is None: block_shape = [128, 128] assert len(block_shape) == 2 block_n, block_k = block_shape[0], block_shape[1] # TODO: fuse this with the preprocess hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k) gateup_input_scale = torch.empty( (gateup_input.size(0), gateup_input.size(1), scale.size(1)), device=hidden_states.device, dtype=scale.dtype, ) fill_gateup_input_triton_kernel[(hidden_states.shape[0],)]( hidden_states, scale, gateup_input, gateup_input_scale, src2dst, topk_ids, top_k, hidden_states.size(1), scale.size(1), BLOCK_SIZE=1024, ) return ( masked_m, expected_m, src2dst, gateup_input, gateup_input_scale, ) @triton.jit def compute_identity_kernel( top_k, hidden_states_ptr, expert_scales_ptr, num_tokens, output_ptr, hidden_dim, scales_stride, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(0) batch_id = pid // (hidden_dim // BLOCK_SIZE) dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE if batch_id >= num_tokens or dim_offset >= hidden_dim: return h = tl.load( hidden_states_ptr + batch_id * hidden_dim + dim_offset + tl.arange(0, BLOCK_SIZE), mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim, ) result = tl.zeros([BLOCK_SIZE], dtype=tl.float32) for i in range(top_k): scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i) result += h * scale tl.store( output_ptr + batch_id * hidden_dim + dim_offset + tl.arange(0, BLOCK_SIZE), result, mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim, ) def zero_experts_compute_triton( expert_indices, expert_scales, num_experts, zero_expert_type, hidden_states ): N = expert_indices.numel() top_k = expert_indices.size(-1) grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) if zero_expert_type == "identity": zero_expert_mask = expert_indices < num_experts zero_expert_scales = expert_scales.clone() zero_expert_scales[zero_expert_mask] = 0.0 normal_expert_mask = expert_indices >= num_experts expert_indices[normal_expert_mask] = -1 expert_scales[normal_expert_mask] = 0.0 output = torch.zeros_like(hidden_states).to(hidden_states.device) hidden_dim = hidden_states.size(-1) num_tokens = hidden_states.size(0) grid = lambda meta: (num_tokens * (hidden_dim // meta["BLOCK_SIZE"]),) compute_identity_kernel[grid]( top_k, hidden_states, zero_expert_scales, num_tokens, output, hidden_dim, zero_expert_scales.stride(0), BLOCK_SIZE=256, ) return output @triton.jit def compute_problem_sizes_w4a8_kernel( masked_m_ptr, problem_sizes1_ptr, problem_sizes2_ptr, n, k, num_experts, BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = pid < num_experts final_occurrences = tl.load(masked_m_ptr + pid, mask=mask, other=0) ps1_idx_0 = pid * 3 ps1_idx_1 = ps1_idx_0 + 1 ps1_idx_2 = ps1_idx_0 + 2 ps2_idx_0 = pid * 3 ps2_idx_1 = ps2_idx_0 + 1 ps2_idx_2 = ps2_idx_0 + 2 ps1_mask_0 = ps1_idx_0 < num_experts * 3 ps1_mask_1 = ps1_idx_1 < num_experts * 3 ps1_mask_2 = ps1_idx_2 < num_experts * 3 ps2_mask_0 = ps2_idx_0 < num_experts * 3 ps2_mask_1 = ps2_idx_1 < num_experts * 3 ps2_mask_2 = ps2_idx_2 < num_experts * 3 tl.store(problem_sizes1_ptr + ps1_idx_0, 2 * n, mask=ps1_mask_0) tl.store(problem_sizes1_ptr + ps1_idx_1, final_occurrences, mask=ps1_mask_1) tl.store(problem_sizes1_ptr + ps1_idx_2, k, mask=ps1_mask_2) tl.store(problem_sizes2_ptr + ps2_idx_0, k, mask=ps2_mask_0) tl.store(problem_sizes2_ptr + ps2_idx_1, final_occurrences, mask=ps2_mask_1) tl.store(problem_sizes2_ptr + ps2_idx_2, n, mask=ps2_mask_2) def compute_problem_sizes_w4a8( masked_m, problem_sizes1, problem_sizes2, n, k, num_experts ): BLOCK_SIZE = 256 grid = lambda meta: (triton.cdiv(num_experts, meta["BLOCK_SIZE"]),) compute_problem_sizes_w4a8_kernel[grid]( masked_m, problem_sizes1, problem_sizes2, n, k, num_experts, BLOCK_SIZE=BLOCK_SIZE, ) return problem_sizes1, problem_sizes2 def deepep_ll_get_cutlass_w4a8_moe_mm_data( masked_m, problem_sizes1, problem_sizes2, num_experts, n, k, ): problem_sizes1, problem_sizes2 = compute_problem_sizes_w4a8( masked_m, problem_sizes1, problem_sizes2, n, k, num_experts ) return ( problem_sizes1.to(torch.int32), problem_sizes2.to(torch.int32), ) @triton.jit def _silu_and_mul_post_per_tensor_quant_kernel( input_ptr, stride_input_expert, stride_input_token, stride_input_dim, output_ptr, stride_output_expert, stride_output_token, stride_output_dim, scale_ptr, masked_m_ptr, inner_dim, fp8_max, fp8_min, BLOCK_N: tl.constexpr, NUM_STAGE: tl.constexpr, ): """ Triton kernel: fused SiLU(gate) * up + per-tensor FP8 quantization. Shape: input: [E, T_padded, 2*D] -> gate: [:,:,D], up: [:,:,D] output: [E, T_padded, D], dtype=float8_e4m3fn """ expert_id = tl.program_id(2) block_id_token = tl.program_id(1) block_id_dim = tl.program_id(0) num_token_blocks = tl.num_programs(1) token_num_cur_expert = tl.load(masked_m_ptr + expert_id) scale = 1.0 / tl.load(scale_ptr).to(tl.float32) stride_input_expert = tl.cast(stride_input_expert, tl.int32) stride_output_expert = tl.cast(stride_output_expert, tl.int32) stride_input_token = tl.cast(stride_input_token, tl.int32) stride_output_token = tl.cast(stride_output_token, tl.int32) offset_d = block_id_dim * BLOCK_N + tl.arange(0, BLOCK_N) mask_d = offset_d < inner_dim # base pointers for current expert and dim block input_base_offs = input_ptr + expert_id * stride_input_expert + offset_d output_base_offs = output_ptr + expert_id * stride_output_expert + offset_d for token_idx in tl.range( block_id_token, token_num_cur_expert, num_token_blocks, num_stages=NUM_STAGE ): gate_ptr = input_base_offs + token_idx * stride_input_token up_ptr = gate_ptr + inner_dim gate = tl.load(gate_ptr, mask=mask_d, other=0.0).to(tl.float32) up = tl.load(up_ptr, mask=mask_d, other=0.0).to(tl.float32) # SiLU: x * sigmoid(x) gate = gate / (1 + tl.exp(-gate)) gate = gate.to(input_ptr.dtype.element_ty) gate_up = up * gate scaled = gate_up * scale output_q = tl.clamp(scaled, fp8_min, fp8_max).to(output_ptr.dtype.element_ty) out_ptr = output_base_offs + token_idx * stride_output_token tl.store(out_ptr, output_q, mask=mask_d) def silu_and_mul_masked_post_per_tensor_quant_fwd( input: torch.Tensor, output: torch.Tensor, masked_m: torch.Tensor, scale: torch.Tensor, ) -> torch.Tensor: """ Fused SiLU + Mul + Per-Tensor Quantization to FP8. Args: input: [expert_num, token_num_padded, 2 * inner_dim] output: [expert_num, token_num_padded, inner_dim], dtype=torch.float8_e4m3fn masked_m: [expert_num], actual token count for each expert scale: [1] or [expert_num], quantization scale (per-tensor or per-expert) Returns: output tensor """ assert input.is_contiguous() assert output.is_contiguous() assert output.dtype == torch.float8_e4m3fn assert input.ndim == 3 assert input.shape[0] == masked_m.shape[0] assert input.shape[-1] % 2 == 0 assert scale.numel() == 1 or scale.shape[0] == input.shape[0] expert_num = input.shape[0] # 3584 inner_dim = input.shape[-1] // 2 BLOCK_N = 256 BLOCK_M = 64 if expert_num < 4 else 32 NUM_STAGES = 3 hidden_dim_split_block_num = triton.cdiv(inner_dim, BLOCK_N) grid = (hidden_dim_split_block_num, BLOCK_M, expert_num) finfo = torch.finfo(torch.float8_e4m3fn) fp8_max = finfo.max fp8_min = -fp8_max _silu_and_mul_post_per_tensor_quant_kernel[grid]( input, *input.stride(), output, *output.stride(), scale, masked_m, inner_dim, fp8_max, fp8_min, BLOCK_N=BLOCK_N, NUM_STAGE=NUM_STAGES, ) return output