Unverified Commit 0df6765c authored by Pavani Majety's avatar Pavani Majety Committed by GitHub
Browse files

[CUTLASS-FP4-MOE] Introduce CutlassMoEParams class for easy initialization of...


[CUTLASS-FP4-MOE]  Introduce CutlassMoEParams class for easy initialization of Cutlass Grouped Gems Metadata (#6887)
Signed-off-by: default avatarPavani Majety <pmajety@nvidia.com>
parent 35b65cf0
......@@ -8,6 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
......@@ -18,11 +19,12 @@ if _is_cuda:
fp8_blockwise_scaled_grouped_mm,
prepare_moe_input,
scaled_fp4_experts_quant,
shuffle_rows,
silu_and_mul,
)
def cutlass_fused_experts(
def cutlass_fused_experts_fp8(
a: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
......@@ -223,17 +225,10 @@ def cutlass_moe_fp4(
w2_fp4: torch.Tensor,
w2_blockscale: torch.Tensor,
w2_alphas: torch.Tensor,
ab_strides_13: torch.Tensor,
ab_strides_2: torch.Tensor,
c_strides_13: torch.Tensor,
c_strides_2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
m: int,
n: int,
k: int,
e: int,
device: torch.device,
params: CutlassMoEParams,
apply_router_weight_on_input: bool = False,
):
"""
MoE implementation for FP4 Inputs
......@@ -291,77 +286,70 @@ def cutlass_moe_fp4(
e_w1, nx2_w1, half_k_w1 = w1_fp4.shape
e_w2, k_w2, half_n_w2 = w2_fp4.shape
assert e_w1 == e_w2 and e_w1 == e, (
assert e_w1 == e_w2 and e_w1 == params.num_experts, (
"Number of experts must match",
" between weights.",
)
assert (
k_a // 2 == half_k_w1 and k == k_w2
k_a // 2 == half_k_w1 and params.hidden_size == k_w2
), "Hidden size mismatch between a, w1 and w2"
assert nx2_w1 == n * 2 and half_n_w2 == n // 2, "mismatch in " "expected `n`"
assert m == m_a, "input shape mismatch"
assert (
nx2_w1 == params.intermediate_size_per_partition * 2
and half_n_w2 == params.intermediate_size_per_partition // 2
), ("mismatch in " "expected `n`")
assert 2 * half_k_w1 == k_w2, "Hidden size mismatch w2 and w1"
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
assert (
topk_weights.shape[0] == m and topk_ids.shape[0] == m
), "topk must be provided for each row of a"
out_dtype = a.dtype
num_topk = topk_ids.shape[1]
expert_offsets = torch.empty((e + 1), dtype=torch.int32, device=device)
# Problem size: (num_experts, (m,2n,k))
problem_sizes1 = torch.empty((e, 3), dtype=torch.int32, device=device)
# Problem size: (num_experts, (m,n,k))
problem_sizes2 = torch.empty((e, 3), dtype=torch.int32, device=device)
device = a.device
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
# problem shapes should have [m, n, k]
# Note that problem sizes are based on logical number of elements.
blockscale_offsets = torch.empty(e + 1, dtype=torch.int32, device=device)
prepare_moe_input(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
params.expert_offsets,
params.problem_sizes1,
params.problem_sizes2,
a_map,
c_map,
e,
n,
k,
blockscale_offsets,
params.num_experts,
params.intermediate_size_per_partition,
params.hidden_size,
params.blockscale_offsets,
)
rep_a_fp4, rep_a_blockscale = scaled_fp4_experts_quant(
a, a1_gscale, expert_offsets, blockscale_offsets, num_topk, expert_map=a_map
a,
a1_gscale,
params.expert_offsets,
params.blockscale_offsets,
num_topk,
expert_map=a_map,
)
c1 = cutlass_fp4_group_mm(
rep_a_fp4,
w1_fp4,
rep_a_blockscale,
w1_blockscale,
w1_alphas,
ab_strides_13,
c_strides_13,
problem_sizes1,
expert_offsets[:-1],
blockscale_offsets[:-1],
out_dtype,
device,
params.to_gemm1_args(),
)
del rep_a_fp4, rep_a_blockscale
# hidden size dimension is split to one halfpytho sized tensor.
intermediate = torch.empty(
(m * num_topk, w1_fp4.shape[1] // 2), device=device, dtype=out_dtype
(m_a * num_topk, w1_fp4.shape[1] // 2), device=device, dtype=out_dtype
)
silu_and_mul(c1, intermediate)
int_fp4, int_blockscale = scaled_fp4_experts_quant(
intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk
intermediate,
a2_gscale,
params.expert_offsets,
params.blockscale_offsets,
num_topk,
)
c2 = cutlass_fp4_group_mm(
int_fp4,
......@@ -369,16 +357,13 @@ def cutlass_moe_fp4(
int_blockscale,
w2_blockscale,
w2_alphas,
ab_strides_2,
c_strides_2,
problem_sizes2,
expert_offsets[:-1],
blockscale_offsets[:-1],
out_dtype,
device,
params.to_gemm2_args(),
)
del int_fp4, int_blockscale
out = (
c2[c_map].view(m, num_topk, k) * topk_weights.view(m, num_topk, 1).half()
).sum(dim=1)
return out.to(dtype=out_dtype)
c2 = shuffle_rows(c2, c_map, (m_a * num_topk, params.hidden_size))
c2 = c2.view(m_a, num_topk, params.hidden_size)
if not apply_router_weight_on_input:
c2 = c2 * topk_weights.view(m_a, num_topk, 1).to(out_dtype)
return c2.sum(dim=1).to(out_dtype)
from dataclasses import dataclass
from enum import Enum, auto
from typing import Optional
import torch
class CutlassMoEType(Enum):
"""
Enum for the different types of cutlass moe operations
that are currently supported in SGLang.
"""
BlockscaledFP8 = auto()
BlockscaledFP4 = auto()
@dataclass
class CutlassMoEParams:
"""
Parameters for the cutlass moe operation.
"""
# Type as defined above
cutlass_moe_type: CutlassMoEType
# Strides for activations, weights and output in logical number of elements.
# The activations & output stride is the number of elements to the next row.
# The weights stride is the number of elements to the next row per expert.
# For example, if the weight is [e, n, k], then the b_stride is a tensor of
# shape [e] with each element being k. Similarly for activations, if the
# shape is [m, k], then the a_stride has shape [e] with each value k.
# Similarly for output, if the output is [m, n], then the c_stride is a
# tensor of shape [e] with each element being k.
# Note: cutlass_fp4_group_mm is designed to accept the strides of
# activations and weights to be the same, so it is passed in as a single
# tensor.
# ab_strides_13: [e] dtype: int64 [Gemm 1: Activation / Weight strides]
# ab_strides_2: [e] dtype: int64 [Gemm 2: Activation / Weight strides]
# c_strides_13: [e] dtype: int64 [Gemm 1: Output Strides]
# c_strides_2: [e] dtype: int64 [Gemm 2: Output Strides]
ab_strides_13: torch.Tensor
ab_strides_2: torch.Tensor
c_strides_13: torch.Tensor
c_strides_2: torch.Tensor
# m: Total number of tokens
# n: intermediate size per partition
# k: hidden size per expert
# e: Number of experts
# device: Device to run computation on and store tensors
m: int
intermediate_size_per_partition: int
hidden_size: int
num_experts: int
device: torch.device
# Pointers container for calculating offsets of the input activations for each expert
# a_ptrs: [e] dtype: int64
a_ptrs: torch.Tensor
# Pointers container for calculating offsets of the input weights for each expert
# b_ptrs: [e] dtype: int64
b_ptrs: torch.Tensor
# Pointers container for calculating offsets of the output activations for each expert
# out_ptrs: [e] dtype: int64
out_ptrs: torch.Tensor
# Pointers container for calculating offsets of the input scales for each expert
# a_scales_ptrs: [e] dtype: int64
# b_scales_ptrs: [e] dtype: int64
a_scales_ptrs: torch.Tensor
b_scales_ptrs: torch.Tensor
# Offsets that mark at which token index each expert begins its computation
# The number of tokens computed with expert E is expert_offsets[E + 1] - expert_offsets[E]
# expert_offsets: [e+1] dtype: int32
expert_offsets: torch.Tensor
# Problem size: (num_experts, (m,2n,k)) for first GEMM
# problem_sizes1: [e, 3] dtype: int32
# Problem size: (num_experts, (m,n,k)) for second GEMM
# problem_sizes2: [e, 3] dtype: int32
problem_sizes1: torch.Tensor
problem_sizes2: torch.Tensor
# Similar to expert_offsets, but for blockscales for FP4 blockscaled Group GEMM
blockscale_offsets: Optional[torch.Tensor] = None
def __init__(
self,
cutlass_moe_type: CutlassMoEType,
device: torch.device,
num_experts: int,
intermediate_size_per_partition: int,
hidden_size: int,
):
self.cutlass_moe_type = cutlass_moe_type
self.device = device
self.num_experts = num_experts
self.intermediate_size_per_partition = intermediate_size_per_partition
self.hidden_size = hidden_size
self.n = self.intermediate_size_per_partition
self.k = self.hidden_size
self.e = self.num_experts
self.ab_strides_13 = torch.full(
(self.e,), self.k, dtype=torch.int64, device=self.device
)
self.ab_strides_2 = torch.full(
(self.e,), self.n, dtype=torch.int64, device=self.device
)
self.c_strides_13 = torch.full(
(self.e,), 2 * self.n, dtype=torch.int64, device=self.device
)
self.c_strides_2 = torch.full(
(self.e,), self.k, dtype=torch.int64, device=self.device
)
self.expert_offsets = torch.empty(
(self.e + 1,), dtype=torch.int32, device=self.device
)
self.problem_sizes1 = torch.empty(
(self.e, 3), dtype=torch.int32, device=self.device
)
self.problem_sizes2 = torch.empty(
(self.e, 3), dtype=torch.int32, device=self.device
)
if self.cutlass_moe_type == CutlassMoEType.BlockscaledFP4:
self.blockscale_offsets = torch.empty(
(self.e + 1,), dtype=torch.int32, device=self.device
)
else:
self.blockscale_offsets = None
self.a_ptrs = torch.empty((self.e,), dtype=torch.int64, device=self.device)
self.b_ptrs = torch.empty((self.e,), dtype=torch.int64, device=self.device)
self.out_ptrs = torch.empty((self.e,), dtype=torch.int64, device=self.device)
self.a_scales_ptrs = torch.empty(
(self.e,), dtype=torch.int64, device=self.device
)
self.b_scales_ptrs = torch.empty(
(self.e,), dtype=torch.int64, device=self.device
)
def to_gemm1_args(self) -> dict:
return {
"ab_strides": self.ab_strides_13,
"c_strides": self.c_strides_13,
"problem_sizes": self.problem_sizes1,
"expert_offsets": self.expert_offsets[:-1],
"blockscale_offsets": self.blockscale_offsets[:-1],
# "a_ptrs": self.a_ptrs,
# "b_ptrs": self.b_ptrs,
# "out_ptrs": self.out_ptrs,
# "a_scales_ptrs": self.a_scales_ptrs,
# "b_scales_ptrs": self.b_scales_ptrs,
}
def to_gemm2_args(self) -> dict:
return {
"ab_strides": self.ab_strides_2,
"c_strides": self.c_strides_2,
"problem_sizes": self.problem_sizes2,
"expert_offsets": self.expert_offsets[:-1],
"blockscale_offsets": self.blockscale_offsets[:-1],
# "a_ptrs": self.a_ptrs,
# "b_ptrs": self.b_ptrs,
# "out_ptrs": self.out_ptrs,
# "a_scales_ptrs": self.a_scales_ptrs,
# "b_scales_ptrs": self.b_scales_ptrs,
}
......@@ -982,9 +982,9 @@ class Fp8MoEMethod:
and self.block_quant
and is_sm100_supported()
):
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
return cutlass_fused_experts(
return cutlass_fused_experts_fp8(
x,
layer.w13_weight.transpose(1, 2),
layer.w2_weight.transpose(1, 2),
......
......@@ -6,7 +6,7 @@ import triton # Added import
import triton.testing # Added import
from transformers import AutoConfig
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
......@@ -125,7 +125,7 @@ def run_test(tp_size, batch_size, model_config, check=False):
problem_sizes2 = torch.empty((E, 3), dtype=torch.int32, device="cuda")
# --- Lambdas for Benchmarking ---
cutlass_lambda = lambda: cutlass_fused_experts(
cutlass_lambda = lambda: cutlass_fused_experts_fp8(
x,
w1.transpose(1, 2), # Transposed
w2.transpose(1, 2), # Transposed
......@@ -193,7 +193,7 @@ def run_test(tp_size, batch_size, model_config, check=False):
print("Running correctness check...")
with torch.no_grad():
# Run CUTLASS version (requires transposed weights)
y_cutlass = cutlass_fused_experts(
y_cutlass = cutlass_fused_experts_fp8(
x,
w1.transpose(1, 2), # Transposed
w2.transpose(1, 2), # Transposed
......
......@@ -5,6 +5,7 @@ from sgl_kernel import scaled_fp4_quant
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.topk import select_experts
if torch.cuda.get_device_capability() < (10, 0):
......@@ -179,6 +180,13 @@ def test_cutlass_fp4_moe_no_graph(
(e,), w2_q.shape[2] * 2, dtype=torch.int64, device=w2_q.device
)
c_strides_2 = torch.full((e,), w2_q.shape[1], dtype=torch.int64, device=w2_q.device)
params = CutlassMoEParams(
CutlassMoEType.BlockscaledFP4,
device=a.device,
num_experts=e,
intermediate_size_per_partition=n, # n
hidden_size=k,
) # k
cutlass_output = cutlass_moe_fp4(
a=a,
a1_gscale=a1_gs,
......@@ -189,17 +197,10 @@ def test_cutlass_fp4_moe_no_graph(
w2_fp4=w2_q,
w2_blockscale=w2_blockscale,
w2_alphas=(1 / w2_gs),
ab_strides_13=ab_strides_13,
ab_strides_2=ab_strides_2,
c_strides_13=c_strides_13,
c_strides_2=c_strides_2,
topk_weights=topk_weights,
topk_ids=topk_ids,
m=m,
n=n,
k=k,
e=e,
device=a.device,
params=params,
apply_router_weight_on_input=False,
)
# Reference check:
......
from typing import Optional
from typing import Any, Dict, Optional
import torch
......@@ -184,13 +184,9 @@ def cutlass_fp4_group_mm(
a_blockscale,
b_blockscale,
alphas,
ab_strides,
c_strides,
problem_sizes,
expert_offsets,
blockscale_offsets,
out_dtype,
device,
params: Dict[str, Any],
):
"""
An FP4 Blockscaled Group Gemm that takes in a_tensors, b_tensors and runs
......@@ -220,10 +216,10 @@ def cutlass_fp4_group_mm(
a_blockscale,
b_blockscale,
alphas,
ab_strides,
c_strides,
problem_sizes,
expert_offsets,
blockscale_offsets,
params["ab_strides"],
params["c_strides"],
params["problem_sizes"],
params["expert_offsets"],
params["blockscale_offsets"],
)
return c.to(dtype=out_dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment