Unverified Commit 6fc93575 authored by Elfie Guo's avatar Elfie Guo Committed by GitHub
Browse files

[2/2] Add python wrapper for CUTLASS FP8 Blockscale MoE Kernel. (#5694)

parent 839fb31e
"""Cutlass MoE kernel."""
import functools
import json
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
if _is_cuda:
import sgl_kernel
from sgl_kernel import (
fp8_blockwise_scaled_grouped_mm,
prepare_moe_input,
silu_and_mul,
)
def cutlass_fused_experts(
a: torch.Tensor,
w1_q: torch.Tensor,
w2_q: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
a1_strides: torch.Tensor,
c1_strides: torch.Tensor,
a2_strides: torch.Tensor,
c2_strides: torch.Tensor,
workspace: torch.Tensor,
a_ptrs: torch.Tensor,
b_ptrs: torch.Tensor,
out_ptrs: torch.Tensor,
a_scales_ptrs: torch.Tensor,
b_scales_ptrs: torch.Tensor,
expert_offsets: torch.Tensor,
problem_sizes1: torch.Tensor,
problem_sizes2: torch.Tensor,
use_fp8_blockscale: bool = True,
) -> torch.Tensor:
"""Performs Fused MoE computation using CUTLASS-like kernels with FP8 weights and activations.
This function implements a Mixture of Experts (MoE) layer with a SwiGLU/SiLU
activation, leveraging custom kernels likely derived from CUTLASS principles
for grouped matrix multiplication (`fp8_blockwise_scaled_grouped_mm`) and
data preparation (`prepare_moe_input`, `silu_and_mul`).
It handles per-token routing, quantizes input activations to FP8 with
per-token scales, performs the expert computations using FP8 GEMMs with
pre-quantized FP8 weights (per-block scales), applies the SiLU activation,
and combines the results weighted by the router scores.
Args:
a (torch.Tensor): Input activations. Shape: `(m, k)`, where `m` is the total
number of tokens and `k` is the hidden size. Expected dtype: `torch.half`
or `torch.bfloat16`.
w1_q (torch.Tensor): Pre-quantized FP8 weight tensor for the first GEMM
(up-projection part of SwiGLU). Expected shape: `(E, k, n*2)`, where
`E` is the number of experts, `k` is the hidden size, and `n*2` is the
intermediate size (`I`). Expected dtype: `torch.float8_e4m3fn`.
Note: This shape implies weights are stored as (num_experts, hidden_size, intermediate_size).
w2_q (torch.Tensor): Pre-quantized FP8 weight tensor for the second GEMM
(down-projection). Expected shape: `(E, n, k)`, where `n` is half the
intermediate size (`I // 2`). Expected dtype: `torch.float8_e4m3fn`.
Note: This shape implies weights are stored as (num_experts, intermediate_size // 2, hidden_size).
w1_scale (torch.Tensor): Scales corresponding to `w1_q` (per-block scales).
Shape: `(E, num_blocks_n, num_blocks_k)`. Dtype: `torch.float32`.
w2_scale (torch.Tensor): Scales corresponding to `w2_q` (per-block scales).
Shape: `(E, num_blocks_k, num_blocks_n)`. Dtype: `torch.float32`.
topk_weights (torch.Tensor): Router weights for the selected top-k experts
for each token. Shape: `(m, topk)`. Dtype should ideally match `a`.
topk_ids (torch.Tensor): Indices of the selected top-k experts for each token.
Shape: `(m, topk)`. Dtype: `torch.int32`.
a1_strides (torch.Tensor): Stride information for the first GEMM's 'a' input.
Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
Note: Its exact usage within `fp8_blockwise_scaled_grouped_mm` needs clarification
as it's passed as both a_stride and b_stride in the first call.
c1_strides (torch.Tensor): Stride information for the first GEMM's 'c' output.
Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
a2_strides (torch.Tensor): Stride information for the second GEMM's 'a' input.
Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
Note: Its exact usage within `fp8_blockwise_scaled_grouped_mm` needs clarification
as it's passed as both a_stride and b_stride in the second call.
c2_strides (torch.Tensor): Stride information for the second GEMM's 'c' output.
Passed directly to the underlying kernel. Expected shape `(E,)`, dtype `torch.int64`.
workspace (torch.Tensor): Reusable workspace for the underlying kernel.
a_ptrs (torch.Tensor): Pointers container for calculating offsets of the input activations for each expert.
b_ptrs (torch.Tensor): Pointers container for calculating offsets of the input weights for each expert.
out_ptrs (torch.Tensor): Pointers container for calculating offsets of the output activations for each expert.
a_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert.
b_scales_ptrs (torch.Tensor): Pointers container for calculating offsets of the input scales for each expert.
use_fp8_blockscale (bool, optional): Flag indicating usage of FP8 with
block scaling. Currently, only `True` is supported. Defaults to `True`.
Returns:
torch.Tensor: The computed MoE layer output. Shape: `(m, k)`, dtype matches `a`.
Raises:
AssertionError: If input shapes, dtypes, or flags are inconsistent or unsupported.
NotImplementedError: If CUDA is not available or `sgl_kernel` is not properly installed.
"""
assert use_fp8_blockscale, "Only support fp8 blockscale for now"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert w1_q.dtype == torch.float8_e4m3fn
assert w2_q.dtype == torch.float8_e4m3fn
assert a.shape[1] == w1_q.shape[1], "Hidden size mismatch w1"
assert w1_q.shape[2] == w2_q.shape[1] * 2, "Hidden size mismatch w2"
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
assert w1_q.shape[0] == w2_q.shape[0], "Weights expert number mismatch"
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
assert a.dtype in [torch.half, torch.bfloat16], "Invalid output dtype"
if is_cuda:
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
)
out_dtype = a.dtype
num_experts = w1_q.size(0)
m = a.size(0)
k = w1_q.size(1)
n = w2_q.size(1)
topk = topk_ids.size(1)
a_q, a1_scale = sglang_per_token_group_quant_fp8(a, 128)
device = a_q.device
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
prepare_moe_input(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
a_map,
c_map,
num_experts,
n,
k,
)
rep_a_q = a_q.view(dtype=torch.uint8)[a_map].view(dtype=a_q.dtype)
rep_a1_scales = a1_scale[a_map]
c1 = torch.empty((m * topk, n * 2), device=device, dtype=out_dtype)
c2 = torch.empty((m * topk, k), device=device, dtype=out_dtype)
a_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int)
w_sf_layout = torch.empty((num_experts, 5), device=device, dtype=torch.int)
fp8_blockwise_scaled_grouped_mm(
c1,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
rep_a_q,
w1_q,
rep_a1_scales,
w1_scale,
a1_strides,
a1_strides,
c1_strides,
a_sf_layout,
w_sf_layout,
problem_sizes1,
expert_offsets[:-1],
workspace,
)
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
silu_and_mul(c1, intermediate)
intemediate_q, a2_scale = sglang_per_token_group_quant_fp8(intermediate, 128)
fp8_blockwise_scaled_grouped_mm(
c2,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
intemediate_q,
w2_q,
a2_scale,
w2_scale,
a2_strides,
a2_strides,
c2_strides,
a_sf_layout,
w_sf_layout,
problem_sizes2,
expert_offsets[:-1],
workspace,
)
return (
c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)
).sum(dim=1)
...@@ -52,6 +52,7 @@ from sglang.srt.layers.quantization.fp8_utils import ( ...@@ -52,6 +52,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
apply_w8a8_block_fp8_linear, apply_w8a8_block_fp8_linear,
cutlass_fp8_supported, cutlass_fp8_supported,
input_to_float8, input_to_float8,
is_sm100_supported,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
...@@ -470,6 +471,7 @@ class Fp8MoEMethod: ...@@ -470,6 +471,7 @@ class Fp8MoEMethod:
def __init__(self, quant_config): def __init__(self, quant_config):
self.quant_config = quant_config self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None self.block_quant = self.quant_config.weight_block_size is not None
self.cutlass_fp8_supported = cutlass_fp8_supported()
def create_weights( def create_weights(
self, self,
...@@ -568,6 +570,63 @@ class Fp8MoEMethod: ...@@ -568,6 +570,63 @@ class Fp8MoEMethod:
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
assert self.quant_config.activation_scheme == "dynamic" assert self.quant_config.activation_scheme == "dynamic"
if (
get_bool_env_var("CUTLASS_MOE")
and self.cutlass_fp8_supported
and is_sm100_supported()
):
self.ab_strides1 = torch.full(
(num_experts,),
hidden_size,
device=w13_weight.device,
dtype=torch.int64,
)
self.c_strides1 = torch.full(
(num_experts,),
2 * intermediate_size,
device=w13_weight.device,
dtype=torch.int64,
)
self.ab_strides2 = torch.full(
(num_experts,),
intermediate_size,
device=w2_weight.device,
dtype=torch.int64,
)
self.c_strides2 = torch.full(
(num_experts,),
hidden_size,
device=w2_weight.device,
dtype=torch.int64,
)
self.workspace = torch.empty(
90000, device=w13_weight.device, dtype=torch.uint8
)
self.a_ptr = torch.empty(
num_experts, device=w13_weight.device, dtype=torch.int64
)
self.b_ptr = torch.empty(
num_experts, device=w13_weight.device, dtype=torch.int64
)
self.out_ptr = torch.empty(
num_experts, device=w13_weight.device, dtype=torch.int64
)
self.a_scales_ptr = torch.empty(
num_experts, device=w13_weight.device, dtype=torch.int64
)
self.b_scales_ptr = torch.empty(
num_experts, device=w13_weight.device, dtype=torch.int64
)
self.expert_offsets = torch.empty(
num_experts + 1, device=w13_weight.device, dtype=torch.int32
)
self.problem_sizes1 = torch.empty(
num_experts, 3, device=w13_weight.device, dtype=torch.int32
)
self.problem_sizes2 = torch.empty(
num_experts, 3, device=w13_weight.device, dtype=torch.int32
)
else: else:
# Allocate 2 scales for w1 and w3 respectively. # Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading. # They will be combined to a single scale after weight loading.
...@@ -913,6 +972,37 @@ class Fp8MoEMethod: ...@@ -913,6 +972,37 @@ class Fp8MoEMethod:
if ret is not None: if ret is not None:
return ret return ret
if (
get_bool_env_var("CUTLASS_MOE")
and self.cutlass_fp8_supported
and self.block_quant
and is_sm100_supported()
):
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts
return cutlass_fused_experts(
x,
layer.w13_weight.transpose(1, 2),
layer.w2_weight.transpose(1, 2),
layer.w13_weight_scale_inv.transpose(1, 2),
layer.w2_weight_scale_inv.transpose(1, 2),
topk_weights,
topk_ids,
self.ab_strides1,
self.c_strides1,
self.ab_strides2,
self.c_strides2,
self.workspace,
self.a_ptr,
self.b_ptr,
self.out_ptr,
self.a_scales_ptr,
self.b_scales_ptr,
self.expert_offsets,
self.problem_sizes1,
self.problem_sizes2,
use_fp8_blockscale=True,
)
# Expert fusion with FP8 quantization # Expert fusion with FP8 quantization
return fused_experts( return fused_experts(
x, x,
......
...@@ -80,6 +80,12 @@ def cutlass_fp8_supported(): ...@@ -80,6 +80,12 @@ def cutlass_fp8_supported():
return False return False
def is_sm100_supported(device=None) -> bool:
return (torch.cuda.get_device_capability(device)[0] == 10) and (
torch.version.cuda >= "12.8"
)
def normalize_e4m3fn_to_e4m3fnuz( def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor, weight_scale: torch.Tensor,
......
import argparse
import time
import torch
import triton # Added import
import triton.testing # Added import
from transformers import AutoConfig
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
def get_model_config(tp_size: int):
config = AutoConfig.from_pretrained(
"deepseek-ai/deepseek-R1", trust_remote_code=True
)
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
return {
"num_experts": E,
"topk": topk,
"hidden_size": config.hidden_size,
"shard_intermediate_size": shard_intermediate_size,
"dtype": config.torch_dtype,
"block_shape": config.quantization_config["weight_block_size"],
}
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
"""Converts tensor to FP8 E4M3, scaling values to fit the range."""
finfo = torch.finfo(torch.float8_e4m3fn)
# Calculate max absolute value safely
max_val = torch.max(torch.abs(tensor))
# Avoid division by zero if tensor is all zeros
if max_val == 0:
scale_factor = 1.0
else:
# Scale factor to bring the max value to finfo.max
scale_factor = finfo.max / max_val
# Apply scaling
scaled_tensor = tensor * scale_factor
# Clamp and convert
fp8_tensor = scaled_tensor.clamp(min=finfo.min, max=finfo.max).to(
dtype=torch.float8_e4m3fn
)
return fp8_tensor
def run_test(tp_size, batch_size, model_config, check=False):
print(f"\n--- Batch Size: {batch_size} ---")
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(42) # For reproducible random numbers
E = model_config["num_experts"]
topk = model_config["topk"]
H = model_config["hidden_size"]
I = model_config["shard_intermediate_size"]
block_shape = model_config["block_shape"] # Tuple (BLOCK_N, BLOCK_K)
dtype = model_config["dtype"] # e.g., torch.bfloat16
print(
f"Config: E={E}, topk={topk}, H={H}, I_shard={I}, dtype={dtype}, block_shape={block_shape}"
)
# --- Input Data ---
# Use bf16/fp16 for input activation based on model config
x = torch.randn((batch_size, H), device="cuda", dtype=dtype) * 0.0001
# --- Weights (Generate in higher precision, then convert to FP8) ---
# Generate weights suitable for FP8 conversion (e.g., scaled appropriately)
w1_hp = (
torch.randn((E, I, H), device="cuda", dtype=torch.float32) * 0.00001 + 0.00001
)
w2_hp = (
torch.randn((E, H, I // 2), device="cuda", dtype=torch.float32) * 0.00001
+ 0.00001
)
w1 = to_fp8(w1_hp)
w2 = to_fp8(w2_hp)
# --- Scales for FP8 Weights ---
block_n, block_k = block_shape
# Calculate number of blocks needed
w1_blocks_dim1 = (I + block_n - 1) // block_n
w1_blocks_dim2 = (H + block_k - 1) // block_k
w2_blocks_dim1 = (H + block_n - 1) // block_n
w2_blocks_dim2 = (I // 2 + block_k - 1) // block_k
# Scales are typically float32 or float16/bfloat16
scale_dtype = torch.float32 # Or dtype if scales match model dtype
w1_scale = torch.full(
(E, w1_blocks_dim1, w1_blocks_dim2), 1, device="cuda", dtype=scale_dtype
) # Avoid zero scales
w2_scale = torch.full(
(E, w2_blocks_dim1, w2_blocks_dim2), 1, device="cuda", dtype=scale_dtype
) # Avoid zero scales
# --- Routing Information ---
topk_weights = torch.softmax(
torch.rand(batch_size, topk, device="cuda", dtype=dtype), dim=-1
)
topk_ids = torch.randint(0, E, (batch_size, topk), dtype=torch.int32, device="cuda")
a1_strides = torch.full((E,), H, dtype=torch.int64, device="cuda")
c1_strides = torch.full((E,), I, dtype=torch.int64, device="cuda")
a2_strides = torch.full((E,), I // 2, dtype=torch.int64, device="cuda")
c2_strides = torch.full((E,), H, dtype=torch.int64, device="cuda")
workspace = torch.empty(
(7182 * 1024), device="cuda", dtype=torch.uint8
) # Allocate sufficient workspace
# Pointer arrays (often filled by the kernel or a prep step, but needed as args)
a_ptrs = torch.empty((E,), dtype=torch.int64, device="cuda")
b_ptrs = torch.empty((E,), dtype=torch.int64, device="cuda")
out_ptrs = torch.empty((E,), dtype=torch.int64, device="cuda")
a_scales_ptrs = torch.empty((E,), dtype=torch.int64, device="cuda")
b_scales_ptrs = torch.empty((E,), dtype=torch.int64, device="cuda")
expert_offsets = torch.empty((E + 1,), dtype=torch.int32, device="cuda")
problem_sizes1 = torch.empty((E, 3), dtype=torch.int32, device="cuda")
problem_sizes2 = torch.empty((E, 3), dtype=torch.int32, device="cuda")
# --- Lambdas for Benchmarking ---
cutlass_lambda = lambda: cutlass_fused_experts(
x,
w1.transpose(1, 2), # Transposed
w2.transpose(1, 2), # Transposed
w1_scale.transpose(1, 2),
w2_scale.transpose(1, 2),
topk_weights,
topk_ids,
a1_strides,
c1_strides,
a2_strides,
c2_strides,
workspace,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
expert_offsets,
problem_sizes1,
problem_sizes2,
)
# Note: Triton expects non-transposed weights
triton_lambda = lambda: fused_experts(
x,
w1,
w2,
topk_weights,
topk_ids,
inplace=False, # Use False for benchmarking to avoid side effects if run multiple times
activation="silu", # Assuming SiLU activation common in MoEs
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
block_shape=block_shape,
)
# --- Warmup ---
print("Warming up...")
for _ in range(10):
_ = cutlass_lambda()
_ = triton_lambda()
torch.cuda.synchronize()
# --- Benchmarking ---
quantiles = [0.5, 0.2, 0.8]
print(f"Benchmarking Cutlass fused_experts...")
cutlass_ms, cutlass_min, cutlass_max = triton.testing.do_bench_cudagraph(
cutlass_lambda, rep=1000, quantiles=quantiles
)
print(f"Benchmarking Triton fused_experts...")
triton_ms, triton_min, triton_max = triton.testing.do_bench_cudagraph(
triton_lambda, rep=1000, quantiles=quantiles
)
print(
f"Cutlass fused_experts time: {cutlass_ms:.3f} ms (median) [{cutlass_min:.3f} - {cutlass_max:.3f}]"
)
print(
f"Triton fused_experts time: {triton_ms:.3f} ms (median) [{triton_min:.3f} - {triton_max:.3f}]"
)
# --- Correctness Check ---
if check:
print("Running correctness check...")
with torch.no_grad():
# Run CUTLASS version (requires transposed weights)
y_cutlass = cutlass_fused_experts(
x,
w1.transpose(1, 2), # Transposed
w2.transpose(1, 2), # Transposed
w1_scale.transpose(1, 2),
w2_scale.transpose(1, 2),
topk_weights,
topk_ids,
a1_strides,
c1_strides,
a2_strides,
c2_strides,
workspace,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
expert_offsets,
problem_sizes1,
problem_sizes2,
)
# Run Triton version (requires original shape weights, use inplace=False)
y_triton = fused_experts(
x,
w1, # Original shape
w2, # Original shape
topk_weights,
topk_ids,
inplace=False, # Important: Use False to get output tensor
activation="silu",
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
block_shape=block_shape,
)
# Ensure outputs are same dtype for comparison
y_cutlass = y_cutlass.to(dtype)
y_triton = y_triton.to(dtype)
abs_error = torch.abs(y_cutlass - y_triton)
rel_error = abs_error / torch.clamp(torch.abs(y_triton), min=1e-2)
max_abs_err = abs_error.max().item()
max_rel_err = rel_error.max().item()
print("y_cutlass:", y_cutlass[:, :10])
print("y_triton:", y_triton[:, :10])
print(f"Max absolute error: {max_abs_err:.6f}")
print(f"Max relative error: {max_rel_err:.6f}")
# Tolerance might need adjustment based on FP8 specifics and kernel differences
# FP8 comparisons often require higher tolerance than FP16/BF16
assert max_rel_err < 5e-1, f"Relative error too high! {max_rel_err}"
print("Correctness check passed.")
def main(tp_size=8, batch_sizes=[1, 4, 8, 16, 32, 64, 128, 256, 512], check=False):
model_config = get_model_config(tp_size)
print("Model Config:", model_config)
for batch_size in batch_sizes:
run_test(tp_size, batch_size, model_config, check)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--tp-size", type=int, default=8, help="Tensor Parallel size")
parser.add_argument(
"--batch-sizes",
type=int,
nargs="+",
default=[1, 4, 8, 16, 32, 64, 128, 256, 512], # Adjusted default
help="List of batch sizes to test",
)
parser.add_argument("--check", action="store_true", help="Enable check mode")
args = parser.parse_args()
print(f"Running benchmarks with TP size: {args.tp_size}")
print(f"Testing batch sizes: {args.batch_sizes}")
main(tp_size=args.tp_size, batch_sizes=args.batch_sizes, check=args.check)
...@@ -207,6 +207,7 @@ set(SOURCES ...@@ -207,6 +207,7 @@ set(SOURCES
"csrc/moe/moe_fused_gate.cu" "csrc/moe/moe_fused_gate.cu"
"csrc/moe/moe_topk_softmax_kernels.cu" "csrc/moe/moe_topk_softmax_kernels.cu"
"csrc/moe/fp8_blockwise_moe_kernel.cu" "csrc/moe/fp8_blockwise_moe_kernel.cu"
"csrc/moe/prepare_moe_input.cu"
"csrc/speculative/eagle_utils.cu" "csrc/speculative/eagle_utils.cu"
"csrc/speculative/speculative_sampling.cu" "csrc/speculative/speculative_sampling.cu"
"csrc/speculative/packbit.cu" "csrc/speculative/packbit.cu"
......
...@@ -151,11 +151,15 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -151,11 +151,15 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"(Tensor[])"); "(Tensor[])");
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate); m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
m.def( m.def(
"fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor " "fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a_ptrs, Tensor b_ptrs, Tensor out_ptrs, Tensor "
"a_scales_ptrs, Tensor b_scales_ptrs, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
"stride_a, Tensor stride_b, Tensor stride_c, Tensor layout_sfa, Tensor layout_sfb, Tensor problem_sizes, Tensor " "stride_a, Tensor stride_b, Tensor stride_c, Tensor layout_sfa, Tensor layout_sfb, Tensor problem_sizes, Tensor "
"expert_offsets) -> ()"); "expert_offsets, Tensor workspace) -> ()");
m.impl("fp8_blockwise_scaled_grouped_mm", torch::kCUDA, &fp8_blockwise_scaled_grouped_mm); m.impl("fp8_blockwise_scaled_grouped_mm", torch::kCUDA, &fp8_blockwise_scaled_grouped_mm);
m.def(
"prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor problem_sizes1, Tensor problem_sizes2, Tensor "
"input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> ()");
m.impl("prepare_moe_input", torch::kCUDA, &prepare_moe_input);
/* /*
* From csrc/speculative * From csrc/speculative
*/ */
......
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cutlass/arch/arch.h> #include <cutlass/arch/arch.h>
#include <torch/all.h> #include <torch/all.h>
...@@ -49,23 +51,16 @@ void launch_sm100_fp8_blockwise_scaled_group_mm( ...@@ -49,23 +51,16 @@ void launch_sm100_fp8_blockwise_scaled_group_mm(
using ElementC = OutType; using ElementC = OutType;
using ElementD = ElementC; using ElementD = ElementC;
using ElementAccumulator = float; using ElementAccumulator = float;
// Layout definitions
using LayoutA = cutlass::layout::RowMajor; using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor; using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = LayoutD; using LayoutC = LayoutD;
// Alignment constraints
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
// Architecture definitions
using ArchTag = cutlass::arch::Sm100; using ArchTag = cutlass::arch::Sm100;
using OperatorClass = cutlass::arch::OpClassTensorOp; using OperatorClass = cutlass::arch::OpClassTensorOp;
// For fp8 block scale.
// using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<ScaleGranularityM, ScaleGranularityN,
// ScaleGranularityK, cute::UMMA::Major::K, cute::UMMA::Major::K>; using LayoutSFA =
// decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, ArchTag,
OperatorClass, OperatorClass,
...@@ -124,9 +119,8 @@ void launch_sm100_fp8_blockwise_scaled_group_mm( ...@@ -124,9 +119,8 @@ void launch_sm100_fp8_blockwise_scaled_group_mm(
cutlass::KernelHardwareInfo hw_info; cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = 0; hw_info.device_id = 0;
hw_info.sm_count = 1; // sm_count is the number of SMs on the current device, since we only support SM100 blackwell, so we set it to 148
// Currently, we are only able to do broadcast on either all or none a_scales hw_info.sm_count = 148;
// and on either all or none b_scales
typename GemmKernel::EpilogueArguments epilogue_args{ typename GemmKernel::EpilogueArguments epilogue_args{
{}, {},
nullptr, nullptr,
...@@ -134,9 +128,7 @@ void launch_sm100_fp8_blockwise_scaled_group_mm( ...@@ -134,9 +128,7 @@ void launch_sm100_fp8_blockwise_scaled_group_mm(
static_cast<ElementD**>(out_ptrs.data_ptr()), static_cast<ElementD**>(out_ptrs.data_ptr()),
static_cast<StrideC*>(stride_c.data_ptr())}; static_cast<StrideC*>(stride_c.data_ptr())};
// Initialize problem_sizes_as_shapes correctly
UnderlyingProblemShape* problem_sizes_as_shapes = static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr()); UnderlyingProblemShape* problem_sizes_as_shapes = static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
// Use prob_shape in the GEMM arguments
typename GemmKernel::Arguments args{ typename GemmKernel::Arguments args{
cutlass::gemm::GemmUniversalMode::kGrouped, cutlass::gemm::GemmUniversalMode::kGrouped,
{num_experts, problem_sizes_as_shapes, nullptr}, {num_experts, problem_sizes_as_shapes, nullptr},
...@@ -144,21 +136,27 @@ void launch_sm100_fp8_blockwise_scaled_group_mm( ...@@ -144,21 +136,27 @@ void launch_sm100_fp8_blockwise_scaled_group_mm(
epilogue_args, epilogue_args,
hw_info}; hw_info};
at::cuda::CUDAGuard device_guard{(char)a_ptrs.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a_ptrs.get_device());
auto can_implement_status = gemm_op.can_implement(args); auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM"); TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
// Run the GEMM auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
auto status = gemm_op.initialize(args, workspace.data_ptr());
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
status = gemm_op.run(); status = gemm_op.run(stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
} }
template <typename OutType> template <typename OutType>
void sm100_fp8_blockwise_group_mm_dispatch_shape( void sm100_fp8_blockwise_group_mm_dispatch_shape(
torch::Tensor& output, torch::Tensor& output,
torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs,
torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs,
torch::Tensor& b_scales_ptrs,
const torch::Tensor& a, const torch::Tensor& a,
const torch::Tensor& b, const torch::Tensor& b,
const torch::Tensor& scales_a, const torch::Tensor& scales_a,
...@@ -169,11 +167,23 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape( ...@@ -169,11 +167,23 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb, const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets) { const torch::Tensor& expert_offsets,
const torch::Tensor& workspace) {
// Check the first matrix size to decide on the configuration // Check the first matrix size to decide on the configuration
// Assuming all matrices in the group have similar size characteristics // Assuming all matrices in the group have similar size characteristics
// bool use_small_config = a[0].size(0) <= 128; // bool use_small_config = a[0].size(0) <= 128;
struct MMALargeConfig { struct MmaConfig1 {
using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_128, _32, _128>;
using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
using ScaleConfig =
cutlass::detail::Sm100BlockwiseScaleConfig<128, 1, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
};
struct MmaConfig2 {
using ElementA = cutlass::float_e4m3_t; using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_128, _128, _128>; using MmaTileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand
...@@ -184,35 +194,28 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape( ...@@ -184,35 +194,28 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
}; };
struct MmaConfig3 {
struct MMASmallConfig {
using ElementA = cutlass::float_e4m3_t; using ElementA = cutlass::float_e4m3_t;
using MmaTileShape = Shape<_128, _16, _128>; using MmaTileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand using ClusterShape = Shape<_1, _1, _1>; // Layout type for SFB matrix operand
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100; using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
using ScaleConfig = using ScaleConfig =
cutlass::detail::Sm100BlockwiseScaleConfig<128, 1, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>; cutlass::detail::Sm100BlockwiseScaleConfig<1, 128, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
}; };
int num_experts = (int)expert_offsets.size(0); int num_experts = (int)expert_offsets.size(0);
torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device()); torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device());
torch::Tensor a_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_ptrs = torch::empty(num_experts, options_int);
torch::Tensor out_ptrs = torch::empty(num_experts, options_int);
torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int);
torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int); torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int);
torch::Tensor workspace = torch::empty(100, options_int);
torch::Tensor output_t = output.t(); torch::Tensor output_t = output.t();
torch::Tensor a_t = a.t(); torch::Tensor a_t = a.t();
torch::Tensor b_t = b.transpose(1, 2); torch::Tensor b_t = b.transpose(1, 2);
torch::Tensor scales_a_t = scales_a.t(); torch::Tensor scales_a_t = scales_a.t();
torch::Tensor scales_b_t = scales_b.transpose(1, 2); torch::Tensor scales_b_t = scales_b.transpose(1, 2);
if (a.size(0) <= 512) { if (a.size(0) <= 512 && a.size(1) >= 2048) {
run_get_group_gemm_starts<MMASmallConfig::LayoutSFA, MMASmallConfig::LayoutSFB, MMASmallConfig::ScaleConfig>( run_get_group_gemm_starts<MmaConfig1::LayoutSFA, MmaConfig1::LayoutSFB, MmaConfig1::ScaleConfig>(
expert_offsets, expert_offsets,
a_ptrs, a_ptrs,
b_ptrs, b_ptrs,
...@@ -229,7 +232,7 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape( ...@@ -229,7 +232,7 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
problem_sizes, problem_sizes,
problem_sizes_transpose, problem_sizes_transpose,
true); true);
launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MMASmallConfig, cutlass::layout::ColumnMajor>( launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MmaConfig1, cutlass::layout::ColumnMajor>(
out_ptrs, out_ptrs,
a_ptrs, a_ptrs,
b_ptrs, b_ptrs,
...@@ -244,8 +247,39 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape( ...@@ -244,8 +247,39 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
expert_offsets, expert_offsets,
workspace); workspace);
output = output_t.t(); output = output_t.t();
} else if (a.size(0) > 512 && a.size(1) >= 2048) {
run_get_group_gemm_starts<MmaConfig2::LayoutSFA, MmaConfig2::LayoutSFB, MmaConfig2::ScaleConfig>(
expert_offsets,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a,
b,
output,
scales_a,
scales_b,
layout_sfa,
layout_sfb,
problem_sizes,
problem_sizes_transpose);
launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MmaConfig2, cutlass::layout::RowMajor>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
stride_a,
stride_b,
stride_c,
layout_sfa,
layout_sfb,
problem_sizes,
expert_offsets,
workspace);
} else { } else {
run_get_group_gemm_starts<MMALargeConfig::LayoutSFA, MMALargeConfig::LayoutSFB, MMALargeConfig::ScaleConfig>( run_get_group_gemm_starts<MmaConfig3::LayoutSFA, MmaConfig3::LayoutSFB, MmaConfig3::ScaleConfig>(
expert_offsets, expert_offsets,
a_ptrs, a_ptrs,
b_ptrs, b_ptrs,
...@@ -261,7 +295,7 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape( ...@@ -261,7 +295,7 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
layout_sfb, layout_sfb,
problem_sizes, problem_sizes,
problem_sizes_transpose); problem_sizes_transpose);
launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MMALargeConfig, cutlass::layout::RowMajor>( launch_sm100_fp8_blockwise_scaled_group_mm<OutType, MmaConfig3, cutlass::layout::RowMajor>(
out_ptrs, out_ptrs,
a_ptrs, a_ptrs,
b_ptrs, b_ptrs,
...@@ -312,6 +346,11 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape( ...@@ -312,6 +346,11 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
*/ */
void fp8_blockwise_scaled_grouped_mm( void fp8_blockwise_scaled_grouped_mm(
torch::Tensor& output, torch::Tensor& output,
torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs,
torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs,
torch::Tensor& b_scales_ptrs,
const torch::Tensor& a, const torch::Tensor& a,
const torch::Tensor& b, const torch::Tensor& b,
const torch::Tensor& scales_a, const torch::Tensor& scales_a,
...@@ -322,7 +361,8 @@ void fp8_blockwise_scaled_grouped_mm( ...@@ -322,7 +361,8 @@ void fp8_blockwise_scaled_grouped_mm(
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb, const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets) { const torch::Tensor& expert_offsets,
const torch::Tensor& workspace) {
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)"); TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)");
TORCH_CHECK( TORCH_CHECK(
...@@ -342,6 +382,29 @@ void fp8_blockwise_scaled_grouped_mm( ...@@ -342,6 +382,29 @@ void fp8_blockwise_scaled_grouped_mm(
TORCH_CHECK(layout_sfb.scalar_type() == torch::kInt32, "layout_sfb must be int32"); TORCH_CHECK(layout_sfb.scalar_type() == torch::kInt32, "layout_sfb must be int32");
TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32, "expert_offsets must be int32"); TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32, "expert_offsets must be int32");
TORCH_CHECK(output.dim() == 2, "output must be 2D tensor");
TORCH_CHECK(a.dim() == 2, "a must be 2D tensor");
TORCH_CHECK(b.dim() == 3, "b must be 3D tensor");
TORCH_CHECK(scales_a.dim() == 2, "scales_a must be 2D tensor");
TORCH_CHECK(scales_b.dim() == 3, "scales_b must be 3D tensor");
TORCH_CHECK(stride_a.dim() == 1, "stride_a must be 1D tensor");
TORCH_CHECK(stride_b.dim() == 1, "stride_b must be 1D tensor");
TORCH_CHECK(stride_c.dim() == 1, "stride_c must be 1D tensor");
TORCH_CHECK(layout_sfa.dim() == 2, "layout_sfa must be 1D tensor");
TORCH_CHECK(layout_sfb.dim() == 2, "layout_sfb must be 1D tensor");
TORCH_CHECK(a_ptrs.dim() == 1, "a_ptrs must be 1D tensor");
TORCH_CHECK(b_ptrs.dim() == 1, "b_ptrs must be 1D tensor");
TORCH_CHECK(out_ptrs.dim() == 1, "out_ptrs must be 1D tensor");
TORCH_CHECK(a_scales_ptrs.dim() == 1, "a_scales_ptrs must be 1D tensor");
TORCH_CHECK(b_scales_ptrs.dim() == 1, "b_scales_ptrs must be 1D tensor");
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)");
TORCH_CHECK(
problem_sizes.size(0) == expert_offsets.size(0), "Number of experts in problem_sizes must match expert_offsets");
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, "problem_sizes must be int32");
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be 1D tensor");
TORCH_CHECK(workspace.dim() == 1, "workspace must be 1D tensor");
bool can_implement = false; bool can_implement = false;
auto sm_version = getSMVersion(); auto sm_version = getSMVersion();
...@@ -351,6 +414,11 @@ void fp8_blockwise_scaled_grouped_mm( ...@@ -351,6 +414,11 @@ void fp8_blockwise_scaled_grouped_mm(
if (output.scalar_type() == torch::kBFloat16) { if (output.scalar_type() == torch::kBFloat16) {
sm100_fp8_blockwise_group_mm_dispatch_shape<cutlass::bfloat16_t>( sm100_fp8_blockwise_group_mm_dispatch_shape<cutlass::bfloat16_t>(
output, output,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a, a,
b, b,
scales_a, scales_a,
...@@ -361,10 +429,16 @@ void fp8_blockwise_scaled_grouped_mm( ...@@ -361,10 +429,16 @@ void fp8_blockwise_scaled_grouped_mm(
layout_sfa, layout_sfa,
layout_sfb, layout_sfb,
problem_sizes, problem_sizes,
expert_offsets); expert_offsets,
workspace);
} else { } else {
sm100_fp8_blockwise_group_mm_dispatch_shape<cutlass::half_t>( sm100_fp8_blockwise_group_mm_dispatch_shape<cutlass::half_t>(
output, output,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a, a,
b, b,
scales_a, scales_a,
...@@ -375,7 +449,8 @@ void fp8_blockwise_scaled_grouped_mm( ...@@ -375,7 +449,8 @@ void fp8_blockwise_scaled_grouped_mm(
layout_sfa, layout_sfa,
layout_sfb, layout_sfb,
problem_sizes, problem_sizes,
expert_offsets); expert_offsets,
workspace);
} }
can_implement = true; can_implement = true;
} }
......
#include <c10/cuda/CUDAGuard.h>
#include <cudaTypedefs.h>
#include <torch/all.h>
#include <iostream>
constexpr uint64_t THREADS_PER_EXPERT = 512;
__global__ void compute_problem_sizes(
const int* __restrict__ topk_ids,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
int32_t* atomic_buffer,
const int topk_length,
const int n,
const int k) {
int expert_id = blockIdx.x;
int occurrences = 0;
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
occurrences += (topk_ids[i] == expert_id);
}
atomicAdd(&atomic_buffer[expert_id], occurrences);
__syncthreads();
if (threadIdx.x == 0) {
int final_occurrences = atomic_buffer[expert_id];
problem_sizes1[expert_id * 3] = final_occurrences;
problem_sizes1[expert_id * 3 + 1] = 2 * n;
problem_sizes1[expert_id * 3 + 2] = k;
problem_sizes2[expert_id * 3] = final_occurrences;
problem_sizes2[expert_id * 3 + 1] = k;
problem_sizes2[expert_id * 3 + 2] = n;
}
}
__global__ void compute_expert_offsets(
const int32_t* __restrict__ problem_sizes1,
int32_t* expert_offsets,
int32_t* atomic_buffer,
const int num_experts) {
int32_t tot_offset = 0;
expert_offsets[0] = 0;
for (int i = 0; i < num_experts; ++i) {
atomic_buffer[i] = tot_offset;
tot_offset += problem_sizes1[i * 3];
expert_offsets[i + 1] = tot_offset;
}
}
__global__ void compute_arg_sorts(
const int* __restrict__ topk_ids,
int32_t* input_permutation,
int32_t* output_permutation,
int32_t* atomic_buffer,
const int topk_length,
const int topk) {
int expert_id = blockIdx.x;
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
if (topk_ids[i] == expert_id) {
int start = atomicAdd(&atomic_buffer[expert_id], 1);
input_permutation[start] = i / topk;
output_permutation[i] = start;
}
}
}
void get_moe_prepare_input_caller(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k) {
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()),
topk_ids.numel(),
n,
k);
compute_expert_offsets<<<1, 1, 0, stream>>>(
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()),
num_experts);
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(input_permutation.data_ptr()),
static_cast<int32_t*>(output_permutation.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()),
topk_ids.numel(),
topk_ids.size(1));
}
void prepare_moe_input(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k) {
TORCH_CHECK(topk_ids.dtype() == torch::kInt32);
get_moe_prepare_input_caller(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
output_permutation,
num_experts,
n,
k);
return;
}
...@@ -211,6 +211,11 @@ std::vector<at::Tensor> moe_fused_gate( ...@@ -211,6 +211,11 @@ std::vector<at::Tensor> moe_fused_gate(
void fp8_blockwise_scaled_grouped_mm( void fp8_blockwise_scaled_grouped_mm(
torch::Tensor& output, torch::Tensor& output,
torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs,
torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs,
torch::Tensor& b_scales_ptrs,
const torch::Tensor& a, const torch::Tensor& a,
const torch::Tensor& b, const torch::Tensor& b,
const torch::Tensor& scales_a, const torch::Tensor& scales_a,
...@@ -221,7 +226,19 @@ void fp8_blockwise_scaled_grouped_mm( ...@@ -221,7 +226,19 @@ void fp8_blockwise_scaled_grouped_mm(
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb, const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes, const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets); const torch::Tensor& expert_offsets,
const torch::Tensor& workspace);
void prepare_moe_input(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k);
/* /*
* From csrc/speculative * From csrc/speculative
......
...@@ -47,6 +47,7 @@ from sgl_kernel.moe import ( ...@@ -47,6 +47,7 @@ from sgl_kernel.moe import (
fp8_blockwise_scaled_grouped_mm, fp8_blockwise_scaled_grouped_mm,
moe_align_block_size, moe_align_block_size,
moe_fused_gate, moe_fused_gate,
prepare_moe_input,
topk_softmax, topk_softmax,
) )
from sgl_kernel.sampling import ( from sgl_kernel.sampling import (
......
...@@ -64,6 +64,11 @@ def moe_fused_gate( ...@@ -64,6 +64,11 @@ def moe_fused_gate(
def fp8_blockwise_scaled_grouped_mm( def fp8_blockwise_scaled_grouped_mm(
output, output,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a, a,
b, b,
scales_a, scales_a,
...@@ -75,9 +80,15 @@ def fp8_blockwise_scaled_grouped_mm( ...@@ -75,9 +80,15 @@ def fp8_blockwise_scaled_grouped_mm(
layout_sfb, layout_sfb,
problem_sizes, problem_sizes,
expert_offsets, expert_offsets,
workspace,
): ):
torch.ops.sgl_kernel.fp8_blockwise_scaled_grouped_mm.default( torch.ops.sgl_kernel.fp8_blockwise_scaled_grouped_mm.default(
output, output,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a, a,
b, b,
scales_a, scales_a,
...@@ -89,4 +100,29 @@ def fp8_blockwise_scaled_grouped_mm( ...@@ -89,4 +100,29 @@ def fp8_blockwise_scaled_grouped_mm(
layout_sfb, layout_sfb,
problem_sizes, problem_sizes,
expert_offsets, expert_offsets,
workspace,
)
def prepare_moe_input(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
output_permutation,
num_experts,
n,
k,
):
torch.ops.sgl_kernel.prepare_moe_input.default(
topk_ids,
expert_offsets,
problem_sizes1,
problem_sizes2,
input_permutation,
output_permutation,
num_experts,
n,
k,
) )
...@@ -131,9 +131,20 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype): ...@@ -131,9 +131,20 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
c_strides = torch.full( c_strides = torch.full(
(num_experts,), c_out.stride(0), device=device, dtype=torch.int64 (num_experts,), c_out.stride(0), device=device, dtype=torch.int64
) )
workspace = torch.empty((1024 * 1024 * 1024), device=device, dtype=torch.uint8)
a_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
b_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
out_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
a_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
b_scales_ptrs = torch.empty((num_experts,), device=device, dtype=torch.int64)
fp8_blockwise_scaled_grouped_mm( fp8_blockwise_scaled_grouped_mm(
c_out, c_out,
a_ptrs,
b_ptrs,
out_ptrs,
a_scales_ptrs,
b_scales_ptrs,
a_stack, a_stack,
b_stack, b_stack,
a_scale_stack, a_scale_stack,
...@@ -145,6 +156,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype): ...@@ -145,6 +156,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype):
layout_sfb, layout_sfb,
problem_sizes, problem_sizes,
expert_offsets[:-1], expert_offsets[:-1],
workspace,
) )
for g in range(num_experts): for g in range(num_experts):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment