Unverified Commit 0df6765c authored by Pavani Majety's avatar Pavani Majety Committed by GitHub
Browse files

[CUTLASS-FP4-MOE] Introduce CutlassMoEParams class for easy initialization of...


[CUTLASS-FP4-MOE]  Introduce CutlassMoEParams class for easy initialization of Cutlass Grouped Gems Metadata (#6887)
Signed-off-by: default avatarPavani Majety <pmajety@nvidia.com>
parent 35b65cf0
...@@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple ...@@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -18,11 +19,12 @@ if _is_cuda: ...@@ -18,11 +19,12 @@ if _is_cuda:
fp8_blockwise_scaled_grouped_mm, fp8_blockwise_scaled_grouped_mm,
prepare_moe_input, prepare_moe_input,
scaled_fp4_experts_quant, scaled_fp4_experts_quant,
shuffle_rows,
silu_and_mul, silu_and_mul,
) )
def cutlass_fused_experts( def cutlass_fused_experts_fp8(
a: torch.Tensor, a: torch.Tensor,
w1_q: torch.Tensor, w1_q: torch.Tensor,
w2_q: torch.Tensor, w2_q: torch.Tensor,
...@@ -223,17 +225,10 @@ def cutlass_moe_fp4( ...@@ -223,17 +225,10 @@ def cutlass_moe_fp4(
w2_fp4: torch.Tensor, w2_fp4: torch.Tensor,
w2_blockscale: torch.Tensor, w2_blockscale: torch.Tensor,
w2_alphas: torch.Tensor, w2_alphas: torch.Tensor,
ab_strides_13: torch.Tensor,
ab_strides_2: torch.Tensor,
c_strides_13: torch.Tensor,
c_strides_2: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
m: int, params: CutlassMoEParams,
n: int, apply_router_weight_on_input: bool = False,
k: int,
e: int,
device: torch.device,
): ):
""" """
MoE implementation for FP4 Inputs MoE implementation for FP4 Inputs
...@@ -291,77 +286,70 @@ def cutlass_moe_fp4( ...@@ -291,77 +286,70 @@ def cutlass_moe_fp4(
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
e_w2, k_w2, half_n_w2 = w2_fp4.shape e_w2, k_w2, half_n_w2 = w2_fp4.shape
assert e_w1 == e_w2 and e_w1 == e, ( assert e_w1 == e_w2 and e_w1 == params.num_experts, (
"Number of experts must match", "Number of experts must match",
" between weights.", " between weights.",
) )
assert ( assert (
k_a // 2 == half_k_w1 and k == k_w2 k_a // 2 == half_k_w1 and params.hidden_size == k_w2
), "Hidden size mismatch between a, w1 and w2" ), "Hidden size mismatch between a, w1 and w2"
assert nx2_w1 == n * 2 and half_n_w2 == n // 2, "mismatch in " "expected `n`" assert (
assert m == m_a, "input shape mismatch" nx2_w1 == params.intermediate_size_per_partition * 2
and half_n_w2 == params.intermediate_size_per_partition // 2
), ("mismatch in " "expected `n`")
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1" assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
assert (
topk_weights.shape[0] == m and topk_ids.shape[0] == m
), "topk must be provided for each row of a"
out_dtype = a.dtype out_dtype = a.dtype
num_topk = topk_ids.shape[1] num_topk = topk_ids.shape[1]
device = a.device
expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
# Problem size: (num_experts, (m,2n,k))
problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device)
# Problem size: (num_experts, (m,n,k))
problem_sizes2 = torch.empty((e, 3), dtype=torch.int32, device=device)
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device) c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
# problem shapes should have [m, n, k]
# Note that problem sizes are based on logical number of elements.
blockscale_offsets = torch.empty(e + 1, dtype=torch.int32, device=device)
prepare_moe_input( prepare_moe_input(
topk_ids, topk_ids,
expert_offsets, params.expert_offsets,
problem_sizes1, params.problem_sizes1,
problem_sizes2, params.problem_sizes2,
a_map, a_map,
c_map, c_map,
e, params.num_experts,
n, params.intermediate_size_per_partition,
k, params.hidden_size,
blockscale_offsets, params.blockscale_offsets,
) )
rep_a_fp4, rep_a_blockscale = scaled_fp4_experts_quant( rep_a_fp4, rep_a_blockscale = scaled_fp4_experts_quant(
a, a1_gscale, expert_offsets, blockscale_offsets, num_topk, expert_map=a_map a,
a1_gscale,
params.expert_offsets,
params.blockscale_offsets,
num_topk,
expert_map=a_map,
) )
c1 = cutlass_fp4_group_mm( c1 = cutlass_fp4_group_mm(
rep_a_fp4, rep_a_fp4,
w1_fp4, w1_fp4,
rep_a_blockscale, rep_a_blockscale,
w1_blockscale, w1_blockscale,
w1_alphas, w1_alphas,
ab_strides_13,
c_strides_13,
problem_sizes1,
expert_offsets[:-1],
blockscale_offsets[:-1],
out_dtype, out_dtype,
device, device,
params.to_gemm1_args(),
) )
del rep_a_fp4, rep_a_blockscale del rep_a_fp4, rep_a_blockscale
# hidden size dimension is split to one halfpytho sized tensor. # hidden size dimension is split to one halfpytho sized tensor.
intermediate = torch.empty( intermediate = torch.empty(
(m * num_topk, w1_fp4.shape[1] // 2), device=device, dtype=out_dtype (m_a * num_topk, w1_fp4.shape[1] // 2), device=device, dtype=out_dtype
) )
silu_and_mul(c1, intermediate) silu_and_mul(c1, intermediate)
int_fp4, int_blockscale = scaled_fp4_experts_quant( int_fp4, int_blockscale = scaled_fp4_experts_quant(
intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk intermediate,
a2_gscale,
params.expert_offsets,
params.blockscale_offsets,
num_topk,
) )
c2 = cutlass_fp4_group_mm( c2 = cutlass_fp4_group_mm(
int_fp4, int_fp4,
...@@ -369,16 +357,13 @@ def cutlass_moe_fp4( ...@@ -369,16 +357,13 @@ def cutlass_moe_fp4(
int_blockscale, int_blockscale,
w2_blockscale, w2_blockscale,
w2_alphas, w2_alphas,
ab_strides_2,
c_strides_2,
problem_sizes2,
expert_offsets[:-1],
blockscale_offsets[:-1],
out_dtype, out_dtype,
device, device,
params.to_gemm2_args(),
) )
del int_fp4, int_blockscale del int_fp4, int_blockscale
out = ( c2 = shuffle_rows(c2, c_map, (m_a * num_topk, params.hidden_size))
c2[c_map].view(m, num_topk, k) * topk_weights.view(m, num_topk, 1).half() c2 = c2.view(m_a, num_topk, params.hidden_size)
).sum(dim=1) if not apply_router_weight_on_input:
return out.to(dtype=out_dtype) c2 = c2 * topk_weights.view(m_a, num_topk, 1).to(out_dtype)
return c2.sum(dim=1).to(out_dtype)
from dataclasses import dataclass
from enum import Enum, auto
from typing import Optional
import torch
class CutlassMoEType(Enum):
"""
Enum for the different types of cutlass moe operations
that are currently supported in SGLang.
"""
BlockscaledFP8 = auto()
BlockscaledFP4 = auto()
@dataclass
class CutlassMoEParams:
"""
Parameters for the cutlass moe operation.
"""
# Type as defined above
cutlass_moe_type: CutlassMoEType
# Strides for activations, weights and output in logical number of elements.
# The activations & output stride is the number of elements to the next row.
# The weights stride is the number of elements to the next row per expert.
# For example, if the weight is [e, n, k], then the b_stride is a tensor of
# shape [e] with each element being k. Similarly for activations, if the
# shape is [m, k], then the a_stride has shape [e] with each value k.
# Similarly for output, if the output is [m, n], then the c_stride is a
# tensor of shape [e] with each element being k.
# Note: cutlass_fp4_group_mm is designed to accept the strides of
# activations and weights to be the same, so it is passed in as a single
# tensor.
# ab_strides_13: [e] dtype: int64 [Gemm 1: Activation / Weight strides]
# ab_strides_2: [e] dtype: int64 [Gemm 2: Activation / Weight strides]
# c_strides_13: [e] dtype: int64 [Gemm 1: Output Strides]
# c_strides_2: [e] dtype: int64 [Gemm 2: Output Strides]
ab_strides_13: torch.Tensor
ab_strides_2: torch.Tensor
c_strides_13: torch.Tensor
c_strides_2: torch.Tensor
# m: Total number of tokens
# n: intermediate size per partition
# k: hidden size per expert
# e: Number of experts
# device: Device to run computation on and store tensors
m: int
intermediate_size_per_partition: int
hidden_size: int
num_experts: int
device: torch.device
# Pointers container for calculating offsets of the input activations for each expert
# a_ptrs: [e] dtype: int64
a_ptrs: torch.Tensor
# Pointers container for calculating offsets of the input weights for each expert
# b_ptrs: [e] dtype: int64
b_ptrs: torch.Tensor
# Pointers container for calculating offsets of the output activations for each expert
# out_ptrs: [e] dtype: int64
out_ptrs: torch.Tensor
# Pointers container for calculating offsets of the input scales for each expert
# a_scales_ptrs: [e] dtype: int64
# b_scales_ptrs: [e] dtype: int64
a_scales_ptrs: torch.Tensor
b_scales_ptrs: torch.Tensor
# Offsets that mark at which token index each expert begins its computation
# The number of tokens computed with expert E is expert_offsets[E + 1] - expert_offsets[E]
# expert_offsets: [e+1] dtype: int32
expert_offsets: torch.Tensor
# Problem size: (num_experts, (m,2n,k)) for first GEMM
# problem_sizes1: [e, 3] dtype: int32
# Problem size: (num_experts, (m,n,k)) for second GEMM
# problem_sizes2: [e, 3] dtype: int32
problem_sizes1: torch.Tensor
problem_sizes2: torch.Tensor
# Similar to expert_offsets, but for blockscales for FP4 blockscaled Group GEMM
blockscale_offsets: Optional[torch.Tensor] = None
def __init__(
self,
cutlass_moe_type: CutlassMoEType,
device: torch.device,
num_experts: int,
intermediate_size_per_partition: int,
hidden_size: int,
):
self.cutlass_moe_type = cutlass_moe_type
self.device = device
self.num_experts = num_experts
self.intermediate_size_per_partition = intermediate_size_per_partition
self.hidden_size = hidden_size
self.n = self.intermediate_size_per_partition
self.k = self.hidden_size
self.e = self.num_experts
self.ab_strides_13 = torch.full(
(self.e,), self.k, dtype=torch.int64, device=self.device
)
self.ab_strides_2 = torch.full(
(self.e,), self.n, dtype=torch.int64, device=self.device
)
self.c_strides_13 = torch.full(
(self.e,), 2 * self.n, dtype=torch.int64, device=self.device
)
self.c_strides_2 = torch.full(
(self.e,), self.k, dtype=torch.int64, device=self.device
)
self.expert_offsets = torch.empty(
(self.e + 1,), dtype=torch.int32, device=self.device
)
self.problem_sizes1 = torch.empty(
(self.e, 3), dtype=torch.int32, device=self.device
)
self.problem_sizes2 = torch.empty(
(self.e, 3), dtype=torch.int32, device=self.device
)
if self.cutlass_moe_type == CutlassMoEType.BlockscaledFP4:
self.blockscale_offsets = torch.empty(
(self.e + 1,), dtype=torch.int32, device=self.device
)
else:
self.blockscale_offsets = None
self.a_ptrs = torch.empty((self.e,), dtype=torch.int64, device=self.device)
self.b_ptrs = torch.empty((self.e,), dtype=torch.int64, device=self.device)
self.out_ptrs = torch.empty((self.e,), dtype=torch.int64, device=self.device)
self.a_scales_ptrs = torch.empty(
(self.e,), dtype=torch.int64, device=self.device
)
self.b_scales_ptrs = torch.empty(
(self.e,), dtype=torch.int64, device=self.device
)
def to_gemm1_args(self) -> dict:
return {
"ab_strides": self.ab_strides_13,
"c_strides": self.c_strides_13,
"problem_sizes": self.problem_sizes1,
"expert_offsets": self.expert_offsets[:-1],
"blockscale_offsets": self.blockscale_offsets[:-1],
# "a_ptrs": self.a_ptrs,
# "b_ptrs": self.b_ptrs,
# "out_ptrs": self.out_ptrs,
# "a_scales_ptrs": self.a_scales_ptrs,
# "b_scales_ptrs": self.b_scales_ptrs,
}
def to_gemm2_args(self) -> dict:
return {
"ab_strides": self.ab_strides_2,
"c_strides": self.c_strides_2,
"problem_sizes": self.problem_sizes2,
"expert_offsets": self.expert_offsets[:-1],
"blockscale_offsets": self.blockscale_offsets[:-1],
# "a_ptrs": self.a_ptrs,
# "b_ptrs": self.b_ptrs,
# "out_ptrs": self.out_ptrs,
# "a_scales_ptrs": self.a_scales_ptrs,
# "b_scales_ptrs": self.b_scales_ptrs,
}
...@@ -982,9 +982,9 @@ class Fp8MoEMethod: ...@@ -982,9 +982,9 @@ class Fp8MoEMethod:
and self.block_quant and self.block_quant
and is_sm100_supported() and is_sm100_supported()
): ):
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
return cutlass_fused_experts( return cutlass_fused_experts_fp8(
x, x,
layer.w13_weight.transpose(1, 2), layer.w13_weight.transpose(1, 2),
layer.w2_weight.transpose(1, 2), layer.w2_weight.transpose(1, 2),
......
...@@ -6,7 +6,7 @@ import triton # Added import ...@@ -6,7 +6,7 @@ import triton # Added import
import triton.testing # Added import import triton.testing # Added import
from transformers import AutoConfig from transformers import AutoConfig
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
...@@ -125,7 +125,7 @@ def run_test(tp_size, batch_size, model_config, check=False): ...@@ -125,7 +125,7 @@ def run_test(tp_size, batch_size, model_config, check=False):
problem_sizes2 = torch.empty((E, 3), dtype=torch.int32, device="cuda") problem_sizes2 = torch.empty((E, 3), dtype=torch.int32, device="cuda")
# --- Lambdas for Benchmarking --- # --- Lambdas for Benchmarking ---
cutlass_lambda = lambda: cutlass_fused_experts( cutlass_lambda = lambda: cutlass_fused_experts_fp8(
x, x,
w1.transpose(1, 2), # Transposed w1.transpose(1, 2), # Transposed
w2.transpose(1, 2), # Transposed w2.transpose(1, 2), # Transposed
...@@ -193,7 +193,7 @@ def run_test(tp_size, batch_size, model_config, check=False): ...@@ -193,7 +193,7 @@ def run_test(tp_size, batch_size, model_config, check=False):
print("Running correctness check...") print("Running correctness check...")
with torch.no_grad(): with torch.no_grad():
# Run CUTLASS version (requires transposed weights) # Run CUTLASS version (requires transposed weights)
y_cutlass = cutlass_fused_experts( y_cutlass = cutlass_fused_experts_fp8(
x, x,
w1.transpose(1, 2), # Transposed w1.transpose(1, 2), # Transposed
w2.transpose(1, 2), # Transposed w2.transpose(1, 2), # Transposed
......
...@@ -5,6 +5,7 @@ from sgl_kernel import scaled_fp4_quant ...@@ -5,6 +5,7 @@ from sgl_kernel import scaled_fp4_quant
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
if torch.cuda.get_device_capability() < (10, 0): if torch.cuda.get_device_capability() < (10, 0):
...@@ -179,6 +180,13 @@ def test_cutlass_fp4_moe_no_graph( ...@@ -179,6 +180,13 @@ def test_cutlass_fp4_moe_no_graph(
(e,), w2_q.shape[2] * 2, dtype=torch.int64, device=w2_q.device (e,), w2_q.shape[2] * 2, dtype=torch.int64, device=w2_q.device
) )
c_strides_2 = torch.full((e,), w2_q.shape[1], dtype=torch.int64, device=w2_q.device) c_strides_2 = torch.full((e,), w2_q.shape[1], dtype=torch.int64, device=w2_q.device)
params = CutlassMoEParams(
CutlassMoEType.BlockscaledFP4,
device=a.device,
num_experts=e,
intermediate_size_per_partition=n, # n
hidden_size=k,
) # k
cutlass_output = cutlass_moe_fp4( cutlass_output = cutlass_moe_fp4(
a=a, a=a,
a1_gscale=a1_gs, a1_gscale=a1_gs,
...@@ -189,17 +197,10 @@ def test_cutlass_fp4_moe_no_graph( ...@@ -189,17 +197,10 @@ def test_cutlass_fp4_moe_no_graph(
w2_fp4=w2_q, w2_fp4=w2_q,
w2_blockscale=w2_blockscale, w2_blockscale=w2_blockscale,
w2_alphas=(1 / w2_gs), w2_alphas=(1 / w2_gs),
ab_strides_13=ab_strides_13,
ab_strides_2=ab_strides_2,
c_strides_13=c_strides_13,
c_strides_2=c_strides_2,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
m=m, params=params,
n=n, apply_router_weight_on_input=False,
k=k,
e=e,
device=a.device,
) )
# Reference check: # Reference check:
......
from typing import Optional from typing import Any, Dict, Optional
import torch import torch
...@@ -184,13 +184,9 @@ def cutlass_fp4_group_mm( ...@@ -184,13 +184,9 @@ def cutlass_fp4_group_mm(
a_blockscale, a_blockscale,
b_blockscale, b_blockscale,
alphas, alphas,
ab_strides,
c_strides,
problem_sizes,
expert_offsets,
blockscale_offsets,
out_dtype, out_dtype,
device, device,
params: Dict[str, Any],
): ):
""" """
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
...@@ -220,10 +216,10 @@ def cutlass_fp4_group_mm( ...@@ -220,10 +216,10 @@ def cutlass_fp4_group_mm(
a_blockscale, a_blockscale,
b_blockscale, b_blockscale,
alphas, alphas,
ab_strides, params["ab_strides"],
c_strides, params["c_strides"],
problem_sizes, params["problem_sizes"],
expert_offsets, params["expert_offsets"],
blockscale_offsets, params["blockscale_offsets"],
) )
return c.to(dtype=out_dtype) return c.to(dtype=out_dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment