Unverified Commit 1976356e authored by Yongye Zhu's avatar Yongye Zhu Committed by GitHub
Browse files

[MoE Refactor] MXFP4 Cutlass Experts to MK (#34542)


Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
parent cbf8f702
...@@ -73,3 +73,29 @@ steps: ...@@ -73,3 +73,29 @@ steps:
num_devices: 2 num_devices: 2
commands: commands:
- pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=evals/gsm8k/configs/moe-refactor-dp-ep/config-b200.txt - pytest -s -v evals/gsm8k/test_gsm8k_correctness.py --config-list-file=evals/gsm8k/configs/moe-refactor-dp-ep/config-b200.txt
- label: GPQA Eval (GPT-OSS) (H100)
timeout_in_minutes: 120
device: h100
optional: true
num_devices: 2
source_file_dependencies:
- csrc/
- vllm/model_executor/layers/quantization
- tests/evals/gpt_oss/
commands:
- uv pip install --system 'gpt-oss[eval]==0.0.5'
- pytest -s -v evals/gpt_oss/test_gpqa_correctness.py --config-list-file=configs/models-h100.txt
- label: GPQA Eval (GPT-OSS) (B200)
timeout_in_minutes: 120
device: b200
optional: true
num_devices: 2
source_file_dependencies:
- csrc/
- vllm/model_executor/layers/quantization
- tests/evals/gpt_oss/
commands:
- uv pip install --system 'gpt-oss[eval]==0.0.5'
- pytest -s -v evals/gpt_oss/test_gpqa_correctness.py --config-list-file=configs/models-b200.txt
...@@ -153,33 +153,6 @@ steps: ...@@ -153,33 +153,6 @@ steps:
- pytest -v -s transformers_utils - pytest -v -s transformers_utils
- pytest -v -s config - pytest -v -s config
- label: GPT-OSS Eval (H100)
timeout_in_minutes: 60
working_dir: "/vllm-workspace/"
device: h100
optional: true
source_file_dependencies:
- tests/evals/gpt_oss
- vllm/model_executor/models/gpt_oss.py
- vllm/model_executor/layers/quantization/mxfp4.py
commands:
- uv pip install --system 'gpt-oss[eval]==0.0.5'
- pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58
- label: GPT-OSS Eval (B200)
timeout_in_minutes: 60
working_dir: "/vllm-workspace/"
device: b200
optional: true
source_file_dependencies:
- tests/evals/gpt_oss
- vllm/model_executor/models/gpt_oss.py
- vllm/model_executor/layers/quantization/mxfp4.py
- vllm/v1/attention/backends/flashinfer.py
commands:
- uv pip install --system 'gpt-oss[eval]==0.0.5'
- pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py --model openai/gpt-oss-20b --metric 0.58
- label: Batch Invariance (H100) - label: Batch Invariance (H100)
timeout_in_minutes: 25 timeout_in_minutes: 25
device: h100 device: h100
......
# GPQA Evaluation using GPT-OSS
This directory contains GPQA evaluation tests using the GPT-OSS evaluation package and vLLM server.
## Usage
### Run tests with pytest (like buildkite)
```bash
# H200
pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py \
--config-list-file=configs/models-h200.txt
# B200
pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py \
--config-list-file=configs/models-b200.txt
```
## Configuration Format
Model configs in `configs/` directory use this YAML format:
```yaml
model_name: "openai/gpt-oss-20b"
metric_threshold: 0.568 # Minimum expected accuracy
reasoning_effort: "low" # Reasoning effort level (default: "low")
server_args: "--tensor-parallel-size 2" # Server arguments
startup_max_wait_seconds: 1800 # Max wait for server startup (default: 1800)
env: # Environment variables (optional)
SOME_VAR: "value"
```
The `server_args` field accepts any arguments that can be passed to `vllm serve`.
The `env` field accepts a dictionary of environment variables to set for the server process.
## Adding New Models
1. Create a new YAML config file in the `configs/` directory
2. Add the filename to the appropriate `models-*.txt` file
## Tiktoken Encoding Files
The tiktoken encoding files required by the vLLM server are automatically downloaded from OpenAI's public blob storage on first run:
- `cl100k_base.tiktoken`
- `o200k_base.tiktoken`
Files are cached in the `data/` directory. The `TIKTOKEN_ENCODINGS_BASE` environment variable is automatically set to point to this directory when running evaluations.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
model_name: "openai/gpt-oss-20b"
metric_threshold: 0.568
reasoning_effort: "low"
server_args: "--tensor-parallel-size 2"
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
model_name: "openai/gpt-oss-20b"
metric_threshold: 0.568
reasoning_effort: "low"
server_args: "--tensor-parallel-size 2"
env:
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: "1"
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
model_name: "openai/gpt-oss-20b"
metric_threshold: 0.568
reasoning_effort: "low"
server_args: "--tensor-parallel-size 2"
env:
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: "1"
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
model_name: "openai/gpt-oss-20b"
metric_threshold: 0.568
reasoning_effort: "low"
server_args: "--tensor-parallel-size 2"
env:
VLLM_MXFP4_USE_MARLIN: "1"
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
model_name: "openai/gpt-oss-20b"
metric_threshold: 0.568
reasoning_effort: "low"
server_args: "--tensor-parallel-size 2"
env:
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: "1"
# B200 model configurations for GPQA evaluation
# Tests different environment variable combinations
gpt-oss-20b-flashinfer-mxfp4-bf16.yaml
gpt-oss-20b-flashinfer-mxfp4-mxfp8-cutlass.yaml
gpt-oss-20b-sm100-fi-mxfp4-mxfp8-trtllm.yaml
\ No newline at end of file
# H100 model configurations for GPQA evaluation
# Tests different environment variable combinations
gpt-oss-20b-baseline.yaml
gpt-oss-20b-flashinfer-mxfp4-bf16.yaml
gpt-oss-20b-marlin.yaml
...@@ -4,13 +4,61 @@ ...@@ -4,13 +4,61 @@
Pytest configuration for GPT-OSS evaluation tests. Pytest configuration for GPT-OSS evaluation tests.
""" """
from pathlib import Path
def pytest_addoption(parser): def pytest_addoption(parser):
"""Add command line options for pytest.""" """Add custom command line options."""
parser.addoption("--model", action="store", help="Model name to evaluate")
parser.addoption(
"--metric", action="store", type=float, help="Expected metric threshold"
)
parser.addoption( parser.addoption(
"--server-args", action="store", default="", help="Additional server arguments" "--config-list-file",
required=True,
help="File containing list of config files to test",
) )
def pytest_generate_tests(metafunc):
"""Generate test parameters from config files."""
if "config_filename" in metafunc.fixturenames:
config_list_file = metafunc.config.getoption("--config-list-file")
# Handle both relative and absolute paths
config_list_path = Path(config_list_file)
if not config_list_path.is_absolute():
# If relative, try relative to test directory first
test_dir_path = Path(__file__).parent / config_list_file
if test_dir_path.exists():
config_list_path = test_dir_path
else:
# Try relative to current working directory
config_list_path = Path.cwd() / config_list_file
print(f"Looking for config list at: {config_list_path}")
config_files = []
if config_list_path.exists():
# Determine config directory (same directory as the list file)
config_dir = config_list_path.parent
with open(config_list_path) as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
config_path = config_dir / line
print(f"Checking config file: {config_path}")
if config_path.exists():
config_files.append(config_path)
print(f" Found: {config_path}")
else:
print(f" Missing: {config_path}")
else:
print(f"Config list file not found: {config_list_path}")
# Generate test parameters
if config_files:
metafunc.parametrize(
"config_filename",
config_files,
ids=[config_file.stem for config_file in config_files],
)
else:
print("No config files found, test will be skipped")
...@@ -5,22 +5,48 @@ GPQA evaluation using vLLM server and GPT-OSS evaluation package. ...@@ -5,22 +5,48 @@ GPQA evaluation using vLLM server and GPT-OSS evaluation package.
Usage: Usage:
pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py \ pytest -s -v tests/evals/gpt_oss/test_gpqa_correctness.py \
--model openai/gpt-oss-20b \ --config-list-file=configs/models-h200.txt
--metric 0.58 \
--server-args "--tensor-parallel-size 2"
""" """
import os
import shlex
import subprocess import subprocess
import sys import sys
import urllib.request
from pathlib import Path
import regex as re import regex as re
import yaml
from tests.utils import RemoteOpenAIServer from tests.utils import RemoteOpenAIServer
TOL = 0.05 # Absolute tolerance for accuracy comparison TOL = 0.05 # Absolute tolerance for accuracy comparison
# Path to tiktoken encoding files
TIKTOKEN_DATA_DIR = Path(__file__).parent / "data"
def run_gpqa_eval(model_name: str, base_url: str) -> float: # Tiktoken encoding files to download
TIKTOKEN_FILES = {
"cl100k_base.tiktoken": "https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken",
"o200k_base.tiktoken": "https://openaipublic.blob.core.windows.net/encodings/o200k_base.tiktoken",
}
def ensure_tiktoken_files():
"""Download tiktoken encoding files if they don't exist."""
TIKTOKEN_DATA_DIR.mkdir(parents=True, exist_ok=True)
for filename, url in TIKTOKEN_FILES.items():
filepath = TIKTOKEN_DATA_DIR / filename
if not filepath.exists():
print(f"Downloading {filename} from {url}...")
urllib.request.urlretrieve(url, filepath)
print(f" Downloaded to {filepath}")
else:
print(f" {filename} already exists.")
def run_gpqa_eval(model_name: str, base_url: str, reasoning_effort: str) -> float:
"""Run GPQA evaluation using the gpt-oss evaluation package.""" """Run GPQA evaluation using the gpt-oss evaluation package."""
# Build the command to run the evaluation # Build the command to run the evaluation
...@@ -33,7 +59,7 @@ def run_gpqa_eval(model_name: str, base_url: str) -> float: ...@@ -33,7 +59,7 @@ def run_gpqa_eval(model_name: str, base_url: str) -> float:
"--model", "--model",
model_name, model_name,
"--reasoning-effort", "--reasoning-effort",
"low", reasoning_effort,
"--base-url", "--base-url",
base_url, base_url,
"--n-threads", "--n-threads",
...@@ -41,16 +67,29 @@ def run_gpqa_eval(model_name: str, base_url: str) -> float: ...@@ -41,16 +67,29 @@ def run_gpqa_eval(model_name: str, base_url: str) -> float:
] ]
try: try:
# Set up environment for the evaluation subprocess
# Inherit current environment and add required variables
eval_env = os.environ.copy()
eval_env["OPENAI_API_KEY"] = "dummy"
# Run the evaluation # Run the evaluation
result = subprocess.run( result = subprocess.run(
cmd, cmd,
text=True, text=True,
capture_output=True, capture_output=True,
timeout=1800, # 30 minute timeout timeout=1800, # 30 minute timeout
env={"OPENAI_API_KEY": "dummy"}, env=eval_env,
) )
print("Evaluation process output:\n", result.stdout) print("Evaluation process stdout:\n", result.stdout)
print("Evaluation process stderr:\n", result.stderr)
print(f"Evaluation process return code: {result.returncode}")
if result.returncode != 0:
raise RuntimeError(
f"Evaluation failed with exit code {result.returncode}:\n"
f"stdout: {result.stdout}\nstderr: {result.stderr}"
)
# Parse the output to extract the score # Parse the output to extract the score
match = re.search(r"'metric':\s*([\d.]+)", result.stdout) match = re.search(r"'metric':\s*([\d.]+)", result.stdout)
...@@ -64,47 +103,62 @@ def run_gpqa_eval(model_name: str, base_url: str) -> float: ...@@ -64,47 +103,62 @@ def run_gpqa_eval(model_name: str, base_url: str) -> float:
except subprocess.TimeoutExpired as e: except subprocess.TimeoutExpired as e:
raise RuntimeError("Evaluation timed out") from e raise RuntimeError("Evaluation timed out") from e
except subprocess.CalledProcessError as e:
raise RuntimeError(
f"Evaluation failed with exit code {e.returncode}:\n"
f"stdout: {e.stdout}\nstderr: {e.stderr}"
) from e
def test_gpqa_correctness(request): def test_gpqa_correctness(config_filename):
"""Test GPQA correctness for GPT-OSS model.""" """Test GPQA correctness for a given model configuration."""
# Ensure tiktoken files are downloaded
ensure_tiktoken_files()
# Verify tiktoken files exist
for filename in TIKTOKEN_FILES:
filepath = TIKTOKEN_DATA_DIR / filename
assert filepath.exists(), f"Tiktoken file not found: {filepath}"
# Get command line arguments eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8"))
model_name = request.config.getoption("--model")
expected_metric = request.config.getoption("--metric")
server_args_str = request.config.getoption("--server-args")
# Parse server arguments # Parse server arguments from config (use shlex to handle quoted strings)
server_args = [] server_args_str = eval_config.get("server_args", "")
if server_args_str: server_args = shlex.split(server_args_str) if server_args_str else []
server_args = server_args_str.split()
# Add standard server arguments # Add standard server arguments
server_args.extend( server_args.extend(
[ [
"--trust-remote-code", "--trust-remote-code",
"--enforce-eager",
"--disable-uvicorn-access-log",
] ]
) )
print(f"Starting GPQA evaluation for model: {model_name}") # Build server environment with tiktoken path and any config-specified vars
print(f"Expected metric threshold: {expected_metric}") server_env = {"TIKTOKEN_ENCODINGS_BASE": str(TIKTOKEN_DATA_DIR)}
if eval_config.get("env"):
server_env.update(eval_config["env"])
reasoning_effort = eval_config.get("reasoning_effort", "low")
print(f"Starting GPQA evaluation for model: {eval_config['model_name']}")
print(f"Expected metric threshold: {eval_config['metric_threshold']}")
print(f"Reasoning effort: {reasoning_effort}")
print(f"Server args: {' '.join(server_args)}") print(f"Server args: {' '.join(server_args)}")
print(f"Server environment variables: {server_env}")
# Launch server and run evaluation # Launch server and run evaluation
with RemoteOpenAIServer( with RemoteOpenAIServer(
model_name, server_args, max_wait_seconds=1800 eval_config["model_name"],
server_args,
env_dict=server_env,
max_wait_seconds=eval_config.get("startup_max_wait_seconds", 1800),
) as remote_server: ) as remote_server:
base_url = remote_server.url_for("v1") base_url = remote_server.url_for("v1")
print(f"Server started at: {base_url}") print(f"Server started at: {base_url}")
measured_metric = run_gpqa_eval(model_name, base_url) measured_metric = run_gpqa_eval(
eval_config["model_name"], base_url, reasoning_effort
)
expected_metric = eval_config["metric_threshold"]
print(f"GPQA Results for {model_name}:") print(f"GPQA Results for {eval_config['model_name']}:")
print(f" Measured metric: {measured_metric:.4f}") print(f" Measured metric: {measured_metric:.4f}")
print(f" Expected metric: {expected_metric:.4f}") print(f" Expected metric: {expected_metric:.4f}")
print(f" Tolerance: {TOL:.4f}") print(f" Tolerance: {TOL:.4f}")
...@@ -115,4 +169,4 @@ def test_gpqa_correctness(request): ...@@ -115,4 +169,4 @@ def test_gpqa_correctness(request):
f"{expected_metric:.4f} - {TOL:.4f} = {expected_metric - TOL:.4f}" f"{expected_metric:.4f} - {TOL:.4f} = {expected_metric - TOL:.4f}"
) )
print(f"GPQA test passed for {model_name}") print(f"GPQA test passed for {eval_config['model_name']}")
...@@ -242,6 +242,10 @@ class FusedMoEQuantConfig: ...@@ -242,6 +242,10 @@ class FusedMoEQuantConfig:
def quant_dtype(self) -> torch.dtype | str | None: def quant_dtype(self) -> torch.dtype | str | None:
return self._a1.dtype return self._a1.dtype
@property
def weight_quant_dtype(self) -> torch.dtype | str | None:
return self._w1.dtype
@property @property
def is_quantized(self) -> bool: def is_quantized(self) -> bool:
return self.quant_dtype is not None return self.quant_dtype is not None
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import torch import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import ( from vllm.model_executor.layers.fused_moe.config import (
...@@ -18,6 +19,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -18,6 +19,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Dynamic128Sym, kFp8Dynamic128Sym,
kFp8Static128BlockSym, kFp8Static128BlockSym,
kFp8StaticTensorSym, kFp8StaticTensorSym,
kMxfp4Static,
kMxfp8Dynamic,
kNvfp4Dynamic, kNvfp4Dynamic,
kNvfp4Static, kNvfp4Static,
) )
...@@ -64,10 +67,18 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -64,10 +67,18 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
): ):
super().__init__(moe_config, quant_config) super().__init__(moe_config, quant_config)
assert quant_config.quant_dtype in ("nvfp4", torch.float8_e4m3fn, None), (
"Only nvfp4, fp8, bfloat16 and" assert quant_config.weight_quant_dtype in (
"mxfp4",
"nvfp4",
torch.float8_e4m3fn,
None,
), (
"Only mxfp4, nvfp4, fp8, bfloat16 and"
" float16 quantization are currently supported." " float16 quantization are currently supported."
) )
self.device = moe_config.device
self.num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank self.ep_rank = moe_config.moe_parallel_config.ep_rank
self.ep_size = moe_config.moe_parallel_config.ep_size self.ep_size = moe_config.moe_parallel_config.ep_size
self.tp_rank = moe_config.moe_parallel_config.tp_rank self.tp_rank = moe_config.moe_parallel_config.tp_rank
...@@ -78,6 +89,28 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -78,6 +89,28 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# - pass per-block weight scales to the kernel # - pass per-block weight scales to the kernel
# - skip input activation quantization (kernel applies scaling) # - skip input activation quantization (kernel applies scaling)
self.use_deepseek_fp8_block_scale = quant_config.is_block_quantized self.use_deepseek_fp8_block_scale = quant_config.is_block_quantized
self.max_capture_size = (
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
)
if quant_config.weight_quant_dtype == "mxfp4":
# This value is used specifically for gpt-oss,
# Need to revisit this for other models
self.gemm1_alpha = torch.tensor(
[1.702] * self.num_experts, dtype=torch.float32, device=self.device
)
self.gemm1_beta = torch.tensor(
[1.0] * self.num_experts, dtype=torch.float32, device=self.device
)
self.gemm1_clamp_limit = torch.tensor(
[7.0] * self.num_experts, dtype=torch.float32, device=self.device
)
if quant_config.quant_dtype == "mxfp8":
self.fake_input_scale = torch.ones(
self.num_experts,
device=self.device,
dtype=torch.float32,
)
@property @property
def expects_unquantized_inputs(self) -> bool: def expects_unquantized_inputs(self) -> bool:
...@@ -119,20 +152,33 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -119,20 +152,33 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
] ]
and p.has_device_capability(90) and p.has_device_capability(90)
) )
# fp8 block-scale on 9.0 # fp8 block-scale, wmxfp4a16 on 9.0
or ( or (
scheme == (kFp8Static128BlockSym, kFp8Dynamic128Sym) scheme
in [
(kMxfp4Static, None),
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
]
and p.is_device_capability(90) and p.is_device_capability(90)
) )
# nvfp4 on 10.0+ # nvfp4, wmxfp4amxfp8 on 10.0+
or ( or (
scheme == (kNvfp4Static, kNvfp4Dynamic) and p.has_device_capability(100) scheme
in [
(kMxfp4Static, kMxfp8Dynamic),
(kNvfp4Static, kNvfp4Dynamic),
]
and p.has_device_capability(100)
) )
) )
@staticmethod @staticmethod
def _supports_activation(activation: MoEActivation) -> bool: def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL] return activation in [
MoEActivation.SILU,
MoEActivation.RELU2_NO_MUL,
MoEActivation.SWIGLUOAI,
]
@staticmethod @staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool: def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
...@@ -216,12 +262,23 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -216,12 +262,23 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
activation_str_to_value_map = { activation_str_to_value_map = {
MoEActivation.SILU: ActivationType.Swiglu, # This is the default MoEActivation.SILU: ActivationType.Swiglu, # This is the default
MoEActivation.SWIGLUOAI: ActivationType.Swiglu, # gpt-oss alias
MoEActivation.RELU2_NO_MUL: ActivationType.Relu2, MoEActivation.RELU2_NO_MUL: ActivationType.Relu2,
} }
assert activation in activation_str_to_value_map, ( assert activation in activation_str_to_value_map, (
f"{activation=} missing from {activation_str_to_value_map.keys()=}" f"{activation=} missing from {activation_str_to_value_map.keys()=}"
) )
quant_scales = None
fc1_expert_weights = None
fc2_expert_weights = None
fc1_expert_biases = None
fc2_expert_biases = None
swiglu_alpha = None
swiglu_beta = None
swiglu_limit = None
use_mxfp8_act_scaling = False
use_w4_group_scaling = False
# Select quantization metadata based on FP8 format/path # Select quantization metadata based on FP8 format/path
if ( if (
self.quant_dtype == torch.float8_e4m3fn self.quant_dtype == torch.float8_e4m3fn
...@@ -256,6 +313,43 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -256,6 +313,43 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
# FlashInfer API requires weight to be long for nvfp4 # FlashInfer API requires weight to be long for nvfp4
fc1_expert_weights = w1.view(torch.long) fc1_expert_weights = w1.view(torch.long)
fc2_expert_weights = w2.view(torch.long) fc2_expert_weights = w2.view(torch.long)
elif self.weight_quant_dtype == "mxfp4":
assert self.w1_scale is not None and self.w2_scale is not None
assert w1.is_contiguous() and w2.is_contiguous()
assert self.gemm1_alpha is not None
assert self.gemm1_beta is not None
assert self.gemm1_clamp_limit is not None
assert topk_ids.is_contiguous()
fc1_expert_biases = self.w1_bias
fc2_expert_biases = self.w2_bias
swiglu_alpha = self.gemm1_alpha
swiglu_beta = self.gemm1_beta
swiglu_limit = self.gemm1_clamp_limit
if self.quant_dtype == "mxfp8":
assert self.fake_input_scale is not None
fc1_expert_weights = w1.view(torch.long)
fc2_expert_weights = w2.view(torch.long)
quant_scales = [
self.w1_scale.view(torch.int32),
self.fake_input_scale,
self.w2_scale.view(torch.int32),
self.fake_input_scale,
]
use_mxfp8_act_scaling = True
else:
assert hidden_states.dtype == torch.bfloat16
fc1_expert_weights = w1
fc2_expert_weights = w2
quant_scales = [
self.w1_scale,
self.w2_scale,
]
a1q_scale = None
use_w4_group_scaling = True
elif self.use_deepseek_fp8_block_scale: elif self.use_deepseek_fp8_block_scale:
# FP8 block-scale path: provide block-scale weights, omit a1q_scale # FP8 block-scale path: provide block-scale weights, omit a1q_scale
quant_scales = [ quant_scales = [
...@@ -277,6 +371,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -277,6 +371,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
token_final_scales=topk_weights, token_final_scales=topk_weights,
fc1_expert_weights=fc1_expert_weights, fc1_expert_weights=fc1_expert_weights,
fc2_expert_weights=fc2_expert_weights, fc2_expert_weights=fc2_expert_weights,
fc1_expert_biases=fc1_expert_biases,
fc2_expert_biases=fc2_expert_biases,
swiglu_alpha=swiglu_alpha,
swiglu_beta=swiglu_beta,
swiglu_limit=swiglu_limit,
output=output,
output_dtype=self.out_dtype, output_dtype=self.out_dtype,
quant_scales=quant_scales, quant_scales=quant_scales,
input_sf=a1q_scale, input_sf=a1q_scale,
...@@ -284,10 +384,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -284,10 +384,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
ep_size=self.ep_size, ep_size=self.ep_size,
ep_rank=self.ep_rank, ep_rank=self.ep_rank,
output=output,
activation_type=activation_str_to_value_map[activation], activation_type=activation_str_to_value_map[activation],
# Informs FlashInfer to use the block-scale decoding path when True # Informs FlashInfer to use the block-scale decoding path when True
use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale, use_deepseek_fp8_block_scale=self.use_deepseek_fp8_block_scale,
use_mxfp8_act_scaling=use_mxfp8_act_scaling,
use_w4_group_scaling=use_w4_group_scaling,
tune_max_num_tokens=max(self.max_capture_size, 1),
) )
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None: def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
......
...@@ -564,9 +564,13 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -564,9 +564,13 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
# #
@property @property
def quant_dtype(self) -> torch.dtype | None: def quant_dtype(self) -> torch.dtype | str | None:
return self.quant_config.quant_dtype return self.quant_config.quant_dtype
@property
def weight_quant_dtype(self) -> torch.dtype | str | None:
return self.quant_config.weight_quant_dtype
@property @property
def block_shape(self) -> list[int] | None: def block_shape(self) -> list[int] | None:
return self.quant_config.block_shape return self.quant_config.block_shape
......
...@@ -25,15 +25,20 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -25,15 +25,20 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
self, self,
moe_config: FusedMoEConfig, moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
gemm1_alpha,
gemm1_beta,
gemm1_clamp_limit,
max_capture_size, max_capture_size,
): ):
super().__init__(moe_config, quant_config) super().__init__(moe_config, quant_config)
self.gemm1_alpha = gemm1_alpha self.device = torch.cuda.current_device()
self.gemm1_beta = gemm1_beta self.num_experts = moe_config.num_local_experts
self.gemm1_clamp_limit = gemm1_clamp_limit self.gemm1_alpha = torch.tensor(
[1.702] * self.num_experts, dtype=torch.float32, device=self.device
)
self.gemm1_beta = torch.tensor(
[1.0] * self.num_experts, dtype=torch.float32, device=self.device
)
self.gemm1_clamp_limit = torch.tensor(
[7.0] * self.num_experts, dtype=torch.float32, device=self.device
)
self.max_capture_size = max_capture_size self.max_capture_size = max_capture_size
@staticmethod @staticmethod
......
...@@ -195,11 +195,12 @@ def _mxfp8_e4m3_quantize( ...@@ -195,11 +195,12 @@ def _mxfp8_e4m3_quantize(
A_scale: torch.Tensor | None, A_scale: torch.Tensor | None,
per_act_token_quant: bool, per_act_token_quant: bool,
block_shape: list[int] | None = None, block_shape: list[int] | None = None,
is_sf_swizzled_layout: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
assert A_scale is None assert A_scale is None
assert not per_act_token_quant assert not per_act_token_quant
assert block_shape is None assert block_shape is None
return mxfp8_e4m3_quantize(A) return mxfp8_e4m3_quantize(A, is_sf_swizzled_layout)
def _mxfp6_e3m2_quantize( def _mxfp6_e3m2_quantize(
...@@ -275,7 +276,13 @@ def moe_kernel_quantize_input( ...@@ -275,7 +276,13 @@ def moe_kernel_quantize_input(
elif quant_dtype == "mxfp8": elif quant_dtype == "mxfp8":
# TODO: `quant_dtype == "mxfp8"` is ambiguous, # TODO: `quant_dtype == "mxfp8"` is ambiguous,
# should be fp8_e4m3. OCP MX also defines `fp8_e5m2`. # should be fp8_e4m3. OCP MX also defines `fp8_e5m2`.
return _mxfp8_e4m3_quantize(A, A_scale, per_act_token_quant, block_shape) return _mxfp8_e4m3_quantize(
A,
A_scale,
per_act_token_quant,
block_shape,
is_sf_swizzled_layout=is_fp4_scale_swizzled,
)
elif quant_dtype == "mxfp6_e3m2": elif quant_dtype == "mxfp6_e3m2":
return _mxfp6_e3m2_quantize(A, A_scale, per_act_token_quant, block_shape) return _mxfp6_e3m2_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == "mxfp6_e2m3": elif quant_dtype == "mxfp6_e2m3":
......
...@@ -256,6 +256,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -256,6 +256,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
"Please check your environment and try again." "Please check your environment and try again."
) )
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {} self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}
# Initialized in process_weights_after_loading for CUTLASS/SM90 backends
self.moe_mk: mk.FusedMoEModularKernel | None = None self.moe_mk: mk.FusedMoEModularKernel | None = None
def create_weights( def create_weights(
...@@ -648,19 +649,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -648,19 +649,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
): ):
layer.gemm1_alpha = Parameter(
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
)
layer.gemm1_beta = Parameter(
torch.tensor([1.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
)
layer.gemm1_clamp_limit = Parameter(
torch.tensor([7.0] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
)
sf_block_size = 32 # mxfp4 block size sf_block_size = 32 # mxfp4 block size
# Common shape assertions # Common shape assertions
...@@ -772,6 +760,30 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -772,6 +760,30 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer.w2_weight_scale = torch.nn.Parameter( layer.w2_weight_scale = torch.nn.Parameter(
w2_scales_interleaved, requires_grad=False w2_scales_interleaved, requires_grad=False
) )
# theses two kernels go through the `flashinfer_cutlass_fused_moe` path
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts,
)
self.moe_quant_config = self.get_fused_moe_quant_config(layer)
assert self.moe_quant_config is not None
prepare_finalize = maybe_make_prepare_finalize(
moe=self.moe,
quant_config=self.moe_quant_config,
routing_tables=layer._maybe_init_expert_routing_tables(),
allow_new_interface=True,
)
assert prepare_finalize is not None
self.moe_mk = mk.FusedMoEModularKernel(
prepare_finalize,
FlashInferExperts(
moe_config=self.moe,
quant_config=self.moe_quant_config,
),
shared_experts=None,
)
elif self.mxfp4_backend == Mxfp4Backend.TRITON: elif self.mxfp4_backend == Mxfp4Backend.TRITON:
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
...@@ -847,7 +859,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -847,7 +859,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
) )
elif self.mxfp4_backend in [Mxfp4Backend.SM100_FI_MXFP4_BF16]: elif self.mxfp4_backend in [
Mxfp4Backend.SM100_FI_MXFP4_BF16,
Mxfp4Backend.SM90_FI_MXFP4_BF16,
]:
return mxfp4_w4a16_moe_quant_config( return mxfp4_w4a16_moe_quant_config(
w1_bias=layer.w13_bias, w1_bias=layer.w13_bias,
w2_bias=layer.w2_bias, w2_bias=layer.w2_bias,
...@@ -897,9 +912,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -897,9 +912,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
): ):
# B200 code-path # B200 code-path
kwargs = { kwargs = {
"gemm1_alpha": layer.gemm1_alpha,
"gemm1_beta": layer.gemm1_beta,
"gemm1_clamp_limit": layer.gemm1_clamp_limit,
# TODO(bnell): part of quant_config # TODO(bnell): part of quant_config
"max_capture_size": self.max_capture_size, "max_capture_size": self.max_capture_size,
} }
...@@ -935,20 +947,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -935,20 +947,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
if layer.enable_eplb: if layer.enable_eplb:
raise NotImplementedError("EPLB is not supported for mxfp4") raise NotImplementedError("EPLB is not supported for mxfp4")
if self.mxfp4_backend == Mxfp4Backend.MARLIN:
assert self.moe_mk is not None
return self.moe_mk(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=layer.activation,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)
assert _can_support_mxfp4( assert _can_support_mxfp4(
layer.use_grouped_topk, layer.use_grouped_topk,
layer.topk_group, layer.topk_group,
...@@ -967,69 +965,23 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -967,69 +965,23 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
assert ( assert (
self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 or self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
or self.mxfp4_backend == Mxfp4Backend.MARLIN
) )
from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe
# Backend-specific preparation
if self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS:
from flashinfer import mxfp8_quantize
x_quant, x_scale = mxfp8_quantize(x, True, 32)
fake_input_scale = torch.ones(self.num_experts, device=x.device)
quant_scales = [
layer.w13_weight_scale.contiguous().view(torch.int32),
fake_input_scale,
layer.w2_weight_scale.contiguous().view(torch.int32),
fake_input_scale,
]
fi_input = x_quant
extra_kwargs = dict(
use_mxfp8_act_scaling=True,
input_sf=x_scale,
fc1_expert_weights=layer.w13_weight.contiguous().view(torch.long),
fc2_expert_weights=layer.w2_weight.contiguous().view(torch.long),
)
elif self.mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16:
assert x.dtype == torch.bfloat16
quant_scales = [
layer.w13_weight_scale,
layer.w2_weight_scale,
]
fi_input = x
extra_kwargs = dict(
use_w4_group_scaling=True,
fc1_expert_weights=layer.w13_weight,
fc2_expert_weights=layer.w2_weight,
)
output = torch.empty_like(x, dtype=torch.bfloat16) assert self.moe_mk is not None
return self.moe_mk(
flashinfer_cutlass_fused_moe( hidden_states=x,
input=fi_input, w1=layer.w13_weight,
token_selected_experts=topk_ids.to(torch.int).contiguous(), w2=layer.w2_weight,
token_final_scales=topk_weights, topk_weights=topk_weights,
output_dtype=torch.bfloat16, topk_ids=topk_ids,
output=output, activation=layer.activation,
quant_scales=quant_scales, global_num_experts=layer.global_num_experts,
fc1_expert_biases=layer.w13_bias, apply_router_weight_on_input=layer.apply_router_weight_on_input,
fc2_expert_biases=layer.w2_bias, expert_map=layer.expert_map,
swiglu_alpha=layer.gemm1_alpha, shared_experts_input=shared_experts_input,
swiglu_beta=layer.gemm1_beta,
swiglu_limit=layer.gemm1_clamp_limit,
tp_size=self.moe.tp_size,
tp_rank=self.moe.tp_rank,
ep_size=self.moe.ep_size,
ep_rank=self.moe.ep_rank,
tune_max_num_tokens=max(self.max_capture_size, 1),
**extra_kwargs,
) )
return output
def apply_monolithic( def apply_monolithic(
self, self,
layer: FusedMoE, layer: FusedMoE,
......
...@@ -19,6 +19,7 @@ if TYPE_CHECKING: ...@@ -19,6 +19,7 @@ if TYPE_CHECKING:
FP8_DTYPE = current_platform.fp8_dtype() FP8_DTYPE = current_platform.fp8_dtype()
FP4_DTYPE = torch.uint8 FP4_DTYPE = torch.uint8
MXFP_SCALE_DTYPE = torch.uint8
def get_fp8_min_max() -> tuple[float, float]: def get_fp8_min_max() -> tuple[float, float]:
...@@ -151,6 +152,18 @@ kFp8Static128BlockSym = QuantKey(FP8_DTYPE, kStatic128BlockScale, symmetric=True ...@@ -151,6 +152,18 @@ kFp8Static128BlockSym = QuantKey(FP8_DTYPE, kStatic128BlockScale, symmetric=True
kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64)) kDynamic64Scale = ScaleDesc(torch.float32, False, GroupShape(1, 64))
kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True) kFp8Dynamic64Sym = QuantKey(FP8_DTYPE, kDynamic64Scale, symmetric=True)
# TODO (zyongye): Convert all the torch.dtype to scale_dtype
# Changing that requires changing torch compile fused AR+Quant Quant key
# to avoid assertion error
kMxfp4DynamicGroupScale = ScaleDesc(MXFP_SCALE_DTYPE, False, GroupShape(1, 32))
kMxfp4Dynamic = QuantKey(FP4_DTYPE, scale=kMxfp4DynamicGroupScale, symmetric=True)
kMxfp8DynamicGroupScale = ScaleDesc(MXFP_SCALE_DTYPE, False, GroupShape(1, 32))
kMxfp8Dynamic = QuantKey(FP8_DTYPE, scale=kMxfp8DynamicGroupScale, symmetric=True)
kMxfp4StaticGroupScale = ScaleDesc(MXFP_SCALE_DTYPE, True, GroupShape(1, 32))
kMxfp4Static = QuantKey(FP4_DTYPE, scale=kMxfp4StaticGroupScale, symmetric=True)
# Normalize the group_shape to the full extent for any dims that are -1 # Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape): def _normalize_quant_group_shape(x: torch.Tensor, group_shape: GroupShape):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment