"docs/vscode:/vscode.git/clone" did not exist on "c891330f7966fae4aec27962e243a1432d0dc6b5"
Unverified Commit d2b8c412 authored by Yongfei Xu's avatar Yongfei Xu Committed by GitHub
Browse files

Opt fused triton moe: add tma for down proj kernel (#10567)


Co-authored-by: default avatarybyang <10629930+whybeyoung@users.noreply.github.com>
parent bf8f7a94
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": false
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5,
"USE_TMA": false
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": false
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"USE_TMA": false
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
"USE_TMA": false
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": false
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5,
"USE_TMA": false
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4,
"USE_TMA": false
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": false
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": false
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": false
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": true
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": true
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": true
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": true
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": true
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": true
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": true
}
}
...@@ -23,7 +23,11 @@ from sglang.srt.utils import ( ...@@ -23,7 +23,11 @@ from sglang.srt.utils import (
) )
from .fused_moe_triton_config import get_config_dtype_str, try_get_optimal_moe_config from .fused_moe_triton_config import get_config_dtype_str, try_get_optimal_moe_config
from .fused_moe_triton_kernels import invoke_fused_moe_kernel, moe_sum_reduce_triton from .fused_moe_triton_kernels import (
invoke_fused_moe_kernel,
moe_sum_reduce_triton,
support_tensor_descriptor,
)
from .moe_align_block_size import moe_align_block_size from .moe_align_block_size import moe_align_block_size
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -78,6 +82,7 @@ def inplace_fused_experts( ...@@ -78,6 +82,7 @@ def inplace_fused_experts(
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
gemm1_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
filter_expert: bool = True,
) -> None: ) -> None:
fused_experts_impl( fused_experts_impl(
hidden_states, hidden_states,
...@@ -106,6 +111,7 @@ def inplace_fused_experts( ...@@ -106,6 +111,7 @@ def inplace_fused_experts(
routed_scaling_factor, routed_scaling_factor,
gemm1_alpha, gemm1_alpha,
gemm1_limit, gemm1_limit,
filter_expert,
) )
...@@ -134,6 +140,7 @@ def inplace_fused_experts_fake( ...@@ -134,6 +140,7 @@ def inplace_fused_experts_fake(
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
gemm1_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
filter_expert: bool = True,
) -> None: ) -> None:
pass pass
...@@ -172,6 +179,7 @@ def outplace_fused_experts( ...@@ -172,6 +179,7 @@ def outplace_fused_experts(
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
gemm1_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
filter_expert: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
return fused_experts_impl( return fused_experts_impl(
hidden_states, hidden_states,
...@@ -200,6 +208,7 @@ def outplace_fused_experts( ...@@ -200,6 +208,7 @@ def outplace_fused_experts(
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
gemm1_alpha=gemm1_alpha, gemm1_alpha=gemm1_alpha,
gemm1_limit=gemm1_limit, gemm1_limit=gemm1_limit,
filter_expert=filter_expert,
) )
...@@ -229,6 +238,7 @@ def outplace_fused_experts_fake( ...@@ -229,6 +238,7 @@ def outplace_fused_experts_fake(
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
gemm1_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
filter_expert: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -263,6 +273,10 @@ def fused_experts( ...@@ -263,6 +273,10 @@ def fused_experts(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
): ):
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
filter_expert = (
moe_runner_config.num_experts is None
or moe_runner_config.num_experts != moe_runner_config.num_local_experts
)
if moe_runner_config.inplace: if moe_runner_config.inplace:
assert not moe_runner_config.no_combine, "no combine + inplace makes no sense" assert not moe_runner_config.no_combine, "no combine + inplace makes no sense"
torch.ops.sglang.inplace_fused_experts( torch.ops.sglang.inplace_fused_experts(
...@@ -290,6 +304,7 @@ def fused_experts( ...@@ -290,6 +304,7 @@ def fused_experts(
moe_runner_config.routed_scaling_factor, moe_runner_config.routed_scaling_factor,
moe_runner_config.gemm1_alpha, moe_runner_config.gemm1_alpha,
moe_runner_config.gemm1_clamp_limit, moe_runner_config.gemm1_clamp_limit,
filter_expert,
) )
return hidden_states return hidden_states
else: else:
...@@ -319,6 +334,7 @@ def fused_experts( ...@@ -319,6 +334,7 @@ def fused_experts(
routed_scaling_factor=moe_runner_config.routed_scaling_factor, routed_scaling_factor=moe_runner_config.routed_scaling_factor,
gemm1_alpha=moe_runner_config.gemm1_alpha, gemm1_alpha=moe_runner_config.gemm1_alpha,
gemm1_limit=moe_runner_config.gemm1_clamp_limit, gemm1_limit=moe_runner_config.gemm1_clamp_limit,
filter_expert=filter_expert,
) )
...@@ -336,6 +352,11 @@ def swiglu_with_alpha_and_limit(x, gemm1_alpha, gemm1_limit): ...@@ -336,6 +352,11 @@ def swiglu_with_alpha_and_limit(x, gemm1_alpha, gemm1_limit):
return gate * torch.sigmoid(gate * gemm1_alpha) * (up + 1) return gate * torch.sigmoid(gate * gemm1_alpha) * (up + 1)
@functools.lru_cache()
def _down_moe_use_tma():
return support_tensor_descriptor()
def fused_experts_impl( def fused_experts_impl(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
...@@ -363,6 +384,7 @@ def fused_experts_impl( ...@@ -363,6 +384,7 @@ def fused_experts_impl(
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
gemm1_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
filter_expert: bool = True,
): ):
padded_size = padding_size padded_size = padding_size
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter: if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
...@@ -402,25 +424,27 @@ def fused_experts_impl( ...@@ -402,25 +424,27 @@ def fused_experts_impl(
topk_ids.shape[1], topk_ids.shape[1],
config_dtype, config_dtype,
block_shape=block_shape, block_shape=block_shape,
return_down_config=True,
) )
config = get_config_func(M) config, (down_config, max_block_m) = get_config_func(M)
down_moe_use_tma = (
cache = torch.empty( _down_moe_use_tma()
M * topk_ids.shape[1] * max(N, w2.shape[1]), and down_config is not None
device=hidden_states.device, and down_config.pop("USE_TMA", False)
dtype=hidden_states.dtype,
) )
intermediate_cache1 = cache[: M * topk_ids.shape[1] * N].view( topk = topk_ids.shape[1]
(M, topk_ids.shape[1], N), max_padded_tokens = (
min(M * topk, E + 1) * (max_block_m - 1) if down_moe_use_tma else 0
) )
intermediate_cache2 = torch.empty( total_tokens = M * topk + max_padded_tokens
(M * topk_ids.shape[1], N // 2), cache = torch.empty(
total_tokens * max(N, w2.shape[1]),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
intermediate_cache3 = cache[: M * topk_ids.shape[1] * w2.shape[1]].view( intermediate_cache3 = cache[: M * topk * w2.shape[1]].view(
(M, topk_ids.shape[1], w2.shape[1]), (M, topk, w2.shape[1]),
) )
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
...@@ -428,7 +452,7 @@ def fused_experts_impl( ...@@ -428,7 +452,7 @@ def fused_experts_impl(
if no_combine: if no_combine:
assert not inplace assert not inplace
out_hidden_states = torch.empty( out_hidden_states = torch.empty(
(num_tokens, topk_ids.shape[1], w2.shape[1]), (num_tokens, topk, w2.shape[1]),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
...@@ -453,12 +477,28 @@ def fused_experts_impl( ...@@ -453,12 +477,28 @@ def fused_experts_impl(
# chunk. Note that in most cases we only have one chunk # chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and # so the cache size and config are already set correctly and
# do not need to be adjusted. # do not need to be adjusted.
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] config, (down_config, _) = get_config_func(tokens_in_chunk)
intermediate_cache2 = intermediate_cache2[ down_moe_use_tma = (
: tokens_in_chunk * topk_ids.shape[1] _down_moe_use_tma()
] and down_config is not None
and down_config.pop("USE_TMA", False)
)
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
config = get_config_func(tokens_in_chunk)
padded_tokens = (
min(tokens_in_chunk * topk, E + 1) * (config["BLOCK_SIZE_M"] - 1)
if down_moe_use_tma
else 0
)
total_tokens = tokens_in_chunk * topk + padded_tokens
intermediate_cache1 = cache[: total_tokens * N].view(
(total_tokens, N),
)
intermediate_cache2 = torch.empty(
(total_tokens, N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
...@@ -490,6 +530,8 @@ def fused_experts_impl( ...@@ -490,6 +530,8 @@ def fused_experts_impl(
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
c_sorted=down_moe_use_tma,
filter_expert=filter_expert,
) )
if activation == "silu": if activation == "silu":
if gemm1_alpha is not None: if gemm1_alpha is not None:
...@@ -536,7 +578,7 @@ def fused_experts_impl( ...@@ -536,7 +578,7 @@ def fused_experts_impl(
num_tokens_post_padded, num_tokens_post_padded,
not apply_router_weight_on_input, not apply_router_weight_on_input,
1, 1,
config, down_config or config,
compute_type=compute_type, compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
...@@ -544,6 +586,9 @@ def fused_experts_impl( ...@@ -544,6 +586,9 @@ def fused_experts_impl(
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
a_use_tma=down_moe_use_tma,
b_use_tma=down_moe_use_tma,
filter_expert=filter_expert,
) )
if routed_scaling_factor is None: if routed_scaling_factor is None:
......
...@@ -21,6 +21,7 @@ def get_config_file_name( ...@@ -21,6 +21,7 @@ def get_config_file_name(
dtype: Optional[str], dtype: Optional[str],
block_shape: Optional[int] = None, block_shape: Optional[int] = None,
per_channel_quant: bool = False, per_channel_quant: bool = False,
down_moe: bool = False,
) -> str: ) -> str:
device_name = get_device_name().replace(" ", "_") device_name = get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}" dtype_selector = "" if not dtype else f",dtype={dtype}"
...@@ -28,7 +29,8 @@ def get_config_file_name( ...@@ -28,7 +29,8 @@ def get_config_file_name(
"" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
) )
per_channel_quant_selector = ",per_channel_quant=True" if per_channel_quant else "" per_channel_quant_selector = ",per_channel_quant=True" if per_channel_quant else ""
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}{per_channel_quant_selector}.json" down_moe_selector = "_down" if down_moe else ""
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}{per_channel_quant_selector}{down_moe_selector}.json"
@functools.lru_cache @functools.lru_cache
...@@ -39,6 +41,7 @@ def get_moe_configs( ...@@ -39,6 +41,7 @@ def get_moe_configs(
block_n: Optional[int] = 0, block_n: Optional[int] = 0,
block_k: Optional[int] = 0, block_k: Optional[int] = 0,
per_channel_quant: bool = False, per_channel_quant: bool = False,
down_moe: bool = False,
) -> Optional[Dict[int, Any]]: ) -> Optional[Dict[int, Any]]:
""" """
Return optimized configurations for the fused MoE kernel. Return optimized configurations for the fused MoE kernel.
...@@ -54,7 +57,12 @@ def get_moe_configs( ...@@ -54,7 +57,12 @@ def get_moe_configs(
# First look up if an optimized configuration is available in the configs # First look up if an optimized configuration is available in the configs
# directory # directory
json_file_name = get_config_file_name( json_file_name = get_config_file_name(
E, N, dtype, [block_n, block_k], per_channel_quant E,
N,
dtype,
[block_n, block_k],
per_channel_quant,
down_moe=down_moe,
) )
# We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains, # We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
...@@ -177,9 +185,12 @@ def try_get_optimal_moe_config( ...@@ -177,9 +185,12 @@ def try_get_optimal_moe_config(
M: int, M: int,
is_marlin: bool = False, is_marlin: bool = False,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
return_down_config: bool = False,
): ):
from sglang.srt.layers.moe.fused_moe_triton import get_config from sglang.srt.layers.moe.fused_moe_triton import get_config
down_config = None
max_block_m = None
override_config = get_config() override_config = get_config()
if override_config: if override_config:
config = override_config config = override_config
...@@ -188,7 +199,7 @@ def try_get_optimal_moe_config( ...@@ -188,7 +199,7 @@ def try_get_optimal_moe_config(
E, _, N = w2_shape E, _, N = w2_shape
block_n = block_shape[0] if block_shape else 0 block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0 block_k = block_shape[1] if block_shape else 0
configs = get_moe_configs(E, N, dtype, block_n, block_k) configs = get_moe_configs(E, N, dtype, block_n, block_k, down_moe=False)
if configs: if configs:
# If an optimal configuration map has been found, look up the # If an optimal configuration map has been found, look up the
...@@ -199,6 +210,21 @@ def try_get_optimal_moe_config( ...@@ -199,6 +210,21 @@ def try_get_optimal_moe_config(
config = get_default_config( config = get_default_config(
M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape
) )
if return_down_config:
down_configs = get_moe_configs(E, N, dtype, block_n, block_k, down_moe=True)
if down_configs:
down_config = down_configs[
min(down_configs.keys(), key=lambda x: abs(x - M))
]
down_config = dict(**down_config)
max_block_m = max(
[cfg["BLOCK_SIZE_M"] for cfg in down_configs.values()]
)
if return_down_config:
assert (
down_config is None or config["BLOCK_SIZE_M"] == down_config["BLOCK_SIZE_M"]
)
return config, (down_config, max_block_m)
return config return config
......
...@@ -25,6 +25,13 @@ from sglang.srt.utils import ( ...@@ -25,6 +25,13 @@ from sglang.srt.utils import (
is_hip, is_hip,
) )
try:
from triton.tools.tensor_descriptor import TensorDescriptor
_support_tensor_descriptor = True
except:
_support_tensor_descriptor = False
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
...@@ -41,6 +48,10 @@ elif _is_hip: ...@@ -41,6 +48,10 @@ elif _is_hip:
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0 padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
def support_tensor_descriptor():
return _support_tensor_descriptor
@triton.jit @triton.jit
def write_zeros_to_output( def write_zeros_to_output(
c_ptr, c_ptr,
...@@ -108,6 +119,7 @@ def fused_moe_kernel_gptq_awq( ...@@ -108,6 +119,7 @@ def fused_moe_kernel_gptq_awq(
use_int4_w4a16: tl.constexpr, use_int4_w4a16: tl.constexpr,
use_int8_w8a16: tl.constexpr, use_int8_w8a16: tl.constexpr,
even_Ks: tl.constexpr, even_Ks: tl.constexpr,
filter_expert: tl.constexpr,
): ):
""" """
Implements the fused computation for a Mixture of Experts (MOE) using Implements the fused computation for a Mixture of Experts (MOE) using
...@@ -161,7 +173,7 @@ def fused_moe_kernel_gptq_awq( ...@@ -161,7 +173,7 @@ def fused_moe_kernel_gptq_awq(
token_mask = offs_token < num_valid_tokens token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
if off_experts == -1: if filter_expert and off_experts == -1:
# ----------------------------------------------------------- # -----------------------------------------------------------
# Write back zeros to the output when the expert is not # Write back zeros to the output when the expert is not
# in the current expert parallel rank. # in the current expert parallel rank.
...@@ -296,7 +308,9 @@ def fused_moe_kernel_gptq_awq( ...@@ -296,7 +308,9 @@ def fused_moe_kernel_gptq_awq(
def fused_moe_kernel( def fused_moe_kernel(
# Pointers to matrices # Pointers to matrices
a_ptr, a_ptr,
a_desc,
b_ptr, b_ptr,
b_desc,
bias_ptr, bias_ptr,
c_ptr, c_ptr,
a_scale_ptr, a_scale_ptr,
...@@ -344,6 +358,8 @@ def fused_moe_kernel( ...@@ -344,6 +358,8 @@ def fused_moe_kernel(
use_int8_w8a16: tl.constexpr, use_int8_w8a16: tl.constexpr,
per_channel_quant: tl.constexpr, per_channel_quant: tl.constexpr,
even_Ks: tl.constexpr, even_Ks: tl.constexpr,
c_sorted: tl.constexpr,
filter_expert: tl.constexpr,
): ):
""" """
Implements the fused computation for a Mixture of Experts (MOE) using Implements the fused computation for a Mixture of Experts (MOE) using
...@@ -399,9 +415,10 @@ def fused_moe_kernel( ...@@ -399,9 +415,10 @@ def fused_moe_kernel(
offs_token = offs_token.to(tl.int64) offs_token = offs_token.to(tl.int64)
token_mask = offs_token < num_valid_tokens token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64) off_experts_i32 = tl.load(expert_ids_ptr + pid_m)
off_experts = off_experts_i32.to(tl.int64)
if off_experts == -1: if filter_expert and off_experts == -1:
# ----------------------------------------------------------- # -----------------------------------------------------------
# Write back zeros to the output when the expert is not # Write back zeros to the output when the expert is not
# in the current expert parallel rank. # in the current expert parallel rank.
...@@ -421,15 +438,23 @@ def fused_moe_kernel( ...@@ -421,15 +438,23 @@ def fused_moe_kernel(
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = tl.arange(0, BLOCK_SIZE_K)
if a_desc is not None:
assert use_fp8_w8a8 and group_n > 0 and group_k > 0
start_offs_m = pid_m * BLOCK_SIZE_M
else:
a_ptrs = a_ptr + ( a_ptrs = a_ptr + (
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
) )
if b_desc is not None:
start_offs_n = pid_n * BLOCK_SIZE_N
else:
b_ptrs = ( b_ptrs = (
b_ptr b_ptr
+ off_experts * stride_be + off_experts * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
) )
if bias_ptr is not None: if bias_ptr is not None:
bias = tl.load( bias = tl.load(
bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n
...@@ -443,8 +468,14 @@ def fused_moe_kernel( ...@@ -443,8 +468,14 @@ def fused_moe_kernel(
if use_fp8_w8a8 or use_int8_w8a8: if use_fp8_w8a8 or use_int8_w8a8:
# block-wise # block-wise
if group_k > 0 and group_n > 0: if group_k > 0 and group_n > 0:
if a_desc is not None:
a_scale_ptrs = a_scale_ptr + offs_token_id * stride_asm
else:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
if BLOCK_SIZE_N > group_n:
offs_bsn = offs_bn // group_n offs_bsn = offs_bn // group_n
else:
offs_bsn = pid_n * BLOCK_SIZE_N // group_n
b_scale_ptrs = ( b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
) )
...@@ -469,37 +500,49 @@ def fused_moe_kernel( ...@@ -469,37 +500,49 @@ def fused_moe_kernel(
# `accumulator` will be converted back to fp16 after the loop. # `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): for k_start in range(0, K, BLOCK_SIZE_K):
# Load the next block of A and B, generate a mask by checking the # Load the next block of A and B, generate a mask by checking the
# K dimension. # K dimension.
if even_Ks: if a_desc is not None:
a = a_desc.load([start_offs_m, k_start])
elif even_Ks:
a = tl.load( a = tl.load(
a_ptrs, a_ptrs,
mask=token_mask[:, None], mask=token_mask[:, None],
other=0.0, other=0.0,
) )
b = tl.load(b_ptrs)
else: else:
a = tl.load( a = tl.load(
a_ptrs, a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), mask=token_mask[:, None] & (offs_k[None, :] < K - k_start),
other=0.0, other=0.0,
) )
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
if b_desc is not None:
b = (
b_desc.load([off_experts_i32, start_offs_n, k_start])
.reshape(BLOCK_SIZE_N, BLOCK_SIZE_K)
.T
)
elif even_Ks:
b = tl.load(b_ptrs)
else:
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k_start, other=0.0)
# We accumulate along the K dimension. # We accumulate along the K dimension.
if use_int8_w8a16: if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8 or use_int8_w8a8: elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0: if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k offs_ks = k_start // group_k
a_scale = tl.load( a_scale = tl.load(
a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
) )
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
if BLOCK_SIZE_N > group_n:
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else:
accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale)
else: else:
if use_fp8_w8a8: if use_fp8_w8a8:
accumulator = tl.dot(a, b, acc=accumulator) accumulator = tl.dot(a, b, acc=accumulator)
...@@ -508,7 +551,9 @@ def fused_moe_kernel( ...@@ -508,7 +551,9 @@ def fused_moe_kernel(
else: else:
accumulator += tl.dot(a, b) accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block. # Advance the ptrs to the next K block.
if a_desc is None:
a_ptrs += BLOCK_SIZE_K * stride_ak a_ptrs += BLOCK_SIZE_K * stride_ak
if b_desc is None:
b_ptrs += BLOCK_SIZE_K * stride_bk b_ptrs += BLOCK_SIZE_K * stride_bk
if use_int8_w8a16: if use_int8_w8a16:
...@@ -528,6 +573,11 @@ def fused_moe_kernel( ...@@ -528,6 +573,11 @@ def fused_moe_kernel(
# ----------------------------------------------------------- # -----------------------------------------------------------
# Write back the block of the output # Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if c_sorted:
c_ptrs = (
c_ptr + stride_cm * offs_token_id[:, None] + stride_cn * offs_cn[None, :]
)
else:
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N) c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask) tl.store(c_ptrs, accumulator, mask=c_mask)
...@@ -557,6 +607,10 @@ def invoke_fused_moe_kernel( ...@@ -557,6 +607,10 @@ def invoke_fused_moe_kernel(
per_channel_quant: bool, per_channel_quant: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False, no_combine: bool = False,
a_use_tma: bool = False,
b_use_tma: bool = False,
c_sorted: bool = False,
filter_expert: bool = True,
) -> None: ) -> None:
assert topk_weights.stride(1) == 1 assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
...@@ -662,14 +716,38 @@ def invoke_fused_moe_kernel( ...@@ -662,14 +716,38 @@ def invoke_fused_moe_kernel(
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
even_Ks=even_Ks, even_Ks=even_Ks,
filter_expert=filter_expert,
**config, **config,
) )
else: else:
if a_use_tma or b_use_tma:
# TMA descriptors require a global memory allocation
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
return torch.empty(size, device="cuda", dtype=torch.int8)
triton.set_allocator(alloc_fn)
if a_use_tma:
a_desc = TensorDescriptor(
A, A.shape, A.stride(), [config["BLOCK_SIZE_M"], config["BLOCK_SIZE_K"]]
)
else:
a_desc = None
if b_use_tma:
b_desc = TensorDescriptor(
B,
B.shape,
B.stride(),
[1, config["BLOCK_SIZE_N"], config["BLOCK_SIZE_K"]],
)
else:
b_desc = None
fused_moe_kernel[grid]( fused_moe_kernel[grid](
A, A,
a_desc,
B, B,
b_desc,
bias, bias,
C, C,
A_scale, A_scale,
...@@ -689,8 +767,8 @@ def invoke_fused_moe_kernel( ...@@ -689,8 +767,8 @@ def invoke_fused_moe_kernel(
B.stride(1), B.stride(1),
bias.stride(0) if bias is not None else 0, bias.stride(0) if bias is not None else 0,
bias.stride(1) if bias is not None else 0, bias.stride(1) if bias is not None else 0,
C.stride(1), C.stride(-2),
C.stride(2), C.stride(-1),
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
...@@ -706,6 +784,8 @@ def invoke_fused_moe_kernel( ...@@ -706,6 +784,8 @@ def invoke_fused_moe_kernel(
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
even_Ks=even_Ks, even_Ks=even_Ks,
c_sorted=c_sorted,
filter_expert=filter_expert,
**config, **config,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment