Commit f233de81 authored by Xiaowei.zhang's avatar Xiaowei.zhang
Browse files

[SYNC] Code sync.

parent 1893a1e0
[submodule "3rdparty/composable_kernel"] [submodule "3rdparty/composable_kernel"]
path = 3rdparty/composable_kernel path = 3rdparty/composable_kernel
url = ../composable_kernel url = ../composable_kernel
branch = rel-5.7.1 branch = main
[submodule "3rdparty/moe_c"] [submodule "3rdparty/moe_c"]
path = 3rdparty/moe_c path = 3rdparty/moe_c
url = ../Moe url = ../Moe
branch = W8A8 branch = master
...@@ -1199,3 +1199,72 @@ gfx938,f8_w8a8_block,torch.float16,12288,1536,3072,64,8,0,0,asm,13001+23000,1087 ...@@ -1199,3 +1199,72 @@ gfx938,f8_w8a8_block,torch.float16,12288,1536,3072,64,8,0,0,asm,13001+23000,1087
gfx938,f8_w8a8_block,torch.float16,16384,1536,3072,64,8,0,0,asm,13001+23000,14314.7352 gfx938,f8_w8a8_block,torch.float16,16384,1536,3072,64,8,0,0,asm,13001+23000,14314.7352
gfx938,f8_w8a8_block,torch.float16,24576,1536,3072,64,8,0,0,asm,13001+23000,21336.6809 gfx938,f8_w8a8_block,torch.float16,24576,1536,3072,64,8,0,0,asm,13001+23000,21336.6809
gfx938,f8_w8a8_block,torch.float16,32768,1536,3072,64,8,0,0,asm,13001+23000,28266.1463 gfx938,f8_w8a8_block,torch.float16,32768,1536,3072,64,8,0,0,asm,13001+23000,28266.1463
gfx938,f8_w8a8_block,torch.float16,1,256,4096,256,8,0,0,asm,10007+20200,66.0562
gfx938,f8_w8a8_block,torch.float16,2,256,4096,256,8,0,0,asm,10001+20000,83.8161
gfx938,f8_w8a8_block,torch.float16,4,256,4096,256,8,0,0,asm,10002+20000,112.3382
gfx938,f8_w8a8_block,torch.float16,6,256,4096,256,8,0,0,asm,10002+20000,141.2728
gfx938,f8_w8a8_block,torch.float16,8,256,4096,256,8,0,0,asm,10007+20000,165.0033
gfx938,f8_w8a8_block,torch.float16,10,256,4096,256,8,0,0,asm,10002+20000,186.1823
gfx938,f8_w8a8_block,torch.float16,12,256,4096,256,8,0,0,asm,10002+20000,205.2475
gfx938,f8_w8a8_block,torch.float16,14,256,4096,256,8,0,0,asm,10002+20000,226.2159
gfx938,f8_w8a8_block,torch.float16,16,256,4096,256,8,0,0,asm,10002+20000,237.2221
gfx938,f8_w8a8_block,torch.float16,20,256,4096,256,8,0,0,asm,10002+20000,264.3938
gfx938,f8_w8a8_block,torch.float16,24,256,4096,256,8,0,0,asm,10002+20000,293.3708
gfx938,f8_w8a8_block,torch.float16,28,256,4096,256,8,0,0,asm,10002+20000,343.0409
gfx938,f8_w8a8_block,torch.float16,32,256,4096,256,8,0,0,asm,10002+20000,359.6472
gfx938,f8_w8a8_block,torch.float16,36,256,4096,256,8,0,0,asm,10002+20000,367.9137
gfx938,f8_w8a8_block,torch.float16,40,256,4096,256,8,0,0,asm,10001+20000,378.7264
gfx938,f8_w8a8_block,torch.float16,44,256,4096,256,8,0,0,asm,10002+20000,389.2864
gfx938,f8_w8a8_block,torch.float16,48,256,4096,256,8,0,0,asm,10002+20000,398.3053
gfx938,f8_w8a8_block,torch.float16,56,256,4096,256,8,0,0,asm,10002+20000,414.7348
gfx938,f8_w8a8_block,torch.float16,64,256,4096,256,8,0,0,asm,10002+20000,430.7348
gfx938,f8_w8a8_block,torch.float16,80,256,4096,256,8,0,0,asm,10002+20000,454.9452
gfx938,f8_w8a8_block,torch.float16,96,256,4096,256,8,0,0,asm,10002+20000,473.8084
gfx938,f8_w8a8_block,torch.float16,112,256,4096,256,8,0,0,asm,10002+20000,489.219
gfx938,f8_w8a8_block,torch.float16,128,256,4096,256,8,0,0,asm,10002+20000,494.3979
gfx938,f8_w8a8_block,torch.float16,160,256,4096,256,8,0,0,asm,10002+20000,499.5009
gfx938,f8_w8a8_block,torch.float16,192,256,4096,256,8,0,0,asm,10002+20000,511.6526
gfx938,f8_w8a8_block,torch.float16,224,256,4096,256,8,0,0,asm,10002+20000,519.9809
gfx938,f8_w8a8_block,torch.float16,256,256,4096,256,8,0,0,asm,10002+20000,520.7221
gfx938,f8_w8a8_block,torch.float16,320,256,4096,256,8,0,0,asm,10002+20000,535.021
gfx938,f8_w8a8_block,torch.float16,384,256,4096,256,8,0,0,asm,10002+20000,565.7914
gfx938,f8_w8a8_block,torch.float16,448,256,4096,256,8,0,0,asm,10002+20000,598.9198
gfx938,f8_w8a8_block,torch.float16,512,256,4096,256,8,0,0,asm,11007+21000,610.5915
gfx938,f8_w8a8_block,torch.float16,576,256,4096,256,8,0,0,asm,11010+21000,639.2314
gfx938,f8_w8a8_block,torch.float16,640,256,4096,256,8,0,0,asm,11009+21000,631.0208
gfx938,f8_w8a8_block,torch.float16,704,256,4096,256,8,0,0,asm,11006+21000,640.8651
gfx938,f8_w8a8_block,torch.float16,768,256,4096,256,8,0,0,asm,11007+21000,659.7198
gfx938,f8_w8a8_block,torch.float16,832,256,4096,256,8,0,0,asm,11009+21200,658.3555
gfx938,f8_w8a8_block,torch.float16,896,256,4096,256,8,0,0,asm,11010+21000,689.7156
gfx938,f8_w8a8_block,torch.float16,960,256,4096,256,8,0,0,asm,11010+21200,722.3639
gfx938,f8_w8a8_block,torch.float16,1024,256,4096,256,8,0,0,asm,11010+21000,751.2649
gfx938,f8_w8a8_block,torch.float16,1152,256,4096,256,8,0,0,asm,11008+21000,870.549
gfx938,f8_w8a8_block,torch.float16,1280,256,4096,256,8,0,0,asm,12002+22000,867.3911
gfx938,f8_w8a8_block,torch.float16,1408,256,4096,256,8,0,0,asm,12003+22000,875.1133
gfx938,f8_w8a8_block,torch.float16,1536,256,4096,256,8,0,0,asm,12003+22000,902.4816
gfx938,f8_w8a8_block,torch.float16,1664,256,4096,256,8,0,0,asm,12004+22000,926.5489
gfx938,f8_w8a8_block,torch.float16,1792,256,4096,256,8,0,0,asm,12003+22000,942.0857
gfx938,f8_w8a8_block,torch.float16,1920,256,4096,256,8,0,0,asm,12005+22000,1018.5236
gfx938,f8_w8a8_block,torch.float16,2048,256,4096,256,8,0,0,asm,12003+22000,1094.3972
gfx938,f8_w8a8_block,torch.float16,2304,256,4096,256,8,0,0,asm,12005+22000,1257.5887
gfx938,f8_w8a8_block,torch.float16,2560,256,4096,256,8,0,0,asm,11010+21200,1374.7673
gfx938,f8_w8a8_block,torch.float16,2816,256,4096,256,8,0,0,asm,13001+23000,1400.5189
gfx938,f8_w8a8_block,torch.float16,3072,256,4096,256,8,0,0,asm,12005+22000,1439.1798
gfx938,f8_w8a8_block,torch.float16,3328,256,4096,256,8,0,0,asm,12005+22000,1456.0893
gfx938,f8_w8a8_block,torch.float16,3584,256,4096,256,8,0,0,asm,12005+22000,1487.5504
gfx938,f8_w8a8_block,torch.float16,3840,256,4096,256,8,0,0,asm,12005+22000,1595.1459
gfx938,f8_w8a8_block,torch.float16,4096,256,4096,256,8,0,0,asm,12006+22000,1756.2657
gfx938,f8_w8a8_block,torch.float16,4608,256,4096,256,8,0,0,asm,12005+22000,2012.6698
gfx938,f8_w8a8_block,torch.float16,5120,256,4096,256,8,0,0,asm,12005+22000,2134.1435
gfx938,f8_w8a8_block,torch.float16,5632,256,4096,256,8,0,0,asm,12005+22000,2246.8753
gfx938,f8_w8a8_block,torch.float16,6144,256,4096,256,8,0,0,asm,12005+22000,2440.6189
gfx938,f8_w8a8_block,torch.float16,6656,256,4096,256,8,0,0,asm,13001+23000,2562.9843
gfx938,f8_w8a8_block,torch.float16,7168,256,4096,256,8,0,0,asm,13001+23001,2768.8287
gfx938,f8_w8a8_block,torch.float16,7680,256,4096,256,8,0,0,asm,13001+23000,2792.2731
gfx938,f8_w8a8_block,torch.float16,8192,256,4096,256,8,0,0,asm,13001+23000,3082.3107
gfx938,f8_w8a8_block,torch.float16,10240,256,4096,256,8,0,0,asm,13001+23000,3801.2909
gfx938,f8_w8a8_block,torch.float16,12288,256,4096,256,8,0,0,asm,13001+23000,4383.3187
gfx938,f8_w8a8_block,torch.float16,14336,256,4096,256,8,0,0,asm,13001+23000,5030.8137
gfx938,f8_w8a8_block,torch.float16,16384,256,4096,256,8,0,0,asm,13001+23000,5608.4473
gfx938,f8_w8a8_block,torch.float16,17408,256,4096,256,8,0,0,asm,13001+23000,6038.0465
gfx938,f8_w8a8_block,torch.float16,24576,256,4096,256,8,0,0,asm,13001+23000,8143.5178
...@@ -15,7 +15,8 @@ from aiter import silu_and_mul,gelu_and_mul ...@@ -15,7 +15,8 @@ from aiter import silu_and_mul,gelu_and_mul
from aiter.ops.triton.fused_moe import ( from aiter.ops.triton.fused_moe import (
triton_moe_sum, triton_moe_sum,
triton_silu_and_mul, triton_silu_and_mul,
triton_gelu_and_mul triton_gelu_and_mul,
triton_relu2,
) )
from aiter.jit.core import AITER_ROOT_DIR from aiter.jit.core import AITER_ROOT_DIR
...@@ -754,8 +755,11 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor, ...@@ -754,8 +755,11 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor,
use_shuffle) use_shuffle)
# #
else: else:
# For gated activations (silu/gelu): w1 has 2*inter_dim cols, so inter_dim = N/2
# For non-gated activations (relu2): w1 has inter_dim cols, so inter_dim = N
asm_inter_dim = N/2 if activation in ("silu", "gelu") else N
if solution_id is None: if solution_id is None:
solution_id = get_moe_asm_solution(arch, tokens_in_chunk, N/2, w1.size(2), E, top_k_num, MoeQuantType.NO_QUANT, use_shuffle) solution_id = get_moe_asm_solution(arch, tokens_in_chunk, asm_inter_dim, w1.size(2), E, top_k_num, MoeQuantType.NO_QUANT, use_shuffle)
config = decode_sol_w8a8_c(solution_id) config = decode_sol_w8a8_c(solution_id)
if persist_cu == cu_num: if persist_cu == cu_num:
calculate_persist_groups(persist_cu, config, MoeQuantType.NO_QUANT) calculate_persist_groups(persist_cu, config, MoeQuantType.NO_QUANT)
...@@ -767,7 +771,7 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor, ...@@ -767,7 +771,7 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor,
moe_sorting_ck(curr_topk_ids, curr_topk_weights, global_num_experts, model_dim, out_hidden_states[begin_chunk_idx:end_chunk_idx], config["BLOCK_SIZE_M"], expert_map) moe_sorting_ck(curr_topk_ids, curr_topk_weights, global_num_experts, model_dim, out_hidden_states[begin_chunk_idx:end_chunk_idx], config["BLOCK_SIZE_M"], expert_map)
) )
if print_log(): if print_log():
print(f"Asm Moe Size: chunk:{chunk}, arch:{arch}, quant:{MoeQuantType.NO_QUANT}, tokens:{tokens_in_chunk}, inter_dim:{int(N/2)}, model_dim:{w1.size(2)}, expert:{E}, topk:{top_k_num}") print(f"Asm Moe Size: chunk:{chunk}, arch:{arch}, quant:{MoeQuantType.NO_QUANT}, tokens:{tokens_in_chunk}, inter_dim:{int(asm_inter_dim)}, model_dim:{w1.size(2)}, expert:{E}, topk:{top_k_num}")
print(f"solution:{solution_id}, shuffle:{use_shuffle}, persist:{persist_cu}") print(f"solution:{solution_id}, shuffle:{use_shuffle}, persist:{persist_cu}")
if solution_id== "default": if solution_id== "default":
print(f">>> Warning: No matching config pattern found, using default asm solution.") print(f">>> Warning: No matching config pattern found, using default asm solution.")
...@@ -797,6 +801,8 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor, ...@@ -797,6 +801,8 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor,
elif activation == "gelu": elif activation == "gelu":
triton_gelu_and_mul(d_silu,d_w1_out) triton_gelu_and_mul(d_silu,d_w1_out)
# gelu_and_mul(d_silu,d_w1_out) # gelu_and_mul(d_silu,d_w1_out)
elif activation == "relu2":
triton_relu2(d_silu,d_w1_out)
else: else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}") raise ValueError(f"Unsupported FusedMoe activation: {activation}")
......
...@@ -23,6 +23,7 @@ class MoeQuantType: ...@@ -23,6 +23,7 @@ class MoeQuantType:
W16A16 = "w16a16" W16A16 = "w16a16"
W4A16 = "w4a16" W4A16 = "w4a16"
W8A8 = "w8a8" W8A8 = "w8a8"
FP8_W8A8 = "fp8_w8a8"
W4A8 = "w4a8" W4A8 = "w4a8"
...@@ -53,9 +54,9 @@ def _try_get_moe_c_config( ...@@ -53,9 +54,9 @@ def _try_get_moe_c_config(
block_size: int, block_size: int,
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
try: try:
if quant_type == MoeQuantType.W4A16:
from .fused_moe_c import get_moe_configs_marlin from .fused_moe_c import get_moe_configs_marlin
if quant_type == MoeQuantType.W4A16:
configs = get_moe_configs_marlin( configs = get_moe_configs_marlin(
E=e, E=e,
N=n, N=n,
...@@ -64,8 +65,6 @@ def _try_get_moe_c_config( ...@@ -64,8 +65,6 @@ def _try_get_moe_c_config(
use_moe_wna16_cuda=True, use_moe_wna16_cuda=True,
) )
elif quant_type == MoeQuantType.W8A8: elif quant_type == MoeQuantType.W8A8:
from .fused_moe_c import get_moe_configs_marlin
configs = get_moe_configs_marlin( configs = get_moe_configs_marlin(
E=e, E=e,
N=n, N=n,
...@@ -73,9 +72,15 @@ def _try_get_moe_c_config( ...@@ -73,9 +72,15 @@ def _try_get_moe_c_config(
is_bottom=False, is_bottom=False,
use_moe_wna16_cuda=True, use_moe_wna16_cuda=True,
) )
elif quant_type == MoeQuantType.FP8_W8A8:
configs = get_moe_configs_marlin(
E=e,
N=n,
dtype="fp8_w8a8",
is_bottom=False,
use_moe_wna16_cuda=True,
)
elif quant_type == MoeQuantType.W4A8: elif quant_type == MoeQuantType.W4A8:
from .fused_moe_c import get_moe_configs_marlin
configs = get_moe_configs_marlin( configs = get_moe_configs_marlin(
E=e, E=e,
N=n, N=n,
...@@ -148,6 +153,22 @@ def _try_get_asm_config( ...@@ -148,6 +153,22 @@ def _try_get_asm_config(
return None return None
return decode_sol_0(solution) return decode_sol_0(solution)
if quant_type == MoeQuantType.FP8_W8A8:
from .fused_moe_asm_wna16 import decode_sol_0
solution = get_moe_asm_solution(
arch=arch,
token=m,
inter_dim=n,
model_dim=k,
expert=e,
topk=top_k,
quant_type=AsmMoeQuantType.F8_W8A8,
)
if solution == "default":
return None
return decode_sol_0(solution)
if quant_type == MoeQuantType.W16A16: if quant_type == MoeQuantType.W16A16:
from .fused_moe_asm_wna16 import decode_sol_0 from .fused_moe_asm_wna16 import decode_sol_0
...@@ -186,6 +207,7 @@ def _try_get_triton_config( ...@@ -186,6 +207,7 @@ def _try_get_triton_config(
dtype_name = { dtype_name = {
MoeQuantType.W4A16: "int4_w4a16", MoeQuantType.W4A16: "int4_w4a16",
MoeQuantType.W8A8: "int8_w8a8", MoeQuantType.W8A8: "int8_w8a8",
MoeQuantType.FP8_W8A8: "fp8_w8a8",
}.get(quant_type) }.get(quant_type)
if dtype_name is None: if dtype_name is None:
return None return None
...@@ -216,7 +238,7 @@ def _try_get_ck_config( ...@@ -216,7 +238,7 @@ def _try_get_ck_config(
block_shape: Optional[List[int]], block_shape: Optional[List[int]],
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
try: try:
if quant_type != MoeQuantType.W8A8: if quant_type not in (MoeQuantType.W8A8, MoeQuantType.FP8_W8A8):
return None return None
from .fused_moe_ck import get_moe_ck_solution_id, MoeQuantType as CkMoeQuantType from .fused_moe_ck import get_moe_ck_solution_id, MoeQuantType as CkMoeQuantType
...@@ -245,29 +267,43 @@ def _try_get_ck_config( ...@@ -245,29 +267,43 @@ def _try_get_ck_config(
def get_aiter_moe_config( def get_aiter_moe_config(
M: int, # Number of tokens (input sequence length) M: int, # Number of tokens (input sequence length)
E: int, # Number of experts E: int, # Number of experts
N1: int, # GEMM1 output dimension, typically equal to (moe_intermediate_size / TP * 2) N1: int, # GEMM1 output dimension: gated = (intermediate_size * 2), non-gated = intermediate_size
N2: int, # GEMM2 output dimension, typically equal to hidden_size N2: int, # GEMM2 output dimension, typically equal to hidden_size
K: int, # GEMM1 input dimension, typically equal to hidden_size; for GEMM2, K typically equal to (moe_intermediate_size / TP) K: int, # GEMM1 input dimension, typically equal to hidden_size; for GEMM2, K typically equal to (moe_intermediate_size / TP)
top_k: int, top_k: int,
block_size: int, block_size: int,
dtype: torch.dtype, dtype: torch.dtype,
quant_type: str, quant_type: str,
activation: str = "silu", # "silu"/"gelu"/"relu2"/...
gated: Optional[bool] = None, # True=GLU-gated (N1=2*inter), False=non-gated (N1=inter); None=auto from activation
) -> Tuple[bool, AiterMoeConfig]: ) -> Tuple[bool, AiterMoeConfig]:
"""Get the best backend config for a MOE problem. """Get the best backend config for a MOE problem.
Currently supported quant types: Currently supported quant types:
- ``MoeQuantType.W16A16`` (non-quantized) - ``MoeQuantType.W16A16`` (non-quantized)
- ``MoeQuantType.W4A16`` - ``MoeQuantType.W4A16``
- ``MoeQuantType.W8A8`` - ``MoeQuantType.W8A8`` (int8)
- ``MoeQuantType.FP8_W8A8`` (fp8)
- ``MoeQuantType.W4A8`` - ``MoeQuantType.W4A8``
Backend priority: Backend priority:
- ``w16a16``: asm > triton - ``w16a16``: asm > triton
- ``w4a16``: moe_c > asm > triton - ``w4a16``: moe_c > asm > triton
- ``w8a8``: asm > moe_c > triton > ck - ``w8a8``: asm > moe_c > triton > ck
- ``fp8_w8a8``: asm > moe_c > triton > ck
- ``w4a8``: moe_c - ``w4a8``: moe_c
For non-gated MOE (e.g. Nemotron with ReLU² activation), pass
``gated=False`` (or let it auto-detect from ``activation="relu2"``)
and set ``N1 = intermediate_size`` (not ``2 * intermediate_size``).
""" """
n = N1 / 2 # Determine gating: explicit > auto-detect from activation
if gated is None:
gated = activation in ("silu", "gelu")
# For gated (GLU): N1 = 2 * intermediate_size, n = N1 // 2
# For non-gated: N1 = intermediate_size, n = N1
n = N1 // 2 if gated else N1
block_shape = [0, block_size] if block_size else None block_shape = [0, block_size] if block_size else None
if quant_type == MoeQuantType.W4A16: if quant_type == MoeQuantType.W4A16:
...@@ -282,7 +318,7 @@ def get_aiter_moe_config( ...@@ -282,7 +318,7 @@ def get_aiter_moe_config(
] ]
else: else:
raise ValueError(f"Unsupported dtype: {dtype}") raise ValueError(f"Unsupported dtype: {dtype}")
elif quant_type == MoeQuantType.W8A8: elif quant_type in (MoeQuantType.W8A8, MoeQuantType.FP8_W8A8):
if block_size == 0: # Channel wise choose MOE_C if block_size == 0: # Channel wise choose MOE_C
candidates = [ candidates = [
(MoeSolutionType.MOE_C, lambda: _try_get_moe_c_config(quant_type, M, E, n, block_size)), (MoeSolutionType.MOE_C, lambda: _try_get_moe_c_config(quant_type, M, E, n, block_size)),
...@@ -348,6 +384,7 @@ def aiter_moe( ...@@ -348,6 +384,7 @@ def aiter_moe(
use_int4_w4a16 = moe_config.quant_type == MoeQuantType.W4A16 use_int4_w4a16 = moe_config.quant_type == MoeQuantType.W4A16
use_int8_w8a8 = moe_config.quant_type == MoeQuantType.W8A8 use_int8_w8a8 = moe_config.quant_type == MoeQuantType.W8A8
use_fp8_w8a8 = moe_config.quant_type == MoeQuantType.FP8_W8A8
use_int8_w4a8 = moe_config.quant_type == MoeQuantType.W4A8 use_int8_w4a8 = moe_config.quant_type == MoeQuantType.W4A8
if moe_config.solution_type == MoeSolutionType.MOE_C: if moe_config.solution_type == MoeSolutionType.MOE_C:
...@@ -362,6 +399,7 @@ def aiter_moe( ...@@ -362,6 +399,7 @@ def aiter_moe(
inplace=inplace, inplace=inplace,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w4a8=use_int8_w4a8, use_int8_w4a8=use_int8_w4a8,
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
...@@ -391,6 +429,7 @@ def aiter_moe( ...@@ -391,6 +429,7 @@ def aiter_moe(
inplace=inplace, inplace=inplace,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_fp8_w8a8=use_fp8_w8a8,
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
...@@ -409,7 +448,7 @@ def aiter_moe( ...@@ -409,7 +448,7 @@ def aiter_moe(
from .ops.triton.fused_moe import fused_experts_impl from .ops.triton.fused_moe import fused_experts_impl
# W8A8 channel-wise (block_shape=None) requires per_channel_quant=True # W8A8 channel-wise (block_shape=None) requires per_channel_quant=True
per_channel_quant = use_int8_w8a8 and block_shape is None per_channel_quant = (use_int8_w8a8 or use_fp8_w8a8) and block_shape is None
return fused_experts_impl( return fused_experts_impl(
hidden_states, hidden_states,
...@@ -421,6 +460,7 @@ def aiter_moe( ...@@ -421,6 +460,7 @@ def aiter_moe(
inplace=inplace, inplace=inplace,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_fp8_w8a8=use_fp8_w8a8,
activation=activation, activation=activation,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
...@@ -448,6 +488,7 @@ def aiter_moe( ...@@ -448,6 +488,7 @@ def aiter_moe(
odtype=hidden_states.dtype, odtype=hidden_states.dtype,
inplace=inplace, inplace=inplace,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_fp8_w8a8=use_fp8_w8a8,
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map, expert_map=expert_map,
......
{
"config": {
"(8, 192, 128, False, True, True, 128)": {
"BLOCK_M": 32,
"BLOCK_N": 64,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 16,
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_ctas": 1,
"num_stages": 1
},
"(16, 192, 128, False, True, False, -1)": {
"BLOCK_M": 32,
"BLOCK_N": 64,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 16,
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_ctas": 1,
"num_stages": 1
}
},
"path": {}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 512,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 512,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 512,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"8192": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"16384": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"32768": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
},
"65536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"8192": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"16384": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"32768": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"65536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"8192": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"16384": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"8192": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 1,
"num_warps": 16,
"num_stages": 2
},
"16384": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"8192": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"16384": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"8192": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 1,
"num_warps": 16,
"num_stages": 2
},
"16384": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
}
}
\ No newline at end of file
...@@ -17,15 +17,21 @@ Memory-efficient attention for prefill. ...@@ -17,15 +17,21 @@ Memory-efficient attention for prefill.
It supports page size = 1 and prefill with KV cache (i.e. extend). It supports page size = 1 and prefill with KV cache (i.e. extend).
""" """
from typing import Optional
import functools import functools
import json import json
from typing import Any, Optional
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
import os import os
from triton.knobs import cache as cache_knob import types
try:
from triton.knobs import cache as cache_knob
except ImportError:
# Triton builds without `triton.knobs` (e.g. 3.2.x in some images): disable saved-kernel path.
cache_knob = types.SimpleNamespace(dir="__triton_knobs_unavailable__")
from aiter.ops.triton.prefill_attention import context_attention_fwd from aiter.ops.triton.prefill_attention import context_attention_fwd
from aiter.ops.triton.activation import _tanh from aiter.ops.triton.activation import _tanh
...@@ -76,6 +82,10 @@ def _fwd_kernel( ...@@ -76,6 +82,10 @@ def _fwd_kernel(
cur_seq = tl.program_id(0) cur_seq = tl.program_id(0)
cur_head = tl.program_id(1) cur_head = tl.program_id(1)
cur_block_m = tl.program_id(2) cur_block_m = tl.program_id(2)
tl.assume(Q_Extend.to(tl.int64) >= 0)
tl.assume(K_Extend.to(tl.int64) >= 0)
tl.assume(V_Extend.to(tl.int64) >= 0)
cur_kv_head = cur_head // kv_group_num cur_kv_head = cur_head // kv_group_num
cur_seq_extend_start_idx = tl.load(qo_indptr + cur_seq) cur_seq_extend_start_idx = tl.load(qo_indptr + cur_seq)
...@@ -292,6 +302,345 @@ def _fwd_kernel( ...@@ -292,6 +302,345 @@ def _fwd_kernel(
) )
@triton.jit
def _fwd_kernel_v2(
Q_Extend,
K_Extend,
V_Extend,
O_Extend,
K_Buffer,
V_Buffer,
qo_indptr,
kv_indptr,
kv_indices,
mask_ptr,
mask_indptr,
sink_ptr,
window_kv_offset_ptr,
sm_scale,
k_scale,
v_scale,
kv_group_num,
stride_qbs,
stride_qh,
stride_kbs,
stride_kh,
stride_vbs,
stride_vh,
stride_obs,
stride_oh,
stride_buf_kbs,
stride_buf_kh,
stride_buf_vbs,
stride_buf_vh,
SLIDING_WINDOW_SIZE: tl.constexpr,
logit_cap: tl.constexpr,
xai_temperature_len: tl.constexpr,
Lq: tl.constexpr,
Lv: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DPE: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
USE_CUSTOM_MASK: tl.constexpr,
IS_CAUSAL: tl.constexpr,
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
STORE_TRANSPOSE: tl.constexpr,
HAS_SINK: tl.constexpr,
):
cur_seq = tl.program_id(0)
cur_head = tl.program_id(1)
cur_block_m = tl.program_id(2)
tl.assume(Q_Extend.to(tl.int64) >= 0)
tl.assume(K_Extend.to(tl.int64) >= 0)
tl.assume(V_Extend.to(tl.int64) >= 0)
cur_kv_head = cur_head // kv_group_num
cur_seq_extend_start_idx = tl.load(qo_indptr + cur_seq)
cur_seq_len_extend = tl.load(qo_indptr + cur_seq + 1) - cur_seq_extend_start_idx
cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq)
cur_seq_len_prefix = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx
cur_seq_len = cur_seq_len_prefix + cur_seq_len_extend
if USE_CUSTOM_MASK:
cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)
window_kv_offset = 0
if USE_CUSTOM_MASK and SLIDING_WINDOW_SIZE > 0:
window_kv_offset = tl.load(window_kv_offset_ptr + cur_seq)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
offs_m = tl.arange(0, BLOCK_M)
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
mask_d = offs_d < Lq
mask_dv = offs_dv < Lv
if xai_temperature_len > 0:
offs_qidx = cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
xai_temperature_reg = tl.where(
offs_qidx > xai_temperature_len,
tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale,
1.0,
)
offs_q = (
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :]
)
q = tl.load(
Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0
)
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
offs_qpe = (
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_qbs
+ cur_head * stride_qh
+ offs_dpe[None, :]
)
qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
offs_n = tl.arange(0, BLOCK_N)
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
deno = tl.zeros([BLOCK_M], dtype=tl.float32)
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask_n = (start_n + offs_n) < cur_seq_len_prefix
final_mask = mask_m[:, None] & mask_n[None, :]
if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK:
custom_mask = tl.load(
mask_ptr
+ cur_seq_mask_start_idx
+ (cur_block_m * BLOCK_M + offs_m[:, None])
* (cur_seq_len + window_kv_offset)
+ window_kv_offset
+ start_n
+ offs_n[None, :],
mask=(mask_m[:, None] & mask_n[None, :]),
other=0,
)
final_mask &= custom_mask
if SLIDING_WINDOW_SIZE > 0:
window_mask = (
cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m[:, None]
) <= (start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE)
final_mask &= window_mask
SKIP_TILE = False
if (USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK) or SLIDING_WINDOW_SIZE > 0:
SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
if not SKIP_TILE:
offs_kv_loc = tl.load(
kv_indices + cur_seq_kv_start_idx + start_n + offs_n,
mask=mask_n,
other=0,
)
offs_buf_k = (
offs_kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_d[:, None]
)
k = tl.load(
K_Buffer + offs_buf_k,
mask=(mask_n[None, :]) & (mask_d[:, None]),
other=0.0,
)
qk = tl.dot(q.to(k.dtype), k)
if BLOCK_DPE > 0:
offs_kpe = (
offs_kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Buffer + offs_kpe,
mask=mask_n[None, :],
other=0.0,
)
qk += tl.dot(qpe.to(kpe.dtype), kpe)
qk *= sm_scale * k_scale
if logit_cap > 0:
qk = logit_cap * _tanh(qk / logit_cap)
if xai_temperature_len > 0:
qk *= xai_temperature_reg[:, None]
qk = tl.where(final_mask, qk, float("-inf"))
# row_max_fixed avoids exp(-inf - (-inf)) when a row is all -inf in this tile;
# only needed under sliding window or custom mask (plain causal matches v1).
if SLIDING_WINDOW_SIZE > 0 or (
USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK
):
row_max = tl.max(qk, 1)
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
n_e_max = tl.maximum(row_max_fixed, e_max)
else:
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
deno = deno * re_scale + tl.sum(p, 1)
offs_buf_v = (
offs_kv_loc[:, None] * stride_buf_vbs
+ cur_kv_head * stride_buf_vh
+ offs_dv[None, :]
)
v = tl.load(
V_Buffer + offs_buf_v,
mask=mask_n[:, None] & mask_dv[None, :],
other=0.0,
)
p = p.to(v.dtype)
acc = acc * re_scale[:, None] + tl.dot(p, v) * v_scale
e_max = n_e_max
cur_block_m_end = (
cur_seq_len_extend
if not IS_CAUSAL
else tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
)
for start_n in range(0, cur_block_m_end, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask_n = (start_n + offs_n) < cur_block_m_end
final_mask = mask_m[:, None] & mask_n[None, :]
if USE_CUSTOM_MASK:
custom_mask = tl.load(
mask_ptr
+ cur_seq_mask_start_idx
+ (cur_block_m * BLOCK_M + offs_m[:, None])
* (cur_seq_len + window_kv_offset)
+ window_kv_offset
+ cur_seq_len_prefix
+ start_n
+ offs_n[None, :],
mask=(mask_m[:, None] & mask_n[None, :]),
other=0,
)
custom_mask &= mask_m[:, None] & mask_n[None, :]
final_mask &= custom_mask
elif IS_CAUSAL:
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
start_n + offs_n[None, :]
)
mask_causual &= mask_m[:, None] & mask_n[None, :]
final_mask &= mask_causual
else:
mask_non_causal = mask_m[:, None] & mask_n[None, :]
final_mask &= mask_non_causal
if SLIDING_WINDOW_SIZE > 0:
window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= (
start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE
)
final_mask &= window_mask
SKIP_TILE = False
if USE_CUSTOM_MASK or SLIDING_WINDOW_SIZE > 0:
SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
if not SKIP_TILE:
offs_k = (
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
+ cur_kv_head * stride_kh
+ offs_d[:, None]
)
k = tl.load(
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
)
qk = tl.dot(q.to(k.dtype), k, out_dtype=tl.float32)
if BLOCK_DPE > 0:
offs_kpe = (
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
+ cur_kv_head * stride_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Extend + offs_kpe,
mask=mask_n[None, :],
other=0.0,
)
qk += tl.dot(qpe.to(kpe.dtype), kpe)
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * _tanh(qk / logit_cap)
if xai_temperature_len > 0:
qk *= xai_temperature_reg[:, None]
qk = tl.where(final_mask, qk, float("-inf"))
if SLIDING_WINDOW_SIZE > 0 or USE_CUSTOM_MASK:
row_max = tl.max(qk, 1)
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
n_e_max = tl.maximum(row_max_fixed, e_max)
else:
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
deno = deno * re_scale + tl.sum(p, 1)
offs_v = (
(cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
+ cur_kv_head * stride_vh
+ offs_dv[None, :]
)
v = tl.load(
V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
)
p = p.to(v.dtype)
acc = acc * re_scale[:, None] + tl.dot(p, v)
e_max = n_e_max
if HAS_SINK:
cur_sink = tl.load(sink_ptr + cur_head)
deno += tl.exp(cur_sink - e_max)
offs_o = (
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_obs
+ cur_head * stride_oh
+ offs_dv[None, :]
)
if STORE_TRANSPOSE:
tl.store(
O_Extend + offs_o.T,
(acc / deno[:, None]).T,
mask=(mask_m[:, None] & mask_dv[None, :]).T,
)
else:
tl.store(
O_Extend + offs_o,
acc / deno[:, None],
mask=mask_m[:, None] & mask_dv[None, :],
)
def create_tuple(k): def create_tuple(k):
if k[0] != '(' and k[-1] != ')': if k[0] != '(' and k[-1] != ')':
return k return k
...@@ -311,8 +660,11 @@ def create_tuple(k): ...@@ -311,8 +660,11 @@ def create_tuple(k):
def _load_config(): def _load_config():
dev = arch_info.get_device() dev = arch_info.get_device()
fpath = f"{AITER_TRITON_CONFIGS_PATH}/{dev}-EXTEND_ATTENTION-FP16.json" fpath = f"{AITER_TRITON_CONFIGS_PATH}/{dev}-EXTEND_ATTENTION-FP16.json"
try:
with open(fpath, "r") as file: with open(fpath, "r") as file:
data = json.load(file) data = json.load(file)
except FileNotFoundError:
return {"config": {}, "path": {}, "key": [], "keys": []}
res = {} res = {}
res['config'] = data['config'] res['config'] = data['config']
res['path'] = data['path'] res['path'] = data['path']
...@@ -323,6 +675,39 @@ def _load_config(): ...@@ -323,6 +675,39 @@ def _load_config():
global_config = _load_config() global_config = _load_config()
def _load_config_v2():
"""Autotuned configs for :func:`_fwd_kernel_v2` (fp8 / sglang-style scale path).
Each ``config`` entry key must parse to a **7-tuple** via :func:`create_tuple`, matching
runtime ``want7``; 5-tuple keys are not accepted.
"""
dev = arch_info.get_device()
fpath = f"{AITER_TRITON_CONFIGS_PATH}/{dev}-EXTEND_ATTENTION-V2-FP16.json"
try:
with open(fpath, "r") as file:
data = json.load(file)
except FileNotFoundError:
return {"config": {}, "path": {}, "key": [], "keys": []}
res = {}
res["config"] = data["config"]
res["path"] = data.get("path", {})
res["key"] = list(data["config"].keys())
res["keys"] = []
for k in res["key"]:
tup = create_tuple(k)
if len(tup) != 7:
raise ValueError(
f"{dev}-EXTEND_ATTENTION-V2-FP16.json keys must be 7-tuples matching runtime "
f"want7 (kv_group_num, Lq, Lv, USE_CUSTOM_MASK, IS_CAUSAL, HAS_SINK, "
f"SLIDING_WINDOW_SIZE); got length {len(tup)} for {k!r}"
)
res["keys"].append(tup)
return res
global_config_v2 = _load_config_v2()
default_config = { default_config = {
"BLOCK_M": 32, "BLOCK_M": 32,
"BLOCK_N": 32, "BLOCK_N": 32,
...@@ -351,6 +736,47 @@ def _get_config(kv_group_num, Lq, Lv, use_custom_mask, is_causal): ...@@ -351,6 +736,47 @@ def _get_config(kv_group_num, Lq, Lv, use_custom_mask, is_causal):
return global_config['config'][key], global_config['path'][key] return global_config['config'][key], global_config['path'][key]
@functools.lru_cache(maxsize=1024)
def _get_config_v2(
kv_group_num,
Lq,
Lv,
use_custom_mask,
is_causal,
has_sink: bool,
sliding_window_size: int,
):
"""
Lookup order for ``_fwd_kernel_v2`` block sizes:
1. ``want7 = (kv_group_num, Lq, Lv, use_custom_mask, is_causal, has_sink, sliding_window_size)``
against ``{arch}-EXTEND_ATTENTION-V2-FP16.json``. JSON keys must be **7-tuple** strings,
same shape as ``want7`` (see :func:`_load_config_v2`).
2. If no V2 entry matches, :data:`default_config` (no fallback to v1 JSON).
Log field mapping (typical): ``kv_group_num = q_extend.size(-2) // k_extend.size(-2)``,
``Lq = q_extend.size(-1)``, ``Lv = v_extend.size(-1)``,
``use_custom_mask = custom_mask is not None``, ``is_causal`` as passed,
``has_sink = sinks is not None``, ``sliding_window_size`` as passed (use ``-1`` if disabled).
"""
want7 = (
kv_group_num,
Lq,
Lv,
use_custom_mask,
is_causal,
has_sink,
sliding_window_size,
)
for i, keys in enumerate(global_config_v2["keys"]):
if keys == want7:
key = global_config_v2["key"][i]
return global_config_v2["config"][key], global_config_v2["path"].get(key)
print("WARNING: optimal V2 config not found, just use default config")
return default_config, None
def has_kernel_cache(path): def has_kernel_cache(path):
return False if not path or not os.path.isdir(f'{cache_knob.dir}/{path}') else True return False if not path or not os.path.isdir(f'{cache_knob.dir}/{path}') else True
...@@ -385,12 +811,23 @@ def extend_attention_fwd( ...@@ -385,12 +811,23 @@ def extend_attention_fwd(
sm_scale=None, sm_scale=None,
logit_cap=0.0, logit_cap=0.0,
skip_prefix_custom_mask=True, skip_prefix_custom_mask=True,
config: Optional[dict[str, any]] = None, config: Optional[dict[str, Any]] = None,
k_scale=None,
v_scale=None,
sliding_window_size=-1,
sinks=None,
window_kv_offsets=None,
xai_temperature_len=-1,
): ):
""" """
q_extend, k_extend, v_extend, o_extend: contiguous tensors q_extend, k_extend, v_extend, o_extend: contiguous tensors
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
Through ``config`` the signature matches the original aiter API. v2 / sglang
extensions follow with defaults. ``k_scale`` / ``v_scale`` must both be
``None`` or both set (``float`` / ``int`` like sglang, or 1-element
``torch.Tensor`` on device); if both are set, :func:`_fwd_kernel_v2` is used.
""" """
Lq, Lv = ( Lq, Lv = (
q_extend.shape[-1], q_extend.shape[-1],
...@@ -422,12 +859,25 @@ def extend_attention_fwd( ...@@ -422,12 +859,25 @@ def extend_attention_fwd(
# Skip custom mask for prefix part # Skip custom mask for prefix part
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
use_v2 = k_scale is not None or v_scale is not None
if not USE_CUSTOM_MASK: if not USE_CUSTOM_MASK:
custom_mask = torch.tensor([0], dtype=torch.bool, device=q_extend.device) custom_mask = torch.tensor([0], dtype=torch.bool, device=q_extend.device)
mask_indptr = torch.tensor([0], dtype=torch.int32, device=q_extend.device) mask_indptr = torch.tensor([0], dtype=torch.int32, device=q_extend.device)
if config is None: if config is None:
if q_extend.dtype == torch.float16: if q_extend.dtype == torch.float16 or q_extend.dtype == torch.bfloat16:
if use_v2:
config, path = _get_config_v2(
kv_group_num,
Lq,
Lv,
USE_CUSTOM_MASK,
is_causal,
sinks is not None,
sliding_window_size,
)
else:
keys = [kv_group_num, Lq, Lv, USE_CUSTOM_MASK, is_causal] keys = [kv_group_num, Lq, Lv, USE_CUSTOM_MASK, is_causal]
config, path = _get_config(*keys) config, path = _get_config(*keys)
else: else:
...@@ -441,24 +891,7 @@ def extend_attention_fwd( ...@@ -441,24 +891,7 @@ def extend_attention_fwd(
# extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2} # extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
fn = _fwd_kernel[grid] if not has_kernel_cache(path) \ stride_args = (
else functools.partial(triton.utils.run_saved_kernel,
_fwd_kernel, path, grid=grid)
fn(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask,
mask_indptr,
sm_scale,
kv_group_num,
q_extend.stride(0), q_extend.stride(0),
q_extend.stride(1), q_extend.stride(1),
k_extend.stride(0), k_extend.stride(0),
...@@ -471,20 +904,77 @@ def extend_attention_fwd( ...@@ -471,20 +904,77 @@ def extend_attention_fwd(
k_buffer.stride(1), k_buffer.stride(1),
v_buffer.stride(0), v_buffer.stride(0),
v_buffer.stride(1), v_buffer.stride(1),
logit_cap=logit_cap, )
block_const = dict(
BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE, BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV, BLOCK_DV=BLOCK_DV,
# BLOCK_M=BLOCK_M,
# BLOCK_N=BLOCK_N,
Lq=Lq, Lq=Lq,
Lv=Lv, Lv=Lv,
USE_CUSTOM_MASK=USE_CUSTOM_MASK, USE_CUSTOM_MASK=USE_CUSTOM_MASK,
IS_CAUSAL=is_causal, IS_CAUSAL=is_causal,
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK, SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
STORE_TRANSPOSE=True, STORE_TRANSPOSE=True,
# num_warps=num_warps, )
# num_stages=num_stages,
if use_v2:
HAS_SINK = sinks is not None
assert k_scale is not None and v_scale is not None, "k_scale and v_scale must both be set"
# k_scale / v_scale kept in Python API; v2 kernel TEMP omits them for perf vs v1.
_fwd_kernel_v2[grid](
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask,
mask_indptr,
sinks,
window_kv_offsets,
sm_scale,
k_scale,
v_scale,
kv_group_num,
*stride_args,
SLIDING_WINDOW_SIZE=sliding_window_size,
logit_cap=logit_cap,
xai_temperature_len=xai_temperature_len,
HAS_SINK=HAS_SINK,
**block_const,
**config,
)
return
fn = (
_fwd_kernel[grid]
if not has_kernel_cache(path)
else functools.partial(
triton.utils.run_saved_kernel, _fwd_kernel, path, grid=grid
)
)
fn(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask,
mask_indptr,
sm_scale,
kv_group_num,
*stride_args,
logit_cap=logit_cap,
**block_const,
**config, **config,
) )
......
...@@ -322,6 +322,7 @@ def fused_moe_kernel_gptq_awq( ...@@ -322,6 +322,7 @@ def fused_moe_kernel_gptq_awq(
USE_MLS_LOAD: tl.constexpr, USE_MLS_LOAD: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr,
USE_ADDR_OFFSET_INT64_A: tl.constexpr, USE_ADDR_OFFSET_INT64_A: tl.constexpr,
USE_ADDR_OFFSET_INT64_B: tl.constexpr,
USE_ADDR_OFFSET_INT64_C: tl.constexpr, USE_ADDR_OFFSET_INT64_C: tl.constexpr,
top_k: tl.constexpr, top_k: tl.constexpr,
compute_type: tl.constexpr, compute_type: tl.constexpr,
...@@ -434,17 +435,45 @@ def fused_moe_kernel_gptq_awq( ...@@ -434,17 +435,45 @@ def fused_moe_kernel_gptq_awq(
if use_int4_w4a16: if use_int4_w4a16:
if group_size_divisible and has_zp: if group_size_divisible and has_zp:
offs_k_continue = tl.arange(0, BLOCK_SIZE_K // 2).to(tl.int32) offs_k_continue = tl.arange(0, BLOCK_SIZE_K // 2).to(tl.int32)
b_ptrs = b_ptr + (off_experts * stride_be + \ if USE_ADDR_OFFSET_INT64_B:
offs_bn[:, None] * stride_bn + offs_k_continue[None, :] * \ b_ptrs = b_ptr + (
stride_bk).to(tl.int32) off_experts.to(tl.int64) * stride_be
+ offs_bn[:, None].to(tl.int64) * stride_bn
+ offs_k_continue[None, :].to(tl.int64) * stride_bk
)
else: else:
b_ptrs = b_ptr + (off_experts * stride_be + \ b_ptrs = b_ptr + (
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * \ off_experts * stride_be
stride_bn).to(tl.int32) + offs_bn[:, None] * stride_bn
+ offs_k_continue[None, :] * stride_bk
).to(tl.int32)
else:
if USE_ADDR_OFFSET_INT64_B:
b_ptrs = b_ptr + (
off_experts.to(tl.int64) * stride_be
+ (offs_k[:, None].to(tl.int64) // 2) * stride_bk
+ offs_bn[None, :].to(tl.int64) * stride_bn
)
else:
b_ptrs = b_ptr + (
off_experts * stride_be
+ (offs_k[:, None] // 2) * stride_bk
+ offs_bn[None, :] * stride_bn
).to(tl.int32)
b_shifter = (offs_k[:, None] % 2) * 4 b_shifter = (offs_k[:, None] % 2) * 4
elif use_int8_w8a16: elif use_int8_w8a16:
b_ptrs = b_ptr + (off_experts * stride_be + \ if USE_ADDR_OFFSET_INT64_B:
offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn).to(tl.int32) b_ptrs = b_ptr + (
off_experts.to(tl.int64) * stride_be
+ offs_k[:, None].to(tl.int64) * stride_bk
+ offs_bn[None, :].to(tl.int64) * stride_bn
)
else:
b_ptrs = b_ptr + (
off_experts * stride_be
+ offs_k[:, None] * stride_bk
+ offs_bn[None, :] * stride_bn
).to(tl.int32)
if not has_zp and use_int4_w4a16: if not has_zp and use_int4_w4a16:
b_zp_num = 8 b_zp_num = 8
...@@ -2552,6 +2581,7 @@ def fused_moe( ...@@ -2552,6 +2581,7 @@ def fused_moe(
assert B_zp is None or B_zp.ndim == 3 assert B_zp is None or B_zp.ndim == 3
offset_max = 2**31 - 1 offset_max = 2**31 - 1
use_addr_offset_int64_a = A.numel() * A.element_size() >= offset_max use_addr_offset_int64_a = A.numel() * A.element_size() >= offset_max
use_addr_offset_int64_b = B.numel() * B.element_size() >= offset_max
use_addr_offset_int64_c = C.numel() * C.element_size() >= offset_max use_addr_offset_int64_c = C.numel() * C.element_size() >= offset_max
if use_int4_w4a8: if use_int4_w4a8:
...@@ -2592,6 +2622,7 @@ def fused_moe( ...@@ -2592,6 +2622,7 @@ def fused_moe(
group_size=block_shape[1], group_size=block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight, MUL_ROUTED_WEIGHT=mul_routed_weight,
USE_ADDR_OFFSET_INT64_A=use_addr_offset_int64_a, USE_ADDR_OFFSET_INT64_A=use_addr_offset_int64_a,
USE_ADDR_OFFSET_INT64_B=use_addr_offset_int64_b,
USE_ADDR_OFFSET_INT64_C=use_addr_offset_int64_c, USE_ADDR_OFFSET_INT64_C=use_addr_offset_int64_c,
top_k=top_k, top_k=top_k,
compute_type=compute_type, compute_type=compute_type,
...@@ -2636,6 +2667,7 @@ def fused_moe( ...@@ -2636,6 +2667,7 @@ def fused_moe(
group_size=block_shape[1], group_size=block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight, MUL_ROUTED_WEIGHT=mul_routed_weight,
USE_ADDR_OFFSET_INT64_A=use_addr_offset_int64_a, USE_ADDR_OFFSET_INT64_A=use_addr_offset_int64_a,
USE_ADDR_OFFSET_INT64_B=use_addr_offset_int64_b,
USE_ADDR_OFFSET_INT64_C=use_addr_offset_int64_c, USE_ADDR_OFFSET_INT64_C=use_addr_offset_int64_c,
top_k=top_k, top_k=top_k,
compute_type=compute_type, compute_type=compute_type,
......
...@@ -27,6 +27,7 @@ def input_helper( ...@@ -27,6 +27,7 @@ def input_helper(
attn_impl="absorb", attn_impl="absorb",
equal_seqlens=False, equal_seqlens=False,
requires_grad=False, requires_grad=False,
kv_num_heads: int = 1,
): ):
torch.manual_seed(0) torch.manual_seed(0)
...@@ -85,9 +86,9 @@ def input_helper( ...@@ -85,9 +86,9 @@ def input_helper(
total_extend, H, Lq, dtype=dtype, device=device total_extend, H, Lq, dtype=dtype, device=device
).requires_grad_(requires_grad) ).requires_grad_(requires_grad)
# extend parts # extend parts (``kv_num_heads`` for GQA: e.g. 2 when log shows ``k_extend [T,2,192]``)
k_extend = torch.randn( k_extend = torch.randn(
total_extend, 1, Lk, dtype=dtype, device=device total_extend, kv_num_heads, Lk, dtype=dtype, device=device
).requires_grad_(requires_grad) ).requires_grad_(requires_grad)
v_extend = k_extend[..., :Lv] v_extend = k_extend[..., :Lv]
...@@ -96,7 +97,7 @@ def input_helper( ...@@ -96,7 +97,7 @@ def input_helper(
# prefix parts # prefix parts
k_buffer = torch.randn( k_buffer = torch.randn(
total_prefix, 1, Lk, dtype=dtype, device=device total_prefix, kv_num_heads, Lk, dtype=dtype, device=device
).requires_grad_(requires_grad) ).requires_grad_(requires_grad)
v_buffer = k_buffer[..., :Lv] v_buffer = k_buffer[..., :Lv]
...@@ -154,12 +155,20 @@ def extend_forward( ...@@ -154,12 +155,20 @@ def extend_forward(
causal, causal,
sm_scale=1.0, sm_scale=1.0,
logit_cap=0.0, logit_cap=0.0,
use_v2: bool = False,
sliding_window_size: int = -1,
sinks=None,
): ):
"""Same tensors; v1 uses ``k_scale=v_scale=None``, v2 uses ``1.0`` (sglang-style scalars)."""
out = torch.empty( out = torch.empty(
(*q_extend.shape[:-1], v_extend.shape[-1]), (*q_extend.shape[:-1], v_extend.shape[-1]),
dtype=q_extend.dtype, dtype=q_extend.dtype,
device=q_extend.device, device=q_extend.device,
) )
k_scale = v_scale = None
if use_v2:
k_scale = 1.0
v_scale = 1.0
extend_attention.extend_attention_fwd( extend_attention.extend_attention_fwd(
q_extend, q_extend,
k_extend, k_extend,
...@@ -176,6 +185,14 @@ def extend_forward( ...@@ -176,6 +185,14 @@ def extend_forward(
max_len_extend, max_len_extend,
sm_scale=sm_scale, sm_scale=sm_scale,
logit_cap=logit_cap, logit_cap=logit_cap,
skip_prefix_custom_mask=True,
config=None,
k_scale=k_scale,
v_scale=v_scale,
sliding_window_size=sliding_window_size,
sinks=sinks,
window_kv_offsets=None,
xai_temperature_len=-1,
) )
return out return out
...@@ -211,13 +228,23 @@ def get_extend_benchmark_configs(): ...@@ -211,13 +228,23 @@ def get_extend_benchmark_configs():
"qk_rope_head_dim", "qk_rope_head_dim",
"v_head_dim", "v_head_dim",
"attn_impl", "attn_impl",
"kv_num_heads",
# 与 ``{arch}-EXTEND_ATTENTION-V2-FP16.json`` 的 want7 对齐(见 extend_attention._get_config_v2)
"is_causal",
"sliding_window_size",
"with_sinks",
] ]
x_vals_list = [ x_vals_list = [
(2, 16, 1024, 1024, 256, 0, 128, "non-absorb"), # (2, 16, 1024, 1024, 256, 0, 128, "non-absorb", 1, False, -1, False),
(2, 16, 4096, 4096, 512, 64, 128, "non-absorb"), # (2, 16, 4096, 4096, 512, 64, 128, "non-absorb", 1, False, -1, False),
(2, 16, 8192, 4096, 512, 64, 128, "non-absorb"), # (2, 16, 8192, 4096, 512, 64, 128, "non-absorb", 1, False, -1, False),
(2, 16, 8192, 4096, 512, 64, 128, "absorb"), # (2, 16, 8192, 4096, 512, 64, 128, "absorb", 1, False, -1, False),
(2, 16, 16324, 8192, 512, 64, 128, "absorb"), # (2, 16, 16324, 8192, 512, 64, 128, "absorb", 1, False, -1, False),
# log 形状 + 命中 BW200B-EXTEND_ATTENTION-V2-FP16.json 两条 key
(2, 16, 4096, 555, 128, 64, 128, "non-absorb", 1, True, -1, False),
(16, 16, 556, 1, 128, 64, 128, "non-absorb", 1, True, -1, False),
(2, 16, 4096, 1024, 128, 64, 128, "non-absorb", 2, True, 128, True),
(16, 16, 556, 1, 128, 64, 128, "non-absorb", 2, True, 128, True),
] ]
return x_names, x_vals_list return x_names, x_vals_list
...@@ -232,13 +259,17 @@ def get_prefill_benchmark_configs(): ...@@ -232,13 +259,17 @@ def get_prefill_benchmark_configs():
"qk_rope_head_dim", "qk_rope_head_dim",
"v_head_dim", "v_head_dim",
"attn_impl", "attn_impl",
"kv_num_heads",
"is_causal",
"sliding_window_size",
"with_sinks",
] ]
x_vals_list = [ x_vals_list = [
(2, 16, 0, 1024, 256, 0, 128, "non-absorb"), (2, 16, 0, 1024, 256, 0, 128, "non-absorb", 1, False, -1, False),
(2, 16, 0, 4096, 512, 64, 128, "non-absorb"), (2, 16, 0, 4096, 512, 64, 128, "non-absorb", 1, False, -1, False),
(2, 16, 0, 4096, 512, 64, 128, "non-absorb"), (2, 16, 0, 4096, 512, 64, 128, "non-absorb", 1, False, -1, False),
(2, 16, 0, 4096, 512, 64, 128, "absorb"), (2, 16, 0, 4096, 512, 64, 128, "absorb", 1, False, -1, False),
(2, 16, 0, 8192, 512, 64, 128, "absorb"), (2, 16, 0, 8192, 512, 64, 128, "absorb", 1, False, -1, False),
] ]
return x_names, x_vals_list return x_names, x_vals_list
...@@ -266,6 +297,10 @@ def model_benchmark_configs(args): ...@@ -266,6 +297,10 @@ def model_benchmark_configs(args):
"qk_rope_head_dim", "qk_rope_head_dim",
"v_head_dim", "v_head_dim",
"attn_impl", "attn_impl",
"kv_num_heads",
"is_causal",
"sliding_window_size",
"with_sinks",
] ]
x_vals_list = [] x_vals_list = []
...@@ -276,7 +311,21 @@ def model_benchmark_configs(args): ...@@ -276,7 +311,21 @@ def model_benchmark_configs(args):
extend = args.extend if args.extend else 8192 extend = args.extend if args.extend else 8192
attn_impl = args.attn_impl if args.attn_impl else "non-absorb" attn_impl = args.attn_impl if args.attn_impl else "non-absorb"
x_vals_list.append( x_vals_list.append(
(model_name, batch_size, HQ, prefix, extend, 512, 64, 128, attn_impl) (
model_name,
batch_size,
HQ,
prefix,
extend,
512,
64,
128,
attn_impl,
1,
False,
-1,
False,
)
) )
return x_names, x_vals_list return x_names, x_vals_list
...@@ -296,7 +345,22 @@ def benchmark(args): ...@@ -296,7 +345,22 @@ def benchmark(args):
elif args.mode == "prefill": elif args.mode == "prefill":
x_names, x_vals_list = get_prefill_benchmark_configs() x_names, x_vals_list = get_prefill_benchmark_configs()
line_vals = ["extend_attention_fwd"] if args.mode == "prefill":
line_vals = ["context_attention_fwd"]
line_names = ["prefill/context"]
styles = [("blue", "-")]
elif args.extend_provider == "v1":
line_vals = ["extend_v1"]
line_names = ["v1 (scale None)"]
styles = [("red", "-")]
elif args.extend_provider == "v2":
line_vals = ["extend_v2"]
line_names = ["v2 (k=v=1.0)"]
styles = [("green", "-")]
else:
line_vals = ["extend_v1", "extend_v2"]
line_names = ["v1 (scale None)", "v2 (k=v=1.0)"]
styles = [("red", "-"), ("green", "-")]
plot_name = ( plot_name = (
args.plot_name + f"-causal-{args.causal}-equal_seqlens-{args.equal_seqlens}" args.plot_name + f"-causal-{args.causal}-equal_seqlens-{args.equal_seqlens}"
...@@ -308,8 +372,8 @@ def benchmark(args): ...@@ -308,8 +372,8 @@ def benchmark(args):
x_vals=x_vals_list, x_vals=x_vals_list,
line_arg="provider", line_arg="provider",
line_vals=line_vals, line_vals=line_vals,
line_names=line_vals, line_names=line_names,
styles=[("red", "-"), ("green", "-")], styles=styles,
ylabel="ms", ylabel="ms",
plot_name=plot_name, plot_name=plot_name,
args={"sm_scale": 1.0, "logit_cap": 0.0, "device": args.device}, args={"sm_scale": 1.0, "logit_cap": 0.0, "device": args.device},
...@@ -317,23 +381,34 @@ def benchmark(args): ...@@ -317,23 +381,34 @@ def benchmark(args):
) )
@triton.testing.perf_report(configs) @triton.testing.perf_report(configs)
def bench_MLA( def bench_MLA(**kwargs):
B, # perf_report 调用形如 fn(**x_args, provider=..., **bench.args),全部为关键字参数
H, warmup = 5
prefix, rep = 30
extend,
kv_lora_rank, provider = kwargs.pop("provider")
qk_rope_head_dim, sm_scale = kwargs.pop("sm_scale")
v_head_dim, logit_cap = kwargs.pop("logit_cap")
attn_impl, device = kwargs.pop("device")
sm_scale, kwargs.pop("model", None)
logit_cap, kv_num_heads = int(kwargs.pop("kv_num_heads", 1))
device, B = kwargs.pop("B")
provider=None, H = kwargs.pop("H")
model=None, prefix = kwargs.pop("prefix")
): extend = kwargs.pop("extend")
warmup = 25 kv_lora_rank = kwargs.pop("kv_lora_rank")
rep = 100 qk_rope_head_dim = kwargs.pop("qk_rope_head_dim")
v_head_dim = kwargs.pop("v_head_dim")
attn_impl = kwargs.pop("attn_impl")
row_causal = kwargs.pop("is_causal")
sliding_window_size = int(kwargs.pop("sliding_window_size"))
with_sinks = bool(kwargs.pop("with_sinks"))
if kwargs:
raise ValueError(f"unexpected benchmark kwargs: {kwargs}")
sinks_tensor = None
if with_sinks:
sinks_tensor = torch.zeros(H, device=device, dtype=torch.float32)
( (
q_extend, q_extend,
...@@ -360,11 +435,15 @@ def benchmark(args): ...@@ -360,11 +435,15 @@ def benchmark(args):
v_head_dim, v_head_dim,
dtype, dtype,
device, device,
attn_impl=attn_impl,
equal_seqlens=args.equal_seqlens,
kv_num_heads=kv_num_heads,
) )
if provider == "extend_attention_fwd": if provider in ("extend_v1", "extend_v2"):
use_v2 = provider == "extend_v2"
def extend_attention(): def fn():
return extend_forward( return extend_forward(
q_extend, q_extend,
k_extend, k_extend,
...@@ -377,33 +456,35 @@ def benchmark(args): ...@@ -377,33 +456,35 @@ def benchmark(args):
custom_mask, custom_mask,
mask_indptr, mask_indptr,
max_len_extend, max_len_extend,
args.causal, row_causal,
sm_scale, sm_scale,
logit_cap, logit_cap,
use_v2=use_v2,
sliding_window_size=sliding_window_size,
sinks=sinks_tensor,
) )
def context_attention(): elif provider == "context_attention_fwd":
return extend_forward( assert (
prefix == 0
), "Prefix length must be 0 for context attention. Try setting -mode prefill."
def fn():
return prefill_forward(
q_extend, q_extend,
k_extend, k_extend,
v_extend, v_extend,
B_Start_Loc, B_Start_Loc,
B_Seqlen, B_Seqlen,
max_len_extend, max_len_extend,
args.causal, row_causal,
) )
if provider == "extend_attention_fwd":
fn = extend_attention
elif provider == "context_attention_fwd":
assert (
prefix == 0
), "Prefix length must be 0 for context attention. Try setting -mode prefill."
fn = context_attention
else: else:
raise ValueError(f"Unknown provider: {provider}") raise ValueError(f"Unknown provider: {provider}")
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) ms = triton.testing.do_bench_cudagraph(fn)
# ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms return ms
...@@ -474,16 +555,16 @@ def parse_args(): ...@@ -474,16 +555,16 @@ def parse_args():
default="extend", default="extend",
help="Mode of the benchmark. Options: extend, prefill", help="Mode of the benchmark. Options: extend, prefill",
) )
parser.add_argument(
"-extend_provider",
type=str,
default="both",
choices=("both", "v1", "v2"),
help="Which extend_attention path to benchmark: v1 (k_scale=v_scale=None), v2 (1.0), or both. Ignored when -mode prefill.",
)
return parser.parse_args() return parser.parse_args()
arg_to_torch_dtype = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
}
def run_bench(args): def run_bench(args):
torch.manual_seed(0) torch.manual_seed(0)
torch.set_default_device(args.device) torch.set_default_device(args.device)
......
# Test for get_aiter_moe_config and aiter_moe with W16A16 non-gated ReLU²
# (Nemotron-style MOE: N1 = intermediate_size, activation = relu2)
import torch
import pandas as pd
from typing import Optional, List
from aiter.fused_moe import fused_topk
from aiter import dtypes
from aiter.test_common import checkAllclose, perftest
from aiter.moe import (
get_aiter_moe_config,
aiter_moe,
MoeSolutionType,
MoeQuantType,
)
from aiter.fused_moe_asm_wna16 import fused_experts_asm_impl
from aiter.ops.shuffle import asm_shuffle_weight_b8
import aiter
torch.set_default_device("cuda")
# ---------------------------------------------------------------------------
# Torch reference for non-gated ReLU² MOE
# ---------------------------------------------------------------------------
def torch_moe_relu2(hidden_states, w1, w2, topk_weights, topk_ids):
"""Reference implementation for non-gated ReLU² MOE.
w1: [E, inter_dim, model_dim] (NOT 2*inter_dim)
w2: [E, model_dim, inter_dim]
"""
computeType = torch.float32
dtype = hidden_states.dtype
hidden_states = hidden_states.to(computeType)
w1 = w1.to(computeType)
w2 = w2.to(computeType)
B, D = hidden_states.shape
topk = topk_weights.shape[1]
hidden_states = hidden_states.view(B, -1, D).repeat(1, topk, 1)
out = torch.zeros(
(B, topk, D),
dtype=computeType,
device=hidden_states.device,
)
for E_id in range(w1.shape[0]):
mask = topk_ids == E_id
if mask.sum():
sub_tokens = hidden_states[mask]
# GEMM1
h = sub_tokens @ w1[E_id].T
# ReLU²
h = torch.relu(h) ** 2
# GEMM2
out[mask] = h @ w2[E_id].T
return (out * topk_weights.view(B, -1, 1)).sum(dim=1).to(dtype)
# ---------------------------------------------------------------------------
# Weight preparation helpers (W16A16 non-gated)
# ---------------------------------------------------------------------------
def prepare_w16a16_nogate_inputs(m, k, n, e, topk, dtype):
"""Build all tensors needed to run a non-gated w16a16 MOE test.
Key difference from gated: w1 shape is [E, n, k] instead of [E, 2*n, k].
"""
torch.manual_seed(0)
input_tensor = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, n, k), device="cuda", dtype=dtype) / 2
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 2
score = torch.randn((m, e), device="cuda", dtype=dtype)
w1_shuffle = asm_shuffle_weight_b8(w1, stage=1)
w2_shuffle = asm_shuffle_weight_b8(w2, stage=2)
topk_weights, topk_ids = fused_topk(input_tensor, score, topk, True)
return {
"input": input_tensor,
"w1": w1,
"w2": w2,
"w1_shuffle": w1_shuffle,
"w2_shuffle": w2_shuffle,
"topk_weights": topk_weights,
"topk_ids": topk_ids,
"score": score,
}
# ---------------------------------------------------------------------------
# Test: get_aiter_moe_config (w16a16 non-gated relu2)
# ---------------------------------------------------------------------------
def test_get_config(m, k, n, e, topk, dtype):
"""Test that get_aiter_moe_config returns a valid w16a16 config with
activation='relu2' or gracefully reports no-solution."""
N1 = n # non-gated: N1 = intermediate_size (NOT 2 * intermediate_size)
N2 = k # down / hidden_size
K = k # model dimension
status, moe_cfg = get_aiter_moe_config(
M=m, E=e, N1=N1, N2=N2, K=K,
top_k=topk, block_size=0, dtype=dtype,
quant_type=MoeQuantType.W16A16,
activation="relu2",
gated=False,
)
if status:
assert moe_cfg.solution_type is not None, \
"status=True but solution_type is None"
assert moe_cfg.config is not None, \
"status=True but config is None"
assert moe_cfg.solution_type in (
MoeSolutionType.ASM,
MoeSolutionType.TRITON,
), f"Unexpected solution_type: {moe_cfg.solution_type}"
assert moe_cfg.quant_type == MoeQuantType.W16A16
aiter.logger.info(
f"[get_config_w16a16_nogate] {m=}, {k=}, {n=}, {e=}, {topk=}, "
f"solution={moe_cfg.solution_type}, "
f"config keys={list(moe_cfg.config.keys())}"
)
else:
assert moe_cfg.solution_type is None, \
"status=False but solution_type is not None"
assert moe_cfg.config is None, \
"status=False but config is not None"
aiter.logger.info(
f"[get_config_w16a16_nogate] {m=}, {k=}, {n=}, {e=}, {topk=}, "
f"no solution found (expected on unsupported configs)"
)
return status, moe_cfg
# ---------------------------------------------------------------------------
# Test: aiter_moe end-to-end for w16a16 non-gated relu2
# ---------------------------------------------------------------------------
@perftest(num_warmup=1, num_iters=2)
def _run_torch_ref(hidden_states, w1, w2, topk_weights, topk_ids):
return torch_moe_relu2(hidden_states, w1, w2, topk_weights, topk_ids)
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1)
def _run_aiter_moe_perf(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
moe_config,
inplace,
activation,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1_scale,
a2_scale,
block_shape,
global_num_experts,
expert_map,
routed_scaling_factor,
):
return aiter_moe(hidden_states, w1, w2, topk_weights, topk_ids, moe_config, inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape, global_num_experts, expert_map, routed_scaling_factor)
def test_aiter_moe_w16a16_nogate(m, k, n, e, topk, dtype, inplace, routed_scaling_factor):
"""End-to-end: get config -> run aiter_moe with relu2 -> compare with
torch reference."""
N1 = n # non-gated: N1 = intermediate_size
N2 = k
K = k
status, moe_cfg = get_aiter_moe_config(
M=m, E=e, N1=N1, N2=N2, K=K,
top_k=topk, block_size=0, dtype=dtype,
quant_type=MoeQuantType.W16A16,
activation="relu2",
gated=False,
)
if not status:
aiter.logger.info(
f"[aiter_moe_w16a16_nogate] SKIP {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}: "
f"no backend available"
)
return None
backend = moe_cfg.solution_type
aiter.logger.info(
f"[aiter_moe_w16a16_nogate] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
f"backend={backend}"
)
data = prepare_w16a16_nogate_inputs(m, k, n, e, topk, dtype)
# Torch reference
ref_out, _ = _run_torch_ref(
data["input"], data["w1"], data["w2"],
data["topk_weights"], data["topk_ids"],
)
# aiter_moe dispatch with relu2 activation
aiter_us = 1.0
aiter_out, aiter_us = _run_aiter_moe_perf(
hidden_states=data["input"],
w1=data["w1"],
w2=data["w2"],
topk_weights=data["topk_weights"],
topk_ids=data["topk_ids"],
moe_config=moe_cfg,
inplace=inplace,
activation="relu2",
w1_scale=None,
w2_scale=None,
w1_zp=None,
w2_zp=None,
a1_scale=None,
a2_scale=None,
block_shape=None,
global_num_experts=e,
expert_map=None,
routed_scaling_factor=routed_scaling_factor,
)
msg = (f"[aiter_moe_w16a16_nogate] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
f"backend={backend}")
checkAllclose(ref_out, aiter_out, rtol=0.01, atol=0.5, msg=msg)
return {"m": m, "backend": backend, "us": aiter_us}
# ---------------------------------------------------------------------------
# Test: aiter_moe w16a16 non-gated ASM shuffle vs non-shuffle
# ---------------------------------------------------------------------------
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1)
def _run_asm_perf(hidden_states, w1, w2, topk_weights, topk_ids,
dtype, global_num_experts, expert_map):
return fused_experts_asm_impl(
hidden_states, w1, w2, topk_weights, topk_ids, dtype,
activation="relu2",
global_num_experts=global_num_experts,
expert_map=expert_map)
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1)
def _run_asm_shuffle_perf(hidden_states, w1, w2, topk_weights, topk_ids,
dtype, global_num_experts, expert_map):
return fused_experts_asm_impl(
hidden_states, w1, w2, topk_weights, topk_ids, dtype,
activation="relu2",
global_num_experts=global_num_experts,
expert_map=expert_map,
use_shuffle=1)
def test_aiter_moe_w16a16_nogate_shuffle(m, k, n, e, topk, dtype):
"""Test w16a16 non-gated ASM with shuffled weights vs non-shuffled ASM."""
data = prepare_w16a16_nogate_inputs(m, k, n, e, topk, dtype)
try:
asm_out, asm_us = _run_asm_perf(
data["input"], data["w1"], data["w2"],
data["topk_weights"], data["topk_ids"],
dtype, e, None)
except Exception as exc:
aiter.logger.info(
f"[w16a16_nogate_shuffle] SKIP {m=}: ASM not available ({exc})")
return None
shuffle_out, shuffle_us = _run_asm_shuffle_perf(
data["input"], data["w1_shuffle"], data["w2_shuffle"],
data["topk_weights"], data["topk_ids"],
dtype, e, None)
msg = (f"[w16a16_nogate_shuffle] {m=}, {k=}, {n=}, {e=}, {topk=}, "
f"asm_us={asm_us:.2f}, shuffle_us={shuffle_us:.2f}")
checkAllclose(asm_out, shuffle_out, rtol=0.01, atol=0.01, msg=msg)
uplift = asm_us / shuffle_us - 1 if shuffle_us > 0 else 0
return {
"m": m,
"asm_us": asm_us,
"shuffle_us": shuffle_us,
"shuffle_uplift": f"{uplift:.1%}",
}
if __name__ == "__main__":
dtype = dtypes.bf16
# Nemotron-style MoE parameters (non-gated, ReLU²)
e = 256
topk = 8
k = 3072 # model_dim / hidden_size
n = 128 # intermediate_size (NOT multiplied by 2)
inplace = False
routed_scaling_factor = 1.0
# --- Part 1: test get_aiter_moe_config (w16a16 non-gated relu2) ---
aiter.logger.info("=" * 60)
aiter.logger.info("Part 1: Testing get_aiter_moe_config for w16a16 non-gated relu2")
aiter.logger.info("=" * 60)
test_tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
for m in test_tokens:
test_get_config(m, k, n, e, topk, dtype)
# --- Part 2: test aiter_moe end-to-end (w16a16 non-gated relu2) ---
aiter.logger.info("=" * 60)
aiter.logger.info("Part 2: Testing aiter_moe end-to-end for w16a16 non-gated relu2")
aiter.logger.info("=" * 60)
df = []
for m in test_tokens:
ret = test_aiter_moe_w16a16_nogate(m, k, n, e, topk, dtype, inplace, routed_scaling_factor)
if ret is not None:
df.append(ret)
if df:
df = pd.DataFrame(df)
aiter.logger.info(f"aiter_moe non-gated relu2 summary:\n{df}")
# --- Part 3: test ASM shuffle vs non-shuffle (w16a16 non-gated relu2) ---
aiter.logger.info("=" * 60)
aiter.logger.info("Part 3: Testing ASM shuffle vs non-shuffle for w16a16 non-gated relu2")
aiter.logger.info("=" * 60)
df_shuffle = []
for m in test_tokens:
ret = test_aiter_moe_w16a16_nogate_shuffle(m, k, n, e, topk, dtype)
if ret is not None:
df_shuffle.append(ret)
if df_shuffle:
df_shuffle = pd.DataFrame(df_shuffle)
aiter.logger.info(f"shuffle summary (non-gated relu2):\n{df_shuffle}")
# Test for get_aiter_moe_config and aiter_moe with w8a8 channel-wise quantization # Test for get_aiter_moe_config and aiter_moe with w8a8 channel-wise quantization
import argparse
import torch import torch
import pandas as pd import pandas as pd
...@@ -13,6 +14,7 @@ from aiter.moe import ( ...@@ -13,6 +14,7 @@ from aiter.moe import (
MoeQuantType, MoeQuantType,
) )
from aiter.ops.shuffle import moe_layout_shuffle_gemm1, moe_layout_shuffle_gemm2 from aiter.ops.shuffle import moe_layout_shuffle_gemm1, moe_layout_shuffle_gemm2
from aiter.ops.quant import pertoken_quant
import aiter import aiter
...@@ -74,9 +76,11 @@ def _run_aiter_moe_perf( ...@@ -74,9 +76,11 @@ def _run_aiter_moe_perf(
) )
def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype): def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type=MoeQuantType.W8A8):
"""Prepare channel-wise quantized w8a8 inputs. """Prepare channel-wise quantized w8a8 inputs.
For int8 (W8A8): weights quantized to torch.int8, scales = max_val / 127.
For fp8 (FP8_W8A8): weights quantized to float8 via pertoken_quant.
Scale shape: (e, out_dim, 1) — one scale per output channel. Scale shape: (e, out_dim, 1) — one scale per output channel.
block_shape is None for channel-wise. block_shape is None for channel-wise.
""" """
...@@ -86,7 +90,12 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype): ...@@ -86,7 +90,12 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype):
w1_fp = torch.randn((e, 2 * n, k), dtype=dtype, device="cuda") w1_fp = torch.randn((e, 2 * n, k), dtype=dtype, device="cuda")
w2_fp = torch.randn((e, k, n), dtype=dtype, device="cuda") w2_fp = torch.randn((e, k, n), dtype=dtype, device="cuda")
# Channel-wise quantization: max per output channel (last dim of weight row) if quant_type == MoeQuantType.FP8_W8A8:
# FP8 channel-wise quantization via pertoken_quant
w1_qweight, w1_scales = pertoken_quant(w1_fp, quant_dtype=dtypes.fp8)
w2_qweight, w2_scales = pertoken_quant(w2_fp, quant_dtype=dtypes.fp8)
else:
# INT8 channel-wise quantization: max per output channel
max_vals_w1 = torch.abs(w1_fp.to(torch.float32)).max(dim=-1, keepdim=True)[0] max_vals_w1 = torch.abs(w1_fp.to(torch.float32)).max(dim=-1, keepdim=True)[0]
max_vals_w1 = max_vals_w1.clamp(min=1e-5) max_vals_w1 = max_vals_w1.clamp(min=1e-5)
w1_scales = max_vals_w1 / 127.0 # (e, 2*n, 1) w1_scales = max_vals_w1 / 127.0 # (e, 2*n, 1)
...@@ -119,7 +128,7 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype): ...@@ -119,7 +128,7 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype):
} }
def test_get_config(m, k, n, e, topk, dtype): def test_get_config(m, k, n, e, topk, dtype, quant_type=MoeQuantType.W8A8):
"""Test get_aiter_moe_config for channel-wise w8a8 (block_size=0).""" """Test get_aiter_moe_config for channel-wise w8a8 (block_size=0)."""
status, moe_cfg = get_aiter_moe_config( status, moe_cfg = get_aiter_moe_config(
M=m, M=m,
...@@ -130,11 +139,12 @@ def test_get_config(m, k, n, e, topk, dtype): ...@@ -130,11 +139,12 @@ def test_get_config(m, k, n, e, topk, dtype):
top_k=topk, top_k=topk,
block_size=0, block_size=0,
dtype=dtype, dtype=dtype,
quant_type=MoeQuantType.W8A8, quant_type=quant_type,
) )
tag = f"get_config_{quant_type}_cw"
if status: if status:
assert moe_cfg.quant_type == MoeQuantType.W8A8 assert moe_cfg.quant_type == quant_type
assert moe_cfg.solution_type in ( assert moe_cfg.solution_type in (
MoeSolutionType.ASM, MoeSolutionType.ASM,
MoeSolutionType.MOE_C, MoeSolutionType.MOE_C,
...@@ -143,19 +153,19 @@ def test_get_config(m, k, n, e, topk, dtype): ...@@ -143,19 +153,19 @@ def test_get_config(m, k, n, e, topk, dtype):
) )
assert moe_cfg.config is not None assert moe_cfg.config is not None
aiter.logger.info( aiter.logger.info(
f"[get_config_w8a8_cw] {m=}, solution={moe_cfg.solution_type}, " f"[{tag}] {m=}, solution={moe_cfg.solution_type}, "
f"config keys={list(moe_cfg.config.keys())}" f"config keys={list(moe_cfg.config.keys())}"
) )
else: else:
assert moe_cfg.solution_type is None assert moe_cfg.solution_type is None
assert moe_cfg.config is None assert moe_cfg.config is None
aiter.logger.info(f"[get_config_w8a8_cw] {m=}, no solution found") aiter.logger.info(f"[{tag}] {m=}, no solution found")
return status, moe_cfg return status, moe_cfg
def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype): def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype, quant_type=MoeQuantType.W8A8):
"""End-to-end test of aiter_moe with channel-wise w8a8.""" """End-to-end test of aiter_moe with channel-wise w8a8 (int8 or fp8)."""
status, moe_cfg = get_aiter_moe_config( status, moe_cfg = get_aiter_moe_config(
M=m, M=m,
E=e, E=e,
...@@ -165,14 +175,15 @@ def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype): ...@@ -165,14 +175,15 @@ def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype):
top_k=topk, top_k=topk,
block_size=0, block_size=0,
dtype=dtype, dtype=dtype,
quant_type=MoeQuantType.W8A8, quant_type=quant_type,
) )
tag = f"aiter_moe_{quant_type}_cw"
if not status: if not status:
aiter.logger.info(f"[aiter_moe_w8a8_cw] SKIP {m=}: no backend available") aiter.logger.info(f"[{tag}] SKIP {m=}: no backend available")
return None return None
data = prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype) data = prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type)
# Torch reference uses original fp weights directly (no scales needed) # Torch reference uses original fp weights directly (no scales needed)
ref_out, _ = _run_torch_ref( ref_out, _ = _run_torch_ref(
...@@ -216,31 +227,46 @@ def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype): ...@@ -216,31 +227,46 @@ def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype):
print("ref_out",ref_out) print("ref_out",ref_out)
msg = f"[aiter_moe_w8a8_cw] {m=}, backend={moe_cfg.solution_type}" msg = f"[{tag}] {m=}, backend={moe_cfg.solution_type}"
checkAllclose(ref_out, aiter_out, rtol=0.01, atol=100, msg=msg) checkAllclose(ref_out, aiter_out, rtol=0.01, atol=100, msg=msg)
return {"m": m, "backend": moe_cfg.solution_type, "us": aiter_us} return {"m": m, "quant_type": quant_type, "backend": moe_cfg.solution_type, "us": aiter_us}
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Test aiter_moe with channel-wise w8a8 quantization",
)
parser.add_argument(
"--quant",
choices=["int8", "fp8"],
default="int8",
help="Quantization type: int8 (MoeQuantType.W8A8) or fp8 (MoeQuantType.FP8_W8A8)",
)
args = parser.parse_args()
quant_type = MoeQuantType.FP8_W8A8 if args.quant == "fp8" else MoeQuantType.W8A8
# for moe_c backend, it does not support n=320 for now;
# for triton backend, it can run with n=320 in NMZ;
dtype = dtypes.bf16 dtype = dtypes.bf16
e = 256 e = 256
topk = 8 topk = 8
k = 6144 k = 6144
n = 256 n = 320
aiter.logger.info("=" * 60) aiter.logger.info("=" * 60)
aiter.logger.info("Part 1: Testing get_aiter_moe_config for w8a8 channel-wise") aiter.logger.info(f"Part 1: Testing get_aiter_moe_config for {quant_type} channel-wise")
aiter.logger.info("=" * 60) aiter.logger.info("=" * 60)
test_tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048 , 4096, 6144 , 8192 , 16384] test_tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048 , 4096, 6144 , 8192 , 16384]
for m in test_tokens: for m in test_tokens:
test_get_config(m, k, n, e, topk, dtype) test_get_config(m, k, n, e, topk, dtype, quant_type)
aiter.logger.info("=" * 60) aiter.logger.info("=" * 60)
aiter.logger.info("Part 2: Testing aiter_moe end-to-end for w8a8 channel-wise") aiter.logger.info(f"Part 2: Testing aiter_moe end-to-end for {quant_type} channel-wise")
aiter.logger.info("=" * 60) aiter.logger.info("=" * 60)
df = [] df = []
for m in test_tokens: for m in test_tokens:
ret = test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype) ret = test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype, quant_type)
if ret is not None: if ret is not None:
df.append(ret) df.append(ret)
if df: if df:
......
...@@ -514,6 +514,7 @@ def fused_moe( ...@@ -514,6 +514,7 @@ def fused_moe(
assert B_zp is None or B_zp.ndim == 3 assert B_zp is None or B_zp.ndim == 3
offset_max = 2**31 - 1 offset_max = 2**31 - 1
use_addr_offset_int64_a = A.numel() * A.element_size() >= offset_max use_addr_offset_int64_a = A.numel() * A.element_size() >= offset_max
use_addr_offset_int64_b = B.numel() * B.element_size() >= offset_max
use_addr_offset_int64_c = C.numel() * C.element_size() >= offset_max use_addr_offset_int64_c = C.numel() * C.element_size() >= offset_max
if use_int4_w4a8: if use_int4_w4a8:
...@@ -554,6 +555,7 @@ def fused_moe( ...@@ -554,6 +555,7 @@ def fused_moe(
group_size=block_shape[1], group_size=block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight, MUL_ROUTED_WEIGHT=mul_routed_weight,
USE_ADDR_OFFSET_INT64_A=use_addr_offset_int64_a, USE_ADDR_OFFSET_INT64_A=use_addr_offset_int64_a,
USE_ADDR_OFFSET_INT64_B=use_addr_offset_int64_b,
USE_ADDR_OFFSET_INT64_C=use_addr_offset_int64_c, USE_ADDR_OFFSET_INT64_C=use_addr_offset_int64_c,
top_k=top_k, top_k=top_k,
compute_type=compute_type, compute_type=compute_type,
...@@ -599,6 +601,7 @@ def fused_moe( ...@@ -599,6 +601,7 @@ def fused_moe(
group_size=block_shape[1], group_size=block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight, MUL_ROUTED_WEIGHT=mul_routed_weight,
USE_ADDR_OFFSET_INT64_A=use_addr_offset_int64_a, USE_ADDR_OFFSET_INT64_A=use_addr_offset_int64_a,
USE_ADDR_OFFSET_INT64_B=use_addr_offset_int64_b,
USE_ADDR_OFFSET_INT64_C=use_addr_offset_int64_c, USE_ADDR_OFFSET_INT64_C=use_addr_offset_int64_c,
top_k=top_k, top_k=top_k,
compute_type=compute_type, compute_type=compute_type,
......
import os import os
import sys
os.environ["AMDGCN_USE_BUFFER_OPS"] = "1" os.environ["AMDGCN_USE_BUFFER_OPS"] = "1"
# :class:`Hcutuner` reads ``TRITON_HCUTUNE_PERF_MODE`` in ``__init__``. Module-level
# ``fn`` / ``fn_v2 = triton.utils.hcutune(...)`` runs at import time, *before*
# ``if __name__ == "__main__"``, so ``--perf`` must be applied here (or set in the shell).
if "--perf" in sys.argv:
os.environ["TRITON_HCUTUNE_PERF_MODE"] = "1"
# GPU / ROCm tuning: run **inside** the ``zww_tl_1`` container (not the bare host), e.g.:
# docker exec zww_tl_1 bash -lc 'cd /data/zhouweiwang/aiter/op_tests/triton_autotune && python tune_extend_attention.py --perf'
# ``do_bench`` timing during tuning (smaller => faster iteration; raise for stable prod numbers).
TUNE_DO_BENCH_WARMUP = 5
TUNE_DO_BENCH_REP = 20
import json import json
import torch import torch
...@@ -8,10 +22,27 @@ import random ...@@ -8,10 +22,27 @@ import random
import itertools import itertools
import argparse import argparse
from aiter.ops.triton.extend_attention import _fwd_kernel from aiter.ops.triton.extend_attention import _fwd_kernel, _fwd_kernel_v2
_is_hip = True _is_hip = True
# hcutune key for :func:`_fwd_kernel_v2`. JSON block-size lookup uses :func:`_get_config_v2`
# ``want7`` only; these names add kernel constexprs (``SKIP_PREFIX_CUSTOM_MASK``,
# ``xai_temperature_len``) used for autotune but not in the V2 JSON key.
# Log alignment (e.g. ``fp8_dp2_tp8_415_triton_rocm_nomtp.log`` ~2291): ``kv_group_num = q.size(-2)//k.size(-2)``,
# ``Lq``/``Lv`` last dims, ``USE_CUSTOM_MASK = custom_mask is not None``, ``HAS_SINK = sinks is not None``.
HCUTUNE_KEY_V2 = [
"kv_group_num",
"Lq",
"Lv",
"USE_CUSTOM_MASK",
"IS_CAUSAL",
"SKIP_PREFIX_CUSTOM_MASK",
"HAS_SINK",
"SLIDING_WINDOW_SIZE",
"xai_temperature_len",
]
version = triton.__version__.split(".") version = triton.__version__.split(".")
major_version, minor_version = eval(version[0]), eval(version[1]) major_version, minor_version = eval(version[0]), eval(version[1])
...@@ -29,6 +60,7 @@ def input_helper( ...@@ -29,6 +60,7 @@ def input_helper(
attn_impl="normal", attn_impl="normal",
equal_seqlens=False, equal_seqlens=False,
requires_grad=False, requires_grad=False,
kv_num_heads: int = 1,
): ):
torch.manual_seed(0) torch.manual_seed(0)
...@@ -82,9 +114,9 @@ def input_helper( ...@@ -82,9 +114,9 @@ def input_helper(
total_extend, H, Lq, dtype=dtype, device=device total_extend, H, Lq, dtype=dtype, device=device
).requires_grad_(requires_grad) ).requires_grad_(requires_grad)
# extend parts # extend parts (``kv_num_heads`` for GQA: e.g. 2 when q has 16 heads and kv_group_num is 8)
k_extend = torch.randn( k_extend = torch.randn(
total_extend, 1, Lk, dtype=dtype, device=device total_extend, kv_num_heads, Lk, dtype=dtype, device=device
).requires_grad_(requires_grad) ).requires_grad_(requires_grad)
v_extend = k_extend[..., :Lv] v_extend = k_extend[..., :Lv]
...@@ -93,7 +125,7 @@ def input_helper( ...@@ -93,7 +125,7 @@ def input_helper(
# prefix parts # prefix parts
k_buffer = torch.randn( k_buffer = torch.randn(
total_prefix, 1, Lk, dtype=dtype, device=device total_prefix, kv_num_heads, Lk, dtype=dtype, device=device
).requires_grad_(requires_grad) ).requires_grad_(requires_grad)
v_buffer = k_buffer[..., :Lv] v_buffer = k_buffer[..., :Lv]
...@@ -169,18 +201,42 @@ def generate_configs(config): ...@@ -169,18 +201,42 @@ def generate_configs(config):
return configs_list return configs_list
# def get_triton_configs():
# config = {
# "BLOCK_M": [16, 32, 64],
# "BLOCK_N": [16, 32, 64],
# "waves_per_eu": [1],
# "num_warps": [4, 8, 16],
# # "instruction_sched_variant": ["none", "llvm-iglp-1", "llvm-iglp-8", "local-prefetch"],
# # "schedule_hint": ["none", "llvm-iglp-8", "local-prefetch"],
# "matrix_instr_nonkdim": [16],
# "num_stages": [1, 2, 3],
# "sched_latency": ["none", "mmac5-ds10"],
# "kpack": [1, 2],
# }
# tt_configs = []
# for c in generate_configs(config):
# num_warps = c['num_warps']
# num_stages = c['num_stages']
# del c['num_warps']
# del c['num_stages']
# tt_configs.append(triton.Config(c, num_warps=num_warps, num_stages=num_stages))
# return tt_configs
def get_triton_configs(): def get_triton_configs():
config = { config = {
"BLOCK_M": [16, 32, 64], "BLOCK_M": [16, 32, 64],
"BLOCK_N": [16, 32, 64], "BLOCK_N": [16, 32, 64],
"waves_per_eu": [1], "waves_per_eu": [1],
"num_warps": [4, 8, 16], "num_warps": [4, 8],
# "instruction_sched_variant": ["none", "llvm-iglp-1", "llvm-iglp-8", "local-prefetch"], # "instruction_sched_variant": ["none", "llvm-iglp-1", "llvm-iglp-8", "local-prefetch"],
# "schedule_hint": ["none", "llvm-iglp-8", "local-prefetch"], # "schedule_hint": ["none", "llvm-iglp-8", "local-prefetch"],
"matrix_instr_nonkdim": [16], "matrix_instr_nonkdim": [16],
"num_stages": [1, 2, 3], "num_stages": [1, 2],
"sched_latency": ["none", "mmac5-ds10"], "sched_latency": ["none", "mmac5-ds10"],
"kpack": [1, 2], "kpack": [1],
} }
tt_configs = [] tt_configs = []
...@@ -193,7 +249,6 @@ def get_triton_configs(): ...@@ -193,7 +249,6 @@ def get_triton_configs():
return tt_configs return tt_configs
def prune_configs(configs, nargs, **kwargs): def prune_configs(configs, nargs, **kwargs):
def _prune(config): def _prune(config):
c = config.all_kwargs() c = config.all_kwargs()
...@@ -216,8 +271,23 @@ key = [ ...@@ -216,8 +271,23 @@ key = [
'SKIP_PREFIX_CUSTOM_MASK', 'SKIP_PREFIX_CUSTOM_MASK',
'STORE_TRANSPOSE', 'STORE_TRANSPOSE',
] ]
fn = triton.utils.hcutune(configs=get_triton_configs(), key=key, perf_debug=True, fn = triton.utils.hcutune(
prune_configs_by={"early_config_prune": prune_configs})(_fwd_kernel) configs=get_triton_configs(),
key=key,
perf_debug=True,
prune_configs_by={"early_config_prune": prune_configs},
warmup=TUNE_DO_BENCH_WARMUP,
rep=TUNE_DO_BENCH_REP,
)(_fwd_kernel)
fn_v2 = triton.utils.hcutune(
configs=get_triton_configs(),
key=HCUTUNE_KEY_V2,
perf_debug=True,
prune_configs_by={"early_config_prune": prune_configs},
warmup=TUNE_DO_BENCH_WARMUP,
rep=TUNE_DO_BENCH_REP,
)(_fwd_kernel_v2)
def extend_attention_fwd( def extend_attention_fwd(
...@@ -329,6 +399,139 @@ def extend_attention_fwd( ...@@ -329,6 +399,139 @@ def extend_attention_fwd(
) )
def extend_attention_fwd_v2(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask,
is_causal,
mask_indptr,
max_len_extend,
sm_scale,
k_scale,
v_scale,
sliding_window_size,
sinks,
window_kv_offsets,
xai_temperature_len,
skip_prefix_custom_mask,
logit_cap,
):
"""Launch :data:`fn_v2` (hcutune-wrapped :func:`_fwd_kernel_v2`); kwargs align with ``extend_attention_fwd`` v2 path."""
Lq, Lv = q_extend.shape[-1], v_extend.shape[-1]
if Lq == 576:
BLOCK_DMODEL, BLOCK_DPE = 512, 64
elif Lq == 288:
BLOCK_DMODEL, BLOCK_DPE = 256, 32
elif Lq == 192:
BLOCK_DMODEL, BLOCK_DPE = 128, 64
else:
BLOCK_DMODEL = triton.next_power_of_2(Lq)
BLOCK_DPE = 0
BLOCK_DV = triton.next_power_of_2(Lv)
sm_scale = sm_scale or 1.0 / (Lq**0.5)
batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1]
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
USE_CUSTOM_MASK = custom_mask is not None
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
HAS_SINK = sinks is not None
if not USE_CUSTOM_MASK:
custom_mask = torch.tensor([0], dtype=torch.bool, device=q_extend.device)
mask_indptr = torch.tensor([0], dtype=torch.int32, device=q_extend.device)
grid = lambda META: (
batch_size,
head_num,
triton.cdiv(max_len_extend, META["BLOCK_M"]),
)
stride_args = (
q_extend.stride(0),
q_extend.stride(1),
k_extend.stride(0),
k_extend.stride(1),
v_extend.stride(0),
v_extend.stride(1),
o_extend.stride(0),
o_extend.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
v_buffer.stride(0),
v_buffer.stride(1),
)
fn_v2[grid](
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask,
mask_indptr,
sinks,
window_kv_offsets,
sm_scale,
k_scale,
v_scale,
kv_group_num,
*stride_args,
SLIDING_WINDOW_SIZE=sliding_window_size,
logit_cap=logit_cap,
xai_temperature_len=xai_temperature_len,
HAS_SINK=HAS_SINK,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV,
Lq=Lq,
Lv=Lv,
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
IS_CAUSAL=is_causal,
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
STORE_TRANSPOSE=True,
)
def get_bench_inputs_v2():
"""Cases for :func:`_fwd_kernel_v2` / ``_get_config_v2`` keys.
对 ``fp8_dp2_tp8_415_triton_rocm_nomtp.log`` 逐条核对后,与 extend 相关的**键控组合**只有两类:
- **MHA**(``k_extend[...,1,...]``):``sliding_window_size=-1``,``sinks=None``(如 log 中 ``k_extend [1152,1,192]`` / ``[3952,1,192]`` 段)。
- **GQA**(``k_extend[...,2,...]``):``sliding_window_size=128``,``sinks shape [16]``(全文件未出现 GQA 与 ``-1``/无 sinks 的同框记录)。
该 log 中 **未出现** ``sliding_window_size=64``,也未出现「GQA + 无 SWA + 无 sinks」;若需覆盖其它模型再单独加行并注明来源。
"""
names = [
"B",
"H",
"prefix",
"extend",
"kv_lora_rank",
"qk_rope_head_dim",
"v_head_dim",
"causal",
"custom_mask",
"sliding_window_size",
"has_sink",
"kv_num_heads",
]
vals = [
# (prefix, extend) 只影响访存/grid;want7 由 head/dim/SWA/sinks 决定。prefix=8192、extend=1024 与长 KV bench 习惯一致。
# (1) GQA Q16/KV2:与 log 中 ``[3952,2,192]`` + ``sliding_window_size=128`` + ``sinks [16]`` 一致;want7 (8,192,128,F,T,T,128)
(4, 16, 8192, 1024, 128, 64, 128, True, False, 128, True, 2),
# (2) MHA Q16/KV1:与 log 中 ``[...,1,192]`` + ``sliding_window_size=-1`` + ``sinks=None`` 一致;want7 (16,192,128,F,T,F,-1)
(4, 16, 8192, 1024, 128, 64, 128, True, False, -1, False, 1),
]
return names, vals
x_names, x_vals = get_bench_inputs() x_names, x_vals = get_bench_inputs()
configs = [ configs = [
triton.testing.Benchmark( triton.testing.Benchmark(
...@@ -401,14 +604,136 @@ def bench_extend_attention(B, H, prefix, extend, kv_lora_rank, qk_rope_head_dim, ...@@ -401,14 +604,136 @@ def bench_extend_attention(B, H, prefix, extend, kv_lora_rank, qk_rope_head_dim,
sm_scale=sm_scale, sm_scale=sm_scale,
logit_cap=logit_cap, logit_cap=logit_cap,
) )
return triton.testing.do_bench(fn) return triton.testing.do_bench_cudagraph(fn)
x_names_v2, x_vals_v2 = get_bench_inputs_v2()
configs_v2 = [
triton.testing.Benchmark(
x_names=x_names_v2,
x_vals=x_vals_v2,
line_arg="provider",
line_vals=["triton_v2"],
line_names=["triton_v2"],
styles=[("blue", "-")],
ylabel="ms",
plot_name="extend_attention_v2_hcutune",
args={"dtype": torch.bfloat16},
)
]
@triton.utils.dist_perf_report(configs_v2)
def bench_extend_attention_v2(
B,
H,
prefix,
extend,
kv_lora_rank,
qk_rope_head_dim,
v_head_dim,
causal,
custom_mask,
sliding_window_size,
has_sink,
kv_num_heads,
provider,
dtype,
):
torch.manual_seed(42)
device = "cpu" if os.getenv("TRITON_HCUTUNE_COMPILE_ONLY", "") == "1" else "cuda"
ref_attn_impl = "normal"
logit_cap = 0.0
k_scale = 1.0
v_scale = 1.0
xai_temperature_len = -1
skip_prefix_custom_mask = True
(
q_extend,
k_extend,
v_extend,
k_buffer,
v_buffer,
kv_indptr,
kv_indices,
qo_indptr,
custom_mask_t,
mask_indptr,
max_len_extend,
) = input_helper(
B,
H,
prefix,
extend,
kv_lora_rank,
qk_rope_head_dim,
v_head_dim,
dtype,
device,
ref_attn_impl,
equal_seqlens=True,
kv_num_heads=kv_num_heads,
)
if custom_mask:
raise NotImplementedError(
"tune v2 with custom_mask requires mask tensors; use custom_mask=False for hcutune key matching"
)
sm_scale = float(1.0 / (q_extend.shape[-1] ** 0.5))
sinks = (
torch.randn(H, dtype=q_extend.dtype, device=device)
if has_sink
else None
)
window_kv_offsets = None
tri_out = torch.empty(
(*q_extend.shape[:-1], v_extend.shape[-1]),
dtype=q_extend.dtype,
device=q_extend.device,
)
def run_once():
extend_attention_fwd_v2(
q_extend,
k_extend,
v_extend,
tri_out,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask_t,
causal,
mask_indptr,
max_len_extend,
sm_scale,
k_scale,
v_scale,
sliding_window_size,
sinks,
window_kv_offsets,
xai_temperature_len,
skip_prefix_custom_mask,
logit_cap,
)
return triton.testing.do_bench_cudagraph(run_once)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--perf", action='store_true', default=False, parser.add_argument("--perf", action="store_true", default=False,
help='benchmark with hcutuner perf mode') help="benchmark with hcutuner perf mode")
parser.add_argument(
"--v1",
action="store_true",
default=False,
help="Tune v1 ``_fwd_kernel`` only. Default: v2 ``_fwd_kernel_v2`` (JSON: ``_get_config_v2`` / EXTEND_ATTENTION-V2-FP16).",
)
return parser.parse_args() return parser.parse_args()
...@@ -417,6 +742,11 @@ if __name__ == "__main__": ...@@ -417,6 +742,11 @@ if __name__ == "__main__":
args = parse_args() args = parse_args()
if args.perf: if args.perf:
os.environ["TRITON_HCUTUNE_PERF_MODE"] = "1" os.environ["TRITON_HCUTUNE_PERF_MODE"] = "1" # idempotent; real set is at top for Hcutuner init
bench_extend_attention.run(print_data=True, save_path='./tune_extend_attention_out') if args.v1:
bench_extend_attention.run(print_data=True, save_path="./tune_extend_attention_out")
else:
bench_extend_attention_v2.run(
print_data=True, save_path="./tune_extend_attention_v2_out"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment