"tests/vscode:/vscode.git/clone" did not exist on "8b8291830e6ca1f5882700e214f114d5442a04db"
Unverified Commit b5245064 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

[code style] restruct fused_moe to avoid very long single file (#9878)

parent 9d9fa9a5
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
fused_experts, from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import (
get_config_file_name, get_config_file_name,
moe_align_block_size,
try_get_optimal_moe_config, try_get_optimal_moe_config,
) )
from sglang.srt.layers.moe.fused_moe_triton.layer import ( from sglang.srt.layers.moe.fused_moe_triton.layer import (
FusedMoE, FusedMoE,
FusedMoeWeightScaleSupported, FusedMoeWeightScaleSupported,
) )
from sglang.srt.layers.moe.fused_moe_triton.moe_align_block_size import (
moe_align_block_size,
)
_config: Optional[Dict[str, Any]] = None _config: Optional[Dict[str, Any]] = None
......
from __future__ import annotations
import functools
import json
import logging
import os
from typing import Any, Dict, List, Optional, Tuple
import torch
import triton
from sglang.srt.utils import get_device_name, is_hip
logger = logging.getLogger(__name__)
_is_hip = is_hip()
def get_config_file_name(
E: int, N: int, dtype: Optional[str], block_shape: Optional[int] = None
) -> str:
device_name = get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}"
block_shape_selector = (
"" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
)
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json"
@functools.lru_cache
def get_moe_configs(
E: int,
N: int,
dtype: Optional[str],
block_n: Optional[int] = 0,
block_k: Optional[int] = 0,
) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the fused MoE kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the fused_moe kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# Supported Triton versions, should be sorted from the newest to the oldest
supported_triton_versions = ["3.3.1", "3.2.0", "3.1.0"]
# First look up if an optimized configuration is available in the configs
# directory
json_file_name = get_config_file_name(E, N, dtype, [block_n, block_k])
# We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
# so we also include the Triton version as a key for finding the fused_moe_kernel config to achieve the best performance.
triton_version = triton.__version__
version_dir = f"triton_{triton_version.replace('.', '_')}"
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"configs",
version_dir,
json_file_name,
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
# Please note that although we find the config files, performance might still be suboptimal.
# This is because the tuning environment might differ from your current environment.
# For example, updating the Triton version might cause all old configs to become suboptimal.
# To achieve the best performance, consider re-tuning the Triton fused MOE kernel in your environment.
# For the tuning method, refer to: https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton
logger.info(f"Using MoE kernel config from {config_file_path}.")
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
# Searching for other triton versions that supports the same config
for try_triton_version in supported_triton_versions:
if try_triton_version == triton_version:
continue
try_config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"configs",
f"triton_{try_triton_version.replace('.', '_')}",
json_file_name,
)
if os.path.exists(try_config_file_path):
with open(try_config_file_path) as f:
logger.warning(
f"Config file not found at {config_file_path}. Fallback to triton version {try_triton_version} and use MoE kernel config from {try_config_file_path}. Performance might be sub-optimal!",
)
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
# If no optimized configuration is available, we will use the default
# configuration
logger.warning(
(
"Using default MoE kernel config. Performance might be sub-optimal! "
"Config file not found at %s, you can create them with https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton"
),
config_file_path,
)
return None
def get_default_config(
M: int,
E: int,
N: int,
K: int,
topk: int,
dtype: Optional[str],
is_marlin: bool,
block_shape: Optional[List[int]] = None,
) -> Dict[str, int]:
if dtype == "fp8_w8a8":
if block_shape is None:
config = {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2 if _is_hip else 4,
}
if M <= E:
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2 if _is_hip else 4,
}
else:
# Block-wise quant: BLOCK_SIZE_K must be divisible by block_shape[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_shape[0],
"BLOCK_SIZE_K": block_shape[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2 if _is_hip else 3,
}
else:
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
}
# A heuristic: fused marlin works faster with this config for small M
if M <= E or (is_marlin and M <= 32):
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
return config
def try_get_optimal_moe_config(
w1_shape: Tuple[int, ...],
w2_shape: Tuple[int, ...],
top_k: int,
dtype: Optional[str],
M: int,
is_marlin: bool = False,
block_shape: Optional[List[int]] = None,
):
from sglang.srt.layers.moe.fused_moe_triton import get_config
override_config = get_config()
if override_config:
config = override_config
else:
# First try to load optimal config from the file
E, _, N = w2_shape
block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0
configs = get_moe_configs(E, N, dtype, block_n, block_k)
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = get_default_config(
M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape
)
return config
def get_config_dtype_str(
dtype: torch.dtype,
use_int8_w8a16: Optional[bool] = False,
use_int4_w4a16: Optional[bool] = False,
use_fp8_w8a8: Optional[bool] = False,
use_int8_w8a8: Optional[bool] = False,
):
if use_fp8_w8a8:
return "fp8_w8a8"
elif use_int8_w8a8:
return "int8_w8a8"
elif use_int4_w4a16:
return "int4_w4a16"
elif use_int8_w8a16:
return "int8_w8a16"
elif dtype == torch.float:
# avoiding cases where kernel fails when float32 MoE
# use fp16/bfloat16 configs
return "float32"
return None
from __future__ import annotations
from typing import Tuple
import torch
import triton
from sglang.srt.utils import is_cuda, is_hip
_is_cuda = is_cuda()
_is_hip = is_hip()
if _is_cuda or _is_hip:
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int, num_experts: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process
so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
sorted_ids = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
)
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
expert_ids = torch.empty(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
# In EP, expert_ids for filtered experts are -1. We have num_experts + 1 ids in total.
cumsum_buffer = torch.empty(
(num_experts + 2,), dtype=torch.int32, device=topk_ids.device
)
# Threshold based on benchmark results
fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
if not fuse_sorted_ids_padding:
sorted_ids.fill_(topk_ids.numel())
sgl_moe_align_block_size(
topk_ids,
num_experts + 1,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
cumsum_buffer,
fuse_sorted_ids_padding,
)
return sorted_ids, expert_ids, num_tokens_post_pad
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment