Commit f233de81 authored by Xiaowei.zhang's avatar Xiaowei.zhang
Browse files

[SYNC] Code sync.

parent 1893a1e0
[submodule "3rdparty/composable_kernel"]
path = 3rdparty/composable_kernel
url = ../composable_kernel
branch = rel-5.7.1
branch = main
[submodule "3rdparty/moe_c"]
path = 3rdparty/moe_c
url = ../Moe
branch = W8A8
branch = master
......@@ -1199,3 +1199,72 @@ gfx938,f8_w8a8_block,torch.float16,12288,1536,3072,64,8,0,0,asm,13001+23000,1087
gfx938,f8_w8a8_block,torch.float16,16384,1536,3072,64,8,0,0,asm,13001+23000,14314.7352
gfx938,f8_w8a8_block,torch.float16,24576,1536,3072,64,8,0,0,asm,13001+23000,21336.6809
gfx938,f8_w8a8_block,torch.float16,32768,1536,3072,64,8,0,0,asm,13001+23000,28266.1463
gfx938,f8_w8a8_block,torch.float16,1,256,4096,256,8,0,0,asm,10007+20200,66.0562
gfx938,f8_w8a8_block,torch.float16,2,256,4096,256,8,0,0,asm,10001+20000,83.8161
gfx938,f8_w8a8_block,torch.float16,4,256,4096,256,8,0,0,asm,10002+20000,112.3382
gfx938,f8_w8a8_block,torch.float16,6,256,4096,256,8,0,0,asm,10002+20000,141.2728
gfx938,f8_w8a8_block,torch.float16,8,256,4096,256,8,0,0,asm,10007+20000,165.0033
gfx938,f8_w8a8_block,torch.float16,10,256,4096,256,8,0,0,asm,10002+20000,186.1823
gfx938,f8_w8a8_block,torch.float16,12,256,4096,256,8,0,0,asm,10002+20000,205.2475
gfx938,f8_w8a8_block,torch.float16,14,256,4096,256,8,0,0,asm,10002+20000,226.2159
gfx938,f8_w8a8_block,torch.float16,16,256,4096,256,8,0,0,asm,10002+20000,237.2221
gfx938,f8_w8a8_block,torch.float16,20,256,4096,256,8,0,0,asm,10002+20000,264.3938
gfx938,f8_w8a8_block,torch.float16,24,256,4096,256,8,0,0,asm,10002+20000,293.3708
gfx938,f8_w8a8_block,torch.float16,28,256,4096,256,8,0,0,asm,10002+20000,343.0409
gfx938,f8_w8a8_block,torch.float16,32,256,4096,256,8,0,0,asm,10002+20000,359.6472
gfx938,f8_w8a8_block,torch.float16,36,256,4096,256,8,0,0,asm,10002+20000,367.9137
gfx938,f8_w8a8_block,torch.float16,40,256,4096,256,8,0,0,asm,10001+20000,378.7264
gfx938,f8_w8a8_block,torch.float16,44,256,4096,256,8,0,0,asm,10002+20000,389.2864
gfx938,f8_w8a8_block,torch.float16,48,256,4096,256,8,0,0,asm,10002+20000,398.3053
gfx938,f8_w8a8_block,torch.float16,56,256,4096,256,8,0,0,asm,10002+20000,414.7348
gfx938,f8_w8a8_block,torch.float16,64,256,4096,256,8,0,0,asm,10002+20000,430.7348
gfx938,f8_w8a8_block,torch.float16,80,256,4096,256,8,0,0,asm,10002+20000,454.9452
gfx938,f8_w8a8_block,torch.float16,96,256,4096,256,8,0,0,asm,10002+20000,473.8084
gfx938,f8_w8a8_block,torch.float16,112,256,4096,256,8,0,0,asm,10002+20000,489.219
gfx938,f8_w8a8_block,torch.float16,128,256,4096,256,8,0,0,asm,10002+20000,494.3979
gfx938,f8_w8a8_block,torch.float16,160,256,4096,256,8,0,0,asm,10002+20000,499.5009
gfx938,f8_w8a8_block,torch.float16,192,256,4096,256,8,0,0,asm,10002+20000,511.6526
gfx938,f8_w8a8_block,torch.float16,224,256,4096,256,8,0,0,asm,10002+20000,519.9809
gfx938,f8_w8a8_block,torch.float16,256,256,4096,256,8,0,0,asm,10002+20000,520.7221
gfx938,f8_w8a8_block,torch.float16,320,256,4096,256,8,0,0,asm,10002+20000,535.021
gfx938,f8_w8a8_block,torch.float16,384,256,4096,256,8,0,0,asm,10002+20000,565.7914
gfx938,f8_w8a8_block,torch.float16,448,256,4096,256,8,0,0,asm,10002+20000,598.9198
gfx938,f8_w8a8_block,torch.float16,512,256,4096,256,8,0,0,asm,11007+21000,610.5915
gfx938,f8_w8a8_block,torch.float16,576,256,4096,256,8,0,0,asm,11010+21000,639.2314
gfx938,f8_w8a8_block,torch.float16,640,256,4096,256,8,0,0,asm,11009+21000,631.0208
gfx938,f8_w8a8_block,torch.float16,704,256,4096,256,8,0,0,asm,11006+21000,640.8651
gfx938,f8_w8a8_block,torch.float16,768,256,4096,256,8,0,0,asm,11007+21000,659.7198
gfx938,f8_w8a8_block,torch.float16,832,256,4096,256,8,0,0,asm,11009+21200,658.3555
gfx938,f8_w8a8_block,torch.float16,896,256,4096,256,8,0,0,asm,11010+21000,689.7156
gfx938,f8_w8a8_block,torch.float16,960,256,4096,256,8,0,0,asm,11010+21200,722.3639
gfx938,f8_w8a8_block,torch.float16,1024,256,4096,256,8,0,0,asm,11010+21000,751.2649
gfx938,f8_w8a8_block,torch.float16,1152,256,4096,256,8,0,0,asm,11008+21000,870.549
gfx938,f8_w8a8_block,torch.float16,1280,256,4096,256,8,0,0,asm,12002+22000,867.3911
gfx938,f8_w8a8_block,torch.float16,1408,256,4096,256,8,0,0,asm,12003+22000,875.1133
gfx938,f8_w8a8_block,torch.float16,1536,256,4096,256,8,0,0,asm,12003+22000,902.4816
gfx938,f8_w8a8_block,torch.float16,1664,256,4096,256,8,0,0,asm,12004+22000,926.5489
gfx938,f8_w8a8_block,torch.float16,1792,256,4096,256,8,0,0,asm,12003+22000,942.0857
gfx938,f8_w8a8_block,torch.float16,1920,256,4096,256,8,0,0,asm,12005+22000,1018.5236
gfx938,f8_w8a8_block,torch.float16,2048,256,4096,256,8,0,0,asm,12003+22000,1094.3972
gfx938,f8_w8a8_block,torch.float16,2304,256,4096,256,8,0,0,asm,12005+22000,1257.5887
gfx938,f8_w8a8_block,torch.float16,2560,256,4096,256,8,0,0,asm,11010+21200,1374.7673
gfx938,f8_w8a8_block,torch.float16,2816,256,4096,256,8,0,0,asm,13001+23000,1400.5189
gfx938,f8_w8a8_block,torch.float16,3072,256,4096,256,8,0,0,asm,12005+22000,1439.1798
gfx938,f8_w8a8_block,torch.float16,3328,256,4096,256,8,0,0,asm,12005+22000,1456.0893
gfx938,f8_w8a8_block,torch.float16,3584,256,4096,256,8,0,0,asm,12005+22000,1487.5504
gfx938,f8_w8a8_block,torch.float16,3840,256,4096,256,8,0,0,asm,12005+22000,1595.1459
gfx938,f8_w8a8_block,torch.float16,4096,256,4096,256,8,0,0,asm,12006+22000,1756.2657
gfx938,f8_w8a8_block,torch.float16,4608,256,4096,256,8,0,0,asm,12005+22000,2012.6698
gfx938,f8_w8a8_block,torch.float16,5120,256,4096,256,8,0,0,asm,12005+22000,2134.1435
gfx938,f8_w8a8_block,torch.float16,5632,256,4096,256,8,0,0,asm,12005+22000,2246.8753
gfx938,f8_w8a8_block,torch.float16,6144,256,4096,256,8,0,0,asm,12005+22000,2440.6189
gfx938,f8_w8a8_block,torch.float16,6656,256,4096,256,8,0,0,asm,13001+23000,2562.9843
gfx938,f8_w8a8_block,torch.float16,7168,256,4096,256,8,0,0,asm,13001+23001,2768.8287
gfx938,f8_w8a8_block,torch.float16,7680,256,4096,256,8,0,0,asm,13001+23000,2792.2731
gfx938,f8_w8a8_block,torch.float16,8192,256,4096,256,8,0,0,asm,13001+23000,3082.3107
gfx938,f8_w8a8_block,torch.float16,10240,256,4096,256,8,0,0,asm,13001+23000,3801.2909
gfx938,f8_w8a8_block,torch.float16,12288,256,4096,256,8,0,0,asm,13001+23000,4383.3187
gfx938,f8_w8a8_block,torch.float16,14336,256,4096,256,8,0,0,asm,13001+23000,5030.8137
gfx938,f8_w8a8_block,torch.float16,16384,256,4096,256,8,0,0,asm,13001+23000,5608.4473
gfx938,f8_w8a8_block,torch.float16,17408,256,4096,256,8,0,0,asm,13001+23000,6038.0465
gfx938,f8_w8a8_block,torch.float16,24576,256,4096,256,8,0,0,asm,13001+23000,8143.5178
......@@ -15,7 +15,8 @@ from aiter import silu_and_mul,gelu_and_mul
from aiter.ops.triton.fused_moe import (
triton_moe_sum,
triton_silu_and_mul,
triton_gelu_and_mul
triton_gelu_and_mul,
triton_relu2,
)
from aiter.jit.core import AITER_ROOT_DIR
......@@ -754,8 +755,11 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor,
use_shuffle)
#
else:
# For gated activations (silu/gelu): w1 has 2*inter_dim cols, so inter_dim = N/2
# For non-gated activations (relu2): w1 has inter_dim cols, so inter_dim = N
asm_inter_dim = N/2 if activation in ("silu", "gelu") else N
if solution_id is None:
solution_id = get_moe_asm_solution(arch, tokens_in_chunk, N/2, w1.size(2), E, top_k_num, MoeQuantType.NO_QUANT, use_shuffle)
solution_id = get_moe_asm_solution(arch, tokens_in_chunk, asm_inter_dim, w1.size(2), E, top_k_num, MoeQuantType.NO_QUANT, use_shuffle)
config = decode_sol_w8a8_c(solution_id)
if persist_cu == cu_num:
calculate_persist_groups(persist_cu, config, MoeQuantType.NO_QUANT)
......@@ -767,7 +771,7 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor,
moe_sorting_ck(curr_topk_ids, curr_topk_weights, global_num_experts, model_dim, out_hidden_states[begin_chunk_idx:end_chunk_idx], config["BLOCK_SIZE_M"], expert_map)
)
if print_log():
print(f"Asm Moe Size: chunk:{chunk}, arch:{arch}, quant:{MoeQuantType.NO_QUANT}, tokens:{tokens_in_chunk}, inter_dim:{int(N/2)}, model_dim:{w1.size(2)}, expert:{E}, topk:{top_k_num}")
print(f"Asm Moe Size: chunk:{chunk}, arch:{arch}, quant:{MoeQuantType.NO_QUANT}, tokens:{tokens_in_chunk}, inter_dim:{int(asm_inter_dim)}, model_dim:{w1.size(2)}, expert:{E}, topk:{top_k_num}")
print(f"solution:{solution_id}, shuffle:{use_shuffle}, persist:{persist_cu}")
if solution_id== "default":
print(f">>> Warning: No matching config pattern found, using default asm solution.")
......@@ -797,6 +801,8 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor,
elif activation == "gelu":
triton_gelu_and_mul(d_silu,d_w1_out)
# gelu_and_mul(d_silu,d_w1_out)
elif activation == "relu2":
triton_relu2(d_silu,d_w1_out)
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
......
......@@ -23,6 +23,7 @@ class MoeQuantType:
W16A16 = "w16a16"
W4A16 = "w4a16"
W8A8 = "w8a8"
FP8_W8A8 = "fp8_w8a8"
W4A8 = "w4a8"
......@@ -53,9 +54,9 @@ def _try_get_moe_c_config(
block_size: int,
) -> Optional[Dict[str, Any]]:
try:
if quant_type == MoeQuantType.W4A16:
from .fused_moe_c import get_moe_configs_marlin
if quant_type == MoeQuantType.W4A16:
configs = get_moe_configs_marlin(
E=e,
N=n,
......@@ -64,8 +65,6 @@ def _try_get_moe_c_config(
use_moe_wna16_cuda=True,
)
elif quant_type == MoeQuantType.W8A8:
from .fused_moe_c import get_moe_configs_marlin
configs = get_moe_configs_marlin(
E=e,
N=n,
......@@ -73,9 +72,15 @@ def _try_get_moe_c_config(
is_bottom=False,
use_moe_wna16_cuda=True,
)
elif quant_type == MoeQuantType.FP8_W8A8:
configs = get_moe_configs_marlin(
E=e,
N=n,
dtype="fp8_w8a8",
is_bottom=False,
use_moe_wna16_cuda=True,
)
elif quant_type == MoeQuantType.W4A8:
from .fused_moe_c import get_moe_configs_marlin
configs = get_moe_configs_marlin(
E=e,
N=n,
......@@ -148,6 +153,22 @@ def _try_get_asm_config(
return None
return decode_sol_0(solution)
if quant_type == MoeQuantType.FP8_W8A8:
from .fused_moe_asm_wna16 import decode_sol_0
solution = get_moe_asm_solution(
arch=arch,
token=m,
inter_dim=n,
model_dim=k,
expert=e,
topk=top_k,
quant_type=AsmMoeQuantType.F8_W8A8,
)
if solution == "default":
return None
return decode_sol_0(solution)
if quant_type == MoeQuantType.W16A16:
from .fused_moe_asm_wna16 import decode_sol_0
......@@ -186,6 +207,7 @@ def _try_get_triton_config(
dtype_name = {
MoeQuantType.W4A16: "int4_w4a16",
MoeQuantType.W8A8: "int8_w8a8",
MoeQuantType.FP8_W8A8: "fp8_w8a8",
}.get(quant_type)
if dtype_name is None:
return None
......@@ -216,7 +238,7 @@ def _try_get_ck_config(
block_shape: Optional[List[int]],
) -> Optional[Dict[str, Any]]:
try:
if quant_type != MoeQuantType.W8A8:
if quant_type not in (MoeQuantType.W8A8, MoeQuantType.FP8_W8A8):
return None
from .fused_moe_ck import get_moe_ck_solution_id, MoeQuantType as CkMoeQuantType
......@@ -245,29 +267,43 @@ def _try_get_ck_config(
def get_aiter_moe_config(
M: int, # Number of tokens (input sequence length)
E: int, # Number of experts
N1: int, # GEMM1 output dimension, typically equal to (moe_intermediate_size / TP * 2)
N1: int, # GEMM1 output dimension: gated = (intermediate_size * 2), non-gated = intermediate_size
N2: int, # GEMM2 output dimension, typically equal to hidden_size
K: int, # GEMM1 input dimension, typically equal to hidden_size; for GEMM2, K typically equal to (moe_intermediate_size / TP)
top_k: int,
block_size: int,
dtype: torch.dtype,
quant_type: str,
activation: str = "silu", # "silu"/"gelu"/"relu2"/...
gated: Optional[bool] = None, # True=GLU-gated (N1=2*inter), False=non-gated (N1=inter); None=auto from activation
) -> Tuple[bool, AiterMoeConfig]:
"""Get the best backend config for a MOE problem.
Currently supported quant types:
- ``MoeQuantType.W16A16`` (non-quantized)
- ``MoeQuantType.W4A16``
- ``MoeQuantType.W8A8``
- ``MoeQuantType.W8A8`` (int8)
- ``MoeQuantType.FP8_W8A8`` (fp8)
- ``MoeQuantType.W4A8``
Backend priority:
- ``w16a16``: asm > triton
- ``w4a16``: moe_c > asm > triton
- ``w8a8``: asm > moe_c > triton > ck
- ``fp8_w8a8``: asm > moe_c > triton > ck
- ``w4a8``: moe_c
For non-gated MOE (e.g. Nemotron with ReLU² activation), pass
``gated=False`` (or let it auto-detect from ``activation="relu2"``)
and set ``N1 = intermediate_size`` (not ``2 * intermediate_size``).
"""
n = N1 / 2
# Determine gating: explicit > auto-detect from activation
if gated is None:
gated = activation in ("silu", "gelu")
# For gated (GLU): N1 = 2 * intermediate_size, n = N1 // 2
# For non-gated: N1 = intermediate_size, n = N1
n = N1 // 2 if gated else N1
block_shape = [0, block_size] if block_size else None
if quant_type == MoeQuantType.W4A16:
......@@ -282,7 +318,7 @@ def get_aiter_moe_config(
]
else:
raise ValueError(f"Unsupported dtype: {dtype}")
elif quant_type == MoeQuantType.W8A8:
elif quant_type in (MoeQuantType.W8A8, MoeQuantType.FP8_W8A8):
if block_size == 0: # Channel wise choose MOE_C
candidates = [
(MoeSolutionType.MOE_C, lambda: _try_get_moe_c_config(quant_type, M, E, n, block_size)),
......@@ -348,6 +384,7 @@ def aiter_moe(
use_int4_w4a16 = moe_config.quant_type == MoeQuantType.W4A16
use_int8_w8a8 = moe_config.quant_type == MoeQuantType.W8A8
use_fp8_w8a8 = moe_config.quant_type == MoeQuantType.FP8_W8A8
use_int8_w4a8 = moe_config.quant_type == MoeQuantType.W4A8
if moe_config.solution_type == MoeSolutionType.MOE_C:
......@@ -362,6 +399,7 @@ def aiter_moe(
inplace=inplace,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a8=use_int8_w8a8,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w4a8=use_int8_w4a8,
activation=activation,
global_num_experts=global_num_experts,
......@@ -391,6 +429,7 @@ def aiter_moe(
inplace=inplace,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a8=use_int8_w8a8,
use_fp8_w8a8=use_fp8_w8a8,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
......@@ -409,7 +448,7 @@ def aiter_moe(
from .ops.triton.fused_moe import fused_experts_impl
# W8A8 channel-wise (block_shape=None) requires per_channel_quant=True
per_channel_quant = use_int8_w8a8 and block_shape is None
per_channel_quant = (use_int8_w8a8 or use_fp8_w8a8) and block_shape is None
return fused_experts_impl(
hidden_states,
......@@ -421,6 +460,7 @@ def aiter_moe(
inplace=inplace,
use_int4_w4a16=use_int4_w4a16,
use_int8_w8a8=use_int8_w8a8,
use_fp8_w8a8=use_fp8_w8a8,
activation=activation,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
......@@ -448,6 +488,7 @@ def aiter_moe(
odtype=hidden_states.dtype,
inplace=inplace,
use_int8_w8a8=use_int8_w8a8,
use_fp8_w8a8=use_fp8_w8a8,
activation=activation,
global_num_experts=global_num_experts,
expert_map=expert_map,
......
{
"config": {
"(8, 192, 128, False, True, True, 128)": {
"BLOCK_M": 32,
"BLOCK_N": 64,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 16,
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_ctas": 1,
"num_stages": 1
},
"(16, 192, 128, False, True, False, -1)": {
"BLOCK_M": 32,
"BLOCK_N": 64,
"waves_per_eu": 1,
"matrix_instr_nonkdim": 16,
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_ctas": 1,
"num_stages": 1
}
},
"path": {}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 512,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 512,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 512,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"8192": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"16384": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"32768": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
},
"65536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"8192": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"16384": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"32768": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"65536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"8192": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"16384": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"8192": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 1,
"num_warps": 16,
"num_stages": 2
},
"16384": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"8192": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"16384": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"8192": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 1,
"num_warps": 16,
"num_stages": 2
},
"16384": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
}
}
\ No newline at end of file
This diff is collapsed.
......@@ -322,6 +322,7 @@ def fused_moe_kernel_gptq_awq(
USE_MLS_LOAD: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
USE_ADDR_OFFSET_INT64_A: tl.constexpr,
USE_ADDR_OFFSET_INT64_B: tl.constexpr,
USE_ADDR_OFFSET_INT64_C: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
......@@ -434,17 +435,45 @@ def fused_moe_kernel_gptq_awq(
if use_int4_w4a16:
if group_size_divisible and has_zp:
offs_k_continue = tl.arange(0, BLOCK_SIZE_K // 2).to(tl.int32)
b_ptrs = b_ptr + (off_experts * stride_be + \
offs_bn[:, None] * stride_bn + offs_k_continue[None, :] * \
stride_bk).to(tl.int32)
if USE_ADDR_OFFSET_INT64_B:
b_ptrs = b_ptr + (
off_experts.to(tl.int64) * stride_be
+ offs_bn[:, None].to(tl.int64) * stride_bn
+ offs_k_continue[None, :].to(tl.int64) * stride_bk
)
else:
b_ptrs = b_ptr + (off_experts * stride_be + \
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * \
stride_bn).to(tl.int32)
b_ptrs = b_ptr + (
off_experts * stride_be
+ offs_bn[:, None] * stride_bn
+ offs_k_continue[None, :] * stride_bk
).to(tl.int32)
else:
if USE_ADDR_OFFSET_INT64_B:
b_ptrs = b_ptr + (
off_experts.to(tl.int64) * stride_be
+ (offs_k[:, None].to(tl.int64) // 2) * stride_bk
+ offs_bn[None, :].to(tl.int64) * stride_bn
)
else:
b_ptrs = b_ptr + (
off_experts * stride_be
+ (offs_k[:, None] // 2) * stride_bk
+ offs_bn[None, :] * stride_bn
).to(tl.int32)
b_shifter = (offs_k[:, None] % 2) * 4
elif use_int8_w8a16:
b_ptrs = b_ptr + (off_experts * stride_be + \
offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn).to(tl.int32)
if USE_ADDR_OFFSET_INT64_B:
b_ptrs = b_ptr + (
off_experts.to(tl.int64) * stride_be
+ offs_k[:, None].to(tl.int64) * stride_bk
+ offs_bn[None, :].to(tl.int64) * stride_bn
)
else:
b_ptrs = b_ptr + (
off_experts * stride_be
+ offs_k[:, None] * stride_bk
+ offs_bn[None, :] * stride_bn
).to(tl.int32)
if not has_zp and use_int4_w4a16:
b_zp_num = 8
......@@ -2552,6 +2581,7 @@ def fused_moe(
assert B_zp is None or B_zp.ndim == 3
offset_max = 2**31 - 1
use_addr_offset_int64_a = A.numel() * A.element_size() >= offset_max
use_addr_offset_int64_b = B.numel() * B.element_size() >= offset_max
use_addr_offset_int64_c = C.numel() * C.element_size() >= offset_max
if use_int4_w4a8:
......@@ -2592,6 +2622,7 @@ def fused_moe(
group_size=block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
USE_ADDR_OFFSET_INT64_A=use_addr_offset_int64_a,
USE_ADDR_OFFSET_INT64_B=use_addr_offset_int64_b,
USE_ADDR_OFFSET_INT64_C=use_addr_offset_int64_c,
top_k=top_k,
compute_type=compute_type,
......@@ -2636,6 +2667,7 @@ def fused_moe(
group_size=block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
USE_ADDR_OFFSET_INT64_A=use_addr_offset_int64_a,
USE_ADDR_OFFSET_INT64_B=use_addr_offset_int64_b,
USE_ADDR_OFFSET_INT64_C=use_addr_offset_int64_c,
top_k=top_k,
compute_type=compute_type,
......
......@@ -27,6 +27,7 @@ def input_helper(
attn_impl="absorb",
equal_seqlens=False,
requires_grad=False,
kv_num_heads: int = 1,
):
torch.manual_seed(0)
......@@ -85,9 +86,9 @@ def input_helper(
total_extend, H, Lq, dtype=dtype, device=device
).requires_grad_(requires_grad)
# extend parts
# extend parts (``kv_num_heads`` for GQA: e.g. 2 when log shows ``k_extend [T,2,192]``)
k_extend = torch.randn(
total_extend, 1, Lk, dtype=dtype, device=device
total_extend, kv_num_heads, Lk, dtype=dtype, device=device
).requires_grad_(requires_grad)
v_extend = k_extend[..., :Lv]
......@@ -96,7 +97,7 @@ def input_helper(
# prefix parts
k_buffer = torch.randn(
total_prefix, 1, Lk, dtype=dtype, device=device
total_prefix, kv_num_heads, Lk, dtype=dtype, device=device
).requires_grad_(requires_grad)
v_buffer = k_buffer[..., :Lv]
......@@ -154,12 +155,20 @@ def extend_forward(
causal,
sm_scale=1.0,
logit_cap=0.0,
use_v2: bool = False,
sliding_window_size: int = -1,
sinks=None,
):
"""Same tensors; v1 uses ``k_scale=v_scale=None``, v2 uses ``1.0`` (sglang-style scalars)."""
out = torch.empty(
(*q_extend.shape[:-1], v_extend.shape[-1]),
dtype=q_extend.dtype,
device=q_extend.device,
)
k_scale = v_scale = None
if use_v2:
k_scale = 1.0
v_scale = 1.0
extend_attention.extend_attention_fwd(
q_extend,
k_extend,
......@@ -176,6 +185,14 @@ def extend_forward(
max_len_extend,
sm_scale=sm_scale,
logit_cap=logit_cap,
skip_prefix_custom_mask=True,
config=None,
k_scale=k_scale,
v_scale=v_scale,
sliding_window_size=sliding_window_size,
sinks=sinks,
window_kv_offsets=None,
xai_temperature_len=-1,
)
return out
......@@ -211,13 +228,23 @@ def get_extend_benchmark_configs():
"qk_rope_head_dim",
"v_head_dim",
"attn_impl",
"kv_num_heads",
# 与 ``{arch}-EXTEND_ATTENTION-V2-FP16.json`` 的 want7 对齐(见 extend_attention._get_config_v2)
"is_causal",
"sliding_window_size",
"with_sinks",
]
x_vals_list = [
(2, 16, 1024, 1024, 256, 0, 128, "non-absorb"),
(2, 16, 4096, 4096, 512, 64, 128, "non-absorb"),
(2, 16, 8192, 4096, 512, 64, 128, "non-absorb"),
(2, 16, 8192, 4096, 512, 64, 128, "absorb"),
(2, 16, 16324, 8192, 512, 64, 128, "absorb"),
# (2, 16, 1024, 1024, 256, 0, 128, "non-absorb", 1, False, -1, False),
# (2, 16, 4096, 4096, 512, 64, 128, "non-absorb", 1, False, -1, False),
# (2, 16, 8192, 4096, 512, 64, 128, "non-absorb", 1, False, -1, False),
# (2, 16, 8192, 4096, 512, 64, 128, "absorb", 1, False, -1, False),
# (2, 16, 16324, 8192, 512, 64, 128, "absorb", 1, False, -1, False),
# log 形状 + 命中 BW200B-EXTEND_ATTENTION-V2-FP16.json 两条 key
(2, 16, 4096, 555, 128, 64, 128, "non-absorb", 1, True, -1, False),
(16, 16, 556, 1, 128, 64, 128, "non-absorb", 1, True, -1, False),
(2, 16, 4096, 1024, 128, 64, 128, "non-absorb", 2, True, 128, True),
(16, 16, 556, 1, 128, 64, 128, "non-absorb", 2, True, 128, True),
]
return x_names, x_vals_list
......@@ -232,13 +259,17 @@ def get_prefill_benchmark_configs():
"qk_rope_head_dim",
"v_head_dim",
"attn_impl",
"kv_num_heads",
"is_causal",
"sliding_window_size",
"with_sinks",
]
x_vals_list = [
(2, 16, 0, 1024, 256, 0, 128, "non-absorb"),
(2, 16, 0, 4096, 512, 64, 128, "non-absorb"),
(2, 16, 0, 4096, 512, 64, 128, "non-absorb"),
(2, 16, 0, 4096, 512, 64, 128, "absorb"),
(2, 16, 0, 8192, 512, 64, 128, "absorb"),
(2, 16, 0, 1024, 256, 0, 128, "non-absorb", 1, False, -1, False),
(2, 16, 0, 4096, 512, 64, 128, "non-absorb", 1, False, -1, False),
(2, 16, 0, 4096, 512, 64, 128, "non-absorb", 1, False, -1, False),
(2, 16, 0, 4096, 512, 64, 128, "absorb", 1, False, -1, False),
(2, 16, 0, 8192, 512, 64, 128, "absorb", 1, False, -1, False),
]
return x_names, x_vals_list
......@@ -266,6 +297,10 @@ def model_benchmark_configs(args):
"qk_rope_head_dim",
"v_head_dim",
"attn_impl",
"kv_num_heads",
"is_causal",
"sliding_window_size",
"with_sinks",
]
x_vals_list = []
......@@ -276,7 +311,21 @@ def model_benchmark_configs(args):
extend = args.extend if args.extend else 8192
attn_impl = args.attn_impl if args.attn_impl else "non-absorb"
x_vals_list.append(
(model_name, batch_size, HQ, prefix, extend, 512, 64, 128, attn_impl)
(
model_name,
batch_size,
HQ,
prefix,
extend,
512,
64,
128,
attn_impl,
1,
False,
-1,
False,
)
)
return x_names, x_vals_list
......@@ -296,7 +345,22 @@ def benchmark(args):
elif args.mode == "prefill":
x_names, x_vals_list = get_prefill_benchmark_configs()
line_vals = ["extend_attention_fwd"]
if args.mode == "prefill":
line_vals = ["context_attention_fwd"]
line_names = ["prefill/context"]
styles = [("blue", "-")]
elif args.extend_provider == "v1":
line_vals = ["extend_v1"]
line_names = ["v1 (scale None)"]
styles = [("red", "-")]
elif args.extend_provider == "v2":
line_vals = ["extend_v2"]
line_names = ["v2 (k=v=1.0)"]
styles = [("green", "-")]
else:
line_vals = ["extend_v1", "extend_v2"]
line_names = ["v1 (scale None)", "v2 (k=v=1.0)"]
styles = [("red", "-"), ("green", "-")]
plot_name = (
args.plot_name + f"-causal-{args.causal}-equal_seqlens-{args.equal_seqlens}"
......@@ -308,8 +372,8 @@ def benchmark(args):
x_vals=x_vals_list,
line_arg="provider",
line_vals=line_vals,
line_names=line_vals,
styles=[("red", "-"), ("green", "-")],
line_names=line_names,
styles=styles,
ylabel="ms",
plot_name=plot_name,
args={"sm_scale": 1.0, "logit_cap": 0.0, "device": args.device},
......@@ -317,23 +381,34 @@ def benchmark(args):
)
@triton.testing.perf_report(configs)
def bench_MLA(
B,
H,
prefix,
extend,
kv_lora_rank,
qk_rope_head_dim,
v_head_dim,
attn_impl,
sm_scale,
logit_cap,
device,
provider=None,
model=None,
):
warmup = 25
rep = 100
def bench_MLA(**kwargs):
# perf_report 调用形如 fn(**x_args, provider=..., **bench.args),全部为关键字参数
warmup = 5
rep = 30
provider = kwargs.pop("provider")
sm_scale = kwargs.pop("sm_scale")
logit_cap = kwargs.pop("logit_cap")
device = kwargs.pop("device")
kwargs.pop("model", None)
kv_num_heads = int(kwargs.pop("kv_num_heads", 1))
B = kwargs.pop("B")
H = kwargs.pop("H")
prefix = kwargs.pop("prefix")
extend = kwargs.pop("extend")
kv_lora_rank = kwargs.pop("kv_lora_rank")
qk_rope_head_dim = kwargs.pop("qk_rope_head_dim")
v_head_dim = kwargs.pop("v_head_dim")
attn_impl = kwargs.pop("attn_impl")
row_causal = kwargs.pop("is_causal")
sliding_window_size = int(kwargs.pop("sliding_window_size"))
with_sinks = bool(kwargs.pop("with_sinks"))
if kwargs:
raise ValueError(f"unexpected benchmark kwargs: {kwargs}")
sinks_tensor = None
if with_sinks:
sinks_tensor = torch.zeros(H, device=device, dtype=torch.float32)
(
q_extend,
......@@ -360,11 +435,15 @@ def benchmark(args):
v_head_dim,
dtype,
device,
attn_impl=attn_impl,
equal_seqlens=args.equal_seqlens,
kv_num_heads=kv_num_heads,
)
if provider == "extend_attention_fwd":
if provider in ("extend_v1", "extend_v2"):
use_v2 = provider == "extend_v2"
def extend_attention():
def fn():
return extend_forward(
q_extend,
k_extend,
......@@ -377,33 +456,35 @@ def benchmark(args):
custom_mask,
mask_indptr,
max_len_extend,
args.causal,
row_causal,
sm_scale,
logit_cap,
use_v2=use_v2,
sliding_window_size=sliding_window_size,
sinks=sinks_tensor,
)
def context_attention():
return extend_forward(
elif provider == "context_attention_fwd":
assert (
prefix == 0
), "Prefix length must be 0 for context attention. Try setting -mode prefill."
def fn():
return prefill_forward(
q_extend,
k_extend,
v_extend,
B_Start_Loc,
B_Seqlen,
max_len_extend,
args.causal,
row_causal,
)
if provider == "extend_attention_fwd":
fn = extend_attention
elif provider == "context_attention_fwd":
assert (
prefix == 0
), "Prefix length must be 0 for context attention. Try setting -mode prefill."
fn = context_attention
else:
raise ValueError(f"Unknown provider: {provider}")
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
ms = triton.testing.do_bench_cudagraph(fn)
# ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
......@@ -474,16 +555,16 @@ def parse_args():
default="extend",
help="Mode of the benchmark. Options: extend, prefill",
)
parser.add_argument(
"-extend_provider",
type=str,
default="both",
choices=("both", "v1", "v2"),
help="Which extend_attention path to benchmark: v1 (k_scale=v_scale=None), v2 (1.0), or both. Ignored when -mode prefill.",
)
return parser.parse_args()
arg_to_torch_dtype = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
}
def run_bench(args):
torch.manual_seed(0)
torch.set_default_device(args.device)
......
# Test for get_aiter_moe_config and aiter_moe with W16A16 non-gated ReLU²
# (Nemotron-style MOE: N1 = intermediate_size, activation = relu2)
import torch
import pandas as pd
from typing import Optional, List
from aiter.fused_moe import fused_topk
from aiter import dtypes
from aiter.test_common import checkAllclose, perftest
from aiter.moe import (
get_aiter_moe_config,
aiter_moe,
MoeSolutionType,
MoeQuantType,
)
from aiter.fused_moe_asm_wna16 import fused_experts_asm_impl
from aiter.ops.shuffle import asm_shuffle_weight_b8
import aiter
torch.set_default_device("cuda")
# ---------------------------------------------------------------------------
# Torch reference for non-gated ReLU² MOE
# ---------------------------------------------------------------------------
def torch_moe_relu2(hidden_states, w1, w2, topk_weights, topk_ids):
"""Reference implementation for non-gated ReLU² MOE.
w1: [E, inter_dim, model_dim] (NOT 2*inter_dim)
w2: [E, model_dim, inter_dim]
"""
computeType = torch.float32
dtype = hidden_states.dtype
hidden_states = hidden_states.to(computeType)
w1 = w1.to(computeType)
w2 = w2.to(computeType)
B, D = hidden_states.shape
topk = topk_weights.shape[1]
hidden_states = hidden_states.view(B, -1, D).repeat(1, topk, 1)
out = torch.zeros(
(B, topk, D),
dtype=computeType,
device=hidden_states.device,
)
for E_id in range(w1.shape[0]):
mask = topk_ids == E_id
if mask.sum():
sub_tokens = hidden_states[mask]
# GEMM1
h = sub_tokens @ w1[E_id].T
# ReLU²
h = torch.relu(h) ** 2
# GEMM2
out[mask] = h @ w2[E_id].T
return (out * topk_weights.view(B, -1, 1)).sum(dim=1).to(dtype)
# ---------------------------------------------------------------------------
# Weight preparation helpers (W16A16 non-gated)
# ---------------------------------------------------------------------------
def prepare_w16a16_nogate_inputs(m, k, n, e, topk, dtype):
"""Build all tensors needed to run a non-gated w16a16 MOE test.
Key difference from gated: w1 shape is [E, n, k] instead of [E, 2*n, k].
"""
torch.manual_seed(0)
input_tensor = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, n, k), device="cuda", dtype=dtype) / 2
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 2
score = torch.randn((m, e), device="cuda", dtype=dtype)
w1_shuffle = asm_shuffle_weight_b8(w1, stage=1)
w2_shuffle = asm_shuffle_weight_b8(w2, stage=2)
topk_weights, topk_ids = fused_topk(input_tensor, score, topk, True)
return {
"input": input_tensor,
"w1": w1,
"w2": w2,
"w1_shuffle": w1_shuffle,
"w2_shuffle": w2_shuffle,
"topk_weights": topk_weights,
"topk_ids": topk_ids,
"score": score,
}
# ---------------------------------------------------------------------------
# Test: get_aiter_moe_config (w16a16 non-gated relu2)
# ---------------------------------------------------------------------------
def test_get_config(m, k, n, e, topk, dtype):
"""Test that get_aiter_moe_config returns a valid w16a16 config with
activation='relu2' or gracefully reports no-solution."""
N1 = n # non-gated: N1 = intermediate_size (NOT 2 * intermediate_size)
N2 = k # down / hidden_size
K = k # model dimension
status, moe_cfg = get_aiter_moe_config(
M=m, E=e, N1=N1, N2=N2, K=K,
top_k=topk, block_size=0, dtype=dtype,
quant_type=MoeQuantType.W16A16,
activation="relu2",
gated=False,
)
if status:
assert moe_cfg.solution_type is not None, \
"status=True but solution_type is None"
assert moe_cfg.config is not None, \
"status=True but config is None"
assert moe_cfg.solution_type in (
MoeSolutionType.ASM,
MoeSolutionType.TRITON,
), f"Unexpected solution_type: {moe_cfg.solution_type}"
assert moe_cfg.quant_type == MoeQuantType.W16A16
aiter.logger.info(
f"[get_config_w16a16_nogate] {m=}, {k=}, {n=}, {e=}, {topk=}, "
f"solution={moe_cfg.solution_type}, "
f"config keys={list(moe_cfg.config.keys())}"
)
else:
assert moe_cfg.solution_type is None, \
"status=False but solution_type is not None"
assert moe_cfg.config is None, \
"status=False but config is not None"
aiter.logger.info(
f"[get_config_w16a16_nogate] {m=}, {k=}, {n=}, {e=}, {topk=}, "
f"no solution found (expected on unsupported configs)"
)
return status, moe_cfg
# ---------------------------------------------------------------------------
# Test: aiter_moe end-to-end for w16a16 non-gated relu2
# ---------------------------------------------------------------------------
@perftest(num_warmup=1, num_iters=2)
def _run_torch_ref(hidden_states, w1, w2, topk_weights, topk_ids):
return torch_moe_relu2(hidden_states, w1, w2, topk_weights, topk_ids)
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1)
def _run_aiter_moe_perf(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
moe_config,
inplace,
activation,
w1_scale,
w2_scale,
w1_zp,
w2_zp,
a1_scale,
a2_scale,
block_shape,
global_num_experts,
expert_map,
routed_scaling_factor,
):
return aiter_moe(hidden_states, w1, w2, topk_weights, topk_ids, moe_config, inplace, activation, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape, global_num_experts, expert_map, routed_scaling_factor)
def test_aiter_moe_w16a16_nogate(m, k, n, e, topk, dtype, inplace, routed_scaling_factor):
"""End-to-end: get config -> run aiter_moe with relu2 -> compare with
torch reference."""
N1 = n # non-gated: N1 = intermediate_size
N2 = k
K = k
status, moe_cfg = get_aiter_moe_config(
M=m, E=e, N1=N1, N2=N2, K=K,
top_k=topk, block_size=0, dtype=dtype,
quant_type=MoeQuantType.W16A16,
activation="relu2",
gated=False,
)
if not status:
aiter.logger.info(
f"[aiter_moe_w16a16_nogate] SKIP {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}: "
f"no backend available"
)
return None
backend = moe_cfg.solution_type
aiter.logger.info(
f"[aiter_moe_w16a16_nogate] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
f"backend={backend}"
)
data = prepare_w16a16_nogate_inputs(m, k, n, e, topk, dtype)
# Torch reference
ref_out, _ = _run_torch_ref(
data["input"], data["w1"], data["w2"],
data["topk_weights"], data["topk_ids"],
)
# aiter_moe dispatch with relu2 activation
aiter_us = 1.0
aiter_out, aiter_us = _run_aiter_moe_perf(
hidden_states=data["input"],
w1=data["w1"],
w2=data["w2"],
topk_weights=data["topk_weights"],
topk_ids=data["topk_ids"],
moe_config=moe_cfg,
inplace=inplace,
activation="relu2",
w1_scale=None,
w2_scale=None,
w1_zp=None,
w2_zp=None,
a1_scale=None,
a2_scale=None,
block_shape=None,
global_num_experts=e,
expert_map=None,
routed_scaling_factor=routed_scaling_factor,
)
msg = (f"[aiter_moe_w16a16_nogate] {m=}, {N1=}, {N2=}, {K=}, {e=}, {topk=}, "
f"backend={backend}")
checkAllclose(ref_out, aiter_out, rtol=0.01, atol=0.5, msg=msg)
return {"m": m, "backend": backend, "us": aiter_us}
# ---------------------------------------------------------------------------
# Test: aiter_moe w16a16 non-gated ASM shuffle vs non-shuffle
# ---------------------------------------------------------------------------
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1)
def _run_asm_perf(hidden_states, w1, w2, topk_weights, topk_ids,
dtype, global_num_experts, expert_map):
return fused_experts_asm_impl(
hidden_states, w1, w2, topk_weights, topk_ids, dtype,
activation="relu2",
global_num_experts=global_num_experts,
expert_map=expert_map)
@perftest(num_warmup=10, num_iters=100, num_rotate_args=1)
def _run_asm_shuffle_perf(hidden_states, w1, w2, topk_weights, topk_ids,
dtype, global_num_experts, expert_map):
return fused_experts_asm_impl(
hidden_states, w1, w2, topk_weights, topk_ids, dtype,
activation="relu2",
global_num_experts=global_num_experts,
expert_map=expert_map,
use_shuffle=1)
def test_aiter_moe_w16a16_nogate_shuffle(m, k, n, e, topk, dtype):
"""Test w16a16 non-gated ASM with shuffled weights vs non-shuffled ASM."""
data = prepare_w16a16_nogate_inputs(m, k, n, e, topk, dtype)
try:
asm_out, asm_us = _run_asm_perf(
data["input"], data["w1"], data["w2"],
data["topk_weights"], data["topk_ids"],
dtype, e, None)
except Exception as exc:
aiter.logger.info(
f"[w16a16_nogate_shuffle] SKIP {m=}: ASM not available ({exc})")
return None
shuffle_out, shuffle_us = _run_asm_shuffle_perf(
data["input"], data["w1_shuffle"], data["w2_shuffle"],
data["topk_weights"], data["topk_ids"],
dtype, e, None)
msg = (f"[w16a16_nogate_shuffle] {m=}, {k=}, {n=}, {e=}, {topk=}, "
f"asm_us={asm_us:.2f}, shuffle_us={shuffle_us:.2f}")
checkAllclose(asm_out, shuffle_out, rtol=0.01, atol=0.01, msg=msg)
uplift = asm_us / shuffle_us - 1 if shuffle_us > 0 else 0
return {
"m": m,
"asm_us": asm_us,
"shuffle_us": shuffle_us,
"shuffle_uplift": f"{uplift:.1%}",
}
if __name__ == "__main__":
dtype = dtypes.bf16
# Nemotron-style MoE parameters (non-gated, ReLU²)
e = 256
topk = 8
k = 3072 # model_dim / hidden_size
n = 128 # intermediate_size (NOT multiplied by 2)
inplace = False
routed_scaling_factor = 1.0
# --- Part 1: test get_aiter_moe_config (w16a16 non-gated relu2) ---
aiter.logger.info("=" * 60)
aiter.logger.info("Part 1: Testing get_aiter_moe_config for w16a16 non-gated relu2")
aiter.logger.info("=" * 60)
test_tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
for m in test_tokens:
test_get_config(m, k, n, e, topk, dtype)
# --- Part 2: test aiter_moe end-to-end (w16a16 non-gated relu2) ---
aiter.logger.info("=" * 60)
aiter.logger.info("Part 2: Testing aiter_moe end-to-end for w16a16 non-gated relu2")
aiter.logger.info("=" * 60)
df = []
for m in test_tokens:
ret = test_aiter_moe_w16a16_nogate(m, k, n, e, topk, dtype, inplace, routed_scaling_factor)
if ret is not None:
df.append(ret)
if df:
df = pd.DataFrame(df)
aiter.logger.info(f"aiter_moe non-gated relu2 summary:\n{df}")
# --- Part 3: test ASM shuffle vs non-shuffle (w16a16 non-gated relu2) ---
aiter.logger.info("=" * 60)
aiter.logger.info("Part 3: Testing ASM shuffle vs non-shuffle for w16a16 non-gated relu2")
aiter.logger.info("=" * 60)
df_shuffle = []
for m in test_tokens:
ret = test_aiter_moe_w16a16_nogate_shuffle(m, k, n, e, topk, dtype)
if ret is not None:
df_shuffle.append(ret)
if df_shuffle:
df_shuffle = pd.DataFrame(df_shuffle)
aiter.logger.info(f"shuffle summary (non-gated relu2):\n{df_shuffle}")
# Test for get_aiter_moe_config and aiter_moe with w8a8 channel-wise quantization
import argparse
import torch
import pandas as pd
......@@ -13,6 +14,7 @@ from aiter.moe import (
MoeQuantType,
)
from aiter.ops.shuffle import moe_layout_shuffle_gemm1, moe_layout_shuffle_gemm2
from aiter.ops.quant import pertoken_quant
import aiter
......@@ -74,9 +76,11 @@ def _run_aiter_moe_perf(
)
def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype):
def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type=MoeQuantType.W8A8):
"""Prepare channel-wise quantized w8a8 inputs.
For int8 (W8A8): weights quantized to torch.int8, scales = max_val / 127.
For fp8 (FP8_W8A8): weights quantized to float8 via pertoken_quant.
Scale shape: (e, out_dim, 1) — one scale per output channel.
block_shape is None for channel-wise.
"""
......@@ -86,7 +90,12 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype):
w1_fp = torch.randn((e, 2 * n, k), dtype=dtype, device="cuda")
w2_fp = torch.randn((e, k, n), dtype=dtype, device="cuda")
# Channel-wise quantization: max per output channel (last dim of weight row)
if quant_type == MoeQuantType.FP8_W8A8:
# FP8 channel-wise quantization via pertoken_quant
w1_qweight, w1_scales = pertoken_quant(w1_fp, quant_dtype=dtypes.fp8)
w2_qweight, w2_scales = pertoken_quant(w2_fp, quant_dtype=dtypes.fp8)
else:
# INT8 channel-wise quantization: max per output channel
max_vals_w1 = torch.abs(w1_fp.to(torch.float32)).max(dim=-1, keepdim=True)[0]
max_vals_w1 = max_vals_w1.clamp(min=1e-5)
w1_scales = max_vals_w1 / 127.0 # (e, 2*n, 1)
......@@ -119,7 +128,7 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype):
}
def test_get_config(m, k, n, e, topk, dtype):
def test_get_config(m, k, n, e, topk, dtype, quant_type=MoeQuantType.W8A8):
"""Test get_aiter_moe_config for channel-wise w8a8 (block_size=0)."""
status, moe_cfg = get_aiter_moe_config(
M=m,
......@@ -130,11 +139,12 @@ def test_get_config(m, k, n, e, topk, dtype):
top_k=topk,
block_size=0,
dtype=dtype,
quant_type=MoeQuantType.W8A8,
quant_type=quant_type,
)
tag = f"get_config_{quant_type}_cw"
if status:
assert moe_cfg.quant_type == MoeQuantType.W8A8
assert moe_cfg.quant_type == quant_type
assert moe_cfg.solution_type in (
MoeSolutionType.ASM,
MoeSolutionType.MOE_C,
......@@ -143,19 +153,19 @@ def test_get_config(m, k, n, e, topk, dtype):
)
assert moe_cfg.config is not None
aiter.logger.info(
f"[get_config_w8a8_cw] {m=}, solution={moe_cfg.solution_type}, "
f"[{tag}] {m=}, solution={moe_cfg.solution_type}, "
f"config keys={list(moe_cfg.config.keys())}"
)
else:
assert moe_cfg.solution_type is None
assert moe_cfg.config is None
aiter.logger.info(f"[get_config_w8a8_cw] {m=}, no solution found")
aiter.logger.info(f"[{tag}] {m=}, no solution found")
return status, moe_cfg
def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype):
"""End-to-end test of aiter_moe with channel-wise w8a8."""
def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype, quant_type=MoeQuantType.W8A8):
"""End-to-end test of aiter_moe with channel-wise w8a8 (int8 or fp8)."""
status, moe_cfg = get_aiter_moe_config(
M=m,
E=e,
......@@ -165,14 +175,15 @@ def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype):
top_k=topk,
block_size=0,
dtype=dtype,
quant_type=MoeQuantType.W8A8,
quant_type=quant_type,
)
tag = f"aiter_moe_{quant_type}_cw"
if not status:
aiter.logger.info(f"[aiter_moe_w8a8_cw] SKIP {m=}: no backend available")
aiter.logger.info(f"[{tag}] SKIP {m=}: no backend available")
return None
data = prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype)
data = prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype, quant_type)
# Torch reference uses original fp weights directly (no scales needed)
ref_out, _ = _run_torch_ref(
......@@ -216,31 +227,46 @@ def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype):
print("ref_out",ref_out)
msg = f"[aiter_moe_w8a8_cw] {m=}, backend={moe_cfg.solution_type}"
msg = f"[{tag}] {m=}, backend={moe_cfg.solution_type}"
checkAllclose(ref_out, aiter_out, rtol=0.01, atol=100, msg=msg)
return {"m": m, "backend": moe_cfg.solution_type, "us": aiter_us}
return {"m": m, "quant_type": quant_type, "backend": moe_cfg.solution_type, "us": aiter_us}
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Test aiter_moe with channel-wise w8a8 quantization",
)
parser.add_argument(
"--quant",
choices=["int8", "fp8"],
default="int8",
help="Quantization type: int8 (MoeQuantType.W8A8) or fp8 (MoeQuantType.FP8_W8A8)",
)
args = parser.parse_args()
quant_type = MoeQuantType.FP8_W8A8 if args.quant == "fp8" else MoeQuantType.W8A8
# for moe_c backend, it does not support n=320 for now;
# for triton backend, it can run with n=320 in NMZ;
dtype = dtypes.bf16
e = 256
topk = 8
k = 6144
n = 256
n = 320
aiter.logger.info("=" * 60)
aiter.logger.info("Part 1: Testing get_aiter_moe_config for w8a8 channel-wise")
aiter.logger.info(f"Part 1: Testing get_aiter_moe_config for {quant_type} channel-wise")
aiter.logger.info("=" * 60)
test_tokens = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048 , 4096, 6144 , 8192 , 16384]
for m in test_tokens:
test_get_config(m, k, n, e, topk, dtype)
test_get_config(m, k, n, e, topk, dtype, quant_type)
aiter.logger.info("=" * 60)
aiter.logger.info("Part 2: Testing aiter_moe end-to-end for w8a8 channel-wise")
aiter.logger.info(f"Part 2: Testing aiter_moe end-to-end for {quant_type} channel-wise")
aiter.logger.info("=" * 60)
df = []
for m in test_tokens:
ret = test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype)
ret = test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype, quant_type)
if ret is not None:
df.append(ret)
if df:
......
......@@ -514,6 +514,7 @@ def fused_moe(
assert B_zp is None or B_zp.ndim == 3
offset_max = 2**31 - 1
use_addr_offset_int64_a = A.numel() * A.element_size() >= offset_max
use_addr_offset_int64_b = B.numel() * B.element_size() >= offset_max
use_addr_offset_int64_c = C.numel() * C.element_size() >= offset_max
if use_int4_w4a8:
......@@ -554,6 +555,7 @@ def fused_moe(
group_size=block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
USE_ADDR_OFFSET_INT64_A=use_addr_offset_int64_a,
USE_ADDR_OFFSET_INT64_B=use_addr_offset_int64_b,
USE_ADDR_OFFSET_INT64_C=use_addr_offset_int64_c,
top_k=top_k,
compute_type=compute_type,
......@@ -599,6 +601,7 @@ def fused_moe(
group_size=block_shape[1],
MUL_ROUTED_WEIGHT=mul_routed_weight,
USE_ADDR_OFFSET_INT64_A=use_addr_offset_int64_a,
USE_ADDR_OFFSET_INT64_B=use_addr_offset_int64_b,
USE_ADDR_OFFSET_INT64_C=use_addr_offset_int64_c,
top_k=top_k,
compute_type=compute_type,
......
import os
import sys
os.environ["AMDGCN_USE_BUFFER_OPS"] = "1"
# :class:`Hcutuner` reads ``TRITON_HCUTUNE_PERF_MODE`` in ``__init__``. Module-level
# ``fn`` / ``fn_v2 = triton.utils.hcutune(...)`` runs at import time, *before*
# ``if __name__ == "__main__"``, so ``--perf`` must be applied here (or set in the shell).
if "--perf" in sys.argv:
os.environ["TRITON_HCUTUNE_PERF_MODE"] = "1"
# GPU / ROCm tuning: run **inside** the ``zww_tl_1`` container (not the bare host), e.g.:
# docker exec zww_tl_1 bash -lc 'cd /data/zhouweiwang/aiter/op_tests/triton_autotune && python tune_extend_attention.py --perf'
# ``do_bench`` timing during tuning (smaller => faster iteration; raise for stable prod numbers).
TUNE_DO_BENCH_WARMUP = 5
TUNE_DO_BENCH_REP = 20
import json
import torch
......@@ -8,10 +22,27 @@ import random
import itertools
import argparse
from aiter.ops.triton.extend_attention import _fwd_kernel
from aiter.ops.triton.extend_attention import _fwd_kernel, _fwd_kernel_v2
_is_hip = True
# hcutune key for :func:`_fwd_kernel_v2`. JSON block-size lookup uses :func:`_get_config_v2`
# ``want7`` only; these names add kernel constexprs (``SKIP_PREFIX_CUSTOM_MASK``,
# ``xai_temperature_len``) used for autotune but not in the V2 JSON key.
# Log alignment (e.g. ``fp8_dp2_tp8_415_triton_rocm_nomtp.log`` ~2291): ``kv_group_num = q.size(-2)//k.size(-2)``,
# ``Lq``/``Lv`` last dims, ``USE_CUSTOM_MASK = custom_mask is not None``, ``HAS_SINK = sinks is not None``.
HCUTUNE_KEY_V2 = [
"kv_group_num",
"Lq",
"Lv",
"USE_CUSTOM_MASK",
"IS_CAUSAL",
"SKIP_PREFIX_CUSTOM_MASK",
"HAS_SINK",
"SLIDING_WINDOW_SIZE",
"xai_temperature_len",
]
version = triton.__version__.split(".")
major_version, minor_version = eval(version[0]), eval(version[1])
......@@ -29,6 +60,7 @@ def input_helper(
attn_impl="normal",
equal_seqlens=False,
requires_grad=False,
kv_num_heads: int = 1,
):
torch.manual_seed(0)
......@@ -82,9 +114,9 @@ def input_helper(
total_extend, H, Lq, dtype=dtype, device=device
).requires_grad_(requires_grad)
# extend parts
# extend parts (``kv_num_heads`` for GQA: e.g. 2 when q has 16 heads and kv_group_num is 8)
k_extend = torch.randn(
total_extend, 1, Lk, dtype=dtype, device=device
total_extend, kv_num_heads, Lk, dtype=dtype, device=device
).requires_grad_(requires_grad)
v_extend = k_extend[..., :Lv]
......@@ -93,7 +125,7 @@ def input_helper(
# prefix parts
k_buffer = torch.randn(
total_prefix, 1, Lk, dtype=dtype, device=device
total_prefix, kv_num_heads, Lk, dtype=dtype, device=device
).requires_grad_(requires_grad)
v_buffer = k_buffer[..., :Lv]
......@@ -169,18 +201,42 @@ def generate_configs(config):
return configs_list
# def get_triton_configs():
# config = {
# "BLOCK_M": [16, 32, 64],
# "BLOCK_N": [16, 32, 64],
# "waves_per_eu": [1],
# "num_warps": [4, 8, 16],
# # "instruction_sched_variant": ["none", "llvm-iglp-1", "llvm-iglp-8", "local-prefetch"],
# # "schedule_hint": ["none", "llvm-iglp-8", "local-prefetch"],
# "matrix_instr_nonkdim": [16],
# "num_stages": [1, 2, 3],
# "sched_latency": ["none", "mmac5-ds10"],
# "kpack": [1, 2],
# }
# tt_configs = []
# for c in generate_configs(config):
# num_warps = c['num_warps']
# num_stages = c['num_stages']
# del c['num_warps']
# del c['num_stages']
# tt_configs.append(triton.Config(c, num_warps=num_warps, num_stages=num_stages))
# return tt_configs
def get_triton_configs():
config = {
"BLOCK_M": [16, 32, 64],
"BLOCK_N": [16, 32, 64],
"waves_per_eu": [1],
"num_warps": [4, 8, 16],
"num_warps": [4, 8],
# "instruction_sched_variant": ["none", "llvm-iglp-1", "llvm-iglp-8", "local-prefetch"],
# "schedule_hint": ["none", "llvm-iglp-8", "local-prefetch"],
"matrix_instr_nonkdim": [16],
"num_stages": [1, 2, 3],
"num_stages": [1, 2],
"sched_latency": ["none", "mmac5-ds10"],
"kpack": [1, 2],
"kpack": [1],
}
tt_configs = []
......@@ -193,7 +249,6 @@ def get_triton_configs():
return tt_configs
def prune_configs(configs, nargs, **kwargs):
def _prune(config):
c = config.all_kwargs()
......@@ -216,8 +271,23 @@ key = [
'SKIP_PREFIX_CUSTOM_MASK',
'STORE_TRANSPOSE',
]
fn = triton.utils.hcutune(configs=get_triton_configs(), key=key, perf_debug=True,
prune_configs_by={"early_config_prune": prune_configs})(_fwd_kernel)
fn = triton.utils.hcutune(
configs=get_triton_configs(),
key=key,
perf_debug=True,
prune_configs_by={"early_config_prune": prune_configs},
warmup=TUNE_DO_BENCH_WARMUP,
rep=TUNE_DO_BENCH_REP,
)(_fwd_kernel)
fn_v2 = triton.utils.hcutune(
configs=get_triton_configs(),
key=HCUTUNE_KEY_V2,
perf_debug=True,
prune_configs_by={"early_config_prune": prune_configs},
warmup=TUNE_DO_BENCH_WARMUP,
rep=TUNE_DO_BENCH_REP,
)(_fwd_kernel_v2)
def extend_attention_fwd(
......@@ -329,6 +399,139 @@ def extend_attention_fwd(
)
def extend_attention_fwd_v2(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask,
is_causal,
mask_indptr,
max_len_extend,
sm_scale,
k_scale,
v_scale,
sliding_window_size,
sinks,
window_kv_offsets,
xai_temperature_len,
skip_prefix_custom_mask,
logit_cap,
):
"""Launch :data:`fn_v2` (hcutune-wrapped :func:`_fwd_kernel_v2`); kwargs align with ``extend_attention_fwd`` v2 path."""
Lq, Lv = q_extend.shape[-1], v_extend.shape[-1]
if Lq == 576:
BLOCK_DMODEL, BLOCK_DPE = 512, 64
elif Lq == 288:
BLOCK_DMODEL, BLOCK_DPE = 256, 32
elif Lq == 192:
BLOCK_DMODEL, BLOCK_DPE = 128, 64
else:
BLOCK_DMODEL = triton.next_power_of_2(Lq)
BLOCK_DPE = 0
BLOCK_DV = triton.next_power_of_2(Lv)
sm_scale = sm_scale or 1.0 / (Lq**0.5)
batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1]
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
USE_CUSTOM_MASK = custom_mask is not None
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
HAS_SINK = sinks is not None
if not USE_CUSTOM_MASK:
custom_mask = torch.tensor([0], dtype=torch.bool, device=q_extend.device)
mask_indptr = torch.tensor([0], dtype=torch.int32, device=q_extend.device)
grid = lambda META: (
batch_size,
head_num,
triton.cdiv(max_len_extend, META["BLOCK_M"]),
)
stride_args = (
q_extend.stride(0),
q_extend.stride(1),
k_extend.stride(0),
k_extend.stride(1),
v_extend.stride(0),
v_extend.stride(1),
o_extend.stride(0),
o_extend.stride(1),
k_buffer.stride(0),
k_buffer.stride(1),
v_buffer.stride(0),
v_buffer.stride(1),
)
fn_v2[grid](
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask,
mask_indptr,
sinks,
window_kv_offsets,
sm_scale,
k_scale,
v_scale,
kv_group_num,
*stride_args,
SLIDING_WINDOW_SIZE=sliding_window_size,
logit_cap=logit_cap,
xai_temperature_len=xai_temperature_len,
HAS_SINK=HAS_SINK,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV,
Lq=Lq,
Lv=Lv,
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
IS_CAUSAL=is_causal,
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
STORE_TRANSPOSE=True,
)
def get_bench_inputs_v2():
"""Cases for :func:`_fwd_kernel_v2` / ``_get_config_v2`` keys.
对 ``fp8_dp2_tp8_415_triton_rocm_nomtp.log`` 逐条核对后,与 extend 相关的**键控组合**只有两类:
- **MHA**(``k_extend[...,1,...]``):``sliding_window_size=-1``,``sinks=None``(如 log 中 ``k_extend [1152,1,192]`` / ``[3952,1,192]`` 段)。
- **GQA**(``k_extend[...,2,...]``):``sliding_window_size=128``,``sinks shape [16]``(全文件未出现 GQA 与 ``-1``/无 sinks 的同框记录)。
该 log 中 **未出现** ``sliding_window_size=64``,也未出现「GQA + 无 SWA + 无 sinks」;若需覆盖其它模型再单独加行并注明来源。
"""
names = [
"B",
"H",
"prefix",
"extend",
"kv_lora_rank",
"qk_rope_head_dim",
"v_head_dim",
"causal",
"custom_mask",
"sliding_window_size",
"has_sink",
"kv_num_heads",
]
vals = [
# (prefix, extend) 只影响访存/grid;want7 由 head/dim/SWA/sinks 决定。prefix=8192、extend=1024 与长 KV bench 习惯一致。
# (1) GQA Q16/KV2:与 log 中 ``[3952,2,192]`` + ``sliding_window_size=128`` + ``sinks [16]`` 一致;want7 (8,192,128,F,T,T,128)
(4, 16, 8192, 1024, 128, 64, 128, True, False, 128, True, 2),
# (2) MHA Q16/KV1:与 log 中 ``[...,1,192]`` + ``sliding_window_size=-1`` + ``sinks=None`` 一致;want7 (16,192,128,F,T,F,-1)
(4, 16, 8192, 1024, 128, 64, 128, True, False, -1, False, 1),
]
return names, vals
x_names, x_vals = get_bench_inputs()
configs = [
triton.testing.Benchmark(
......@@ -401,14 +604,136 @@ def bench_extend_attention(B, H, prefix, extend, kv_lora_rank, qk_rope_head_dim,
sm_scale=sm_scale,
logit_cap=logit_cap,
)
return triton.testing.do_bench(fn)
return triton.testing.do_bench_cudagraph(fn)
x_names_v2, x_vals_v2 = get_bench_inputs_v2()
configs_v2 = [
triton.testing.Benchmark(
x_names=x_names_v2,
x_vals=x_vals_v2,
line_arg="provider",
line_vals=["triton_v2"],
line_names=["triton_v2"],
styles=[("blue", "-")],
ylabel="ms",
plot_name="extend_attention_v2_hcutune",
args={"dtype": torch.bfloat16},
)
]
@triton.utils.dist_perf_report(configs_v2)
def bench_extend_attention_v2(
B,
H,
prefix,
extend,
kv_lora_rank,
qk_rope_head_dim,
v_head_dim,
causal,
custom_mask,
sliding_window_size,
has_sink,
kv_num_heads,
provider,
dtype,
):
torch.manual_seed(42)
device = "cpu" if os.getenv("TRITON_HCUTUNE_COMPILE_ONLY", "") == "1" else "cuda"
ref_attn_impl = "normal"
logit_cap = 0.0
k_scale = 1.0
v_scale = 1.0
xai_temperature_len = -1
skip_prefix_custom_mask = True
(
q_extend,
k_extend,
v_extend,
k_buffer,
v_buffer,
kv_indptr,
kv_indices,
qo_indptr,
custom_mask_t,
mask_indptr,
max_len_extend,
) = input_helper(
B,
H,
prefix,
extend,
kv_lora_rank,
qk_rope_head_dim,
v_head_dim,
dtype,
device,
ref_attn_impl,
equal_seqlens=True,
kv_num_heads=kv_num_heads,
)
if custom_mask:
raise NotImplementedError(
"tune v2 with custom_mask requires mask tensors; use custom_mask=False for hcutune key matching"
)
sm_scale = float(1.0 / (q_extend.shape[-1] ** 0.5))
sinks = (
torch.randn(H, dtype=q_extend.dtype, device=device)
if has_sink
else None
)
window_kv_offsets = None
tri_out = torch.empty(
(*q_extend.shape[:-1], v_extend.shape[-1]),
dtype=q_extend.dtype,
device=q_extend.device,
)
def run_once():
extend_attention_fwd_v2(
q_extend,
k_extend,
v_extend,
tri_out,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask_t,
causal,
mask_indptr,
max_len_extend,
sm_scale,
k_scale,
v_scale,
sliding_window_size,
sinks,
window_kv_offsets,
xai_temperature_len,
skip_prefix_custom_mask,
logit_cap,
)
return triton.testing.do_bench_cudagraph(run_once)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--perf", action='store_true', default=False,
help='benchmark with hcutuner perf mode')
parser.add_argument("--perf", action="store_true", default=False,
help="benchmark with hcutuner perf mode")
parser.add_argument(
"--v1",
action="store_true",
default=False,
help="Tune v1 ``_fwd_kernel`` only. Default: v2 ``_fwd_kernel_v2`` (JSON: ``_get_config_v2`` / EXTEND_ATTENTION-V2-FP16).",
)
return parser.parse_args()
......@@ -417,6 +742,11 @@ if __name__ == "__main__":
args = parse_args()
if args.perf:
os.environ["TRITON_HCUTUNE_PERF_MODE"] = "1"
os.environ["TRITON_HCUTUNE_PERF_MODE"] = "1" # idempotent; real set is at top for Hcutuner init
bench_extend_attention.run(print_data=True, save_path='./tune_extend_attention_out')
if args.v1:
bench_extend_attention.run(print_data=True, save_path="./tune_extend_attention_out")
else:
bench_extend_attention_v2.run(
print_data=True, save_path="./tune_extend_attention_v2_out"
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment