Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
aiter
Commits
f233de81
Commit
f233de81
authored
Apr 17, 2026
by
Xiaowei.zhang
Browse files
[SYNC] Code sync.
parent
1893a1e0
Changes
23
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2912 additions
and
155 deletions
+2912
-155
.gitmodules
.gitmodules
+2
-2
aiter/configs/tuned_fmoe_asm_w8a8_group_shuffle.csv
aiter/configs/tuned_fmoe_asm_w8a8_group_shuffle.csv
+69
-0
aiter/fused_moe_asm_wna16.py
aiter/fused_moe_asm_wna16.py
+9
-3
aiter/moe.py
aiter/moe.py
+53
-12
aiter/ops/triton/configs/BW200B-EXTEND_ATTENTION-V2-FP16.json
...r/ops/triton/configs/BW200B-EXTEND_ATTENTION-V2-FP16.json
+27
-0
aiter/ops/triton/configs/moe/E=256,N=320,device_name=BW200B,dtype=fp8_w8a8,is_bottom=True.json
...320,device_name=BW200B,dtype=fp8_w8a8,is_bottom=True.json
+236
-0
aiter/ops/triton/configs/moe/E=256,N=320,device_name=BW200B,dtype=fp8_w8a8.json
...gs/moe/E=256,N=320,device_name=BW200B,dtype=fp8_w8a8.json
+236
-0
aiter/ops/triton/configs/moe/E=384,N=1024,device_name=BW200B,dtype=int4_w4a16,is_bottom=True.json
...4,device_name=BW200B,dtype=int4_w4a16,is_bottom=True.json
+210
-0
aiter/ops/triton/configs/moe/E=384,N=1024,device_name=BW200B,dtype=int4_w4a16.json
...moe/E=384,N=1024,device_name=BW200B,dtype=int4_w4a16.json
+210
-0
aiter/ops/triton/configs/moe/E=384,N=2048,device_name=BW200B,dtype=int4_w4a16,is_bottom=True.json
...8,device_name=BW200B,dtype=int4_w4a16,is_bottom=True.json
+210
-0
aiter/ops/triton/configs/moe/E=384,N=2048,device_name=BW200B,dtype=int4_w4a16.json
...moe/E=384,N=2048,device_name=BW200B,dtype=int4_w4a16.json
+210
-0
aiter/ops/triton/configs/sage_attention/_attn_fwd-device=gfx936cu_80-dtype=f16_f16_f16_f32_f32_f16_f32.json
...evice=gfx936cu_80-dtype=f16_f16_f16_f32_f32_f16_f32.json
+0
-0
aiter/ops/triton/configs/sage_attention/quant_per_block_int8_kernel-device=gfx936cu_80-dtype=f16_i8_f32.json
...ock_int8_kernel-device=gfx936cu_80-dtype=f16_i8_f32.json
+0
-0
aiter/ops/triton/extend_attention.py
aiter/ops/triton/extend_attention.py
+521
-31
aiter/ops/triton/moe_op.py
aiter/ops/triton/moe_op.py
+40
-8
op_tests/op_benchmarks/triton/bench_extend_attention.py
op_tests/op_benchmarks/triton/bench_extend_attention.py
+138
-57
op_tests/test_aiter_moe_with_config_w16a16_nogate.py
op_tests/test_aiter_moe_with_config_w16a16_nogate.py
+340
-0
op_tests/test_aiter_moe_with_config_w8a8_channelwise.py
op_tests/test_aiter_moe_with_config_w8a8_channelwise.py
+53
-27
op_tests/triton_autotune/fused_moe/autotune_patches.py
op_tests/triton_autotune/fused_moe/autotune_patches.py
+3
-0
op_tests/triton_autotune/tune_extend_attention.py
op_tests/triton_autotune/tune_extend_attention.py
+345
-15
No files found.
.gitmodules
View file @
f233de81
[submodule "3rdparty/composable_kernel"]
path = 3rdparty/composable_kernel
url = ../composable_kernel
branch =
rel-5.7.1
branch =
main
[submodule "3rdparty/moe_c"]
path = 3rdparty/moe_c
url = ../Moe
branch =
W8A8
branch =
master
aiter/configs/tuned_fmoe_asm_w8a8_group_shuffle.csv
View file @
f233de81
...
...
@@ -1199,3 +1199,72 @@ gfx938,f8_w8a8_block,torch.float16,12288,1536,3072,64,8,0,0,asm,13001+23000,1087
gfx938,f8_w8a8_block,torch.float16,16384,1536,3072,64,8,0,0,asm,13001+23000,14314.7352
gfx938,f8_w8a8_block,torch.float16,24576,1536,3072,64,8,0,0,asm,13001+23000,21336.6809
gfx938,f8_w8a8_block,torch.float16,32768,1536,3072,64,8,0,0,asm,13001+23000,28266.1463
gfx938,f8_w8a8_block,torch.float16,1,256,4096,256,8,0,0,asm,10007+20200,66.0562
gfx938,f8_w8a8_block,torch.float16,2,256,4096,256,8,0,0,asm,10001+20000,83.8161
gfx938,f8_w8a8_block,torch.float16,4,256,4096,256,8,0,0,asm,10002+20000,112.3382
gfx938,f8_w8a8_block,torch.float16,6,256,4096,256,8,0,0,asm,10002+20000,141.2728
gfx938,f8_w8a8_block,torch.float16,8,256,4096,256,8,0,0,asm,10007+20000,165.0033
gfx938,f8_w8a8_block,torch.float16,10,256,4096,256,8,0,0,asm,10002+20000,186.1823
gfx938,f8_w8a8_block,torch.float16,12,256,4096,256,8,0,0,asm,10002+20000,205.2475
gfx938,f8_w8a8_block,torch.float16,14,256,4096,256,8,0,0,asm,10002+20000,226.2159
gfx938,f8_w8a8_block,torch.float16,16,256,4096,256,8,0,0,asm,10002+20000,237.2221
gfx938,f8_w8a8_block,torch.float16,20,256,4096,256,8,0,0,asm,10002+20000,264.3938
gfx938,f8_w8a8_block,torch.float16,24,256,4096,256,8,0,0,asm,10002+20000,293.3708
gfx938,f8_w8a8_block,torch.float16,28,256,4096,256,8,0,0,asm,10002+20000,343.0409
gfx938,f8_w8a8_block,torch.float16,32,256,4096,256,8,0,0,asm,10002+20000,359.6472
gfx938,f8_w8a8_block,torch.float16,36,256,4096,256,8,0,0,asm,10002+20000,367.9137
gfx938,f8_w8a8_block,torch.float16,40,256,4096,256,8,0,0,asm,10001+20000,378.7264
gfx938,f8_w8a8_block,torch.float16,44,256,4096,256,8,0,0,asm,10002+20000,389.2864
gfx938,f8_w8a8_block,torch.float16,48,256,4096,256,8,0,0,asm,10002+20000,398.3053
gfx938,f8_w8a8_block,torch.float16,56,256,4096,256,8,0,0,asm,10002+20000,414.7348
gfx938,f8_w8a8_block,torch.float16,64,256,4096,256,8,0,0,asm,10002+20000,430.7348
gfx938,f8_w8a8_block,torch.float16,80,256,4096,256,8,0,0,asm,10002+20000,454.9452
gfx938,f8_w8a8_block,torch.float16,96,256,4096,256,8,0,0,asm,10002+20000,473.8084
gfx938,f8_w8a8_block,torch.float16,112,256,4096,256,8,0,0,asm,10002+20000,489.219
gfx938,f8_w8a8_block,torch.float16,128,256,4096,256,8,0,0,asm,10002+20000,494.3979
gfx938,f8_w8a8_block,torch.float16,160,256,4096,256,8,0,0,asm,10002+20000,499.5009
gfx938,f8_w8a8_block,torch.float16,192,256,4096,256,8,0,0,asm,10002+20000,511.6526
gfx938,f8_w8a8_block,torch.float16,224,256,4096,256,8,0,0,asm,10002+20000,519.9809
gfx938,f8_w8a8_block,torch.float16,256,256,4096,256,8,0,0,asm,10002+20000,520.7221
gfx938,f8_w8a8_block,torch.float16,320,256,4096,256,8,0,0,asm,10002+20000,535.021
gfx938,f8_w8a8_block,torch.float16,384,256,4096,256,8,0,0,asm,10002+20000,565.7914
gfx938,f8_w8a8_block,torch.float16,448,256,4096,256,8,0,0,asm,10002+20000,598.9198
gfx938,f8_w8a8_block,torch.float16,512,256,4096,256,8,0,0,asm,11007+21000,610.5915
gfx938,f8_w8a8_block,torch.float16,576,256,4096,256,8,0,0,asm,11010+21000,639.2314
gfx938,f8_w8a8_block,torch.float16,640,256,4096,256,8,0,0,asm,11009+21000,631.0208
gfx938,f8_w8a8_block,torch.float16,704,256,4096,256,8,0,0,asm,11006+21000,640.8651
gfx938,f8_w8a8_block,torch.float16,768,256,4096,256,8,0,0,asm,11007+21000,659.7198
gfx938,f8_w8a8_block,torch.float16,832,256,4096,256,8,0,0,asm,11009+21200,658.3555
gfx938,f8_w8a8_block,torch.float16,896,256,4096,256,8,0,0,asm,11010+21000,689.7156
gfx938,f8_w8a8_block,torch.float16,960,256,4096,256,8,0,0,asm,11010+21200,722.3639
gfx938,f8_w8a8_block,torch.float16,1024,256,4096,256,8,0,0,asm,11010+21000,751.2649
gfx938,f8_w8a8_block,torch.float16,1152,256,4096,256,8,0,0,asm,11008+21000,870.549
gfx938,f8_w8a8_block,torch.float16,1280,256,4096,256,8,0,0,asm,12002+22000,867.3911
gfx938,f8_w8a8_block,torch.float16,1408,256,4096,256,8,0,0,asm,12003+22000,875.1133
gfx938,f8_w8a8_block,torch.float16,1536,256,4096,256,8,0,0,asm,12003+22000,902.4816
gfx938,f8_w8a8_block,torch.float16,1664,256,4096,256,8,0,0,asm,12004+22000,926.5489
gfx938,f8_w8a8_block,torch.float16,1792,256,4096,256,8,0,0,asm,12003+22000,942.0857
gfx938,f8_w8a8_block,torch.float16,1920,256,4096,256,8,0,0,asm,12005+22000,1018.5236
gfx938,f8_w8a8_block,torch.float16,2048,256,4096,256,8,0,0,asm,12003+22000,1094.3972
gfx938,f8_w8a8_block,torch.float16,2304,256,4096,256,8,0,0,asm,12005+22000,1257.5887
gfx938,f8_w8a8_block,torch.float16,2560,256,4096,256,8,0,0,asm,11010+21200,1374.7673
gfx938,f8_w8a8_block,torch.float16,2816,256,4096,256,8,0,0,asm,13001+23000,1400.5189
gfx938,f8_w8a8_block,torch.float16,3072,256,4096,256,8,0,0,asm,12005+22000,1439.1798
gfx938,f8_w8a8_block,torch.float16,3328,256,4096,256,8,0,0,asm,12005+22000,1456.0893
gfx938,f8_w8a8_block,torch.float16,3584,256,4096,256,8,0,0,asm,12005+22000,1487.5504
gfx938,f8_w8a8_block,torch.float16,3840,256,4096,256,8,0,0,asm,12005+22000,1595.1459
gfx938,f8_w8a8_block,torch.float16,4096,256,4096,256,8,0,0,asm,12006+22000,1756.2657
gfx938,f8_w8a8_block,torch.float16,4608,256,4096,256,8,0,0,asm,12005+22000,2012.6698
gfx938,f8_w8a8_block,torch.float16,5120,256,4096,256,8,0,0,asm,12005+22000,2134.1435
gfx938,f8_w8a8_block,torch.float16,5632,256,4096,256,8,0,0,asm,12005+22000,2246.8753
gfx938,f8_w8a8_block,torch.float16,6144,256,4096,256,8,0,0,asm,12005+22000,2440.6189
gfx938,f8_w8a8_block,torch.float16,6656,256,4096,256,8,0,0,asm,13001+23000,2562.9843
gfx938,f8_w8a8_block,torch.float16,7168,256,4096,256,8,0,0,asm,13001+23001,2768.8287
gfx938,f8_w8a8_block,torch.float16,7680,256,4096,256,8,0,0,asm,13001+23000,2792.2731
gfx938,f8_w8a8_block,torch.float16,8192,256,4096,256,8,0,0,asm,13001+23000,3082.3107
gfx938,f8_w8a8_block,torch.float16,10240,256,4096,256,8,0,0,asm,13001+23000,3801.2909
gfx938,f8_w8a8_block,torch.float16,12288,256,4096,256,8,0,0,asm,13001+23000,4383.3187
gfx938,f8_w8a8_block,torch.float16,14336,256,4096,256,8,0,0,asm,13001+23000,5030.8137
gfx938,f8_w8a8_block,torch.float16,16384,256,4096,256,8,0,0,asm,13001+23000,5608.4473
gfx938,f8_w8a8_block,torch.float16,17408,256,4096,256,8,0,0,asm,13001+23000,6038.0465
gfx938,f8_w8a8_block,torch.float16,24576,256,4096,256,8,0,0,asm,13001+23000,8143.5178
aiter/fused_moe_asm_wna16.py
View file @
f233de81
...
...
@@ -15,7 +15,8 @@ from aiter import silu_and_mul,gelu_and_mul
from
aiter.ops.triton.fused_moe
import
(
triton_moe_sum
,
triton_silu_and_mul
,
triton_gelu_and_mul
triton_gelu_and_mul
,
triton_relu2
,
)
from
aiter.jit.core
import
AITER_ROOT_DIR
...
...
@@ -754,8 +755,11 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor,
use_shuffle
)
#
else
:
# For gated activations (silu/gelu): w1 has 2*inter_dim cols, so inter_dim = N/2
# For non-gated activations (relu2): w1 has inter_dim cols, so inter_dim = N
asm_inter_dim
=
N
/
2
if
activation
in
(
"silu"
,
"gelu"
)
else
N
if
solution_id
is
None
:
solution_id
=
get_moe_asm_solution
(
arch
,
tokens_in_chunk
,
N
/
2
,
w1
.
size
(
2
),
E
,
top_k_num
,
MoeQuantType
.
NO_QUANT
,
use_shuffle
)
solution_id
=
get_moe_asm_solution
(
arch
,
tokens_in_chunk
,
asm_inter_dim
,
w1
.
size
(
2
),
E
,
top_k_num
,
MoeQuantType
.
NO_QUANT
,
use_shuffle
)
config
=
decode_sol_w8a8_c
(
solution_id
)
if
persist_cu
==
cu_num
:
calculate_persist_groups
(
persist_cu
,
config
,
MoeQuantType
.
NO_QUANT
)
...
...
@@ -767,7 +771,7 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor,
moe_sorting_ck
(
curr_topk_ids
,
curr_topk_weights
,
global_num_experts
,
model_dim
,
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
config
[
"BLOCK_SIZE_M"
],
expert_map
)
)
if
print_log
():
print
(
f
"Asm Moe Size: chunk:
{
chunk
}
, arch:
{
arch
}
, quant:
{
MoeQuantType
.
NO_QUANT
}
, tokens:
{
tokens_in_chunk
}
, inter_dim:
{
int
(
N
/
2
)
}
, model_dim:
{
w1
.
size
(
2
)
}
, expert:
{
E
}
, topk:
{
top_k_num
}
"
)
print
(
f
"Asm Moe Size: chunk:
{
chunk
}
, arch:
{
arch
}
, quant:
{
MoeQuantType
.
NO_QUANT
}
, tokens:
{
tokens_in_chunk
}
, inter_dim:
{
int
(
asm_inter_dim
)
}
, model_dim:
{
w1
.
size
(
2
)
}
, expert:
{
E
}
, topk:
{
top_k_num
}
"
)
print
(
f
"solution:
{
solution_id
}
, shuffle:
{
use_shuffle
}
, persist:
{
persist_cu
}
"
)
if
solution_id
==
"default"
:
print
(
f
">>> Warning: No matching config pattern found, using default asm solution."
)
...
...
@@ -797,6 +801,8 @@ def fused_experts_asm_impl(hidden_states: torch.Tensor,
elif
activation
==
"gelu"
:
triton_gelu_and_mul
(
d_silu
,
d_w1_out
)
# gelu_and_mul(d_silu,d_w1_out)
elif
activation
==
"relu2"
:
triton_relu2
(
d_silu
,
d_w1_out
)
else
:
raise
ValueError
(
f
"Unsupported FusedMoe activation:
{
activation
}
"
)
...
...
aiter/moe.py
View file @
f233de81
...
...
@@ -23,6 +23,7 @@ class MoeQuantType:
W16A16
=
"w16a16"
W4A16
=
"w4a16"
W8A8
=
"w8a8"
FP8_W8A8
=
"fp8_w8a8"
W4A8
=
"w4a8"
...
...
@@ -53,9 +54,9 @@ def _try_get_moe_c_config(
block_size
:
int
,
)
->
Optional
[
Dict
[
str
,
Any
]]:
try
:
if
quant_type
==
MoeQuantType
.
W4A16
:
from
.fused_moe_c
import
get_moe_configs_marlin
if
quant_type
==
MoeQuantType
.
W4A16
:
configs
=
get_moe_configs_marlin
(
E
=
e
,
N
=
n
,
...
...
@@ -64,8 +65,6 @@ def _try_get_moe_c_config(
use_moe_wna16_cuda
=
True
,
)
elif
quant_type
==
MoeQuantType
.
W8A8
:
from
.fused_moe_c
import
get_moe_configs_marlin
configs
=
get_moe_configs_marlin
(
E
=
e
,
N
=
n
,
...
...
@@ -73,9 +72,15 @@ def _try_get_moe_c_config(
is_bottom
=
False
,
use_moe_wna16_cuda
=
True
,
)
elif
quant_type
==
MoeQuantType
.
FP8_W8A8
:
configs
=
get_moe_configs_marlin
(
E
=
e
,
N
=
n
,
dtype
=
"fp8_w8a8"
,
is_bottom
=
False
,
use_moe_wna16_cuda
=
True
,
)
elif
quant_type
==
MoeQuantType
.
W4A8
:
from
.fused_moe_c
import
get_moe_configs_marlin
configs
=
get_moe_configs_marlin
(
E
=
e
,
N
=
n
,
...
...
@@ -148,6 +153,22 @@ def _try_get_asm_config(
return
None
return
decode_sol_0
(
solution
)
if
quant_type
==
MoeQuantType
.
FP8_W8A8
:
from
.fused_moe_asm_wna16
import
decode_sol_0
solution
=
get_moe_asm_solution
(
arch
=
arch
,
token
=
m
,
inter_dim
=
n
,
model_dim
=
k
,
expert
=
e
,
topk
=
top_k
,
quant_type
=
AsmMoeQuantType
.
F8_W8A8
,
)
if
solution
==
"default"
:
return
None
return
decode_sol_0
(
solution
)
if
quant_type
==
MoeQuantType
.
W16A16
:
from
.fused_moe_asm_wna16
import
decode_sol_0
...
...
@@ -186,6 +207,7 @@ def _try_get_triton_config(
dtype_name
=
{
MoeQuantType
.
W4A16
:
"int4_w4a16"
,
MoeQuantType
.
W8A8
:
"int8_w8a8"
,
MoeQuantType
.
FP8_W8A8
:
"fp8_w8a8"
,
}.
get
(
quant_type
)
if
dtype_name
is
None
:
return
None
...
...
@@ -216,7 +238,7 @@ def _try_get_ck_config(
block_shape
:
Optional
[
List
[
int
]],
)
->
Optional
[
Dict
[
str
,
Any
]]:
try
:
if
quant_type
!=
MoeQuantType
.
W8A8
:
if
quant_type
not
in
(
MoeQuantType
.
W8A8
,
MoeQuantType
.
FP8_
W8A8
)
:
return
None
from
.fused_moe_ck
import
get_moe_ck_solution_id
,
MoeQuantType
as
CkMoeQuantType
...
...
@@ -245,29 +267,43 @@ def _try_get_ck_config(
def
get_aiter_moe_config
(
M
:
int
,
# Number of tokens (input sequence length)
E
:
int
,
# Number of experts
N1
:
int
,
# GEMM1 output dimension
, typically equal to (moe_intermediate_size / TP * 2)
N1
:
int
,
# GEMM1 output dimension
: gated = (intermediate_size * 2), non-gated = intermediate_size
N2
:
int
,
# GEMM2 output dimension, typically equal to hidden_size
K
:
int
,
# GEMM1 input dimension, typically equal to hidden_size; for GEMM2, K typically equal to (moe_intermediate_size / TP)
top_k
:
int
,
block_size
:
int
,
dtype
:
torch
.
dtype
,
quant_type
:
str
,
activation
:
str
=
"silu"
,
# "silu"/"gelu"/"relu2"/...
gated
:
Optional
[
bool
]
=
None
,
# True=GLU-gated (N1=2*inter), False=non-gated (N1=inter); None=auto from activation
)
->
Tuple
[
bool
,
AiterMoeConfig
]:
"""Get the best backend config for a MOE problem.
Currently supported quant types:
- ``MoeQuantType.W16A16`` (non-quantized)
- ``MoeQuantType.W4A16``
- ``MoeQuantType.W8A8``
- ``MoeQuantType.W8A8`` (int8)
- ``MoeQuantType.FP8_W8A8`` (fp8)
- ``MoeQuantType.W4A8``
Backend priority:
- ``w16a16``: asm > triton
- ``w4a16``: moe_c > asm > triton
- ``w8a8``: asm > moe_c > triton > ck
- ``fp8_w8a8``: asm > moe_c > triton > ck
- ``w4a8``: moe_c
For non-gated MOE (e.g. Nemotron with ReLU² activation), pass
``gated=False`` (or let it auto-detect from ``activation="relu2"``)
and set ``N1 = intermediate_size`` (not ``2 * intermediate_size``).
"""
n
=
N1
/
2
# Determine gating: explicit > auto-detect from activation
if
gated
is
None
:
gated
=
activation
in
(
"silu"
,
"gelu"
)
# For gated (GLU): N1 = 2 * intermediate_size, n = N1 // 2
# For non-gated: N1 = intermediate_size, n = N1
n
=
N1
//
2
if
gated
else
N1
block_shape
=
[
0
,
block_size
]
if
block_size
else
None
if
quant_type
==
MoeQuantType
.
W4A16
:
...
...
@@ -282,7 +318,7 @@ def get_aiter_moe_config(
]
else
:
raise
ValueError
(
f
"Unsupported dtype:
{
dtype
}
"
)
elif
quant_type
==
MoeQuantType
.
W8A8
:
elif
quant_type
in
(
MoeQuantType
.
W8A8
,
MoeQuantType
.
FP8_
W8A8
)
:
if
block_size
==
0
:
# Channel wise choose MOE_C
candidates
=
[
(
MoeSolutionType
.
MOE_C
,
lambda
:
_try_get_moe_c_config
(
quant_type
,
M
,
E
,
n
,
block_size
)),
...
...
@@ -348,6 +384,7 @@ def aiter_moe(
use_int4_w4a16
=
moe_config
.
quant_type
==
MoeQuantType
.
W4A16
use_int8_w8a8
=
moe_config
.
quant_type
==
MoeQuantType
.
W8A8
use_fp8_w8a8
=
moe_config
.
quant_type
==
MoeQuantType
.
FP8_W8A8
use_int8_w4a8
=
moe_config
.
quant_type
==
MoeQuantType
.
W4A8
if
moe_config
.
solution_type
==
MoeSolutionType
.
MOE_C
:
...
...
@@ -362,6 +399,7 @@ def aiter_moe(
inplace
=
inplace
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int8_w8a8
=
use_int8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w4a8
=
use_int8_w4a8
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
...
...
@@ -391,6 +429,7 @@ def aiter_moe(
inplace
=
inplace
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int8_w8a8
=
use_int8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
...
...
@@ -409,7 +448,7 @@ def aiter_moe(
from
.ops.triton.fused_moe
import
fused_experts_impl
# W8A8 channel-wise (block_shape=None) requires per_channel_quant=True
per_channel_quant
=
use_int8_w8a8
and
block_shape
is
None
per_channel_quant
=
(
use_int8_w8a8
or
use_fp8_w8a8
)
and
block_shape
is
None
return
fused_experts_impl
(
hidden_states
,
...
...
@@ -421,6 +460,7 @@ def aiter_moe(
inplace
=
inplace
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int8_w8a8
=
use_int8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
activation
=
activation
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
...
...
@@ -448,6 +488,7 @@ def aiter_moe(
odtype
=
hidden_states
.
dtype
,
inplace
=
inplace
,
use_int8_w8a8
=
use_int8_w8a8
,
use_fp8_w8a8
=
use_fp8_w8a8
,
activation
=
activation
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
...
...
aiter/ops/triton/configs/BW200B-EXTEND_ATTENTION-V2-FP16.json
0 → 100644
View file @
f233de81
{
"config"
:
{
"(8, 192, 128, False, True, True, 128)"
:
{
"BLOCK_M"
:
32
,
"BLOCK_N"
:
64
,
"waves_per_eu"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_ctas"
:
1
,
"num_stages"
:
1
},
"(16, 192, 128, False, True, False, -1)"
:
{
"BLOCK_M"
:
32
,
"BLOCK_N"
:
64
,
"waves_per_eu"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_ctas"
:
1
,
"num_stages"
:
1
}
},
"path"
:
{}
}
aiter/ops/triton/configs/moe/E=256,N=320,device_name=BW200B,dtype=fp8_w8a8,is_bottom=True.json
0 → 100644
View file @
f233de81
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
512
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
512
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
512
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"512"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"8192"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"16384"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"32768"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
1
},
"65536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
1
}
}
\ No newline at end of file
aiter/ops/triton/configs/moe/E=256,N=320,device_name=BW200B,dtype=fp8_w8a8.json
0 → 100644
View file @
f233de81
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"512"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
1
},
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"8192"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"16384"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"32768"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"65536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
}
}
\ No newline at end of file
aiter/ops/triton/configs/moe/E=384,N=1024,device_name=BW200B,dtype=int4_w4a16,is_bottom=True.json
0 → 100644
View file @
f233de81
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"512"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"1024"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"8192"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"16384"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
}
}
\ No newline at end of file
aiter/ops/triton/configs/moe/E=384,N=1024,device_name=BW200B,dtype=int4_w4a16.json
0 → 100644
View file @
f233de81
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"512"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"1024"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"8192"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
16
,
"num_stages"
:
2
},
"16384"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
}
}
\ No newline at end of file
aiter/ops/triton/configs/moe/E=384,N=2048,device_name=BW200B,dtype=int4_w4a16,is_bottom=True.json
0 → 100644
View file @
f233de81
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
1
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
1
},
"512"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"1024"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"8192"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"16384"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
}
}
\ No newline at end of file
aiter/ops/triton/configs/moe/E=384,N=2048,device_name=BW200B,dtype=int4_w4a16.json
0 → 100644
View file @
f233de81
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"512"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"1024"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"8192"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
16
,
"num_stages"
:
2
},
"16384"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
}
}
\ No newline at end of file
aiter/ops/triton/configs/sage_attention/_attn_fwd-device=gfx936
:
cu_80-dtype=f16_f16_f16_f32_f32_f16_f32.json
→
aiter/ops/triton/configs/sage_attention/_attn_fwd-device=gfx936
cu_80-dtype=f16_f16_f16_f32_f32_f16_f32.json
View file @
f233de81
File moved
aiter/ops/triton/configs/sage_attention/quant_per_block_int8_kernel-device=gfx936
:
cu_80-dtype=f16_i8_f32.json
→
aiter/ops/triton/configs/sage_attention/quant_per_block_int8_kernel-device=gfx936
cu_80-dtype=f16_i8_f32.json
View file @
f233de81
File moved
aiter/ops/triton/extend_attention.py
View file @
f233de81
...
...
@@ -17,15 +17,21 @@ Memory-efficient attention for prefill.
It supports page size = 1 and prefill with KV cache (i.e. extend).
"""
from
typing
import
Optional
import
functools
import
json
from
typing
import
Any
,
Optional
import
torch
import
triton
import
triton.language
as
tl
import
os
from
triton.knobs
import
cache
as
cache_knob
import
types
try
:
from
triton.knobs
import
cache
as
cache_knob
except
ImportError
:
# Triton builds without `triton.knobs` (e.g. 3.2.x in some images): disable saved-kernel path.
cache_knob
=
types
.
SimpleNamespace
(
dir
=
"__triton_knobs_unavailable__"
)
from
aiter.ops.triton.prefill_attention
import
context_attention_fwd
from
aiter.ops.triton.activation
import
_tanh
...
...
@@ -76,6 +82,10 @@ def _fwd_kernel(
cur_seq
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_block_m
=
tl
.
program_id
(
2
)
tl
.
assume
(
Q_Extend
.
to
(
tl
.
int64
)
>=
0
)
tl
.
assume
(
K_Extend
.
to
(
tl
.
int64
)
>=
0
)
tl
.
assume
(
V_Extend
.
to
(
tl
.
int64
)
>=
0
)
cur_kv_head
=
cur_head
//
kv_group_num
cur_seq_extend_start_idx
=
tl
.
load
(
qo_indptr
+
cur_seq
)
...
...
@@ -292,6 +302,345 @@ def _fwd_kernel(
)
@
triton
.
jit
def
_fwd_kernel_v2
(
Q_Extend
,
K_Extend
,
V_Extend
,
O_Extend
,
K_Buffer
,
V_Buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
mask_ptr
,
mask_indptr
,
sink_ptr
,
window_kv_offset_ptr
,
sm_scale
,
k_scale
,
v_scale
,
kv_group_num
,
stride_qbs
,
stride_qh
,
stride_kbs
,
stride_kh
,
stride_vbs
,
stride_vh
,
stride_obs
,
stride_oh
,
stride_buf_kbs
,
stride_buf_kh
,
stride_buf_vbs
,
stride_buf_vh
,
SLIDING_WINDOW_SIZE
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
xai_temperature_len
:
tl
.
constexpr
,
Lq
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
USE_CUSTOM_MASK
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
SKIP_PREFIX_CUSTOM_MASK
:
tl
.
constexpr
,
STORE_TRANSPOSE
:
tl
.
constexpr
,
HAS_SINK
:
tl
.
constexpr
,
):
cur_seq
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_block_m
=
tl
.
program_id
(
2
)
tl
.
assume
(
Q_Extend
.
to
(
tl
.
int64
)
>=
0
)
tl
.
assume
(
K_Extend
.
to
(
tl
.
int64
)
>=
0
)
tl
.
assume
(
V_Extend
.
to
(
tl
.
int64
)
>=
0
)
cur_kv_head
=
cur_head
//
kv_group_num
cur_seq_extend_start_idx
=
tl
.
load
(
qo_indptr
+
cur_seq
)
cur_seq_len_extend
=
tl
.
load
(
qo_indptr
+
cur_seq
+
1
)
-
cur_seq_extend_start_idx
cur_seq_kv_start_idx
=
tl
.
load
(
kv_indptr
+
cur_seq
)
cur_seq_len_prefix
=
tl
.
load
(
kv_indptr
+
cur_seq
+
1
)
-
cur_seq_kv_start_idx
cur_seq_len
=
cur_seq_len_prefix
+
cur_seq_len_extend
if
USE_CUSTOM_MASK
:
cur_seq_mask_start_idx
=
tl
.
load
(
mask_indptr
+
cur_seq
)
window_kv_offset
=
0
if
USE_CUSTOM_MASK
and
SLIDING_WINDOW_SIZE
>
0
:
window_kv_offset
=
tl
.
load
(
window_kv_offset_ptr
+
cur_seq
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
offs_m
=
tl
.
arange
(
0
,
BLOCK_M
)
mask_m
=
(
cur_block_m
*
BLOCK_M
+
offs_m
)
<
cur_seq_len_extend
mask_d
=
offs_d
<
Lq
mask_dv
=
offs_dv
<
Lv
if
xai_temperature_len
>
0
:
offs_qidx
=
cur_seq_len_prefix
+
cur_block_m
*
BLOCK_M
+
offs_m
xai_temperature_scale
=
1.0
/
tl
.
log2
(
float
(
xai_temperature_len
))
xai_temperature_reg
=
tl
.
where
(
offs_qidx
>
xai_temperature_len
,
tl
.
log2
(
offs_qidx
.
to
(
tl
.
float32
))
*
xai_temperature_scale
,
1.0
,
)
offs_q
=
(
(
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
)
q
=
tl
.
load
(
Q_Extend
+
offs_q
,
mask
=
(
mask_m
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
offs_qpe
=
(
(
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_dpe
[
None
,
:]
)
qpe
=
tl
.
load
(
Q_Extend
+
offs_qpe
,
mask
=
mask_m
[:,
None
],
other
=
0.0
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DV
],
dtype
=
tl
.
float32
)
deno
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
e_max
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
for
start_n
in
range
(
0
,
cur_seq_len_prefix
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
mask_n
=
(
start_n
+
offs_n
)
<
cur_seq_len_prefix
final_mask
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
if
USE_CUSTOM_MASK
and
not
SKIP_PREFIX_CUSTOM_MASK
:
custom_mask
=
tl
.
load
(
mask_ptr
+
cur_seq_mask_start_idx
+
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
(
cur_seq_len
+
window_kv_offset
)
+
window_kv_offset
+
start_n
+
offs_n
[
None
,
:],
mask
=
(
mask_m
[:,
None
]
&
mask_n
[
None
,
:]),
other
=
0
,
)
final_mask
&=
custom_mask
if
SLIDING_WINDOW_SIZE
>
0
:
window_mask
=
(
cur_seq_len_prefix
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
]
)
<=
(
start_n
+
offs_n
[
None
,
:]
+
SLIDING_WINDOW_SIZE
)
final_mask
&=
window_mask
SKIP_TILE
=
False
if
(
USE_CUSTOM_MASK
and
not
SKIP_PREFIX_CUSTOM_MASK
)
or
SLIDING_WINDOW_SIZE
>
0
:
SKIP_TILE
=
tl
.
max
(
tl
.
max
(
final_mask
.
to
(
tl
.
int32
),
axis
=
1
),
axis
=
0
)
==
0
if
not
SKIP_TILE
:
offs_kv_loc
=
tl
.
load
(
kv_indices
+
cur_seq_kv_start_idx
+
start_n
+
offs_n
,
mask
=
mask_n
,
other
=
0
,
)
offs_buf_k
=
(
offs_kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[:,
None
]
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
mask_n
[
None
,
:])
&
(
mask_d
[:,
None
]),
other
=
0.0
,
)
qk
=
tl
.
dot
(
q
.
to
(
k
.
dtype
),
k
)
if
BLOCK_DPE
>
0
:
offs_kpe
=
(
offs_kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Buffer
+
offs_kpe
,
mask
=
mask_n
[
None
,
:],
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
.
to
(
kpe
.
dtype
),
kpe
)
qk
*=
sm_scale
*
k_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
_tanh
(
qk
/
logit_cap
)
if
xai_temperature_len
>
0
:
qk
*=
xai_temperature_reg
[:,
None
]
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
# row_max_fixed avoids exp(-inf - (-inf)) when a row is all -inf in this tile;
# only needed under sliding window or custom mask (plain causal matches v1).
if
SLIDING_WINDOW_SIZE
>
0
or
(
USE_CUSTOM_MASK
and
not
SKIP_PREFIX_CUSTOM_MASK
):
row_max
=
tl
.
max
(
qk
,
1
)
row_max_fixed
=
tl
.
where
(
row_max
==
float
(
"-inf"
),
-
1e20
,
row_max
)
n_e_max
=
tl
.
maximum
(
row_max_fixed
,
e_max
)
else
:
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
offs_buf_v
=
(
offs_kv_loc
[:,
None
]
*
stride_buf_vbs
+
cur_kv_head
*
stride_buf_vh
+
offs_dv
[
None
,
:]
)
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
mask_n
[:,
None
]
&
mask_dv
[
None
,
:],
other
=
0.0
,
)
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
re_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
*
v_scale
e_max
=
n_e_max
cur_block_m_end
=
(
cur_seq_len_extend
if
not
IS_CAUSAL
else
tl
.
minimum
(
cur_seq_len_extend
,
(
cur_block_m
+
1
)
*
BLOCK_M
)
)
for
start_n
in
range
(
0
,
cur_block_m_end
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
mask_n
=
(
start_n
+
offs_n
)
<
cur_block_m_end
final_mask
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
if
USE_CUSTOM_MASK
:
custom_mask
=
tl
.
load
(
mask_ptr
+
cur_seq_mask_start_idx
+
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
(
cur_seq_len
+
window_kv_offset
)
+
window_kv_offset
+
cur_seq_len_prefix
+
start_n
+
offs_n
[
None
,
:],
mask
=
(
mask_m
[:,
None
]
&
mask_n
[
None
,
:]),
other
=
0
,
)
custom_mask
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
final_mask
&=
custom_mask
elif
IS_CAUSAL
:
mask_causual
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
>=
(
start_n
+
offs_n
[
None
,
:]
)
mask_causual
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
final_mask
&=
mask_causual
else
:
mask_non_causal
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
final_mask
&=
mask_non_causal
if
SLIDING_WINDOW_SIZE
>
0
:
window_mask
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
<=
(
start_n
+
offs_n
[
None
,
:]
+
SLIDING_WINDOW_SIZE
)
final_mask
&=
window_mask
SKIP_TILE
=
False
if
USE_CUSTOM_MASK
or
SLIDING_WINDOW_SIZE
>
0
:
SKIP_TILE
=
tl
.
max
(
tl
.
max
(
final_mask
.
to
(
tl
.
int32
),
axis
=
1
),
axis
=
0
)
==
0
if
not
SKIP_TILE
:
offs_k
=
(
(
cur_seq_extend_start_idx
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_d
[:,
None
]
)
k
=
tl
.
load
(
K_Extend
+
offs_k
,
mask
=
(
mask_n
[
None
,
:])
&
(
mask_d
[:,
None
]),
other
=
0.0
)
qk
=
tl
.
dot
(
q
.
to
(
k
.
dtype
),
k
,
out_dtype
=
tl
.
float32
)
if
BLOCK_DPE
>
0
:
offs_kpe
=
(
(
cur_seq_extend_start_idx
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Extend
+
offs_kpe
,
mask
=
mask_n
[
None
,
:],
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
.
to
(
kpe
.
dtype
),
kpe
)
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
_tanh
(
qk
/
logit_cap
)
if
xai_temperature_len
>
0
:
qk
*=
xai_temperature_reg
[:,
None
]
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
if
SLIDING_WINDOW_SIZE
>
0
or
USE_CUSTOM_MASK
:
row_max
=
tl
.
max
(
qk
,
1
)
row_max_fixed
=
tl
.
where
(
row_max
==
float
(
"-inf"
),
-
1e20
,
row_max
)
n_e_max
=
tl
.
maximum
(
row_max_fixed
,
e_max
)
else
:
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
offs_v
=
(
(
cur_seq_extend_start_idx
+
start_n
+
offs_n
[:,
None
])
*
stride_vbs
+
cur_kv_head
*
stride_vh
+
offs_dv
[
None
,
:]
)
v
=
tl
.
load
(
V_Extend
+
offs_v
,
mask
=
mask_n
[:,
None
]
&
mask_dv
[
None
,
:],
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
re_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
e_max
=
n_e_max
if
HAS_SINK
:
cur_sink
=
tl
.
load
(
sink_ptr
+
cur_head
)
deno
+=
tl
.
exp
(
cur_sink
-
e_max
)
offs_o
=
(
(
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_obs
+
cur_head
*
stride_oh
+
offs_dv
[
None
,
:]
)
if
STORE_TRANSPOSE
:
tl
.
store
(
O_Extend
+
offs_o
.
T
,
(
acc
/
deno
[:,
None
]).
T
,
mask
=
(
mask_m
[:,
None
]
&
mask_dv
[
None
,
:]).
T
,
)
else
:
tl
.
store
(
O_Extend
+
offs_o
,
acc
/
deno
[:,
None
],
mask
=
mask_m
[:,
None
]
&
mask_dv
[
None
,
:],
)
def
create_tuple
(
k
):
if
k
[
0
]
!=
'('
and
k
[
-
1
]
!=
')'
:
return
k
...
...
@@ -311,8 +660,11 @@ def create_tuple(k):
def
_load_config
():
dev
=
arch_info
.
get_device
()
fpath
=
f
"
{
AITER_TRITON_CONFIGS_PATH
}
/
{
dev
}
-EXTEND_ATTENTION-FP16.json"
try
:
with
open
(
fpath
,
"r"
)
as
file
:
data
=
json
.
load
(
file
)
except
FileNotFoundError
:
return
{
"config"
:
{},
"path"
:
{},
"key"
:
[],
"keys"
:
[]}
res
=
{}
res
[
'config'
]
=
data
[
'config'
]
res
[
'path'
]
=
data
[
'path'
]
...
...
@@ -323,6 +675,39 @@ def _load_config():
global_config
=
_load_config
()
def
_load_config_v2
():
"""Autotuned configs for :func:`_fwd_kernel_v2` (fp8 / sglang-style scale path).
Each ``config`` entry key must parse to a **7-tuple** via :func:`create_tuple`, matching
runtime ``want7``; 5-tuple keys are not accepted.
"""
dev
=
arch_info
.
get_device
()
fpath
=
f
"
{
AITER_TRITON_CONFIGS_PATH
}
/
{
dev
}
-EXTEND_ATTENTION-V2-FP16.json"
try
:
with
open
(
fpath
,
"r"
)
as
file
:
data
=
json
.
load
(
file
)
except
FileNotFoundError
:
return
{
"config"
:
{},
"path"
:
{},
"key"
:
[],
"keys"
:
[]}
res
=
{}
res
[
"config"
]
=
data
[
"config"
]
res
[
"path"
]
=
data
.
get
(
"path"
,
{})
res
[
"key"
]
=
list
(
data
[
"config"
].
keys
())
res
[
"keys"
]
=
[]
for
k
in
res
[
"key"
]:
tup
=
create_tuple
(
k
)
if
len
(
tup
)
!=
7
:
raise
ValueError
(
f
"
{
dev
}
-EXTEND_ATTENTION-V2-FP16.json keys must be 7-tuples matching runtime "
f
"want7 (kv_group_num, Lq, Lv, USE_CUSTOM_MASK, IS_CAUSAL, HAS_SINK, "
f
"SLIDING_WINDOW_SIZE); got length
{
len
(
tup
)
}
for
{
k
!
r
}
"
)
res
[
"keys"
].
append
(
tup
)
return
res
global_config_v2
=
_load_config_v2
()
default_config
=
{
"BLOCK_M"
:
32
,
"BLOCK_N"
:
32
,
...
...
@@ -351,6 +736,47 @@ def _get_config(kv_group_num, Lq, Lv, use_custom_mask, is_causal):
return
global_config
[
'config'
][
key
],
global_config
[
'path'
][
key
]
@
functools
.
lru_cache
(
maxsize
=
1024
)
def
_get_config_v2
(
kv_group_num
,
Lq
,
Lv
,
use_custom_mask
,
is_causal
,
has_sink
:
bool
,
sliding_window_size
:
int
,
):
"""
Lookup order for ``_fwd_kernel_v2`` block sizes:
1. ``want7 = (kv_group_num, Lq, Lv, use_custom_mask, is_causal, has_sink, sliding_window_size)``
against ``{arch}-EXTEND_ATTENTION-V2-FP16.json``. JSON keys must be **7-tuple** strings,
same shape as ``want7`` (see :func:`_load_config_v2`).
2. If no V2 entry matches, :data:`default_config` (no fallback to v1 JSON).
Log field mapping (typical): ``kv_group_num = q_extend.size(-2) // k_extend.size(-2)``,
``Lq = q_extend.size(-1)``, ``Lv = v_extend.size(-1)``,
``use_custom_mask = custom_mask is not None``, ``is_causal`` as passed,
``has_sink = sinks is not None``, ``sliding_window_size`` as passed (use ``-1`` if disabled).
"""
want7
=
(
kv_group_num
,
Lq
,
Lv
,
use_custom_mask
,
is_causal
,
has_sink
,
sliding_window_size
,
)
for
i
,
keys
in
enumerate
(
global_config_v2
[
"keys"
]):
if
keys
==
want7
:
key
=
global_config_v2
[
"key"
][
i
]
return
global_config_v2
[
"config"
][
key
],
global_config_v2
[
"path"
].
get
(
key
)
print
(
"WARNING: optimal V2 config not found, just use default config"
)
return
default_config
,
None
def
has_kernel_cache
(
path
):
return
False
if
not
path
or
not
os
.
path
.
isdir
(
f
'
{
cache_knob
.
dir
}
/
{
path
}
'
)
else
True
...
...
@@ -385,12 +811,23 @@ def extend_attention_fwd(
sm_scale
=
None
,
logit_cap
=
0.0
,
skip_prefix_custom_mask
=
True
,
config
:
Optional
[
dict
[
str
,
any
]]
=
None
,
config
:
Optional
[
dict
[
str
,
Any
]]
=
None
,
k_scale
=
None
,
v_scale
=
None
,
sliding_window_size
=-
1
,
sinks
=
None
,
window_kv_offsets
=
None
,
xai_temperature_len
=-
1
,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
Through ``config`` the signature matches the original aiter API. v2 / sglang
extensions follow with defaults. ``k_scale`` / ``v_scale`` must both be
``None`` or both set (``float`` / ``int`` like sglang, or 1-element
``torch.Tensor`` on device); if both are set, :func:`_fwd_kernel_v2` is used.
"""
Lq
,
Lv
=
(
q_extend
.
shape
[
-
1
],
...
...
@@ -422,12 +859,25 @@ def extend_attention_fwd(
# Skip custom mask for prefix part
SKIP_PREFIX_CUSTOM_MASK
=
skip_prefix_custom_mask
use_v2
=
k_scale
is
not
None
or
v_scale
is
not
None
if
not
USE_CUSTOM_MASK
:
custom_mask
=
torch
.
tensor
([
0
],
dtype
=
torch
.
bool
,
device
=
q_extend
.
device
)
mask_indptr
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
,
device
=
q_extend
.
device
)
if
config
is
None
:
if
q_extend
.
dtype
==
torch
.
float16
:
if
q_extend
.
dtype
==
torch
.
float16
or
q_extend
.
dtype
==
torch
.
bfloat16
:
if
use_v2
:
config
,
path
=
_get_config_v2
(
kv_group_num
,
Lq
,
Lv
,
USE_CUSTOM_MASK
,
is_causal
,
sinks
is
not
None
,
sliding_window_size
,
)
else
:
keys
=
[
kv_group_num
,
Lq
,
Lv
,
USE_CUSTOM_MASK
,
is_causal
]
config
,
path
=
_get_config
(
*
keys
)
else
:
...
...
@@ -441,24 +891,7 @@ def extend_attention_fwd(
# extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
fn
=
_fwd_kernel
[
grid
]
if
not
has_kernel_cache
(
path
)
\
else
functools
.
partial
(
triton
.
utils
.
run_saved_kernel
,
_fwd_kernel
,
path
,
grid
=
grid
)
fn
(
q_extend
,
k_extend
,
v_extend
,
o_extend
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask
,
mask_indptr
,
sm_scale
,
kv_group_num
,
stride_args
=
(
q_extend
.
stride
(
0
),
q_extend
.
stride
(
1
),
k_extend
.
stride
(
0
),
...
...
@@ -471,20 +904,77 @@ def extend_attention_fwd(
k_buffer
.
stride
(
1
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
logit_cap
=
logit_cap
,
)
block_const
=
dict
(
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DV
=
BLOCK_DV
,
# BLOCK_M=BLOCK_M,
# BLOCK_N=BLOCK_N,
Lq
=
Lq
,
Lv
=
Lv
,
USE_CUSTOM_MASK
=
USE_CUSTOM_MASK
,
IS_CAUSAL
=
is_causal
,
SKIP_PREFIX_CUSTOM_MASK
=
SKIP_PREFIX_CUSTOM_MASK
,
STORE_TRANSPOSE
=
True
,
# num_warps=num_warps,
# num_stages=num_stages,
)
if
use_v2
:
HAS_SINK
=
sinks
is
not
None
assert
k_scale
is
not
None
and
v_scale
is
not
None
,
"k_scale and v_scale must both be set"
# k_scale / v_scale kept in Python API; v2 kernel TEMP omits them for perf vs v1.
_fwd_kernel_v2
[
grid
](
q_extend
,
k_extend
,
v_extend
,
o_extend
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask
,
mask_indptr
,
sinks
,
window_kv_offsets
,
sm_scale
,
k_scale
,
v_scale
,
kv_group_num
,
*
stride_args
,
SLIDING_WINDOW_SIZE
=
sliding_window_size
,
logit_cap
=
logit_cap
,
xai_temperature_len
=
xai_temperature_len
,
HAS_SINK
=
HAS_SINK
,
**
block_const
,
**
config
,
)
return
fn
=
(
_fwd_kernel
[
grid
]
if
not
has_kernel_cache
(
path
)
else
functools
.
partial
(
triton
.
utils
.
run_saved_kernel
,
_fwd_kernel
,
path
,
grid
=
grid
)
)
fn
(
q_extend
,
k_extend
,
v_extend
,
o_extend
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask
,
mask_indptr
,
sm_scale
,
kv_group_num
,
*
stride_args
,
logit_cap
=
logit_cap
,
**
block_const
,
**
config
,
)
...
...
aiter/ops/triton/moe_op.py
View file @
f233de81
...
...
@@ -322,6 +322,7 @@ def fused_moe_kernel_gptq_awq(
USE_MLS_LOAD
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
USE_ADDR_OFFSET_INT64_A
:
tl
.
constexpr
,
USE_ADDR_OFFSET_INT64_B
:
tl
.
constexpr
,
USE_ADDR_OFFSET_INT64_C
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
...
...
@@ -434,17 +435,45 @@ def fused_moe_kernel_gptq_awq(
if
use_int4_w4a16
:
if
group_size_divisible
and
has_zp
:
offs_k_continue
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
//
2
).
to
(
tl
.
int32
)
b_ptrs
=
b_ptr
+
(
off_experts
*
stride_be
+
\
offs_bn
[:,
None
]
*
stride_bn
+
offs_k_continue
[
None
,
:]
*
\
stride_bk
).
to
(
tl
.
int32
)
if
USE_ADDR_OFFSET_INT64_B
:
b_ptrs
=
b_ptr
+
(
off_experts
.
to
(
tl
.
int64
)
*
stride_be
+
offs_bn
[:,
None
].
to
(
tl
.
int64
)
*
stride_bn
+
offs_k_continue
[
None
,
:].
to
(
tl
.
int64
)
*
stride_bk
)
else
:
b_ptrs
=
b_ptr
+
(
off_experts
*
stride_be
+
\
(
offs_k
[:,
None
]
//
2
)
*
stride_bk
+
offs_bn
[
None
,
:]
*
\
stride_bn
).
to
(
tl
.
int32
)
b_ptrs
=
b_ptr
+
(
off_experts
*
stride_be
+
offs_bn
[:,
None
]
*
stride_bn
+
offs_k_continue
[
None
,
:]
*
stride_bk
).
to
(
tl
.
int32
)
else
:
if
USE_ADDR_OFFSET_INT64_B
:
b_ptrs
=
b_ptr
+
(
off_experts
.
to
(
tl
.
int64
)
*
stride_be
+
(
offs_k
[:,
None
].
to
(
tl
.
int64
)
//
2
)
*
stride_bk
+
offs_bn
[
None
,
:].
to
(
tl
.
int64
)
*
stride_bn
)
else
:
b_ptrs
=
b_ptr
+
(
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
//
2
)
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
).
to
(
tl
.
int32
)
b_shifter
=
(
offs_k
[:,
None
]
%
2
)
*
4
elif
use_int8_w8a16
:
b_ptrs
=
b_ptr
+
(
off_experts
*
stride_be
+
\
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
).
to
(
tl
.
int32
)
if
USE_ADDR_OFFSET_INT64_B
:
b_ptrs
=
b_ptr
+
(
off_experts
.
to
(
tl
.
int64
)
*
stride_be
+
offs_k
[:,
None
].
to
(
tl
.
int64
)
*
stride_bk
+
offs_bn
[
None
,
:].
to
(
tl
.
int64
)
*
stride_bn
)
else
:
b_ptrs
=
b_ptr
+
(
off_experts
*
stride_be
+
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
).
to
(
tl
.
int32
)
if
not
has_zp
and
use_int4_w4a16
:
b_zp_num
=
8
...
...
@@ -2552,6 +2581,7 @@ def fused_moe(
assert
B_zp
is
None
or
B_zp
.
ndim
==
3
offset_max
=
2
**
31
-
1
use_addr_offset_int64_a
=
A
.
numel
()
*
A
.
element_size
()
>=
offset_max
use_addr_offset_int64_b
=
B
.
numel
()
*
B
.
element_size
()
>=
offset_max
use_addr_offset_int64_c
=
C
.
numel
()
*
C
.
element_size
()
>=
offset_max
if
use_int4_w4a8
:
...
...
@@ -2592,6 +2622,7 @@ def fused_moe(
group_size
=
block_shape
[
1
],
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
USE_ADDR_OFFSET_INT64_A
=
use_addr_offset_int64_a
,
USE_ADDR_OFFSET_INT64_B
=
use_addr_offset_int64_b
,
USE_ADDR_OFFSET_INT64_C
=
use_addr_offset_int64_c
,
top_k
=
top_k
,
compute_type
=
compute_type
,
...
...
@@ -2636,6 +2667,7 @@ def fused_moe(
group_size
=
block_shape
[
1
],
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
USE_ADDR_OFFSET_INT64_A
=
use_addr_offset_int64_a
,
USE_ADDR_OFFSET_INT64_B
=
use_addr_offset_int64_b
,
USE_ADDR_OFFSET_INT64_C
=
use_addr_offset_int64_c
,
top_k
=
top_k
,
compute_type
=
compute_type
,
...
...
op_tests/op_benchmarks/triton/bench_extend_attention.py
View file @
f233de81
...
...
@@ -27,6 +27,7 @@ def input_helper(
attn_impl
=
"absorb"
,
equal_seqlens
=
False
,
requires_grad
=
False
,
kv_num_heads
:
int
=
1
,
):
torch
.
manual_seed
(
0
)
...
...
@@ -85,9 +86,9 @@ def input_helper(
total_extend
,
H
,
Lq
,
dtype
=
dtype
,
device
=
device
).
requires_grad_
(
requires_grad
)
# extend parts
# extend parts
(``kv_num_heads`` for GQA: e.g. 2 when log shows ``k_extend [T,2,192]``)
k_extend
=
torch
.
randn
(
total_extend
,
1
,
Lk
,
dtype
=
dtype
,
device
=
device
total_extend
,
kv_num_heads
,
Lk
,
dtype
=
dtype
,
device
=
device
).
requires_grad_
(
requires_grad
)
v_extend
=
k_extend
[...,
:
Lv
]
...
...
@@ -96,7 +97,7 @@ def input_helper(
# prefix parts
k_buffer
=
torch
.
randn
(
total_prefix
,
1
,
Lk
,
dtype
=
dtype
,
device
=
device
total_prefix
,
kv_num_heads
,
Lk
,
dtype
=
dtype
,
device
=
device
).
requires_grad_
(
requires_grad
)
v_buffer
=
k_buffer
[...,
:
Lv
]
...
...
@@ -154,12 +155,20 @@ def extend_forward(
causal
,
sm_scale
=
1.0
,
logit_cap
=
0.0
,
use_v2
:
bool
=
False
,
sliding_window_size
:
int
=
-
1
,
sinks
=
None
,
):
"""Same tensors; v1 uses ``k_scale=v_scale=None``, v2 uses ``1.0`` (sglang-style scalars)."""
out
=
torch
.
empty
(
(
*
q_extend
.
shape
[:
-
1
],
v_extend
.
shape
[
-
1
]),
dtype
=
q_extend
.
dtype
,
device
=
q_extend
.
device
,
)
k_scale
=
v_scale
=
None
if
use_v2
:
k_scale
=
1.0
v_scale
=
1.0
extend_attention
.
extend_attention_fwd
(
q_extend
,
k_extend
,
...
...
@@ -176,6 +185,14 @@ def extend_forward(
max_len_extend
,
sm_scale
=
sm_scale
,
logit_cap
=
logit_cap
,
skip_prefix_custom_mask
=
True
,
config
=
None
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
sliding_window_size
=
sliding_window_size
,
sinks
=
sinks
,
window_kv_offsets
=
None
,
xai_temperature_len
=-
1
,
)
return
out
...
...
@@ -211,13 +228,23 @@ def get_extend_benchmark_configs():
"qk_rope_head_dim"
,
"v_head_dim"
,
"attn_impl"
,
"kv_num_heads"
,
# 与 ``{arch}-EXTEND_ATTENTION-V2-FP16.json`` 的 want7 对齐(见 extend_attention._get_config_v2)
"is_causal"
,
"sliding_window_size"
,
"with_sinks"
,
]
x_vals_list
=
[
(
2
,
16
,
1024
,
1024
,
256
,
0
,
128
,
"non-absorb"
),
(
2
,
16
,
4096
,
4096
,
512
,
64
,
128
,
"non-absorb"
),
(
2
,
16
,
8192
,
4096
,
512
,
64
,
128
,
"non-absorb"
),
(
2
,
16
,
8192
,
4096
,
512
,
64
,
128
,
"absorb"
),
(
2
,
16
,
16324
,
8192
,
512
,
64
,
128
,
"absorb"
),
# (2, 16, 1024, 1024, 256, 0, 128, "non-absorb", 1, False, -1, False),
# (2, 16, 4096, 4096, 512, 64, 128, "non-absorb", 1, False, -1, False),
# (2, 16, 8192, 4096, 512, 64, 128, "non-absorb", 1, False, -1, False),
# (2, 16, 8192, 4096, 512, 64, 128, "absorb", 1, False, -1, False),
# (2, 16, 16324, 8192, 512, 64, 128, "absorb", 1, False, -1, False),
# log 形状 + 命中 BW200B-EXTEND_ATTENTION-V2-FP16.json 两条 key
(
2
,
16
,
4096
,
555
,
128
,
64
,
128
,
"non-absorb"
,
1
,
True
,
-
1
,
False
),
(
16
,
16
,
556
,
1
,
128
,
64
,
128
,
"non-absorb"
,
1
,
True
,
-
1
,
False
),
(
2
,
16
,
4096
,
1024
,
128
,
64
,
128
,
"non-absorb"
,
2
,
True
,
128
,
True
),
(
16
,
16
,
556
,
1
,
128
,
64
,
128
,
"non-absorb"
,
2
,
True
,
128
,
True
),
]
return
x_names
,
x_vals_list
...
...
@@ -232,13 +259,17 @@ def get_prefill_benchmark_configs():
"qk_rope_head_dim"
,
"v_head_dim"
,
"attn_impl"
,
"kv_num_heads"
,
"is_causal"
,
"sliding_window_size"
,
"with_sinks"
,
]
x_vals_list
=
[
(
2
,
16
,
0
,
1024
,
256
,
0
,
128
,
"non-absorb"
),
(
2
,
16
,
0
,
4096
,
512
,
64
,
128
,
"non-absorb"
),
(
2
,
16
,
0
,
4096
,
512
,
64
,
128
,
"non-absorb"
),
(
2
,
16
,
0
,
4096
,
512
,
64
,
128
,
"absorb"
),
(
2
,
16
,
0
,
8192
,
512
,
64
,
128
,
"absorb"
),
(
2
,
16
,
0
,
1024
,
256
,
0
,
128
,
"non-absorb"
,
1
,
False
,
-
1
,
False
),
(
2
,
16
,
0
,
4096
,
512
,
64
,
128
,
"non-absorb"
,
1
,
False
,
-
1
,
False
),
(
2
,
16
,
0
,
4096
,
512
,
64
,
128
,
"non-absorb"
,
1
,
False
,
-
1
,
False
),
(
2
,
16
,
0
,
4096
,
512
,
64
,
128
,
"absorb"
,
1
,
False
,
-
1
,
False
),
(
2
,
16
,
0
,
8192
,
512
,
64
,
128
,
"absorb"
,
1
,
False
,
-
1
,
False
),
]
return
x_names
,
x_vals_list
...
...
@@ -266,6 +297,10 @@ def model_benchmark_configs(args):
"qk_rope_head_dim"
,
"v_head_dim"
,
"attn_impl"
,
"kv_num_heads"
,
"is_causal"
,
"sliding_window_size"
,
"with_sinks"
,
]
x_vals_list
=
[]
...
...
@@ -276,7 +311,21 @@ def model_benchmark_configs(args):
extend
=
args
.
extend
if
args
.
extend
else
8192
attn_impl
=
args
.
attn_impl
if
args
.
attn_impl
else
"non-absorb"
x_vals_list
.
append
(
(
model_name
,
batch_size
,
HQ
,
prefix
,
extend
,
512
,
64
,
128
,
attn_impl
)
(
model_name
,
batch_size
,
HQ
,
prefix
,
extend
,
512
,
64
,
128
,
attn_impl
,
1
,
False
,
-
1
,
False
,
)
)
return
x_names
,
x_vals_list
...
...
@@ -296,7 +345,22 @@ def benchmark(args):
elif
args
.
mode
==
"prefill"
:
x_names
,
x_vals_list
=
get_prefill_benchmark_configs
()
line_vals
=
[
"extend_attention_fwd"
]
if
args
.
mode
==
"prefill"
:
line_vals
=
[
"context_attention_fwd"
]
line_names
=
[
"prefill/context"
]
styles
=
[(
"blue"
,
"-"
)]
elif
args
.
extend_provider
==
"v1"
:
line_vals
=
[
"extend_v1"
]
line_names
=
[
"v1 (scale None)"
]
styles
=
[(
"red"
,
"-"
)]
elif
args
.
extend_provider
==
"v2"
:
line_vals
=
[
"extend_v2"
]
line_names
=
[
"v2 (k=v=1.0)"
]
styles
=
[(
"green"
,
"-"
)]
else
:
line_vals
=
[
"extend_v1"
,
"extend_v2"
]
line_names
=
[
"v1 (scale None)"
,
"v2 (k=v=1.0)"
]
styles
=
[(
"red"
,
"-"
),
(
"green"
,
"-"
)]
plot_name
=
(
args
.
plot_name
+
f
"-causal-
{
args
.
causal
}
-equal_seqlens-
{
args
.
equal_seqlens
}
"
...
...
@@ -308,8 +372,8 @@ def benchmark(args):
x_vals
=
x_vals_list
,
line_arg
=
"provider"
,
line_vals
=
line_vals
,
line_names
=
line_
val
s
,
styles
=
[(
"red"
,
"-"
),
(
"green"
,
"-"
)]
,
line_names
=
line_
name
s
,
styles
=
styles
,
ylabel
=
"ms"
,
plot_name
=
plot_name
,
args
=
{
"sm_scale"
:
1.0
,
"logit_cap"
:
0.0
,
"device"
:
args
.
device
},
...
...
@@ -317,23 +381,34 @@ def benchmark(args):
)
@
triton
.
testing
.
perf_report
(
configs
)
def
bench_MLA
(
B
,
H
,
prefix
,
extend
,
kv_lora_rank
,
qk_rope_head_dim
,
v_head_dim
,
attn_impl
,
sm_scale
,
logit_cap
,
device
,
provider
=
None
,
model
=
None
,
):
warmup
=
25
rep
=
100
def
bench_MLA
(
**
kwargs
):
# perf_report 调用形如 fn(**x_args, provider=..., **bench.args),全部为关键字参数
warmup
=
5
rep
=
30
provider
=
kwargs
.
pop
(
"provider"
)
sm_scale
=
kwargs
.
pop
(
"sm_scale"
)
logit_cap
=
kwargs
.
pop
(
"logit_cap"
)
device
=
kwargs
.
pop
(
"device"
)
kwargs
.
pop
(
"model"
,
None
)
kv_num_heads
=
int
(
kwargs
.
pop
(
"kv_num_heads"
,
1
))
B
=
kwargs
.
pop
(
"B"
)
H
=
kwargs
.
pop
(
"H"
)
prefix
=
kwargs
.
pop
(
"prefix"
)
extend
=
kwargs
.
pop
(
"extend"
)
kv_lora_rank
=
kwargs
.
pop
(
"kv_lora_rank"
)
qk_rope_head_dim
=
kwargs
.
pop
(
"qk_rope_head_dim"
)
v_head_dim
=
kwargs
.
pop
(
"v_head_dim"
)
attn_impl
=
kwargs
.
pop
(
"attn_impl"
)
row_causal
=
kwargs
.
pop
(
"is_causal"
)
sliding_window_size
=
int
(
kwargs
.
pop
(
"sliding_window_size"
))
with_sinks
=
bool
(
kwargs
.
pop
(
"with_sinks"
))
if
kwargs
:
raise
ValueError
(
f
"unexpected benchmark kwargs:
{
kwargs
}
"
)
sinks_tensor
=
None
if
with_sinks
:
sinks_tensor
=
torch
.
zeros
(
H
,
device
=
device
,
dtype
=
torch
.
float32
)
(
q_extend
,
...
...
@@ -360,11 +435,15 @@ def benchmark(args):
v_head_dim
,
dtype
,
device
,
attn_impl
=
attn_impl
,
equal_seqlens
=
args
.
equal_seqlens
,
kv_num_heads
=
kv_num_heads
,
)
if
provider
==
"extend_attention_fwd"
:
if
provider
in
(
"extend_v1"
,
"extend_v2"
):
use_v2
=
provider
==
"extend_v2"
def
extend_attentio
n
():
def
f
n
():
return
extend_forward
(
q_extend
,
k_extend
,
...
...
@@ -377,33 +456,35 @@ def benchmark(args):
custom_mask
,
mask_indptr
,
max_len_extend
,
args
.
causal
,
row_
causal
,
sm_scale
,
logit_cap
,
use_v2
=
use_v2
,
sliding_window_size
=
sliding_window_size
,
sinks
=
sinks_tensor
,
)
def
context_attention
():
return
extend_forward
(
elif
provider
==
"context_attention_fwd"
:
assert
(
prefix
==
0
),
"Prefix length must be 0 for context attention. Try setting -mode prefill."
def
fn
():
return
prefill_forward
(
q_extend
,
k_extend
,
v_extend
,
B_Start_Loc
,
B_Seqlen
,
max_len_extend
,
args
.
causal
,
row_
causal
,
)
if
provider
==
"extend_attention_fwd"
:
fn
=
extend_attention
elif
provider
==
"context_attention_fwd"
:
assert
(
prefix
==
0
),
"Prefix length must be 0 for context attention. Try setting -mode prefill."
fn
=
context_attention
else
:
raise
ValueError
(
f
"Unknown provider:
{
provider
}
"
)
ms
=
triton
.
testing
.
do_bench
(
fn
,
warmup
=
warmup
,
rep
=
rep
)
ms
=
triton
.
testing
.
do_bench_cudagraph
(
fn
)
# ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return
ms
...
...
@@ -474,16 +555,16 @@ def parse_args():
default
=
"extend"
,
help
=
"Mode of the benchmark. Options: extend, prefill"
,
)
parser
.
add_argument
(
"-extend_provider"
,
type
=
str
,
default
=
"both"
,
choices
=
(
"both"
,
"v1"
,
"v2"
),
help
=
"Which extend_attention path to benchmark: v1 (k_scale=v_scale=None), v2 (1.0), or both. Ignored when -mode prefill."
,
)
return
parser
.
parse_args
()
arg_to_torch_dtype
=
{
"fp16"
:
torch
.
float16
,
"bf16"
:
torch
.
bfloat16
,
"fp32"
:
torch
.
float32
,
}
def
run_bench
(
args
):
torch
.
manual_seed
(
0
)
torch
.
set_default_device
(
args
.
device
)
...
...
op_tests/test_aiter_moe_with_config_w16a16_nogate.py
0 → 100644
View file @
f233de81
# Test for get_aiter_moe_config and aiter_moe with W16A16 non-gated ReLU²
# (Nemotron-style MOE: N1 = intermediate_size, activation = relu2)
import
torch
import
pandas
as
pd
from
typing
import
Optional
,
List
from
aiter.fused_moe
import
fused_topk
from
aiter
import
dtypes
from
aiter.test_common
import
checkAllclose
,
perftest
from
aiter.moe
import
(
get_aiter_moe_config
,
aiter_moe
,
MoeSolutionType
,
MoeQuantType
,
)
from
aiter.fused_moe_asm_wna16
import
fused_experts_asm_impl
from
aiter.ops.shuffle
import
asm_shuffle_weight_b8
import
aiter
torch
.
set_default_device
(
"cuda"
)
# ---------------------------------------------------------------------------
# Torch reference for non-gated ReLU² MOE
# ---------------------------------------------------------------------------
def
torch_moe_relu2
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
):
"""Reference implementation for non-gated ReLU² MOE.
w1: [E, inter_dim, model_dim] (NOT 2*inter_dim)
w2: [E, model_dim, inter_dim]
"""
computeType
=
torch
.
float32
dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
computeType
)
w1
=
w1
.
to
(
computeType
)
w2
=
w2
.
to
(
computeType
)
B
,
D
=
hidden_states
.
shape
topk
=
topk_weights
.
shape
[
1
]
hidden_states
=
hidden_states
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
)
out
=
torch
.
zeros
(
(
B
,
topk
,
D
),
dtype
=
computeType
,
device
=
hidden_states
.
device
,
)
for
E_id
in
range
(
w1
.
shape
[
0
]):
mask
=
topk_ids
==
E_id
if
mask
.
sum
():
sub_tokens
=
hidden_states
[
mask
]
# GEMM1
h
=
sub_tokens
@
w1
[
E_id
].
T
# ReLU²
h
=
torch
.
relu
(
h
)
**
2
# GEMM2
out
[
mask
]
=
h
@
w2
[
E_id
].
T
return
(
out
*
topk_weights
.
view
(
B
,
-
1
,
1
)).
sum
(
dim
=
1
).
to
(
dtype
)
# ---------------------------------------------------------------------------
# Weight preparation helpers (W16A16 non-gated)
# ---------------------------------------------------------------------------
def
prepare_w16a16_nogate_inputs
(
m
,
k
,
n
,
e
,
topk
,
dtype
):
"""Build all tensors needed to run a non-gated w16a16 MOE test.
Key difference from gated: w1 shape is [E, n, k] instead of [E, 2*n, k].
"""
torch
.
manual_seed
(
0
)
input_tensor
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
2
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
2
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
w1_shuffle
=
asm_shuffle_weight_b8
(
w1
,
stage
=
1
)
w2_shuffle
=
asm_shuffle_weight_b8
(
w2
,
stage
=
2
)
topk_weights
,
topk_ids
=
fused_topk
(
input_tensor
,
score
,
topk
,
True
)
return
{
"input"
:
input_tensor
,
"w1"
:
w1
,
"w2"
:
w2
,
"w1_shuffle"
:
w1_shuffle
,
"w2_shuffle"
:
w2_shuffle
,
"topk_weights"
:
topk_weights
,
"topk_ids"
:
topk_ids
,
"score"
:
score
,
}
# ---------------------------------------------------------------------------
# Test: get_aiter_moe_config (w16a16 non-gated relu2)
# ---------------------------------------------------------------------------
def
test_get_config
(
m
,
k
,
n
,
e
,
topk
,
dtype
):
"""Test that get_aiter_moe_config returns a valid w16a16 config with
activation='relu2' or gracefully reports no-solution."""
N1
=
n
# non-gated: N1 = intermediate_size (NOT 2 * intermediate_size)
N2
=
k
# down / hidden_size
K
=
k
# model dimension
status
,
moe_cfg
=
get_aiter_moe_config
(
M
=
m
,
E
=
e
,
N1
=
N1
,
N2
=
N2
,
K
=
K
,
top_k
=
topk
,
block_size
=
0
,
dtype
=
dtype
,
quant_type
=
MoeQuantType
.
W16A16
,
activation
=
"relu2"
,
gated
=
False
,
)
if
status
:
assert
moe_cfg
.
solution_type
is
not
None
,
\
"status=True but solution_type is None"
assert
moe_cfg
.
config
is
not
None
,
\
"status=True but config is None"
assert
moe_cfg
.
solution_type
in
(
MoeSolutionType
.
ASM
,
MoeSolutionType
.
TRITON
,
),
f
"Unexpected solution_type:
{
moe_cfg
.
solution_type
}
"
assert
moe_cfg
.
quant_type
==
MoeQuantType
.
W16A16
aiter
.
logger
.
info
(
f
"[get_config_w16a16_nogate]
{
m
=
}
,
{
k
=
}
,
{
n
=
}
,
{
e
=
}
,
{
topk
=
}
, "
f
"solution=
{
moe_cfg
.
solution_type
}
, "
f
"config keys=
{
list
(
moe_cfg
.
config
.
keys
())
}
"
)
else
:
assert
moe_cfg
.
solution_type
is
None
,
\
"status=False but solution_type is not None"
assert
moe_cfg
.
config
is
None
,
\
"status=False but config is not None"
aiter
.
logger
.
info
(
f
"[get_config_w16a16_nogate]
{
m
=
}
,
{
k
=
}
,
{
n
=
}
,
{
e
=
}
,
{
topk
=
}
, "
f
"no solution found (expected on unsupported configs)"
)
return
status
,
moe_cfg
# ---------------------------------------------------------------------------
# Test: aiter_moe end-to-end for w16a16 non-gated relu2
# ---------------------------------------------------------------------------
@
perftest
(
num_warmup
=
1
,
num_iters
=
2
)
def
_run_torch_ref
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
):
return
torch_moe_relu2
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
)
@
perftest
(
num_warmup
=
10
,
num_iters
=
100
,
num_rotate_args
=
1
)
def
_run_aiter_moe_perf
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
moe_config
,
inplace
,
activation
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
global_num_experts
,
expert_map
,
routed_scaling_factor
,
):
return
aiter_moe
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
moe_config
,
inplace
,
activation
,
w1_scale
,
w2_scale
,
w1_zp
,
w2_zp
,
a1_scale
,
a2_scale
,
block_shape
,
global_num_experts
,
expert_map
,
routed_scaling_factor
)
def
test_aiter_moe_w16a16_nogate
(
m
,
k
,
n
,
e
,
topk
,
dtype
,
inplace
,
routed_scaling_factor
):
"""End-to-end: get config -> run aiter_moe with relu2 -> compare with
torch reference."""
N1
=
n
# non-gated: N1 = intermediate_size
N2
=
k
K
=
k
status
,
moe_cfg
=
get_aiter_moe_config
(
M
=
m
,
E
=
e
,
N1
=
N1
,
N2
=
N2
,
K
=
K
,
top_k
=
topk
,
block_size
=
0
,
dtype
=
dtype
,
quant_type
=
MoeQuantType
.
W16A16
,
activation
=
"relu2"
,
gated
=
False
,
)
if
not
status
:
aiter
.
logger
.
info
(
f
"[aiter_moe_w16a16_nogate] SKIP
{
m
=
}
,
{
N1
=
}
,
{
N2
=
}
,
{
K
=
}
,
{
e
=
}
,
{
topk
=
}
: "
f
"no backend available"
)
return
None
backend
=
moe_cfg
.
solution_type
aiter
.
logger
.
info
(
f
"[aiter_moe_w16a16_nogate]
{
m
=
}
,
{
N1
=
}
,
{
N2
=
}
,
{
K
=
}
,
{
e
=
}
,
{
topk
=
}
, "
f
"backend=
{
backend
}
"
)
data
=
prepare_w16a16_nogate_inputs
(
m
,
k
,
n
,
e
,
topk
,
dtype
)
# Torch reference
ref_out
,
_
=
_run_torch_ref
(
data
[
"input"
],
data
[
"w1"
],
data
[
"w2"
],
data
[
"topk_weights"
],
data
[
"topk_ids"
],
)
# aiter_moe dispatch with relu2 activation
aiter_us
=
1.0
aiter_out
,
aiter_us
=
_run_aiter_moe_perf
(
hidden_states
=
data
[
"input"
],
w1
=
data
[
"w1"
],
w2
=
data
[
"w2"
],
topk_weights
=
data
[
"topk_weights"
],
topk_ids
=
data
[
"topk_ids"
],
moe_config
=
moe_cfg
,
inplace
=
inplace
,
activation
=
"relu2"
,
w1_scale
=
None
,
w2_scale
=
None
,
w1_zp
=
None
,
w2_zp
=
None
,
a1_scale
=
None
,
a2_scale
=
None
,
block_shape
=
None
,
global_num_experts
=
e
,
expert_map
=
None
,
routed_scaling_factor
=
routed_scaling_factor
,
)
msg
=
(
f
"[aiter_moe_w16a16_nogate]
{
m
=
}
,
{
N1
=
}
,
{
N2
=
}
,
{
K
=
}
,
{
e
=
}
,
{
topk
=
}
, "
f
"backend=
{
backend
}
"
)
checkAllclose
(
ref_out
,
aiter_out
,
rtol
=
0.01
,
atol
=
0.5
,
msg
=
msg
)
return
{
"m"
:
m
,
"backend"
:
backend
,
"us"
:
aiter_us
}
# ---------------------------------------------------------------------------
# Test: aiter_moe w16a16 non-gated ASM shuffle vs non-shuffle
# ---------------------------------------------------------------------------
@
perftest
(
num_warmup
=
10
,
num_iters
=
100
,
num_rotate_args
=
1
)
def
_run_asm_perf
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
dtype
,
global_num_experts
,
expert_map
):
return
fused_experts_asm_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
dtype
,
activation
=
"relu2"
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
)
@
perftest
(
num_warmup
=
10
,
num_iters
=
100
,
num_rotate_args
=
1
)
def
_run_asm_shuffle_perf
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
dtype
,
global_num_experts
,
expert_map
):
return
fused_experts_asm_impl
(
hidden_states
,
w1
,
w2
,
topk_weights
,
topk_ids
,
dtype
,
activation
=
"relu2"
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
use_shuffle
=
1
)
def
test_aiter_moe_w16a16_nogate_shuffle
(
m
,
k
,
n
,
e
,
topk
,
dtype
):
"""Test w16a16 non-gated ASM with shuffled weights vs non-shuffled ASM."""
data
=
prepare_w16a16_nogate_inputs
(
m
,
k
,
n
,
e
,
topk
,
dtype
)
try
:
asm_out
,
asm_us
=
_run_asm_perf
(
data
[
"input"
],
data
[
"w1"
],
data
[
"w2"
],
data
[
"topk_weights"
],
data
[
"topk_ids"
],
dtype
,
e
,
None
)
except
Exception
as
exc
:
aiter
.
logger
.
info
(
f
"[w16a16_nogate_shuffle] SKIP
{
m
=
}
: ASM not available (
{
exc
}
)"
)
return
None
shuffle_out
,
shuffle_us
=
_run_asm_shuffle_perf
(
data
[
"input"
],
data
[
"w1_shuffle"
],
data
[
"w2_shuffle"
],
data
[
"topk_weights"
],
data
[
"topk_ids"
],
dtype
,
e
,
None
)
msg
=
(
f
"[w16a16_nogate_shuffle]
{
m
=
}
,
{
k
=
}
,
{
n
=
}
,
{
e
=
}
,
{
topk
=
}
, "
f
"asm_us=
{
asm_us
:.
2
f
}
, shuffle_us=
{
shuffle_us
:.
2
f
}
"
)
checkAllclose
(
asm_out
,
shuffle_out
,
rtol
=
0.01
,
atol
=
0.01
,
msg
=
msg
)
uplift
=
asm_us
/
shuffle_us
-
1
if
shuffle_us
>
0
else
0
return
{
"m"
:
m
,
"asm_us"
:
asm_us
,
"shuffle_us"
:
shuffle_us
,
"shuffle_uplift"
:
f
"
{
uplift
:.
1
%
}
"
,
}
if
__name__
==
"__main__"
:
dtype
=
dtypes
.
bf16
# Nemotron-style MoE parameters (non-gated, ReLU²)
e
=
256
topk
=
8
k
=
3072
# model_dim / hidden_size
n
=
128
# intermediate_size (NOT multiplied by 2)
inplace
=
False
routed_scaling_factor
=
1.0
# --- Part 1: test get_aiter_moe_config (w16a16 non-gated relu2) ---
aiter
.
logger
.
info
(
"="
*
60
)
aiter
.
logger
.
info
(
"Part 1: Testing get_aiter_moe_config for w16a16 non-gated relu2"
)
aiter
.
logger
.
info
(
"="
*
60
)
test_tokens
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]
for
m
in
test_tokens
:
test_get_config
(
m
,
k
,
n
,
e
,
topk
,
dtype
)
# --- Part 2: test aiter_moe end-to-end (w16a16 non-gated relu2) ---
aiter
.
logger
.
info
(
"="
*
60
)
aiter
.
logger
.
info
(
"Part 2: Testing aiter_moe end-to-end for w16a16 non-gated relu2"
)
aiter
.
logger
.
info
(
"="
*
60
)
df
=
[]
for
m
in
test_tokens
:
ret
=
test_aiter_moe_w16a16_nogate
(
m
,
k
,
n
,
e
,
topk
,
dtype
,
inplace
,
routed_scaling_factor
)
if
ret
is
not
None
:
df
.
append
(
ret
)
if
df
:
df
=
pd
.
DataFrame
(
df
)
aiter
.
logger
.
info
(
f
"aiter_moe non-gated relu2 summary:
\n
{
df
}
"
)
# --- Part 3: test ASM shuffle vs non-shuffle (w16a16 non-gated relu2) ---
aiter
.
logger
.
info
(
"="
*
60
)
aiter
.
logger
.
info
(
"Part 3: Testing ASM shuffle vs non-shuffle for w16a16 non-gated relu2"
)
aiter
.
logger
.
info
(
"="
*
60
)
df_shuffle
=
[]
for
m
in
test_tokens
:
ret
=
test_aiter_moe_w16a16_nogate_shuffle
(
m
,
k
,
n
,
e
,
topk
,
dtype
)
if
ret
is
not
None
:
df_shuffle
.
append
(
ret
)
if
df_shuffle
:
df_shuffle
=
pd
.
DataFrame
(
df_shuffle
)
aiter
.
logger
.
info
(
f
"shuffle summary (non-gated relu2):
\n
{
df_shuffle
}
"
)
op_tests/test_aiter_moe_with_config_w8a8_channelwise.py
View file @
f233de81
# Test for get_aiter_moe_config and aiter_moe with w8a8 channel-wise quantization
import
argparse
import
torch
import
pandas
as
pd
...
...
@@ -13,6 +14,7 @@ from aiter.moe import (
MoeQuantType
,
)
from
aiter.ops.shuffle
import
moe_layout_shuffle_gemm1
,
moe_layout_shuffle_gemm2
from
aiter.ops.quant
import
pertoken_quant
import
aiter
...
...
@@ -74,9 +76,11 @@ def _run_aiter_moe_perf(
)
def
prepare_w8a8_channelwise_inputs
(
m
,
k
,
n
,
e
,
topk
,
dtype
):
def
prepare_w8a8_channelwise_inputs
(
m
,
k
,
n
,
e
,
topk
,
dtype
,
quant_type
=
MoeQuantType
.
W8A8
):
"""Prepare channel-wise quantized w8a8 inputs.
For int8 (W8A8): weights quantized to torch.int8, scales = max_val / 127.
For fp8 (FP8_W8A8): weights quantized to float8 via pertoken_quant.
Scale shape: (e, out_dim, 1) — one scale per output channel.
block_shape is None for channel-wise.
"""
...
...
@@ -86,7 +90,12 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype):
w1_fp
=
torch
.
randn
((
e
,
2
*
n
,
k
),
dtype
=
dtype
,
device
=
"cuda"
)
w2_fp
=
torch
.
randn
((
e
,
k
,
n
),
dtype
=
dtype
,
device
=
"cuda"
)
# Channel-wise quantization: max per output channel (last dim of weight row)
if
quant_type
==
MoeQuantType
.
FP8_W8A8
:
# FP8 channel-wise quantization via pertoken_quant
w1_qweight
,
w1_scales
=
pertoken_quant
(
w1_fp
,
quant_dtype
=
dtypes
.
fp8
)
w2_qweight
,
w2_scales
=
pertoken_quant
(
w2_fp
,
quant_dtype
=
dtypes
.
fp8
)
else
:
# INT8 channel-wise quantization: max per output channel
max_vals_w1
=
torch
.
abs
(
w1_fp
.
to
(
torch
.
float32
)).
max
(
dim
=-
1
,
keepdim
=
True
)[
0
]
max_vals_w1
=
max_vals_w1
.
clamp
(
min
=
1e-5
)
w1_scales
=
max_vals_w1
/
127.0
# (e, 2*n, 1)
...
...
@@ -119,7 +128,7 @@ def prepare_w8a8_channelwise_inputs(m, k, n, e, topk, dtype):
}
def
test_get_config
(
m
,
k
,
n
,
e
,
topk
,
dtype
):
def
test_get_config
(
m
,
k
,
n
,
e
,
topk
,
dtype
,
quant_type
=
MoeQuantType
.
W8A8
):
"""Test get_aiter_moe_config for channel-wise w8a8 (block_size=0)."""
status
,
moe_cfg
=
get_aiter_moe_config
(
M
=
m
,
...
...
@@ -130,11 +139,12 @@ def test_get_config(m, k, n, e, topk, dtype):
top_k
=
topk
,
block_size
=
0
,
dtype
=
dtype
,
quant_type
=
MoeQ
uant
T
ype
.
W8A8
,
quant_type
=
q
uant
_t
ype
,
)
tag
=
f
"get_config_
{
quant_type
}
_cw"
if
status
:
assert
moe_cfg
.
quant_type
==
MoeQ
uant
T
ype
.
W8A8
assert
moe_cfg
.
quant_type
==
q
uant
_t
ype
assert
moe_cfg
.
solution_type
in
(
MoeSolutionType
.
ASM
,
MoeSolutionType
.
MOE_C
,
...
...
@@ -143,19 +153,19 @@ def test_get_config(m, k, n, e, topk, dtype):
)
assert
moe_cfg
.
config
is
not
None
aiter
.
logger
.
info
(
f
"[
get_config_w8a8_cw
]
{
m
=
}
, solution=
{
moe_cfg
.
solution_type
}
, "
f
"[
{
tag
}
]
{
m
=
}
, solution=
{
moe_cfg
.
solution_type
}
, "
f
"config keys=
{
list
(
moe_cfg
.
config
.
keys
())
}
"
)
else
:
assert
moe_cfg
.
solution_type
is
None
assert
moe_cfg
.
config
is
None
aiter
.
logger
.
info
(
f
"[
get_config_w8a8_cw
]
{
m
=
}
, no solution found"
)
aiter
.
logger
.
info
(
f
"[
{
tag
}
]
{
m
=
}
, no solution found"
)
return
status
,
moe_cfg
def
test_aiter_moe_w8a8_channelwise
(
m
,
k
,
n
,
e
,
topk
,
dtype
):
"""End-to-end test of aiter_moe with channel-wise w8a8."""
def
test_aiter_moe_w8a8_channelwise
(
m
,
k
,
n
,
e
,
topk
,
dtype
,
quant_type
=
MoeQuantType
.
W8A8
):
"""End-to-end test of aiter_moe with channel-wise w8a8
(int8 or fp8)
."""
status
,
moe_cfg
=
get_aiter_moe_config
(
M
=
m
,
E
=
e
,
...
...
@@ -165,14 +175,15 @@ def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype):
top_k
=
topk
,
block_size
=
0
,
dtype
=
dtype
,
quant_type
=
MoeQ
uant
T
ype
.
W8A8
,
quant_type
=
q
uant
_t
ype
,
)
tag
=
f
"aiter_moe_
{
quant_type
}
_cw"
if
not
status
:
aiter
.
logger
.
info
(
f
"[
aiter_moe_w8a8_cw
] SKIP
{
m
=
}
: no backend available"
)
aiter
.
logger
.
info
(
f
"[
{
tag
}
] SKIP
{
m
=
}
: no backend available"
)
return
None
data
=
prepare_w8a8_channelwise_inputs
(
m
,
k
,
n
,
e
,
topk
,
dtype
)
data
=
prepare_w8a8_channelwise_inputs
(
m
,
k
,
n
,
e
,
topk
,
dtype
,
quant_type
)
# Torch reference uses original fp weights directly (no scales needed)
ref_out
,
_
=
_run_torch_ref
(
...
...
@@ -216,31 +227,46 @@ def test_aiter_moe_w8a8_channelwise(m, k, n, e, topk, dtype):
print
(
"ref_out"
,
ref_out
)
msg
=
f
"[
aiter_moe_w8a8_cw
]
{
m
=
}
, backend=
{
moe_cfg
.
solution_type
}
"
msg
=
f
"[
{
tag
}
]
{
m
=
}
, backend=
{
moe_cfg
.
solution_type
}
"
checkAllclose
(
ref_out
,
aiter_out
,
rtol
=
0.01
,
atol
=
100
,
msg
=
msg
)
return
{
"m"
:
m
,
"backend"
:
moe_cfg
.
solution_type
,
"us"
:
aiter_us
}
return
{
"m"
:
m
,
"quant_type"
:
quant_type
,
"backend"
:
moe_cfg
.
solution_type
,
"us"
:
aiter_us
}
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Test aiter_moe with channel-wise w8a8 quantization"
,
)
parser
.
add_argument
(
"--quant"
,
choices
=
[
"int8"
,
"fp8"
],
default
=
"int8"
,
help
=
"Quantization type: int8 (MoeQuantType.W8A8) or fp8 (MoeQuantType.FP8_W8A8)"
,
)
args
=
parser
.
parse_args
()
quant_type
=
MoeQuantType
.
FP8_W8A8
if
args
.
quant
==
"fp8"
else
MoeQuantType
.
W8A8
# for moe_c backend, it does not support n=320 for now;
# for triton backend, it can run with n=320 in NMZ;
dtype
=
dtypes
.
bf16
e
=
256
topk
=
8
k
=
6144
n
=
256
n
=
320
aiter
.
logger
.
info
(
"="
*
60
)
aiter
.
logger
.
info
(
"Part 1: Testing get_aiter_moe_config for
w8a8
channel-wise"
)
aiter
.
logger
.
info
(
f
"Part 1: Testing get_aiter_moe_config for
{
quant_type
}
channel-wise"
)
aiter
.
logger
.
info
(
"="
*
60
)
test_tokens
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
6144
,
8192
,
16384
]
for
m
in
test_tokens
:
test_get_config
(
m
,
k
,
n
,
e
,
topk
,
dtype
)
test_get_config
(
m
,
k
,
n
,
e
,
topk
,
dtype
,
quant_type
)
aiter
.
logger
.
info
(
"="
*
60
)
aiter
.
logger
.
info
(
"Part 2: Testing aiter_moe end-to-end for
w8a8
channel-wise"
)
aiter
.
logger
.
info
(
f
"Part 2: Testing aiter_moe end-to-end for
{
quant_type
}
channel-wise"
)
aiter
.
logger
.
info
(
"="
*
60
)
df
=
[]
for
m
in
test_tokens
:
ret
=
test_aiter_moe_w8a8_channelwise
(
m
,
k
,
n
,
e
,
topk
,
dtype
)
ret
=
test_aiter_moe_w8a8_channelwise
(
m
,
k
,
n
,
e
,
topk
,
dtype
,
quant_type
)
if
ret
is
not
None
:
df
.
append
(
ret
)
if
df
:
...
...
op_tests/triton_autotune/fused_moe/autotune_patches.py
View file @
f233de81
...
...
@@ -514,6 +514,7 @@ def fused_moe(
assert
B_zp
is
None
or
B_zp
.
ndim
==
3
offset_max
=
2
**
31
-
1
use_addr_offset_int64_a
=
A
.
numel
()
*
A
.
element_size
()
>=
offset_max
use_addr_offset_int64_b
=
B
.
numel
()
*
B
.
element_size
()
>=
offset_max
use_addr_offset_int64_c
=
C
.
numel
()
*
C
.
element_size
()
>=
offset_max
if
use_int4_w4a8
:
...
...
@@ -554,6 +555,7 @@ def fused_moe(
group_size
=
block_shape
[
1
],
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
USE_ADDR_OFFSET_INT64_A
=
use_addr_offset_int64_a
,
USE_ADDR_OFFSET_INT64_B
=
use_addr_offset_int64_b
,
USE_ADDR_OFFSET_INT64_C
=
use_addr_offset_int64_c
,
top_k
=
top_k
,
compute_type
=
compute_type
,
...
...
@@ -599,6 +601,7 @@ def fused_moe(
group_size
=
block_shape
[
1
],
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
USE_ADDR_OFFSET_INT64_A
=
use_addr_offset_int64_a
,
USE_ADDR_OFFSET_INT64_B
=
use_addr_offset_int64_b
,
USE_ADDR_OFFSET_INT64_C
=
use_addr_offset_int64_c
,
top_k
=
top_k
,
compute_type
=
compute_type
,
...
...
op_tests/triton_autotune/tune_extend_attention.py
View file @
f233de81
import
os
import
sys
os
.
environ
[
"AMDGCN_USE_BUFFER_OPS"
]
=
"1"
# :class:`Hcutuner` reads ``TRITON_HCUTUNE_PERF_MODE`` in ``__init__``. Module-level
# ``fn`` / ``fn_v2 = triton.utils.hcutune(...)`` runs at import time, *before*
# ``if __name__ == "__main__"``, so ``--perf`` must be applied here (or set in the shell).
if
"--perf"
in
sys
.
argv
:
os
.
environ
[
"TRITON_HCUTUNE_PERF_MODE"
]
=
"1"
# GPU / ROCm tuning: run **inside** the ``zww_tl_1`` container (not the bare host), e.g.:
# docker exec zww_tl_1 bash -lc 'cd /data/zhouweiwang/aiter/op_tests/triton_autotune && python tune_extend_attention.py --perf'
# ``do_bench`` timing during tuning (smaller => faster iteration; raise for stable prod numbers).
TUNE_DO_BENCH_WARMUP
=
5
TUNE_DO_BENCH_REP
=
20
import
json
import
torch
...
...
@@ -8,10 +22,27 @@ import random
import
itertools
import
argparse
from
aiter.ops.triton.extend_attention
import
_fwd_kernel
from
aiter.ops.triton.extend_attention
import
_fwd_kernel
,
_fwd_kernel_v2
_is_hip
=
True
# hcutune key for :func:`_fwd_kernel_v2`. JSON block-size lookup uses :func:`_get_config_v2`
# ``want7`` only; these names add kernel constexprs (``SKIP_PREFIX_CUSTOM_MASK``,
# ``xai_temperature_len``) used for autotune but not in the V2 JSON key.
# Log alignment (e.g. ``fp8_dp2_tp8_415_triton_rocm_nomtp.log`` ~2291): ``kv_group_num = q.size(-2)//k.size(-2)``,
# ``Lq``/``Lv`` last dims, ``USE_CUSTOM_MASK = custom_mask is not None``, ``HAS_SINK = sinks is not None``.
HCUTUNE_KEY_V2
=
[
"kv_group_num"
,
"Lq"
,
"Lv"
,
"USE_CUSTOM_MASK"
,
"IS_CAUSAL"
,
"SKIP_PREFIX_CUSTOM_MASK"
,
"HAS_SINK"
,
"SLIDING_WINDOW_SIZE"
,
"xai_temperature_len"
,
]
version
=
triton
.
__version__
.
split
(
"."
)
major_version
,
minor_version
=
eval
(
version
[
0
]),
eval
(
version
[
1
])
...
...
@@ -29,6 +60,7 @@ def input_helper(
attn_impl
=
"normal"
,
equal_seqlens
=
False
,
requires_grad
=
False
,
kv_num_heads
:
int
=
1
,
):
torch
.
manual_seed
(
0
)
...
...
@@ -82,9 +114,9 @@ def input_helper(
total_extend
,
H
,
Lq
,
dtype
=
dtype
,
device
=
device
).
requires_grad_
(
requires_grad
)
# extend parts
# extend parts
(``kv_num_heads`` for GQA: e.g. 2 when q has 16 heads and kv_group_num is 8)
k_extend
=
torch
.
randn
(
total_extend
,
1
,
Lk
,
dtype
=
dtype
,
device
=
device
total_extend
,
kv_num_heads
,
Lk
,
dtype
=
dtype
,
device
=
device
).
requires_grad_
(
requires_grad
)
v_extend
=
k_extend
[...,
:
Lv
]
...
...
@@ -93,7 +125,7 @@ def input_helper(
# prefix parts
k_buffer
=
torch
.
randn
(
total_prefix
,
1
,
Lk
,
dtype
=
dtype
,
device
=
device
total_prefix
,
kv_num_heads
,
Lk
,
dtype
=
dtype
,
device
=
device
).
requires_grad_
(
requires_grad
)
v_buffer
=
k_buffer
[...,
:
Lv
]
...
...
@@ -169,18 +201,42 @@ def generate_configs(config):
return
configs_list
# def get_triton_configs():
# config = {
# "BLOCK_M": [16, 32, 64],
# "BLOCK_N": [16, 32, 64],
# "waves_per_eu": [1],
# "num_warps": [4, 8, 16],
# # "instruction_sched_variant": ["none", "llvm-iglp-1", "llvm-iglp-8", "local-prefetch"],
# # "schedule_hint": ["none", "llvm-iglp-8", "local-prefetch"],
# "matrix_instr_nonkdim": [16],
# "num_stages": [1, 2, 3],
# "sched_latency": ["none", "mmac5-ds10"],
# "kpack": [1, 2],
# }
# tt_configs = []
# for c in generate_configs(config):
# num_warps = c['num_warps']
# num_stages = c['num_stages']
# del c['num_warps']
# del c['num_stages']
# tt_configs.append(triton.Config(c, num_warps=num_warps, num_stages=num_stages))
# return tt_configs
def
get_triton_configs
():
config
=
{
"BLOCK_M"
:
[
16
,
32
,
64
],
"BLOCK_N"
:
[
16
,
32
,
64
],
"waves_per_eu"
:
[
1
],
"num_warps"
:
[
4
,
8
,
16
],
"num_warps"
:
[
4
,
8
],
# "instruction_sched_variant": ["none", "llvm-iglp-1", "llvm-iglp-8", "local-prefetch"],
# "schedule_hint": ["none", "llvm-iglp-8", "local-prefetch"],
"matrix_instr_nonkdim"
:
[
16
],
"num_stages"
:
[
1
,
2
,
3
],
"num_stages"
:
[
1
,
2
],
"sched_latency"
:
[
"none"
,
"mmac5-ds10"
],
"kpack"
:
[
1
,
2
],
"kpack"
:
[
1
],
}
tt_configs
=
[]
...
...
@@ -193,7 +249,6 @@ def get_triton_configs():
return
tt_configs
def
prune_configs
(
configs
,
nargs
,
**
kwargs
):
def
_prune
(
config
):
c
=
config
.
all_kwargs
()
...
...
@@ -216,8 +271,23 @@ key = [
'SKIP_PREFIX_CUSTOM_MASK'
,
'STORE_TRANSPOSE'
,
]
fn
=
triton
.
utils
.
hcutune
(
configs
=
get_triton_configs
(),
key
=
key
,
perf_debug
=
True
,
prune_configs_by
=
{
"early_config_prune"
:
prune_configs
})(
_fwd_kernel
)
fn
=
triton
.
utils
.
hcutune
(
configs
=
get_triton_configs
(),
key
=
key
,
perf_debug
=
True
,
prune_configs_by
=
{
"early_config_prune"
:
prune_configs
},
warmup
=
TUNE_DO_BENCH_WARMUP
,
rep
=
TUNE_DO_BENCH_REP
,
)(
_fwd_kernel
)
fn_v2
=
triton
.
utils
.
hcutune
(
configs
=
get_triton_configs
(),
key
=
HCUTUNE_KEY_V2
,
perf_debug
=
True
,
prune_configs_by
=
{
"early_config_prune"
:
prune_configs
},
warmup
=
TUNE_DO_BENCH_WARMUP
,
rep
=
TUNE_DO_BENCH_REP
,
)(
_fwd_kernel_v2
)
def
extend_attention_fwd
(
...
...
@@ -329,6 +399,139 @@ def extend_attention_fwd(
)
def
extend_attention_fwd_v2
(
q_extend
,
k_extend
,
v_extend
,
o_extend
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask
,
is_causal
,
mask_indptr
,
max_len_extend
,
sm_scale
,
k_scale
,
v_scale
,
sliding_window_size
,
sinks
,
window_kv_offsets
,
xai_temperature_len
,
skip_prefix_custom_mask
,
logit_cap
,
):
"""Launch :data:`fn_v2` (hcutune-wrapped :func:`_fwd_kernel_v2`); kwargs align with ``extend_attention_fwd`` v2 path."""
Lq
,
Lv
=
q_extend
.
shape
[
-
1
],
v_extend
.
shape
[
-
1
]
if
Lq
==
576
:
BLOCK_DMODEL
,
BLOCK_DPE
=
512
,
64
elif
Lq
==
288
:
BLOCK_DMODEL
,
BLOCK_DPE
=
256
,
32
elif
Lq
==
192
:
BLOCK_DMODEL
,
BLOCK_DPE
=
128
,
64
else
:
BLOCK_DMODEL
=
triton
.
next_power_of_2
(
Lq
)
BLOCK_DPE
=
0
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
sm_scale
=
sm_scale
or
1.0
/
(
Lq
**
0.5
)
batch_size
,
head_num
=
qo_indptr
.
shape
[
0
]
-
1
,
q_extend
.
shape
[
1
]
kv_group_num
=
q_extend
.
shape
[
1
]
//
k_extend
.
shape
[
1
]
USE_CUSTOM_MASK
=
custom_mask
is
not
None
SKIP_PREFIX_CUSTOM_MASK
=
skip_prefix_custom_mask
HAS_SINK
=
sinks
is
not
None
if
not
USE_CUSTOM_MASK
:
custom_mask
=
torch
.
tensor
([
0
],
dtype
=
torch
.
bool
,
device
=
q_extend
.
device
)
mask_indptr
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
,
device
=
q_extend
.
device
)
grid
=
lambda
META
:
(
batch_size
,
head_num
,
triton
.
cdiv
(
max_len_extend
,
META
[
"BLOCK_M"
]),
)
stride_args
=
(
q_extend
.
stride
(
0
),
q_extend
.
stride
(
1
),
k_extend
.
stride
(
0
),
k_extend
.
stride
(
1
),
v_extend
.
stride
(
0
),
v_extend
.
stride
(
1
),
o_extend
.
stride
(
0
),
o_extend
.
stride
(
1
),
k_buffer
.
stride
(
0
),
k_buffer
.
stride
(
1
),
v_buffer
.
stride
(
0
),
v_buffer
.
stride
(
1
),
)
fn_v2
[
grid
](
q_extend
,
k_extend
,
v_extend
,
o_extend
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask
,
mask_indptr
,
sinks
,
window_kv_offsets
,
sm_scale
,
k_scale
,
v_scale
,
kv_group_num
,
*
stride_args
,
SLIDING_WINDOW_SIZE
=
sliding_window_size
,
logit_cap
=
logit_cap
,
xai_temperature_len
=
xai_temperature_len
,
HAS_SINK
=
HAS_SINK
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DV
=
BLOCK_DV
,
Lq
=
Lq
,
Lv
=
Lv
,
USE_CUSTOM_MASK
=
USE_CUSTOM_MASK
,
IS_CAUSAL
=
is_causal
,
SKIP_PREFIX_CUSTOM_MASK
=
SKIP_PREFIX_CUSTOM_MASK
,
STORE_TRANSPOSE
=
True
,
)
def
get_bench_inputs_v2
():
"""Cases for :func:`_fwd_kernel_v2` / ``_get_config_v2`` keys.
对 ``fp8_dp2_tp8_415_triton_rocm_nomtp.log`` 逐条核对后,与 extend 相关的**键控组合**只有两类:
- **MHA**(``k_extend[...,1,...]``):``sliding_window_size=-1``,``sinks=None``(如 log 中 ``k_extend [1152,1,192]`` / ``[3952,1,192]`` 段)。
- **GQA**(``k_extend[...,2,...]``):``sliding_window_size=128``,``sinks shape [16]``(全文件未出现 GQA 与 ``-1``/无 sinks 的同框记录)。
该 log 中 **未出现** ``sliding_window_size=64``,也未出现「GQA + 无 SWA + 无 sinks」;若需覆盖其它模型再单独加行并注明来源。
"""
names
=
[
"B"
,
"H"
,
"prefix"
,
"extend"
,
"kv_lora_rank"
,
"qk_rope_head_dim"
,
"v_head_dim"
,
"causal"
,
"custom_mask"
,
"sliding_window_size"
,
"has_sink"
,
"kv_num_heads"
,
]
vals
=
[
# (prefix, extend) 只影响访存/grid;want7 由 head/dim/SWA/sinks 决定。prefix=8192、extend=1024 与长 KV bench 习惯一致。
# (1) GQA Q16/KV2:与 log 中 ``[3952,2,192]`` + ``sliding_window_size=128`` + ``sinks [16]`` 一致;want7 (8,192,128,F,T,T,128)
(
4
,
16
,
8192
,
1024
,
128
,
64
,
128
,
True
,
False
,
128
,
True
,
2
),
# (2) MHA Q16/KV1:与 log 中 ``[...,1,192]`` + ``sliding_window_size=-1`` + ``sinks=None`` 一致;want7 (16,192,128,F,T,F,-1)
(
4
,
16
,
8192
,
1024
,
128
,
64
,
128
,
True
,
False
,
-
1
,
False
,
1
),
]
return
names
,
vals
x_names
,
x_vals
=
get_bench_inputs
()
configs
=
[
triton
.
testing
.
Benchmark
(
...
...
@@ -401,14 +604,136 @@ def bench_extend_attention(B, H, prefix, extend, kv_lora_rank, qk_rope_head_dim,
sm_scale
=
sm_scale
,
logit_cap
=
logit_cap
,
)
return
triton
.
testing
.
do_bench
(
fn
)
return
triton
.
testing
.
do_bench_cudagraph
(
fn
)
x_names_v2
,
x_vals_v2
=
get_bench_inputs_v2
()
configs_v2
=
[
triton
.
testing
.
Benchmark
(
x_names
=
x_names_v2
,
x_vals
=
x_vals_v2
,
line_arg
=
"provider"
,
line_vals
=
[
"triton_v2"
],
line_names
=
[
"triton_v2"
],
styles
=
[(
"blue"
,
"-"
)],
ylabel
=
"ms"
,
plot_name
=
"extend_attention_v2_hcutune"
,
args
=
{
"dtype"
:
torch
.
bfloat16
},
)
]
@
triton
.
utils
.
dist_perf_report
(
configs_v2
)
def
bench_extend_attention_v2
(
B
,
H
,
prefix
,
extend
,
kv_lora_rank
,
qk_rope_head_dim
,
v_head_dim
,
causal
,
custom_mask
,
sliding_window_size
,
has_sink
,
kv_num_heads
,
provider
,
dtype
,
):
torch
.
manual_seed
(
42
)
device
=
"cpu"
if
os
.
getenv
(
"TRITON_HCUTUNE_COMPILE_ONLY"
,
""
)
==
"1"
else
"cuda"
ref_attn_impl
=
"normal"
logit_cap
=
0.0
k_scale
=
1.0
v_scale
=
1.0
xai_temperature_len
=
-
1
skip_prefix_custom_mask
=
True
(
q_extend
,
k_extend
,
v_extend
,
k_buffer
,
v_buffer
,
kv_indptr
,
kv_indices
,
qo_indptr
,
custom_mask_t
,
mask_indptr
,
max_len_extend
,
)
=
input_helper
(
B
,
H
,
prefix
,
extend
,
kv_lora_rank
,
qk_rope_head_dim
,
v_head_dim
,
dtype
,
device
,
ref_attn_impl
,
equal_seqlens
=
True
,
kv_num_heads
=
kv_num_heads
,
)
if
custom_mask
:
raise
NotImplementedError
(
"tune v2 with custom_mask requires mask tensors; use custom_mask=False for hcutune key matching"
)
sm_scale
=
float
(
1.0
/
(
q_extend
.
shape
[
-
1
]
**
0.5
))
sinks
=
(
torch
.
randn
(
H
,
dtype
=
q_extend
.
dtype
,
device
=
device
)
if
has_sink
else
None
)
window_kv_offsets
=
None
tri_out
=
torch
.
empty
(
(
*
q_extend
.
shape
[:
-
1
],
v_extend
.
shape
[
-
1
]),
dtype
=
q_extend
.
dtype
,
device
=
q_extend
.
device
,
)
def
run_once
():
extend_attention_fwd_v2
(
q_extend
,
k_extend
,
v_extend
,
tri_out
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask_t
,
causal
,
mask_indptr
,
max_len_extend
,
sm_scale
,
k_scale
,
v_scale
,
sliding_window_size
,
sinks
,
window_kv_offsets
,
xai_temperature_len
,
skip_prefix_custom_mask
,
logit_cap
,
)
return
triton
.
testing
.
do_bench_cudagraph
(
run_once
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--perf"
,
action
=
'store_true'
,
default
=
False
,
help
=
'benchmark with hcutuner perf mode'
)
parser
.
add_argument
(
"--perf"
,
action
=
"store_true"
,
default
=
False
,
help
=
"benchmark with hcutuner perf mode"
)
parser
.
add_argument
(
"--v1"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Tune v1 ``_fwd_kernel`` only. Default: v2 ``_fwd_kernel_v2`` (JSON: ``_get_config_v2`` / EXTEND_ATTENTION-V2-FP16)."
,
)
return
parser
.
parse_args
()
...
...
@@ -417,6 +742,11 @@ if __name__ == "__main__":
args
=
parse_args
()
if
args
.
perf
:
os
.
environ
[
"TRITON_HCUTUNE_PERF_MODE"
]
=
"1"
os
.
environ
[
"TRITON_HCUTUNE_PERF_MODE"
]
=
"1"
# idempotent; real set is at top for Hcutuner init
bench_extend_attention
.
run
(
print_data
=
True
,
save_path
=
'./tune_extend_attention_out'
)
if
args
.
v1
:
bench_extend_attention
.
run
(
print_data
=
True
,
save_path
=
"./tune_extend_attention_out"
)
else
:
bench_extend_attention_v2
.
run
(
print_data
=
True
,
save_path
=
"./tune_extend_attention_v2_out"
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment