Unverified Commit fb0571b0 authored by Varun Sundar Rabindranath's avatar Varun Sundar Rabindranath Committed by GitHub
Browse files

[GPTOSS][DP/EP][Marlin] Enable GPTOSS Batched DP/EP using Marlin kernels (#25997)


Signed-off-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: default avatarVarun Sundar Rabindranath <vsundarr@redhat.com>
parent 2ed8b6b3
......@@ -8,12 +8,77 @@
#include "../cuda_compat.h"
#include "../dispatch_utils.h"
#include "core/math.hpp"
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
namespace vllm {
namespace moe {
namespace batched_moe_align_block_size {
// Note num_threads needs to be 1024 for BlockScan Reduction in the kernel.
static constexpr int32_t num_threads = 1024;
static constexpr int32_t num_blocks = 1;
__global__ void batched_moe_align_block_size_kernel(
int32_t const num_batches, int32_t const max_tokens_per_batch,
int32_t const block_size, int32_t const* __restrict__ batch_num_tokens,
int32_t* __restrict__ sorted_ids, int32_t* __restrict__ block_ids,
int32_t* __restrict__ num_tokens_post_pad) {
// TODO(varun): This is a naive implementation. Could be optimized.
size_t const batch_id = threadIdx.x;
size_t const stride = blockDim.x * gridDim.x;
int32_t const num_blocks_per_batch =
CEILDIV(max_tokens_per_batch, block_size);
int32_t const sorted_ids_size =
num_blocks_per_batch * num_batches * block_size;
int32_t const block_ids_size = sorted_ids_size / block_size;
int32_t const SENTINEL =
num_batches * max_tokens_per_batch; // To denote invalid entries.
// Intialize sorted_ids
for (size_t i = threadIdx.x; i < sorted_ids_size; i += stride) {
sorted_ids[i] = SENTINEL;
}
// Intialize expert_ids with -1
for (size_t i = threadIdx.x; i < block_ids_size; i += stride) {
block_ids[i] = -1;
}
int32_t b_num_tokens = 0;
if (batch_id < num_batches) {
b_num_tokens = batch_num_tokens[batch_id];
}
int32_t const ceil_b_num_tokens =
CEILDIV(b_num_tokens, block_size) * block_size;
// Compute prefix sum over token counts per expert
using BlockScan = cub::BlockScan<int32_t, 1024>;
__shared__ typename BlockScan::TempStorage temp_storage;
int cumsum_val;
BlockScan(temp_storage).ExclusiveSum(ceil_b_num_tokens, cumsum_val);
__syncthreads();
bool const is_last_batch = batch_id == (num_batches - 1);
if (is_last_batch) {
*num_tokens_post_pad = cumsum_val + ceil_b_num_tokens;
}
if (batch_id < num_batches) {
int32_t const batch_offset = batch_id * max_tokens_per_batch;
for (size_t i = 0; i < b_num_tokens; ++i) {
sorted_ids[cumsum_val + i] = batch_offset + i;
}
int32_t const block_start = cumsum_val / block_size;
int32_t const num_blocks = ceil_b_num_tokens / block_size;
for (size_t i = 0; i < num_blocks; ++i) {
block_ids[block_start + i] = batch_id;
}
}
}
} // namespace batched_moe_align_block_size
template <typename scalar_t>
__global__ void moe_align_block_size_kernel(
const scalar_t* __restrict__ topk_ids,
......@@ -280,6 +345,33 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
});
}
void batched_moe_align_block_size(int64_t max_tokens_per_batch,
int64_t block_size,
torch::Tensor const& batch_num_tokens,
torch::Tensor sorted_ids,
torch::Tensor batch_ids,
torch::Tensor num_tokens_post_pad) {
namespace batched_kernel = vllm::moe::batched_moe_align_block_size;
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int32_t const B = batch_num_tokens.size(0);
int32_t const num_blocks_per_batch =
round_to_next_multiple_of(max_tokens_per_batch, block_size) / block_size;
int32_t const num_blocks = num_blocks_per_batch * B;
int64_t const sorted_ids_size = num_blocks * block_size;
TORCH_CHECK(sorted_ids.size(0) == sorted_ids_size);
TORCH_CHECK(batch_ids.size(0) == sorted_ids_size / block_size);
TORCH_CHECK(num_tokens_post_pad.size(0) == 1);
TORCH_CHECK(B <= batched_kernel::num_threads);
batched_kernel::batched_moe_align_block_size_kernel<<<
batched_kernel::num_blocks, batched_kernel::num_threads, 0, stream>>>(
B, max_tokens_per_batch, block_size, batch_num_tokens.data_ptr<int32_t>(),
sorted_ids.data_ptr<int32_t>(), batch_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>());
}
void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
torch::Tensor& output) // [num_tokens, hidden_size]
{
......
......@@ -12,6 +12,14 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int64_t block_size, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);
void batched_moe_align_block_size(int64_t max_tokens_per_batch,
int64_t block_size,
torch::Tensor const& expert_num_tokens,
torch::Tensor sorted_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad);
#ifndef USE_ROCM
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor b_qweight, torch::Tensor b_scales,
......
......@@ -22,6 +22,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" Tensor! num_tokens_post_pad) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
// Aligning the number of tokens to be processed by each expert such
// that it is divisible by the block size, but for the batched case.
m.def(
"batched_moe_align_block_size(int max_tokens_per_batch,"
" int block_size, Tensor expert_num_tokens,"
" Tensor! sorted_token_ids,"
" Tensor! experts_ids,"
" Tensor! num_tokens_post_pad) -> ()");
m.impl("batched_moe_align_block_size", torch::kCUDA,
&batched_moe_align_block_size);
#ifndef USE_ROCM
m.def(
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "
......
......@@ -92,8 +92,8 @@ To be used with a particular `FusedMoEPrepareAndFinalize` sub-class, MoE kernels
| flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],</br>[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
| deep gemm+triton<sup>2</sup> | standard,</br>batched | all<sup>1</sup> | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],</br>[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] |
| marlin | standard | <sup>3</sup> | <sup>3</sup> | silu,</br>swigluoai | Y | N | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe] |
| marlin experts | standard | N/A | N/A | silu,</br>swigluoai | Y | Y | [`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts] |
| marlin | standard | <sup>3</sup> | <sup>3</sup> | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
| marlin experts | standard,</br>batched | N/A | N/A | silu,</br>swigluoai | Y | Y | [`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
......@@ -116,5 +116,5 @@ The following table shows "families" of modular kernels that are intended to wor
| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses |
|----------------------------------|------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------|
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8`|
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts`|
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |
This diff is collapsed.
......@@ -9,6 +9,7 @@ import pytest
import torch
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
batched_moe_align_block_size,
moe_align_block_size,
)
from vllm.platforms import current_platform
......@@ -300,3 +301,96 @@ def test_moe_align_block_size_deterministic():
assert torch.equal(results[0][2], results[i][2]), (
"num_tokens should be deterministic"
)
@pytest.mark.parametrize("max_tokens_per_batch", [13, 16, 512])
@pytest.mark.parametrize("num_experts", [8, 16, 32, 64])
@pytest.mark.parametrize("block_size", [8, 16, 32, 64])
@pytest.mark.parametrize("simulate_empty_batches", [False, True])
def test_batched_moe_align_block_size(
max_tokens_per_batch: int,
num_experts: int,
block_size: int,
simulate_empty_batches: bool,
):
def ref_outputs(
expert_num_tokens: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
E = expert_num_tokens.size(0)
# Round up so each batch can be split to blocks evenly.
Msum = round_up(max_tokens_per_batch, block_size) * E
ref_sorted_ids = torch.empty((Msum,), dtype=torch.int32)
ref_expert_ids = torch.empty((Msum // block_size,), dtype=torch.int32)
ref_num_tokens_post_pad = torch.empty((1,), dtype=torch.int32)
# Intialize
sentinel = E * max_tokens_per_batch
ref_sorted_ids.fill_(sentinel)
ref_expert_ids.fill_(-1)
# Fill ref_sorted_ids
i = 0
for expert_id, expert_nt in enumerate(expert_num_tokens):
token_offset = expert_id * max_tokens_per_batch
for j in range(expert_nt):
ref_sorted_ids[i] = token_offset + j
i += 1
# round up i to the next block_size
i = round_up(i, block_size)
ref_num_tokens_post_pad[0] = i
# Fill expert_ids
nt_ceil_sum = 0
for expert_id, expert_nt in enumerate(expert_num_tokens):
expert_ids_offset = nt_ceil_sum // block_size
ceil_expert_nt = round_up(int(expert_nt.item()), block_size)
num_blocks = ceil_expert_nt // block_size
for x in range(num_blocks):
ref_expert_ids[expert_ids_offset + x] = expert_id
nt_ceil_sum += ceil_expert_nt
return (
ref_sorted_ids.to("cuda"),
ref_expert_ids.to("cuda"),
ref_num_tokens_post_pad.to("cuda"),
)
# Compute expert_num_tokens
expert_num_tokens = torch.randint(
low=0,
high=max_tokens_per_batch,
size=(num_experts,),
device="cpu",
dtype=torch.int32,
)
if simulate_empty_batches:
# mark half the batches to have 0 tokens
zero_batches = torch.randperm(num_experts)[: num_experts // 2]
expert_num_tokens[zero_batches] = 0
# ref outputs
ref_sorted_ids, ref_expert_ids, ref_num_tokens_post_pad = ref_outputs(
expert_num_tokens
)
# outputs
sorted_ids, expert_ids, num_tokens_post_pad = batched_moe_align_block_size(
max_tokens_per_batch, block_size, expert_num_tokens.to("cuda")
)
assert ref_sorted_ids.size() == sorted_ids.size(), (
f"{ref_sorted_ids.size()} vs {sorted_ids.size()}"
)
assert ref_expert_ids.size() == expert_ids.size(), (
f"{ref_expert_ids.size()} vs {expert_ids.size()}"
)
assert ref_num_tokens_post_pad.size() == num_tokens_post_pad.size(), (
f"{ref_num_tokens_post_pad.size()} vs {num_tokens_post_pad.size()}"
)
torch.testing.assert_close(ref_sorted_ids, sorted_ids, atol=0, rtol=0)
torch.testing.assert_close(ref_expert_ids, expert_ids, atol=0, rtol=0)
torch.testing.assert_close(
ref_num_tokens_post_pad, num_tokens_post_pad, atol=0, rtol=0
)
......@@ -1789,6 +1789,24 @@ def moe_align_block_size(
)
def batched_moe_align_block_size(
max_tokens_per_batch: int,
block_size: int,
expert_num_tokens: torch.Tensor,
sorted_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
) -> None:
torch.ops._moe_C.batched_moe_align_block_size(
max_tokens_per_batch,
block_size,
expert_num_tokens,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
def moe_wna16_gemm(
input: torch.Tensor,
output: torch.Tensor,
......
......@@ -50,7 +50,31 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# DeepEP low-latency kernels are compiled only for certain
# specific hidden sizes.
SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 6144, 7168]
# NOTE: Keep this list sorted, maybe_roundup_layer_hidden_size depends
# on it.
SUPPORTED_HIDDEN_SIZES = [2048, 2560, 4096, 5120, 6144, 7168, 8192]
@staticmethod
def maybe_roundup_layer_hidden_size(hidden_size: int) -> int:
# Round up hidden size to the closest supported hidden size.
_supported_hs = DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES
# Check sorted
num_supported_hs = len(_supported_hs)
assert all(
[
_supported_hs[i] < _supported_hs[i + 1]
for i in range(num_supported_hs - 1)
]
)
for x in _supported_hs:
if x >= hidden_size:
return x
raise ValueError(
f"Hidden Size {hidden_size} is greater than the "
f"maximum supported hidden size {_supported_hs[-1]}"
)
def __init__(
self,
......
......@@ -994,6 +994,11 @@ def maybe_roundup_hidden_size(
hidden_size, act_dtype
)
if moe_parallel_config.use_deepep_ll_kernels:
hidden_size = DeepEPLLPrepareAndFinalize.maybe_roundup_layer_hidden_size(
hidden_size
)
# we are padding globally so EP buffer allocation works
if quant_config and quant_config.get_name() == "mxfp4":
from vllm.model_executor.layers.quantization.mxfp4 import (
......
......@@ -83,3 +83,92 @@ def moe_align_block_size(
expert_ids = expert_map[expert_ids]
return sorted_ids, expert_ids, num_tokens_post_pad
def batched_moe_align_block_size(
max_tokens_per_batch: int, block_size: int, expert_num_tokens: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Given num_batches, max_tokens_per_batch, block_size and the number of
valid-tokens in each batch, prepare sorted_token_ids, expert_ids and
num_tokens_post_pad. sorted_token_ids, expert_ids and num_tokens_post_pad
have the same semantics as in moe_align_block_size.
This function is intended to be a drop in replacement for
moe_align_batch_size for the batched case.
Parameters:
- max_tokens_per_batch (int): Number of tokens in each batch (both
valid and invalid).
- block_size (int): block_size to align the data to.
- expert_num_tokens (torch.Tensor): expert_num_tokens[i], indicates
the number of valid tokens in batch i.
Returns:
- sorted_token_ids (torch.Tensor): Torch tensor of size
(num_batches * max_tokens_per_batch) indicating the token indices for
that block.
- expert_ids (torch.Tensor): Torch tensor of size
ceil((num_batches * max_tokens_per_batch) / block_size) indicating
what expert to use for each block.
- num_tokens_post_pad (torch.Tensor): Torch tensor of size 1
indicating the number of valid blocks with actual data to
process. This is represented in terms of num tokens.
Example:
Let num_batches=5, max_tokens_per_batch=8, block_size=4, and
expert_num_tokens=[2, 3, 0, 6, 8]. This expert_num_tokens tensor
indicates that,
- The first 2 tokens in the 0th batch are valid and the rest 6 are
invalid (i.e. in the 2D hidden_states tensor of shape,
[num_batches * max_tokens_per_batch, K], indices 0, 1 are valid)
- The first 3 tokens in the 1st batch are valid. i.e. indices 8, 9, 10
- 0 tokens in the 2nd batch are valid
- first 6 tokens in the 3rd batch are valid. i.e. indices,
24, 25, 26, 27, 28, 29
- so on ...
In this case,
sorted_token_ids will be [0, 1, 40, 40,
8, 9, 10, 40,
24, 25, 26, 27,
28, 29, 40, 40,
32, 33, 34, 35,
36, 37, 38, 39,
40, 40, 40, 40,
(rest all 40, 40, 40, 40)
...]
Here, 40 represents an invalid index. as there is no token index 40.
The gemm kernel using this sorted_token_ids is expected to skip the
gemm computation when it encounters this invalid index.
expert_ids will be [0, 1, 3, 3, 4, 5, 5, -1, -1, (rest all -1) ...]
Here, -1 represents an invalid expert. The gemm kernel using this
expert_ids is expected to skip the gemm computation when it encounters
an expert of id -1.
num_tokens_post_pad will be 24 as sorted_token_ids has valid entries
until 24.
"""
B = expert_num_tokens.size(0)
device = expert_num_tokens.device
# Round up so each batch can be split to blocks evenly.
max_num_tokens_padded = B * round_up(max_tokens_per_batch, block_size)
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device=device)
assert max_num_tokens_padded % block_size == 0
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device=device)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=device)
ops.batched_moe_align_block_size(
max_tokens_per_batch,
block_size,
expert_num_tokens,
sorted_ids,
expert_ids,
num_tokens_post_pad,
)
return sorted_ids, expert_ids, num_tokens_post_pad
......@@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
ocp_mx_moe_quant_config,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
MarlinExperts,
fused_marlin_moe,
)
......@@ -797,8 +798,18 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
):
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
assert self.moe_quant_config is not None
return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
)
else:
raise NotImplementedError(
"Mxfp4 does not support batched experts format for EP"
"Incompatible Mxfp4 backend for EP batched experts format"
)
else:
assert self.moe_quant_config is not None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment