"tools/vscode:/vscode.git/clone" did not exist on "e74b3d3dd297c395ff1c74a1c0d1da4c49cebd1b"
Unverified Commit d2b8c412 authored by Yongfei Xu's avatar Yongfei Xu Committed by GitHub
Browse files

Opt fused triton moe: add tma for down proj kernel (#10567)


Co-authored-by: default avatarybyang <10629930+whybeyoung@users.noreply.github.com>
parent bf8f7a94
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": false
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5,
"USE_TMA": false
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": false
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"USE_TMA": false
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
"USE_TMA": false
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": false
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5,
"USE_TMA": false
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4,
"USE_TMA": false
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": false
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": false
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": false
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": true
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": true
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": true
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": true
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": true
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": true
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3,
"USE_TMA": true
}
}
......@@ -23,7 +23,11 @@ from sglang.srt.utils import (
)
from .fused_moe_triton_config import get_config_dtype_str, try_get_optimal_moe_config
from .fused_moe_triton_kernels import invoke_fused_moe_kernel, moe_sum_reduce_triton
from .fused_moe_triton_kernels import (
invoke_fused_moe_kernel,
moe_sum_reduce_triton,
support_tensor_descriptor,
)
from .moe_align_block_size import moe_align_block_size
if TYPE_CHECKING:
......@@ -78,6 +82,7 @@ def inplace_fused_experts(
routed_scaling_factor: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None,
filter_expert: bool = True,
) -> None:
fused_experts_impl(
hidden_states,
......@@ -106,6 +111,7 @@ def inplace_fused_experts(
routed_scaling_factor,
gemm1_alpha,
gemm1_limit,
filter_expert,
)
......@@ -134,6 +140,7 @@ def inplace_fused_experts_fake(
routed_scaling_factor: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None,
filter_expert: bool = True,
) -> None:
pass
......@@ -172,6 +179,7 @@ def outplace_fused_experts(
routed_scaling_factor: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None,
filter_expert: bool = True,
) -> torch.Tensor:
return fused_experts_impl(
hidden_states,
......@@ -200,6 +208,7 @@ def outplace_fused_experts(
routed_scaling_factor=routed_scaling_factor,
gemm1_alpha=gemm1_alpha,
gemm1_limit=gemm1_limit,
filter_expert=filter_expert,
)
......@@ -229,6 +238,7 @@ def outplace_fused_experts_fake(
routed_scaling_factor: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None,
filter_expert: bool = True,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
......@@ -263,6 +273,10 @@ def fused_experts(
block_shape: Optional[List[int]] = None,
):
topk_weights, topk_ids, _ = topk_output
filter_expert = (
moe_runner_config.num_experts is None
or moe_runner_config.num_experts != moe_runner_config.num_local_experts
)
if moe_runner_config.inplace:
assert not moe_runner_config.no_combine, "no combine + inplace makes no sense"
torch.ops.sglang.inplace_fused_experts(
......@@ -290,6 +304,7 @@ def fused_experts(
moe_runner_config.routed_scaling_factor,
moe_runner_config.gemm1_alpha,
moe_runner_config.gemm1_clamp_limit,
filter_expert,
)
return hidden_states
else:
......@@ -319,6 +334,7 @@ def fused_experts(
routed_scaling_factor=moe_runner_config.routed_scaling_factor,
gemm1_alpha=moe_runner_config.gemm1_alpha,
gemm1_limit=moe_runner_config.gemm1_clamp_limit,
filter_expert=filter_expert,
)
......@@ -336,6 +352,11 @@ def swiglu_with_alpha_and_limit(x, gemm1_alpha, gemm1_limit):
return gate * torch.sigmoid(gate * gemm1_alpha) * (up + 1)
@functools.lru_cache()
def _down_moe_use_tma():
return support_tensor_descriptor()
def fused_experts_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
......@@ -363,6 +384,7 @@ def fused_experts_impl(
routed_scaling_factor: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None,
filter_expert: bool = True,
):
padded_size = padding_size
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
......@@ -402,25 +424,27 @@ def fused_experts_impl(
topk_ids.shape[1],
config_dtype,
block_shape=block_shape,
return_down_config=True,
)
config = get_config_func(M)
cache = torch.empty(
M * topk_ids.shape[1] * max(N, w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
config, (down_config, max_block_m) = get_config_func(M)
down_moe_use_tma = (
_down_moe_use_tma()
and down_config is not None
and down_config.pop("USE_TMA", False)
)
intermediate_cache1 = cache[: M * topk_ids.shape[1] * N].view(
(M, topk_ids.shape[1], N),
topk = topk_ids.shape[1]
max_padded_tokens = (
min(M * topk, E + 1) * (max_block_m - 1) if down_moe_use_tma else 0
)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N // 2),
total_tokens = M * topk + max_padded_tokens
cache = torch.empty(
total_tokens * max(N, w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache3 = cache[: M * topk_ids.shape[1] * w2.shape[1]].view(
(M, topk_ids.shape[1], w2.shape[1]),
intermediate_cache3 = cache[: M * topk * w2.shape[1]].view(
(M, topk, w2.shape[1]),
)
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
......@@ -428,7 +452,7 @@ def fused_experts_impl(
if no_combine:
assert not inplace
out_hidden_states = torch.empty(
(num_tokens, topk_ids.shape[1], w2.shape[1]),
(num_tokens, topk, w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
......@@ -453,12 +477,28 @@ def fused_experts_impl(
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
intermediate_cache2 = intermediate_cache2[
: tokens_in_chunk * topk_ids.shape[1]
]
config, (down_config, _) = get_config_func(tokens_in_chunk)
down_moe_use_tma = (
_down_moe_use_tma()
and down_config is not None
and down_config.pop("USE_TMA", False)
)
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
config = get_config_func(tokens_in_chunk)
padded_tokens = (
min(tokens_in_chunk * topk, E + 1) * (config["BLOCK_SIZE_M"] - 1)
if down_moe_use_tma
else 0
)
total_tokens = tokens_in_chunk * topk + padded_tokens
intermediate_cache1 = cache[: total_tokens * N].view(
(total_tokens, N),
)
intermediate_cache2 = torch.empty(
(total_tokens, N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
......@@ -490,6 +530,8 @@ def fused_experts_impl(
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
c_sorted=down_moe_use_tma,
filter_expert=filter_expert,
)
if activation == "silu":
if gemm1_alpha is not None:
......@@ -536,7 +578,7 @@ def fused_experts_impl(
num_tokens_post_padded,
not apply_router_weight_on_input,
1,
config,
down_config or config,
compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
......@@ -544,6 +586,9 @@ def fused_experts_impl(
use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant,
block_shape=block_shape,
a_use_tma=down_moe_use_tma,
b_use_tma=down_moe_use_tma,
filter_expert=filter_expert,
)
if routed_scaling_factor is None:
......
......@@ -21,6 +21,7 @@ def get_config_file_name(
dtype: Optional[str],
block_shape: Optional[int] = None,
per_channel_quant: bool = False,
down_moe: bool = False,
) -> str:
device_name = get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}"
......@@ -28,7 +29,8 @@ def get_config_file_name(
"" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
)
per_channel_quant_selector = ",per_channel_quant=True" if per_channel_quant else ""
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}{per_channel_quant_selector}.json"
down_moe_selector = "_down" if down_moe else ""
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}{per_channel_quant_selector}{down_moe_selector}.json"
@functools.lru_cache
......@@ -39,6 +41,7 @@ def get_moe_configs(
block_n: Optional[int] = 0,
block_k: Optional[int] = 0,
per_channel_quant: bool = False,
down_moe: bool = False,
) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the fused MoE kernel.
......@@ -54,7 +57,12 @@ def get_moe_configs(
# First look up if an optimized configuration is available in the configs
# directory
json_file_name = get_config_file_name(
E, N, dtype, [block_n, block_k], per_channel_quant
E,
N,
dtype,
[block_n, block_k],
per_channel_quant,
down_moe=down_moe,
)
# We found that using the fused_moe_kernel config from Triton 3.1.0 with Triton 3.2.0 results in negative performance gains,
......@@ -177,9 +185,12 @@ def try_get_optimal_moe_config(
M: int,
is_marlin: bool = False,
block_shape: Optional[List[int]] = None,
return_down_config: bool = False,
):
from sglang.srt.layers.moe.fused_moe_triton import get_config
down_config = None
max_block_m = None
override_config = get_config()
if override_config:
config = override_config
......@@ -188,7 +199,7 @@ def try_get_optimal_moe_config(
E, _, N = w2_shape
block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0
configs = get_moe_configs(E, N, dtype, block_n, block_k)
configs = get_moe_configs(E, N, dtype, block_n, block_k, down_moe=False)
if configs:
# If an optimal configuration map has been found, look up the
......@@ -199,6 +210,21 @@ def try_get_optimal_moe_config(
config = get_default_config(
M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape
)
if return_down_config:
down_configs = get_moe_configs(E, N, dtype, block_n, block_k, down_moe=True)
if down_configs:
down_config = down_configs[
min(down_configs.keys(), key=lambda x: abs(x - M))
]
down_config = dict(**down_config)
max_block_m = max(
[cfg["BLOCK_SIZE_M"] for cfg in down_configs.values()]
)
if return_down_config:
assert (
down_config is None or config["BLOCK_SIZE_M"] == down_config["BLOCK_SIZE_M"]
)
return config, (down_config, max_block_m)
return config
......
......@@ -25,6 +25,13 @@ from sglang.srt.utils import (
is_hip,
)
try:
from triton.tools.tensor_descriptor import TensorDescriptor
_support_tensor_descriptor = True
except:
_support_tensor_descriptor = False
_is_hip = is_hip()
_is_cuda = is_cuda()
_is_cpu_amx_available = cpu_has_amx_support()
......@@ -41,6 +48,10 @@ elif _is_hip:
padding_size = 128 if bool(int(os.getenv("SGLANG_MOE_PADDING", "0"))) else 0
def support_tensor_descriptor():
return _support_tensor_descriptor
@triton.jit
def write_zeros_to_output(
c_ptr,
......@@ -108,6 +119,7 @@ def fused_moe_kernel_gptq_awq(
use_int4_w4a16: tl.constexpr,
use_int8_w8a16: tl.constexpr,
even_Ks: tl.constexpr,
filter_expert: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
......@@ -161,7 +173,7 @@ def fused_moe_kernel_gptq_awq(
token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
if off_experts == -1:
if filter_expert and off_experts == -1:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
......@@ -296,7 +308,9 @@ def fused_moe_kernel_gptq_awq(
def fused_moe_kernel(
# Pointers to matrices
a_ptr,
a_desc,
b_ptr,
b_desc,
bias_ptr,
c_ptr,
a_scale_ptr,
......@@ -344,6 +358,8 @@ def fused_moe_kernel(
use_int8_w8a16: tl.constexpr,
per_channel_quant: tl.constexpr,
even_Ks: tl.constexpr,
c_sorted: tl.constexpr,
filter_expert: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
......@@ -399,9 +415,10 @@ def fused_moe_kernel(
offs_token = offs_token.to(tl.int64)
token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
off_experts_i32 = tl.load(expert_ids_ptr + pid_m)
off_experts = off_experts_i32.to(tl.int64)
if off_experts == -1:
if filter_expert and off_experts == -1:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
......@@ -421,15 +438,23 @@ def fused_moe_kernel(
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
if a_desc is not None:
assert use_fp8_w8a8 and group_n > 0 and group_k > 0
start_offs_m = pid_m * BLOCK_SIZE_M
else:
a_ptrs = a_ptr + (
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
if b_desc is not None:
start_offs_n = pid_n * BLOCK_SIZE_N
else:
b_ptrs = (
b_ptr
+ off_experts * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
b_ptrs = (
b_ptr
+ off_experts * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
if bias_ptr is not None:
bias = tl.load(
bias_ptr + off_experts * stride_bias_e + offs_bn[None, :] * stride_bias_n
......@@ -443,8 +468,14 @@ def fused_moe_kernel(
if use_fp8_w8a8 or use_int8_w8a8:
# block-wise
if group_k > 0 and group_n > 0:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
offs_bsn = offs_bn // group_n
if a_desc is not None:
a_scale_ptrs = a_scale_ptr + offs_token_id * stride_asm
else:
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
if BLOCK_SIZE_N > group_n:
offs_bsn = offs_bn // group_n
else:
offs_bsn = pid_n * BLOCK_SIZE_N // group_n
b_scale_ptrs = (
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
)
......@@ -469,37 +500,49 @@ def fused_moe_kernel(
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
for k_start in range(0, K, BLOCK_SIZE_K):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
if even_Ks:
if a_desc is not None:
a = a_desc.load([start_offs_m, k_start])
elif even_Ks:
a = tl.load(
a_ptrs,
mask=token_mask[:, None],
other=0.0,
)
b = tl.load(b_ptrs)
else:
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
mask=token_mask[:, None] & (offs_k[None, :] < K - k_start),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
if b_desc is not None:
b = (
b_desc.load([off_experts_i32, start_offs_n, k_start])
.reshape(BLOCK_SIZE_N, BLOCK_SIZE_K)
.T
)
elif even_Ks:
b = tl.load(b_ptrs)
else:
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k_start, other=0.0)
# We accumulate along the K dimension.
if use_int8_w8a16:
accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
elif use_fp8_w8a8 or use_int8_w8a8:
if group_k > 0 and group_n > 0:
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_scale = tl.load(
a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0
)
b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk)
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
if BLOCK_SIZE_N > group_n:
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
else:
accumulator += tl.dot(a, b) * (a_scale[:, None] * b_scale)
else:
if use_fp8_w8a8:
accumulator = tl.dot(a, b, acc=accumulator)
......@@ -508,8 +551,10 @@ def fused_moe_kernel(
else:
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if a_desc is None:
a_ptrs += BLOCK_SIZE_K * stride_ak
if b_desc is None:
b_ptrs += BLOCK_SIZE_K * stride_bk
if use_int8_w8a16:
accumulator *= b_scale
......@@ -528,7 +573,12 @@ def fused_moe_kernel(
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
if c_sorted:
c_ptrs = (
c_ptr + stride_cm * offs_token_id[:, None] + stride_cn * offs_cn[None, :]
)
else:
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
......@@ -557,6 +607,10 @@ def invoke_fused_moe_kernel(
per_channel_quant: bool,
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
a_use_tma: bool = False,
b_use_tma: bool = False,
c_sorted: bool = False,
filter_expert: bool = True,
) -> None:
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
......@@ -662,14 +716,38 @@ def invoke_fused_moe_kernel(
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16,
even_Ks=even_Ks,
filter_expert=filter_expert,
**config,
)
else:
if a_use_tma or b_use_tma:
# TMA descriptors require a global memory allocation
def alloc_fn(size: int, alignment: int, stream: Optional[int]):
return torch.empty(size, device="cuda", dtype=torch.int8)
triton.set_allocator(alloc_fn)
if a_use_tma:
a_desc = TensorDescriptor(
A, A.shape, A.stride(), [config["BLOCK_SIZE_M"], config["BLOCK_SIZE_K"]]
)
else:
a_desc = None
if b_use_tma:
b_desc = TensorDescriptor(
B,
B.shape,
B.stride(),
[1, config["BLOCK_SIZE_N"], config["BLOCK_SIZE_K"]],
)
else:
b_desc = None
fused_moe_kernel[grid](
A,
a_desc,
B,
b_desc,
bias,
C,
A_scale,
......@@ -689,8 +767,8 @@ def invoke_fused_moe_kernel(
B.stride(1),
bias.stride(0) if bias is not None else 0,
bias.stride(1) if bias is not None else 0,
C.stride(1),
C.stride(2),
C.stride(-2),
C.stride(-1),
A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
......@@ -706,6 +784,8 @@ def invoke_fused_moe_kernel(
use_int8_w8a16=use_int8_w8a16,
per_channel_quant=per_channel_quant,
even_Ks=even_Ks,
c_sorted=c_sorted,
filter_expert=filter_expert,
**config,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment