Unverified Commit 82f39dc1 authored by Shu Wang's avatar Shu Wang Committed by GitHub
Browse files

Add mm_fp4 trtllm backend (#12406)

parent 627bac64
......@@ -66,6 +66,7 @@ SGLang supports various environment variables that can be used to configure its
| `SGLANG_MOE_PADDING` | Enable MoE padding (sets padding size to 128 if value is `1`, often set to `1` in Docker builds) | `0` |
| `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` |
| `SGLANG_ENABLE_FLASHINFER_GEMM` | Use flashinfer kernels when running blockwise fp8 GEMM on Blackwell GPUs | `false` |
| `SGLANG_FLASHINFER_FP4_GEMM_BACKEND` | Select backend for `mm_fp4` on Blackwell GPUS | `` |
| `SGLANG_SUPPORT_CUTLASS_BLOCK_FP8` | Use Cutlass kernels when running blockwise fp8 GEMM on Hopper or Blackwell GPUs | `false` |
| `SGLANG_CUTLASS_MOE` (deprecated) | Use Cutlass FP8 MoE kernel on Blackwell GPUs (deprecated, use --moe-runner-backend=cutlass) | `false` |
......
......@@ -198,6 +198,8 @@ class Envs:
# Flashinfer
SGLANG_IS_FLASHINFER_AVAILABLE = EnvBool(True)
SGLANG_ENABLE_FLASHINFER_GEMM = EnvBool(False)
# Default to the pick from flashinfer
SGLANG_FLASHINFER_FP4_GEMM_BACKEND = EnvStr("")
SGLANG_FLASHINFER_WORKSPACE_SIZE = EnvInt(384 * 1024 * 1024)
# Triton
......
......@@ -11,6 +11,7 @@ from sglang.srt.distributed import get_tp_group
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.environ import envs
from sglang.srt.layers.dp_attention import (
get_dp_global_num_tokens,
get_local_dp_buffer,
......@@ -94,14 +95,12 @@ logger = logging.getLogger(__name__)
CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
"SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE", "true"
)
USE_CUTLASS_BACKEND_FOR_FP4_GEMM = get_bool_env_var(
"SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM", "true"
)
# TODO make it true by default when the DeepEP PR is merged
CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
"SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH", "false"
)
FLASHINFER_FP4_GEMM_BACKEND = envs.SGLANG_FLASHINFER_FP4_GEMM_BACKEND.get()
# Supported activation schemes for the current configuration
ACTIVATION_SCHEMES = ["static"]
......@@ -1006,7 +1005,26 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
layer.input_scale_inv = Parameter(
(1 / input_scale_2).to(torch.float32), requires_grad=False
)
if FLASHINFER_FP4_GEMM_BACKEND == "trtllm":
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
# layout but we use our own quantization so we have to call
# shuffles ourselves.
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
weight = layer.weight
scale = layer.weight_scale
epilogue_tile_m = 128
weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
scale = (
shuffle_matrix_sf_a(scale.view(torch.uint8), epilogue_tile_m)
.reshape(scale.shape)
.view(torch.float8_e4m3fn)
)
layer.weight_scale_interleaved = Parameter(scale, requires_grad=False)
layer.weight = Parameter(weight, requires_grad=False)
return
# Pad and blockwise interleave weight_scale
scales = layer.weight_scale
scale_ndim = scales.ndim
......@@ -1056,6 +1074,11 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
if enable_flashinfer_fp4_gemm:
w = layer.weight.T
w_scale_interleaved = layer.weight_scale_interleaved.T
# TODO(shuw@nvidia.com)
# Remove the default after flashinfer bumped to 0.5.1
backend = (
FLASHINFER_FP4_GEMM_BACKEND if FLASHINFER_FP4_GEMM_BACKEND else "cutlass"
)
out = fp4_gemm(
x_fp4,
w,
......@@ -1063,7 +1086,7 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
w_scale_interleaved,
layer.alpha,
output_dtype,
**(dict(backend="cutlass") if USE_CUTLASS_BACKEND_FOR_FP4_GEMM else dict()),
**(dict(backend=backend)),
)
if bias is not None:
out = out + bias
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment