Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
aiter
Commits
bb596f6e
Commit
bb596f6e
authored
Jun 04, 2026
by
xiaowei.zhang
Browse files
1. Update MOE; 2. Update sglang mHC; 3. Update test scripts; 4 Add new
ops.
parent
d9ebb683
Changes
232
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5662 additions
and
333 deletions
+5662
-333
aiter/ops/triton/configs/moe/E=256,N=128,device_name=K100_AI,dtype=int8_w8a8,is_bottom=True,block_shape=[128,128].json
...dtype=int8_w8a8,is_bottom=True,block_shape=[128,128].json
+97
-36
aiter/ops/triton/configs/moe/E=256,N=256,device_name=K100_AI,dtype=int4_w4a16,is_bottom=True.json
...,device_name=K100_AI,dtype=int4_w4a16,is_bottom=True.json
+129
-58
aiter/ops/triton/configs/moe/E=256,N=256,device_name=K100_AI,dtype=int4_w4a16.json
...moe/E=256,N=256,device_name=K100_AI,dtype=int4_w4a16.json
+121
-50
aiter/ops/triton/configs/moe/E=256,N=256,device_name=K100_AI,dtype=int8_w8a8,block_shape=[128,128].json
...e_name=K100_AI,dtype=int8_w8a8,block_shape=[128,128].json
+105
-44
aiter/ops/triton/configs/moe/E=256,N=256,device_name=K100_AI,dtype=int8_w8a8,is_bottom=True,block_shape=[128,128].json
...dtype=int8_w8a8,is_bottom=True,block_shape=[128,128].json
+96
-35
aiter/ops/triton/configs/moe/E=288,N=160,device_name=BW200B,dtype=fp8_w8a8,is_bottom=True.json
...160,device_name=BW200B,dtype=fp8_w8a8,is_bottom=True.json
+223
-0
aiter/ops/triton/configs/moe/E=288,N=160,device_name=BW200B,dtype=fp8_w8a8.json
...gs/moe/E=288,N=160,device_name=BW200B,dtype=fp8_w8a8.json
+223
-0
aiter/ops/triton/configs/moe/E=288,N=320,device_name=BW200B,dtype=fp8_w8a8,is_bottom=True.json
...320,device_name=BW200B,dtype=fp8_w8a8,is_bottom=True.json
+223
-0
aiter/ops/triton/configs/moe/E=288,N=320,device_name=BW200B,dtype=fp8_w8a8.json
...gs/moe/E=288,N=320,device_name=BW200B,dtype=fp8_w8a8.json
+223
-0
aiter/ops/triton/configs/moe/E=384,N=1024,device_name=BW200,dtype=int4_w4a16,is_bottom=True.json
...24,device_name=BW200,dtype=int4_w4a16,is_bottom=True.json
+223
-0
aiter/ops/triton/configs/moe/E=384,N=1024,device_name=BW200,dtype=int4_w4a16.json
.../moe/E=384,N=1024,device_name=BW200,dtype=int4_w4a16.json
+223
-0
aiter/ops/triton/configs/moe/E=512,N=336,device_name=BW200B,is_bottom=True.json
...gs/moe/E=512,N=336,device_name=BW200B,is_bottom=True.json
+223
-0
aiter/ops/triton/configs/moe/E=512,N=336,device_name=BW200B.json
...ps/triton/configs/moe/E=512,N=336,device_name=BW200B.json
+223
-0
aiter/ops/triton/extend_attention.py
aiter/ops/triton/extend_attention.py
+1138
-110
aiter/ops/triton/fla/fused_recurrent.py
aiter/ops/triton/fla/fused_recurrent.py
+272
-0
aiter/ops/triton/fla/fused_sigmoid_gating.py
aiter/ops/triton/fla/fused_sigmoid_gating.py
+351
-0
aiter/ops/triton/fla/fused_sigmoid_gating_recurrent.py
aiter/ops/triton/fla/fused_sigmoid_gating_recurrent.py
+434
-0
aiter/ops/triton/fla/fused_sigmoid_gating_recurrent_ref.py
aiter/ops/triton/fla/fused_sigmoid_gating_recurrent_ref.py
+353
-0
aiter/ops/triton/fla/sglang/chunk_delta_h.py
aiter/ops/triton/fla/sglang/chunk_delta_h.py
+489
-0
aiter/ops/triton/fla/sglang/chunk_o.py
aiter/ops/triton/fla/sglang/chunk_o.py
+293
-0
No files found.
Too many changes to show.
To preserve performance only
232 of 232+
files are displayed.
Plain diff
Email patch
aiter/ops/triton/configs/moe/E=256,N=128,device_name=K100_AI,dtype=int8_w8a8,is_bottom=True,block_shape=[128,128].json
View file @
bb596f6e
{
{
"1"
:
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
2
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"2"
:
{
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
1
"num_stages"
:
1
},
},
"4"
:
{
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_N"
:
512
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
1
"num_stages"
:
1
},
},
...
@@ -35,8 +44,11 @@
...
@@ -35,8 +44,11 @@
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"num_warps"
:
8
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
"num_stages"
:
1
},
},
"16"
:
{
"16"
:
{
...
@@ -45,89 +57,116 @@
...
@@ -45,89 +57,116 @@
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
1
"num_stages"
:
1
},
},
"24"
:
{
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
1
"num_stages"
:
1
},
},
"32"
:
{
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"num_warps"
:
8
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
"num_stages"
:
1
},
},
"64"
:
{
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
8
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"128"
:
{
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
8
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"256"
:
{
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
8
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"512"
:
{
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"1024"
:
{
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"2048"
:
{
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
2
,
"sched_latency"
:
"none"
,
"num_stages"
:
1
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
},
"4096"
:
{
"4096"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
64
,
...
@@ -135,28 +174,50 @@
...
@@ -135,28 +174,50 @@
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
1
"num_stages"
:
1
},
},
"8192"
:
{
"8192"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
1
"num_stages"
:
1
},
},
"16384"
:
{
"16384"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
1
"num_stages"
:
2
},
"32768"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
}
}
}
}
\ No newline at end of file
aiter/ops/triton/configs/moe/E=256,N=256,device_name=K100_AI,dtype=int4_w4a16,is_bottom=True.json
View file @
bb596f6e
{
{
"1"
:
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"num_warps"
:
2
,
"instruction_sched_variant"
:
"none"
,
"num_stages"
:
2
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
},
"2"
:
{
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
4
,
"sched_latency"
:
"none"
,
"num_stages"
:
2
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
},
"4"
:
{
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
4
,
"sched_latency"
:
"none"
,
"num_stages"
:
2
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
},
"8"
:
{
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
1
},
},
"16"
:
{
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
1
},
},
"24"
:
{
"24"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
1
},
},
"32"
:
{
"32"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
1
},
},
"64"
:
{
"64"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
1
},
},
"128"
:
{
"128"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
1
},
},
"256"
:
{
"256"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
1
},
},
"512"
:
{
"512"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
1
},
},
"1024"
:
{
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
1
},
},
"2048"
:
{
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
1
6
,
"BLOCK_SIZE_K"
:
6
4
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
1
},
},
"4096"
:
{
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
1
6
,
"BLOCK_SIZE_K"
:
6
4
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
1
},
},
"8192"
:
{
"8192"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
16
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
4
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"16384"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"32768"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
1
}
}
}
}
\ No newline at end of file
aiter/ops/triton/configs/moe/E=256,N=256,device_name=K100_AI,dtype=int4_w4a16.json
View file @
bb596f6e
...
@@ -2,151 +2,222 @@
...
@@ -2,151 +2,222 @@
"1"
:
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
1
"num_stages"
:
2
},
},
"2"
:
{
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
6
4
,
"BLOCK_SIZE_K"
:
25
6
,
"GROUP_SIZE_M"
:
8
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_warps"
:
2
,
"num_stages"
:
1
"num_stages"
:
2
},
},
"4"
:
{
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
6
4
,
"BLOCK_SIZE_K"
:
25
6
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
1
},
},
"8"
:
{
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
6
4
,
"BLOCK_SIZE_K"
:
25
6
,
"GROUP_SIZE_M"
:
4
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
1
},
},
"16"
:
{
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
6
4
,
"BLOCK_SIZE_K"
:
25
6
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
1
"num_stages"
:
1
},
},
"24"
:
{
"24"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
6
4
,
"BLOCK_SIZE_K"
:
25
6
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
8
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
"num_stages"
:
1
},
},
"32"
:
{
"32"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
6
4
,
"BLOCK_SIZE_K"
:
25
6
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
8
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
"num_stages"
:
1
},
},
"64"
:
{
"64"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
6
4
,
"BLOCK_SIZE_K"
:
25
6
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"num_warps"
:
8
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
"num_stages"
:
1
},
},
"128"
:
{
"128"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
6
4
,
"BLOCK_SIZE_K"
:
25
6
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
8
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
"num_stages"
:
1
},
},
"256"
:
{
"256"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
6
4
,
"BLOCK_SIZE_K"
:
25
6
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
8
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
"num_stages"
:
1
},
},
"512"
:
{
"512"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
16
,
"num_stages"
:
2
},
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
8
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
"num_stages"
:
1
},
},
"
1024
"
:
{
"
2048
"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"num_warps"
:
8
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
"num_stages"
:
1
},
},
"
2048
"
:
{
"
4096
"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
16
,
"sched_latency"
:
"mmac5-ds10"
,
"num_stages"
:
2
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
},
"
4096
"
:
{
"
8192
"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
16
,
"sched_latency"
:
"mmac5-ds10"
,
"num_stages"
:
2
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
},
"
8192
"
:
{
"
16384
"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
16
,
"sched_latency"
:
"none"
,
"num_stages"
:
2
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"32768"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
}
}
}
}
\ No newline at end of file
aiter/ops/triton/configs/moe/E=256,N=256,device_name=K100_AI,dtype=int8_w8a8,block_shape=[128,128].json
View file @
bb596f6e
{
{
"1"
:
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
true
,
"COMBINE_SCALE_LOAD"
:
true
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_warps"
:
2
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"2"
:
{
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
true
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
1
"num_stages"
:
2
},
},
"4"
:
{
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
true
,
"COMBINE_SCALE_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"USE_MLS_LOAD"
:
false
,
"num_warps"
:
8
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"8"
:
{
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
true
,
"COMBINE_SCALE_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"USE_MLS_LOAD"
:
false
,
"num_warps"
:
8
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"16"
:
{
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
true
,
"COMBINE_SCALE_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"USE_MLS_LOAD"
:
false
,
"num_warps"
:
8
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"24"
:
{
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
true
,
"COMBINE_SCALE_LOAD"
:
true
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
8
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"32"
:
{
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
true
,
"COMBINE_SCALE_LOAD"
:
true
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
8
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"64"
:
{
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
true
,
"COMBINE_SCALE_LOAD"
:
true
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
8
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"128"
:
{
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
true
,
"COMBINE_SCALE_LOAD"
:
true
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
8
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"256"
:
{
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
true
,
"COMBINE_SCALE_LOAD"
:
true
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
8
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"512"
:
{
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
true
,
"COMBINE_SCALE_LOAD"
:
true
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
8
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"1024"
:
{
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
true
,
"COMBINE_SCALE_LOAD"
:
true
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_warps"
:
8
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"2048"
:
{
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
true
,
"COMBINE_SCALE_LOAD"
:
true
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_warps"
:
8
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"4096"
:
{
"4096"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
true
,
"COMBINE_SCALE_LOAD"
:
true
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_warps"
:
8
,
"num_stages"
:
2
"num_stages"
:
2
},
},
...
@@ -143,20 +185,39 @@
...
@@ -143,20 +185,39 @@
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
true
,
"COMBINE_SCALE_LOAD"
:
true
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
8
,
"sched_latency"
:
"mmac5-ds10"
,
"num_stages"
:
2
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
},
"16384"
:
{
"16384"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
true
,
"COMBINE_SCALE_LOAD"
:
true
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_warps"
:
8
,
"num_stages"
:
2
"num_stages"
:
1
},
"32768"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
}
}
}
}
\ No newline at end of file
aiter/ops/triton/configs/moe/E=256,N=256,device_name=K100_AI,dtype=int8_w8a8,is_bottom=True,block_shape=[128,128].json
View file @
bb596f6e
{
{
"1"
:
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
8
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"2"
:
{
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
4
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
"num_stages"
:
1
},
},
"4"
:
{
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
4
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
"num_stages"
:
1
},
},
"8"
:
{
"8"
:
{
...
@@ -35,78 +44,102 @@
...
@@ -35,78 +44,102 @@
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"num_warps"
:
8
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
"num_stages"
:
1
},
},
"16"
:
{
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
8
,
"sched_latency"
:
"none"
,
"num_stages"
:
2
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
},
"24"
:
{
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
8
,
"sched_latency"
:
"none"
,
"num_stages"
:
2
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
},
"32"
:
{
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"64"
:
{
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"num_warps"
:
2
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"128"
:
{
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"256"
:
{
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"instruction_sched_variant"
:
"local-prefetch"
,
"num_warps"
:
8
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"512"
:
{
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"USE_MLS_LOAD"
:
false
,
"num_warps"
:
4
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"1024"
:
{
"1024"
:
{
...
@@ -115,18 +148,24 @@
...
@@ -115,18 +148,24 @@
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_warps"
:
8
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"2048"
:
{
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"num_warps"
:
2
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
"num_stages"
:
2
},
},
"4096"
:
{
"4096"
:
{
...
@@ -135,7 +174,10 @@
...
@@ -135,7 +174,10 @@
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_warps"
:
2
,
"num_stages"
:
2
"num_stages"
:
2
},
},
...
@@ -145,18 +187,37 @@
...
@@ -145,18 +187,37 @@
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_warps"
:
2
,
"num_stages"
:
1
"num_stages"
:
1
},
},
"16384"
:
{
"16384"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"COMBINE_SCALE_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"USE_MLS_LOAD"
:
false
,
"num_warps"
:
2
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"32768"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
"num_stages"
:
1
}
}
}
}
\ No newline at end of file
aiter/ops/triton/configs/moe/E=288,N=160,device_name=BW200B,dtype=fp8_w8a8,is_bottom=True.json
0 → 100644
View file @
bb596f6e
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
512
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
512
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
512
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"512"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"1024"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"8192"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"16384"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"32768"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
1
}
}
\ No newline at end of file
aiter/ops/triton/configs/moe/E=288,N=160,device_name=BW200B,dtype=fp8_w8a8.json
0 → 100644
View file @
bb596f6e
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"512"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
1
},
"1024"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
1
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
1
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"8192"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"16384"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"32768"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
}
}
\ No newline at end of file
aiter/ops/triton/configs/moe/E=288,N=320,device_name=BW200B,dtype=fp8_w8a8,is_bottom=True.json
0 → 100644
View file @
bb596f6e
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
512
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
512
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"512"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"1024"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"8192"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"16384"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"32768"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
1
}
}
\ No newline at end of file
aiter/ops/triton/configs/moe/E=288,N=320,device_name=BW200B,dtype=fp8_w8a8.json
0 → 100644
View file @
bb596f6e
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"512"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"1024"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
1
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"8192"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"16384"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"32768"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
true
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
}
}
\ No newline at end of file
aiter/ops/triton/configs/moe/E=384,N=1024,device_name=BW200,dtype=int4_w4a16,is_bottom=True.json
0 → 100644
View file @
bb596f6e
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
1
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"512"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"1024"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
16
,
"num_stages"
:
2
},
"8192"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
16
,
"num_stages"
:
2
},
"16384"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
16
,
"num_stages"
:
2
},
"32768"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
16
,
"num_stages"
:
2
}
}
\ No newline at end of file
aiter/ops/triton/configs/moe/E=384,N=1024,device_name=BW200,dtype=int4_w4a16.json
0 → 100644
View file @
bb596f6e
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
1
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"512"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"1024"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
16
,
"num_stages"
:
2
},
"8192"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
16
,
"num_stages"
:
2
},
"16384"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
16
,
"num_stages"
:
2
},
"32768"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"mmac5-ds10"
,
"kpack"
:
1
,
"num_warps"
:
16
,
"num_stages"
:
2
}
}
\ No newline at end of file
aiter/ops/triton/configs/moe/E=512,N=336,device_name=BW200B,is_bottom=True.json
0 → 100644
View file @
bb596f6e
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"512"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"1024"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
2
,
"num_stages"
:
1
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"4096"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"8192"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"16384"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"32768"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
8
,
"num_stages"
:
1
}
}
\ No newline at end of file
aiter/ops/triton/configs/moe/E=512,N=336,device_name=BW200B.json
0 → 100644
View file @
bb596f6e
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
2
,
"num_stages"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
2
,
"num_stages"
:
2
},
"512"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"local-prefetch"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"1024"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
2
,
"num_stages"
:
1
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"4096"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"8192"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
8
,
"num_stages"
:
1
},
"16384"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
1
},
"32768"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"COMBINE_SCALE_LOAD"
:
false
,
"USE_MLS_LOAD"
:
false
,
"instruction_sched_variant"
:
"none"
,
"sched_latency"
:
"none"
,
"kpack"
:
2
,
"num_warps"
:
4
,
"num_stages"
:
1
}
}
\ No newline at end of file
aiter/ops/triton/extend_attention.py
View file @
bb596f6e
...
@@ -38,6 +38,8 @@ from aiter.ops.triton.activation import _tanh
...
@@ -38,6 +38,8 @@ from aiter.ops.triton.activation import _tanh
import
aiter.ops.triton.utils.arch_info
as
arch_info
import
aiter.ops.triton.utils.arch_info
as
arch_info
from
aiter.ops.triton.utils.core
import
AITER_TRITON_CONFIGS_PATH
from
aiter.ops.triton.utils.core
import
AITER_TRITON_CONFIGS_PATH
from
triton
import
__version__
as
triton_version
triton_minor_version
=
int
(
triton_version
.
split
(
"."
)[
1
])
@
triton
.
jit
@
triton
.
jit
def
_fwd_kernel
(
def
_fwd_kernel
(
...
@@ -348,6 +350,10 @@ def _fwd_kernel_v2(
...
@@ -348,6 +350,10 @@ def _fwd_kernel_v2(
SKIP_PREFIX_CUSTOM_MASK
:
tl
.
constexpr
,
SKIP_PREFIX_CUSTOM_MASK
:
tl
.
constexpr
,
STORE_TRANSPOSE
:
tl
.
constexpr
,
STORE_TRANSPOSE
:
tl
.
constexpr
,
HAS_SINK
:
tl
.
constexpr
,
HAS_SINK
:
tl
.
constexpr
,
head_num
:
tl
.
constexpr
,
USE_MLS
:
tl
.
constexpr
,
batch_size
:
tl
.
constexpr
,
# max_len_extend: tl.constexpr,
):
):
cur_seq
=
tl
.
program_id
(
0
)
cur_seq
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
...
@@ -357,6 +363,24 @@ def _fwd_kernel_v2(
...
@@ -357,6 +363,24 @@ def _fwd_kernel_v2(
tl
.
assume
(
K_Extend
.
to
(
tl
.
int64
)
>=
0
)
tl
.
assume
(
K_Extend
.
to
(
tl
.
int64
)
>=
0
)
tl
.
assume
(
V_Extend
.
to
(
tl
.
int64
)
>=
0
)
tl
.
assume
(
V_Extend
.
to
(
tl
.
int64
)
>=
0
)
tl
.
assume
(
kv_group_num
>=
0
)
tl
.
assume
(
stride_qbs
>=
0
)
tl
.
assume
(
stride_qh
>=
0
)
tl
.
assume
(
stride_kbs
>=
0
)
tl
.
assume
(
stride_kh
>=
0
)
tl
.
assume
(
stride_vbs
>=
0
)
tl
.
assume
(
stride_vh
>=
0
)
tl
.
assume
(
stride_obs
>=
0
)
tl
.
assume
(
stride_oh
>=
0
)
tl
.
assume
(
stride_buf_kbs
>=
0
)
tl
.
assume
(
stride_buf_kh
>=
0
)
tl
.
assume
(
stride_buf_vbs
>=
0
)
tl
.
assume
(
stride_buf_vh
>=
0
)
tl
.
assume
(
head_num
>=
0
)
tl
.
assume
(
batch_size
>=
0
)
# tl.assume(max_len_extend >= 0)
kv_head_num
=
head_num
//
kv_group_num
cur_kv_head
=
cur_head
//
kv_group_num
cur_kv_head
=
cur_head
//
kv_group_num
cur_seq_extend_start_idx
=
tl
.
load
(
qo_indptr
+
cur_seq
)
cur_seq_extend_start_idx
=
tl
.
load
(
qo_indptr
+
cur_seq
)
...
@@ -380,6 +404,10 @@ def _fwd_kernel_v2(
...
@@ -380,6 +404,10 @@ def _fwd_kernel_v2(
mask_d
=
offs_d
<
Lq
mask_d
=
offs_d
<
Lq
mask_dv
=
offs_dv
<
Lv
mask_dv
=
offs_dv
<
Lv
ALL_MASK_M
=
tl
.
min
(
mask_m
.
to
(
tl
.
int32
),
axis
=
0
)
==
1
ALL_MASK_D
=
tl
.
min
(
mask_d
.
to
(
tl
.
int32
),
axis
=
0
)
==
1
ALL_MASK_DV
=
tl
.
min
(
mask_dv
.
to
(
tl
.
int32
),
axis
=
0
)
==
1
if
xai_temperature_len
>
0
:
if
xai_temperature_len
>
0
:
offs_qidx
=
cur_seq_len_prefix
+
cur_block_m
*
BLOCK_M
+
offs_m
offs_qidx
=
cur_seq_len_prefix
+
cur_block_m
*
BLOCK_M
+
offs_m
xai_temperature_scale
=
1.0
/
tl
.
log2
(
float
(
xai_temperature_len
))
xai_temperature_scale
=
1.0
/
tl
.
log2
(
float
(
xai_temperature_len
))
...
@@ -389,10 +417,534 @@ def _fwd_kernel_v2(
...
@@ -389,10 +417,534 @@ def _fwd_kernel_v2(
1.0
,
1.0
,
)
)
offs_q
=
(
if
USE_MLS
:
q
=
tl
.
matrix_load
(
Q_Extend
+
cur_head
*
stride_qh
,
shape
=
(
head_num
,
Lq
),
strides
=
(
stride_qbs
,
1
),
block_shape
=
(
BLOCK_M
,
BLOCK_DMODEL
),
offsets
=
((
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
).
to
(
tl
.
int32
),
0
),
)
if
not
(
ALL_MASK_M
&
ALL_MASK_D
):
q
=
tl
.
where
((
mask_m
[:,
None
])
&
(
mask_d
[
None
,
:]),
q
,
0.0
)
else
:
offs_q
=
(
(
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
)
q
=
tl
.
load
(
Q_Extend
+
offs_q
,
mask
=
(
mask_m
[:,
None
])
&
(
mask_d
[
None
,
:]),
other
=
0.0
)
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
if
USE_MLS
:
qpe
=
tl
.
matrix_load
(
Q_Extend
+
cur_head
*
stride_qh
,
shape
=
(
head_num
,
Lq
),
strides
=
(
stride_qbs
,
1
),
block_shape
=
(
BLOCK_M
,
BLOCK_DPE
),
offsets
=
((
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
).
to
(
tl
.
int32
),
BLOCK_DMODEL
),
)
if
not
ALL_MASK_M
:
qpe
=
tl
.
where
(
mask_m
[:,
None
],
qpe
,
0.0
)
else
:
offs_qpe
=
(
(
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_dpe
[
None
,
:]
)
qpe
=
tl
.
load
(
Q_Extend
+
offs_qpe
,
mask
=
mask_m
[:,
None
],
other
=
0.0
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DV
],
dtype
=
tl
.
float32
)
deno
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
e_max
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
for
start_n
in
range
(
0
,
cur_seq_len_prefix
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
mask_n
=
(
start_n
+
offs_n
)
<
cur_seq_len_prefix
ALL_MASK_N
=
tl
.
min
(
mask_n
.
to
(
tl
.
int32
),
axis
=
0
)
==
1
final_mask
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
if
USE_CUSTOM_MASK
and
not
SKIP_PREFIX_CUSTOM_MASK
:
if
USE_MLS
:
custom_mask
=
tl
.
matrix_load
(
mask_ptr
+
cur_seq_mask_start_idx
,
shape
=
(
cur_seq_len_extend
,
cur_seq_len
+
window_kv_offset
),
strides
=
(
cur_seq_len
+
window_kv_offset
,
1
),
block_shape
=
(
BLOCK_M
,
BLOCK_N
),
offsets
=
((
cur_block_m
*
BLOCK_M
).
to
(
tl
.
int32
),
(
window_kv_offset
+
start_n
).
to
(
tl
.
int32
)),
)
if
not
(
ALL_MASK_M
&
ALL_MASK_N
):
custom_mask
=
tl
.
where
((
mask_m
[:,
None
])
&
(
mask_n
[
None
,
:]),
custom_mask
,
0
)
else
:
custom_mask
=
tl
.
load
(
mask_ptr
+
cur_seq_mask_start_idx
+
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
(
cur_seq_len
+
window_kv_offset
)
+
window_kv_offset
+
start_n
+
offs_n
[
None
,
:],
mask
=
(
mask_m
[:,
None
]
&
mask_n
[
None
,
:]),
other
=
0
,
)
final_mask
&=
custom_mask
if
SLIDING_WINDOW_SIZE
>
0
:
window_mask
=
(
cur_seq_len_prefix
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
]
)
<=
(
start_n
+
offs_n
[
None
,
:]
+
SLIDING_WINDOW_SIZE
)
final_mask
&=
window_mask
SKIP_TILE
=
False
if
(
USE_CUSTOM_MASK
and
not
SKIP_PREFIX_CUSTOM_MASK
)
or
SLIDING_WINDOW_SIZE
>
0
:
SKIP_TILE
=
tl
.
max
(
tl
.
max
(
final_mask
.
to
(
tl
.
int32
),
axis
=
1
),
axis
=
0
)
==
0
if
not
SKIP_TILE
:
offs_kv_loc
=
tl
.
load
(
kv_indices
+
cur_seq_kv_start_idx
+
start_n
+
offs_n
,
mask
=
mask_n
,
other
=
0
,
)
# offs_kv_next = offs_kv_loc[1:] # [1, 2, ..., N-1]
# offs_kv_curr = offs_kv_loc[:-1] # [0, 1, ..., N-2]
# diff = offs_kv_next - offs_kv_curr # 长度 N-1
# is_continuous = tl.all(diff == 1)
# if USE_MLS and is_continuous:
# offs_kv_start_idx = tl.load(
# kv_indices + cur_seq_kv_start_idx + start_n,
# )
# k = tl.matrix_load(
# K_Buffer + cur_kv_head * stride_buf_kh,
# shape=(Lq, cur_seq_len_prefix.to(tl.int32)),
# strides=(1, stride_buf_kbs),
# block_shape=(BLOCK_DMODEL, BLOCK_N),
# offsets=(0, offs_kv_start_idx.to(tl.int32)),
# # mask=(mask_m[:, None] & mask_n[None, :]),
# # boundary_check=(0, 1),
# )
# if not (ALL_MASK_N & ALL_MASK_D):
# k = tl.where((mask_n[None, :]) & (mask_d[:, None]), k, 0.0)
# else:
# offs_buf_k = (
# offs_kv_loc[None, :] * stride_buf_kbs
# + cur_kv_head * stride_buf_kh
# + offs_d[:, None]
# )
# k = tl.load(
# K_Buffer + offs_buf_k,
# mask=(mask_n[None, :]) & (mask_d[:, None]),
# other=0.0,
# )
offs_buf_k
=
(
offs_kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_d
[:,
None
]
)
k
=
tl
.
load
(
K_Buffer
+
offs_buf_k
,
mask
=
(
mask_n
[
None
,
:])
&
(
mask_d
[:,
None
]),
other
=
0.0
,
)
qk
=
tl
.
dot
(
q
.
to
(
k
.
dtype
),
k
)
if
BLOCK_DPE
>
0
:
# if USE_MLS:
# kpe = tl.matrix_load(
# K_Buffer + cur_kv_head * stride_buf_kh,
# shape=(Lq, cur_seq_len_prefix.to(tl.int32)),
# strides=(1, stride_buf_kbs),
# block_shape=(BLOCK_DPE, BLOCK_N),
# offsets=(BLOCK_DMODEL, offs_kv_start_idx.to(tl.int32)),
# # mask=(mask_m[:, None] & mask_n[None, :]),
# # boundary_check=(0, 1),
# )
# if not ALL_MASK_N:
# kpe = tl.where((mask_n[None, :]), kpe, 0.0)
# else:
# offs_kpe = (
# offs_kv_loc[None, :] * stride_buf_kbs
# + cur_kv_head * stride_buf_kh
# + offs_dpe[:, None]
# )
# kpe = tl.load(
# K_Buffer + offs_kpe,
# mask=mask_n[None, :],
# other=0.0,
# )
offs_kpe
=
(
offs_kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Buffer
+
offs_kpe
,
mask
=
mask_n
[
None
,
:],
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
.
to
(
kpe
.
dtype
),
kpe
)
qk
*=
sm_scale
*
k_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
_tanh
(
qk
/
logit_cap
)
if
xai_temperature_len
>
0
:
qk
*=
xai_temperature_reg
[:,
None
]
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
row_max
=
tl
.
max
(
qk
,
1
)
row_max_fixed
=
tl
.
where
(
row_max
==
float
(
"-inf"
),
-
1e20
,
row_max
)
n_e_max
=
tl
.
maximum
(
row_max_fixed
,
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
# if USE_MLS:
# v = tl.matrix_load(
# V_Buffer + cur_kv_head * stride_buf_vh,
# shape=(cur_seq_len_prefix.to(tl.int32), Lv),
# strides=(stride_buf_kbs, 1),
# block_shape=(BLOCK_N, BLOCK_DV),
# offsets=(offs_kv_start_idx.to(tl.int32), 0),
# # mask=(mask_m[:, None] & mask_n[None, :]),
# # boundary_check=(0, 1),
# )
# if not (ALL_MASK_N & ALL_MASK_DV):
# v = tl.where((mask_n[:, None] & mask_dv[None, :]), v, 0.0)
# else:
# offs_buf_v = (
# offs_kv_loc[:, None] * stride_buf_vbs
# + cur_kv_head * stride_buf_vh
# + offs_dv[None, :]
# )
# v = tl.load(
# V_Buffer + offs_buf_v,
# mask=mask_n[:, None] & mask_dv[None, :],
# other=0.0,
# )
offs_buf_v
=
(
offs_kv_loc
[:,
None
]
*
stride_buf_vbs
+
cur_kv_head
*
stride_buf_vh
+
offs_dv
[
None
,
:]
)
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
mask
=
mask_n
[:,
None
]
&
mask_dv
[
None
,
:],
other
=
0.0
,
)
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
re_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
*
v_scale
e_max
=
n_e_max
cur_block_m_end
=
(
cur_seq_len_extend
if
not
IS_CAUSAL
else
tl
.
minimum
(
cur_seq_len_extend
,
(
cur_block_m
+
1
)
*
BLOCK_M
)
)
for
start_n
in
range
(
0
,
cur_block_m_end
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
mask_n
=
(
start_n
+
offs_n
)
<
cur_block_m_end
ALL_MASK_N
=
tl
.
min
(
mask_n
.
to
(
tl
.
int32
),
axis
=
0
)
==
1
final_mask
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
if
USE_CUSTOM_MASK
:
if
USE_MLS
:
custom_mask
=
tl
.
matrix_load
(
mask_ptr
+
cur_seq_mask_start_idx
,
shape
=
(
cur_block_m_end
,
cur_seq_len
+
window_kv_offset
),
strides
=
(
cur_seq_len
+
window_kv_offset
,
1
),
block_shape
=
(
BLOCK_M
,
BLOCK_N
),
offsets
=
((
cur_block_m
*
BLOCK_M
).
to
(
tl
.
int32
),
(
window_kv_offset
+
cur_seq_len_prefix
+
start_n
).
to
(
tl
.
int32
)),
)
if
not
(
ALL_MASK_M
&
ALL_MASK_N
):
custom_mask
=
tl
.
where
((
mask_m
[:,
None
])
&
(
mask_n
[
None
,
:]),
custom_mask
,
0
)
else
:
custom_mask
=
tl
.
load
(
mask_ptr
+
cur_seq_mask_start_idx
+
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
(
cur_seq_len
+
window_kv_offset
)
+
window_kv_offset
+
cur_seq_len_prefix
+
start_n
+
offs_n
[
None
,
:],
mask
=
(
mask_m
[:,
None
]
&
mask_n
[
None
,
:]),
other
=
0
,
)
custom_mask
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
final_mask
&=
custom_mask
elif
IS_CAUSAL
:
mask_causual
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
>=
(
start_n
+
offs_n
[
None
,
:]
)
mask_causual
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
final_mask
&=
mask_causual
else
:
mask_non_causal
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
final_mask
&=
mask_non_causal
if
SLIDING_WINDOW_SIZE
>
0
:
window_mask
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
<=
(
start_n
+
offs_n
[
None
,
:]
+
SLIDING_WINDOW_SIZE
)
final_mask
&=
window_mask
SKIP_TILE
=
False
if
USE_CUSTOM_MASK
or
SLIDING_WINDOW_SIZE
>
0
:
SKIP_TILE
=
tl
.
max
(
tl
.
max
(
final_mask
.
to
(
tl
.
int32
),
axis
=
1
),
axis
=
0
)
==
0
if
not
SKIP_TILE
:
if
USE_MLS
:
k
=
tl
.
matrix_load
(
K_Extend
+
cur_kv_head
*
stride_kh
,
shape
=
(
kv_head_num
,
Lq
),
strides
=
(
1
,
stride_kbs
),
block_shape
=
(
BLOCK_DMODEL
,
BLOCK_N
),
offsets
=
(
0
,
(
cur_seq_extend_start_idx
+
start_n
).
to
(
tl
.
int32
)),
)
if
not
(
ALL_MASK_N
&
ALL_MASK_D
):
k
=
tl
.
where
((
mask_d
[:,
None
])
&
(
mask_n
[
None
,
:]),
k
,
0.0
)
else
:
offs_k
=
(
(
cur_seq_extend_start_idx
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_d
[:,
None
]
)
k
=
tl
.
load
(
K_Extend
+
offs_k
,
mask
=
(
mask_n
[
None
,
:])
&
(
mask_d
[:,
None
]),
other
=
0.0
)
qk
=
tl
.
dot
(
q
.
to
(
k
.
dtype
),
k
,
out_dtype
=
tl
.
float32
)
if
BLOCK_DPE
>
0
:
if
USE_MLS
:
kpe
=
tl
.
matrix_load
(
K_Extend
+
cur_kv_head
*
stride_kh
,
shape
=
(
kv_head_num
,
Lq
),
strides
=
(
1
,
stride_kbs
),
block_shape
=
(
BLOCK_DPE
,
BLOCK_N
),
offsets
=
(
BLOCK_DMODEL
,
(
cur_seq_extend_start_idx
+
start_n
).
to
(
tl
.
int32
)),
)
if
not
ALL_MASK_N
:
kpe
=
tl
.
where
(
mask_n
[
None
,
:],
kpe
,
0.0
)
else
:
offs_kpe
=
(
(
cur_seq_extend_start_idx
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Extend
+
offs_kpe
,
mask
=
mask_n
[
None
,
:],
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
.
to
(
kpe
.
dtype
),
kpe
)
qk
*=
sm_scale
if
logit_cap
>
0
:
qk
=
logit_cap
*
_tanh
(
qk
/
logit_cap
)
if
xai_temperature_len
>
0
:
qk
*=
xai_temperature_reg
[:,
None
]
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
row_max
=
tl
.
max
(
qk
,
1
)
row_max_fixed
=
tl
.
where
(
row_max
==
float
(
"-inf"
),
-
1e20
,
row_max
)
n_e_max
=
tl
.
maximum
(
row_max_fixed
,
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
if
USE_MLS
:
v
=
tl
.
matrix_load
(
V_Extend
+
cur_kv_head
*
stride_vh
,
shape
=
(
kv_head_num
,
Lv
),
strides
=
(
stride_vbs
,
1
),
block_shape
=
(
BLOCK_N
,
BLOCK_DV
),
offsets
=
((
cur_seq_extend_start_idx
+
start_n
).
to
(
tl
.
int32
),
0
),
)
if
not
(
ALL_MASK_N
&
ALL_MASK_DV
):
v
=
tl
.
where
((
mask_n
[:,
None
])
&
(
mask_dv
[
None
,
:]),
v
,
0.0
)
else
:
offs_v
=
(
(
cur_seq_extend_start_idx
+
start_n
+
offs_n
[:,
None
])
*
stride_vbs
+
cur_kv_head
*
stride_vh
+
offs_dv
[
None
,
:]
)
v
=
tl
.
load
(
V_Extend
+
offs_v
,
mask
=
mask_n
[:,
None
]
&
mask_dv
[
None
,
:],
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
re_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
e_max
=
n_e_max
if
HAS_SINK
:
cur_sink
=
tl
.
load
(
sink_ptr
+
cur_head
)
deno
+=
tl
.
exp
(
cur_sink
-
e_max
)
offs_o
=
(
(
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
(
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_qbs
*
stride_obs
+
cur_head
*
stride_qh
+
cur_head
*
stride_oh
+
offs_dv
[
None
,
:]
)
if
STORE_TRANSPOSE
:
tl
.
store
(
O_Extend
+
offs_o
.
T
,
(
acc
/
deno
[:,
None
]).
T
,
mask
=
(
mask_m
[:,
None
]
&
mask_dv
[
None
,
:]).
T
,
)
else
:
tl
.
store
(
O_Extend
+
offs_o
,
acc
/
deno
[:,
None
],
mask
=
mask_m
[:,
None
]
&
mask_dv
[
None
,
:],
)
@
triton
.
jit
def
_fwd_kernel_v2_decode
(
Q_Extend
,
K_Extend
,
V_Extend
,
O_Extend
,
K_Buffer
,
V_Buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
mask_ptr
,
mask_indptr
,
sink_ptr
,
window_kv_offset_ptr
,
sm_scale
,
k_scale
,
v_scale
,
stride_qbs
,
stride_qh
,
stride_kbs
,
stride_kh
,
stride_vbs
,
stride_vh
,
stride_obs
,
stride_oh
,
stride_buf_kbs
,
stride_buf_kh
,
stride_buf_vbs
,
stride_buf_vh
,
SLIDING_WINDOW_SIZE
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
xai_temperature_len
:
tl
.
constexpr
,
Lq
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DPE
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
USE_CUSTOM_MASK
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
SKIP_PREFIX_CUSTOM_MASK
:
tl
.
constexpr
,
STORE_TRANSPOSE
:
tl
.
constexpr
,
HAS_SINK
:
tl
.
constexpr
,
kv_group_num
:
tl
.
constexpr
,
num_query_heads
:
tl
.
constexpr
,
USE_MLS
:
tl
.
constexpr
,
batch_size
:
tl
.
constexpr
,
# max_len_extend: tl.constexpr,
):
"""
v2 decode: grid (batch, num_kv_heads, cdiv(max_len_extend, Q_SEQ)) with
Q_SEQ = BLOCK_M // kv_group_num (same as unified ``BLOCK_Q``; floor, not ceil). If BLOCK_M is not
a multiple of G, adjacent ``cur_block_m`` may overlap in query_pos like unified, but `mask_m`
and sequence bounds keep correctness. BLOCK_M is a power of 2 (host). Require BLOCK_M // G >= 1 to launch.
"""
# Per unified: BLOCK_Q = BLOCK_M // G; stride in query token index for each +1 of cur_block_m.
Q_SEQ
:
tl
.
constexpr
=
BLOCK_M
//
kv_group_num
cur_seq
=
tl
.
program_id
(
0
)
cur_kv_head
=
tl
.
program_id
(
1
)
cur_block_m
=
tl
.
program_id
(
2
)
tl
.
assume
(
Q_Extend
.
to
(
tl
.
int64
)
>=
0
)
tl
.
assume
(
K_Extend
.
to
(
tl
.
int64
)
>=
0
)
tl
.
assume
(
V_Extend
.
to
(
tl
.
int64
)
>=
0
)
tl
.
assume
(
stride_qbs
>=
0
)
tl
.
assume
(
stride_qh
>=
0
)
tl
.
assume
(
stride_kbs
>=
0
)
tl
.
assume
(
stride_kh
>=
0
)
tl
.
assume
(
stride_vbs
>=
0
)
tl
.
assume
(
stride_vh
>=
0
)
tl
.
assume
(
stride_obs
>=
0
)
tl
.
assume
(
stride_oh
>=
0
)
tl
.
assume
(
stride_buf_kbs
>=
0
)
tl
.
assume
(
stride_buf_kh
>=
0
)
tl
.
assume
(
stride_buf_vbs
>=
0
)
tl
.
assume
(
stride_buf_vh
>=
0
)
tl
.
assume
(
batch_size
>=
0
)
# tl.assume(max_len_extend >= 0)
kv_head_num
=
num_query_heads
//
kv_group_num
cur_seq_extend_start_idx
=
tl
.
load
(
qo_indptr
+
cur_seq
)
cur_seq_len_extend
=
tl
.
load
(
qo_indptr
+
cur_seq
+
1
)
-
cur_seq_extend_start_idx
cur_seq_kv_start_idx
=
tl
.
load
(
kv_indptr
+
cur_seq
)
cur_seq_len_prefix
=
tl
.
load
(
kv_indptr
+
cur_seq
+
1
)
-
cur_seq_kv_start_idx
cur_seq_len
=
cur_seq_len_prefix
+
cur_seq_len_extend
if
cur_block_m
*
Q_SEQ
>=
cur_seq_len_extend
:
return
if
USE_CUSTOM_MASK
:
cur_seq_mask_start_idx
=
tl
.
load
(
mask_indptr
+
cur_seq
)
window_kv_offset
=
0
if
USE_CUSTOM_MASK
and
SLIDING_WINDOW_SIZE
>
0
:
window_kv_offset
=
tl
.
load
(
window_kv_offset_ptr
+
cur_seq
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_dv
=
tl
.
arange
(
0
,
BLOCK_DV
)
offs_m
=
tl
.
arange
(
0
,
BLOCK_M
)
# unified_attention-style: per offs_m, row = offs_m // G, h = offs_m % G (G = kv_group_num)
query_pos
=
cur_block_m
*
Q_SEQ
+
(
offs_m
//
kv_group_num
)
q_head_in_group
=
offs_m
%
kv_group_num
query_offset_0
=
cur_seq_extend_start_idx
+
query_pos
query_offset_1
=
cur_kv_head
*
kv_group_num
+
q_head_in_group
mask_m
=
(
query_pos
<
cur_seq_len_extend
)
&
(
query_offset_1
<
num_query_heads
)
mask_d
=
offs_d
<
Lq
mask_dv
=
offs_dv
<
Lv
ALL_MASK_M
=
tl
.
min
(
mask_m
.
to
(
tl
.
int32
),
axis
=
0
)
==
1
ALL_MASK_D
=
tl
.
min
(
mask_d
.
to
(
tl
.
int32
),
axis
=
0
)
==
1
ALL_MASK_DV
=
tl
.
min
(
mask_dv
.
to
(
tl
.
int32
),
axis
=
0
)
==
1
if
xai_temperature_len
>
0
:
offs_qidx
=
cur_seq_len_prefix
+
query_pos
xai_temperature_scale
=
1.0
/
tl
.
log2
(
float
(
xai_temperature_len
))
xai_temperature_reg
=
tl
.
where
(
offs_qidx
>
xai_temperature_len
,
tl
.
log2
(
offs_qidx
.
to
(
tl
.
float32
))
*
xai_temperature_scale
,
1.0
,
)
offs_q
=
(
query_offset_0
[:,
None
]
*
stride_qbs
+
query_offset_1
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
+
offs_d
[
None
,
:]
)
)
q
=
tl
.
load
(
q
=
tl
.
load
(
...
@@ -402,9 +954,8 @@ def _fwd_kernel_v2(
...
@@ -402,9 +954,8 @@ def _fwd_kernel_v2(
if
BLOCK_DPE
>
0
:
if
BLOCK_DPE
>
0
:
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
offs_dpe
=
BLOCK_DMODEL
+
tl
.
arange
(
0
,
BLOCK_DPE
)
offs_qpe
=
(
offs_qpe
=
(
(
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
query_offset_0
[:,
None
]
*
stride_qbs
*
stride_qbs
+
query_offset_1
[:,
None
]
*
stride_qh
+
cur_head
*
stride_qh
+
offs_dpe
[
None
,
:]
+
offs_dpe
[
None
,
:]
)
)
qpe
=
tl
.
load
(
Q_Extend
+
offs_qpe
,
mask
=
mask_m
[:,
None
],
other
=
0.0
)
qpe
=
tl
.
load
(
Q_Extend
+
offs_qpe
,
mask
=
mask_m
[:,
None
],
other
=
0.0
)
...
@@ -418,14 +969,38 @@ def _fwd_kernel_v2(
...
@@ -418,14 +969,38 @@ def _fwd_kernel_v2(
for
start_n
in
range
(
0
,
cur_seq_len_prefix
,
BLOCK_N
):
for
start_n
in
range
(
0
,
cur_seq_len_prefix
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
mask_n
=
(
start_n
+
offs_n
)
<
cur_seq_len_prefix
mask_n
=
(
start_n
+
offs_n
)
<
cur_seq_len_prefix
ALL_MASK_N
=
tl
.
min
(
mask_n
.
to
(
tl
.
int32
),
axis
=
0
)
==
1
final_mask
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
final_mask
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
if
USE_CUSTOM_MASK
and
not
SKIP_PREFIX_CUSTOM_MASK
:
if
USE_CUSTOM_MASK
and
not
SKIP_PREFIX_CUSTOM_MASK
:
# if USE_MLS:
# group_id = offs_m // kv_group_num
# custom_mask_group = tl.matrix_load(
# mask_ptr + cur_seq_mask_start_idx,
# shape=(cur_seq_len_prefix, cur_seq_len + window_kv_offset),
# strides=(cur_seq_len + window_kv_offset, 1),
# block_shape=(BLOCK_M // kv_group_num, BLOCK_N),
# offsets=((cur_block_m * Q_SEQ).to(tl.int32),
# (window_kv_offset + start_n).to(tl.int32)),
# )
# custom_mask = custom_mask_group[group_id[:, None], offs_n[None, :]]
# if not (ALL_MASK_M & ALL_MASK_N):
# custom_mask = tl.where((mask_m[:, None] & mask_n[None, :]), custom_mask, 0)
# else:
# custom_mask = tl.load(
# mask_ptr
# + cur_seq_mask_start_idx
# + (query_pos[:, None]) * (cur_seq_len + window_kv_offset)
# + window_kv_offset
# + start_n
# + offs_n[None, :],
# mask=(mask_m[:, None] & mask_n[None, :]),
# other=0,
# )
custom_mask
=
tl
.
load
(
custom_mask
=
tl
.
load
(
mask_ptr
mask_ptr
+
cur_seq_mask_start_idx
+
cur_seq_mask_start_idx
+
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
+
(
query_pos
[:,
None
])
*
(
cur_seq_len
+
window_kv_offset
)
*
(
cur_seq_len
+
window_kv_offset
)
+
window_kv_offset
+
window_kv_offset
+
start_n
+
start_n
+
offs_n
[
None
,
:],
+
offs_n
[
None
,
:],
...
@@ -435,7 +1010,7 @@ def _fwd_kernel_v2(
...
@@ -435,7 +1010,7 @@ def _fwd_kernel_v2(
final_mask
&=
custom_mask
final_mask
&=
custom_mask
if
SLIDING_WINDOW_SIZE
>
0
:
if
SLIDING_WINDOW_SIZE
>
0
:
window_mask
=
(
window_mask
=
(
cur_seq_len_prefix
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
]
cur_seq_len_prefix
+
query_pos
[:,
None
]
)
<=
(
start_n
+
offs_n
[
None
,
:]
+
SLIDING_WINDOW_SIZE
)
)
<=
(
start_n
+
offs_n
[
None
,
:]
+
SLIDING_WINDOW_SIZE
)
final_mask
&=
window_mask
final_mask
&=
window_mask
...
@@ -450,6 +1025,31 @@ def _fwd_kernel_v2(
...
@@ -450,6 +1025,31 @@ def _fwd_kernel_v2(
other
=
0
,
other
=
0
,
)
)
# if USE_MLS:
# offs_kv_start_idx = tl.load(
# kv_indices + cur_seq_kv_start_idx + start_n,
# )
# k = tl.matrix_load(
# K_Buffer + cur_kv_head * stride_buf_kh,
# shape=(Lq, cur_seq_len_prefix.to(tl.int32)),
# strides=(1, stride_buf_kbs),
# block_shape=(BLOCK_DMODEL, BLOCK_N),
# offsets=(0, offs_kv_start_idx.to(tl.int32)),
# )
# if not (ALL_MASK_N & ALL_MASK_D):
# k = tl.where((mask_n[None, :]) & (mask_d[:, None]), k, 0.0)
# else:
# offs_buf_k = (
# offs_kv_loc[None, :] * stride_buf_kbs
# + cur_kv_head * stride_buf_kh
# + offs_d[:, None]
# )
# k = tl.load(
# K_Buffer + offs_buf_k,
# mask=(mask_n[None, :]) & (mask_d[:, None]),
# other=0.0,
# )
offs_buf_k
=
(
offs_buf_k
=
(
offs_kv_loc
[
None
,
:]
*
stride_buf_kbs
offs_kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
cur_kv_head
*
stride_buf_kh
...
@@ -460,9 +1060,29 @@ def _fwd_kernel_v2(
...
@@ -460,9 +1060,29 @@ def _fwd_kernel_v2(
mask
=
(
mask_n
[
None
,
:])
&
(
mask_d
[:,
None
]),
mask
=
(
mask_n
[
None
,
:])
&
(
mask_d
[:,
None
]),
other
=
0.0
,
other
=
0.0
,
)
)
qk
=
tl
.
dot
(
q
.
to
(
k
.
dtype
),
k
)
qk
=
tl
.
dot
(
q
.
to
(
k
.
dtype
),
k
)
if
BLOCK_DPE
>
0
:
if
BLOCK_DPE
>
0
:
# if USE_MLS:
# kpe = tl.matrix_load(
# K_Buffer + cur_kv_head * stride_buf_kh,
# shape=(Lq, cur_seq_len_prefix.to(tl.int32)),
# strides=(1, stride_buf_kbs),
# block_shape=(BLOCK_DPE, BLOCK_N),
# offsets=(BLOCK_DMODEL, offs_kv_start_idx.to(tl.int32)),
# )
# if not ALL_MASK_N:
# kpe = tl.where((mask_n[None, :]), kpe, 0.0)
# else:
# offs_kpe = (
# offs_kv_loc[None, :] * stride_buf_kbs
# + cur_kv_head * stride_buf_kh
# + offs_dpe[:, None]
# )
# kpe = tl.load(
# K_Buffer + offs_kpe,
# mask=mask_n[None, :],
# other=0.0,
# )
offs_kpe
=
(
offs_kpe
=
(
offs_kv_loc
[
None
,
:]
*
stride_buf_kbs
offs_kv_loc
[
None
,
:]
*
stride_buf_kbs
+
cur_kv_head
*
stride_buf_kh
+
cur_kv_head
*
stride_buf_kh
...
@@ -484,30 +1104,40 @@ def _fwd_kernel_v2(
...
@@ -484,30 +1104,40 @@ def _fwd_kernel_v2(
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
# row_max_fixed avoids exp(-inf - (-inf)) when a row is all -inf in this tile;
row_max
=
tl
.
max
(
qk
,
1
)
# only needed under sliding window or custom mask (plain causal matches v1).
row_max_fixed
=
tl
.
where
(
row_max
==
float
(
"-inf"
),
-
1e20
,
row_max
)
if
SLIDING_WINDOW_SIZE
>
0
or
(
n_e_max
=
tl
.
maximum
(
row_max_fixed
,
e_max
)
USE_CUSTOM_MASK
and
not
SKIP_PREFIX_CUSTOM_MASK
):
row_max
=
tl
.
max
(
qk
,
1
)
row_max_fixed
=
tl
.
where
(
row_max
==
float
(
"-inf"
),
-
1e20
,
row_max
)
n_e_max
=
tl
.
maximum
(
row_max_fixed
,
e_max
)
else
:
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
# if USE_MLS:
# v = tl.matrix_load(
# V_Buffer + cur_kv_head * stride_buf_vh,
# shape=(cur_seq_len_prefix.to(tl.int32), Lv),
# strides=(stride_buf_kbs, 1),
# block_shape=(BLOCK_N, BLOCK_DV),
# offsets=(offs_kv_start_idx.to(tl.int32), 0),
# )
# if not (ALL_MASK_N & ALL_MASK_DV):
# v = tl.where((mask_n[:, None] & mask_dv[None, :]), v, 0.0)
# else:
# offs_buf_v = (
# offs_kv_loc[:, None] * stride_buf_vbs
# + cur_kv_head * stride_buf_vh
# + offs_dv[None, :]
# )
# v = tl.load(
# V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
# )
offs_buf_v
=
(
offs_buf_v
=
(
offs_kv_loc
[:,
None
]
*
stride_buf_vbs
offs_kv_loc
[:,
None
]
*
stride_buf_vbs
+
cur_kv_head
*
stride_buf_vh
+
cur_kv_head
*
stride_buf_vh
+
offs_dv
[
None
,
:]
+
offs_dv
[
None
,
:]
)
)
v
=
tl
.
load
(
v
=
tl
.
load
(
V_Buffer
+
offs_buf_v
,
V_Buffer
+
offs_buf_v
,
mask
=
mask_n
[:,
None
]
&
mask_dv
[
None
,
:],
other
=
0.0
mask
=
mask_n
[:,
None
]
&
mask_dv
[
None
,
:],
other
=
0.0
,
)
)
p
=
p
.
to
(
v
.
dtype
)
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
re_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
*
v_scale
acc
=
acc
*
re_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
*
v_scale
...
@@ -517,19 +1147,44 @@ def _fwd_kernel_v2(
...
@@ -517,19 +1147,44 @@ def _fwd_kernel_v2(
cur_block_m_end
=
(
cur_block_m_end
=
(
cur_seq_len_extend
cur_seq_len_extend
if
not
IS_CAUSAL
if
not
IS_CAUSAL
else
tl
.
minimum
(
cur_seq_len_extend
,
(
cur_block_m
+
1
)
*
BLOCK_M
)
else
tl
.
minimum
(
cur_seq_len_extend
,
(
cur_block_m
+
1
)
*
Q_SEQ
)
)
)
for
start_n
in
range
(
0
,
cur_block_m_end
,
BLOCK_N
):
for
start_n
in
range
(
0
,
cur_block_m_end
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
mask_n
=
(
start_n
+
offs_n
)
<
cur_block_m_end
mask_n
=
(
start_n
+
offs_n
)
<
cur_block_m_end
ALL_MASK_N
=
tl
.
min
(
mask_n
.
to
(
tl
.
int32
),
axis
=
0
)
==
1
final_mask
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
final_mask
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
if
USE_CUSTOM_MASK
:
if
USE_CUSTOM_MASK
:
# if USE_MLS:
# group_id = offs_m // kv_group_num
# custom_mask_group = tl.matrix_load(
# mask_ptr + cur_seq_mask_start_idx,
# shape=(cur_block_m_end, cur_seq_len + window_kv_offset),
# strides=(cur_seq_len + window_kv_offset, 1),
# block_shape=(BLOCK_M // kv_group_num, BLOCK_N),
# offsets=((cur_block_m * Q_SEQ).to(tl.int32),
# (window_kv_offset + cur_seq_len_prefix + start_n).to(tl.int32)),
# )
# custom_mask = custom_mask_group[group_id[:, None], offs_n[None, :]]
# if not (ALL_MASK_M & ALL_MASK_N):
# custom_mask = tl.where((mask_m[:, None] & mask_n[None, :]), custom_mask, 0)
# else:
# custom_mask = tl.load(
# mask_ptr
# + cur_seq_mask_start_idx
# + (query_pos[:, None]) * (cur_seq_len + window_kv_offset)
# + window_kv_offset
# + cur_seq_len_prefix
# + start_n
# + offs_n[None, :],
# mask=(mask_m[:, None] & mask_n[None, :]),
# other=0,
# )
custom_mask
=
tl
.
load
(
custom_mask
=
tl
.
load
(
mask_ptr
mask_ptr
+
cur_seq_mask_start_idx
+
cur_seq_mask_start_idx
+
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
+
(
query_pos
[:,
None
])
*
(
cur_seq_len
+
window_kv_offset
)
*
(
cur_seq_len
+
window_kv_offset
)
+
window_kv_offset
+
window_kv_offset
+
cur_seq_len_prefix
+
cur_seq_len_prefix
+
start_n
+
start_n
...
@@ -540,9 +1195,7 @@ def _fwd_kernel_v2(
...
@@ -540,9 +1195,7 @@ def _fwd_kernel_v2(
custom_mask
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
custom_mask
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
final_mask
&=
custom_mask
final_mask
&=
custom_mask
elif
IS_CAUSAL
:
elif
IS_CAUSAL
:
mask_causual
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
>=
(
mask_causual
=
query_pos
[:,
None
]
>=
(
start_n
+
offs_n
[
None
,
:])
start_n
+
offs_n
[
None
,
:]
)
mask_causual
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
mask_causual
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
final_mask
&=
mask_causual
final_mask
&=
mask_causual
else
:
else
:
...
@@ -550,7 +1203,7 @@ def _fwd_kernel_v2(
...
@@ -550,7 +1203,7 @@ def _fwd_kernel_v2(
final_mask
&=
mask_non_causal
final_mask
&=
mask_non_causal
if
SLIDING_WINDOW_SIZE
>
0
:
if
SLIDING_WINDOW_SIZE
>
0
:
window_mask
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
]
)
<=
(
window_mask
=
query_pos
[:,
None
]
<=
(
start_n
+
offs_n
[
None
,
:]
+
SLIDING_WINDOW_SIZE
start_n
+
offs_n
[
None
,
:]
+
SLIDING_WINDOW_SIZE
)
)
final_mask
&=
window_mask
final_mask
&=
window_mask
...
@@ -560,27 +1213,49 @@ def _fwd_kernel_v2(
...
@@ -560,27 +1213,49 @@ def _fwd_kernel_v2(
SKIP_TILE
=
tl
.
max
(
tl
.
max
(
final_mask
.
to
(
tl
.
int32
),
axis
=
1
),
axis
=
0
)
==
0
SKIP_TILE
=
tl
.
max
(
tl
.
max
(
final_mask
.
to
(
tl
.
int32
),
axis
=
1
),
axis
=
0
)
==
0
if
not
SKIP_TILE
:
if
not
SKIP_TILE
:
offs_k
=
(
if
USE_MLS
:
(
cur_seq_extend_start_idx
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
k
=
tl
.
matrix_load
(
+
cur_kv_head
*
stride_kh
K_Extend
+
cur_kv_head
*
stride_kh
,
+
offs_d
[:,
None
]
shape
=
(
kv_head_num
,
Lq
),
)
strides
=
(
1
,
stride_kbs
),
k
=
tl
.
load
(
block_shape
=
(
BLOCK_DMODEL
,
BLOCK_N
),
K_Extend
+
offs_k
,
mask
=
(
mask_n
[
None
,
:])
&
(
mask_d
[:,
None
]),
other
=
0.0
offsets
=
(
0
,
(
cur_seq_extend_start_idx
+
start_n
).
to
(
tl
.
int32
)),
)
)
if
not
(
ALL_MASK_D
&
ALL_MASK_N
):
qk
=
tl
.
dot
(
q
.
to
(
k
.
dtype
),
k
,
out_dtype
=
tl
.
float32
)
k
=
tl
.
where
((
mask_d
[:,
None
]
&
(
mask_n
[
None
,
:])),
k
,
0.0
)
if
BLOCK_DPE
>
0
:
else
:
offs_k
pe
=
(
offs_k
=
(
(
cur_seq_extend_start_idx
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
(
cur_seq_extend_start_idx
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
cur_kv_head
*
stride_kh
+
offs_d
pe
[:,
None
]
+
offs_d
[:,
None
]
)
)
kpe
=
tl
.
load
(
k
=
tl
.
load
(
K_Extend
+
offs_kpe
,
K_Extend
+
offs_k
,
mask
=
(
mask_n
[
None
,
:])
&
(
mask_d
[:,
None
]),
other
=
0.0
mask
=
mask_n
[
None
,
:],
other
=
0.0
,
)
)
qk
=
tl
.
dot
(
q
.
to
(
k
.
dtype
),
k
,
out_dtype
=
tl
.
float32
)
if
BLOCK_DPE
>
0
:
if
USE_MLS
:
kpe
=
tl
.
matrix_load
(
K_Extend
+
cur_kv_head
*
stride_kh
,
shape
=
(
kv_head_num
,
Lq
),
strides
=
(
1
,
stride_kbs
),
block_shape
=
(
BLOCK_DPE
,
BLOCK_N
),
offsets
=
(
BLOCK_DMODEL
,
(
cur_seq_extend_start_idx
+
start_n
).
to
(
tl
.
int32
)),
)
if
not
ALL_MASK_N
:
kpe
=
tl
.
where
(
mask_n
[
None
,
:],
kpe
,
0.0
)
else
:
offs_kpe
=
(
(
cur_seq_extend_start_idx
+
start_n
+
offs_n
[
None
,
:])
*
stride_kbs
+
cur_kv_head
*
stride_kh
+
offs_dpe
[:,
None
]
)
kpe
=
tl
.
load
(
K_Extend
+
offs_kpe
,
mask
=
mask_n
[
None
,
:],
other
=
0.0
,
)
qk
+=
tl
.
dot
(
qpe
.
to
(
kpe
.
dtype
),
kpe
)
qk
+=
tl
.
dot
(
qpe
.
to
(
kpe
.
dtype
),
kpe
)
qk
*=
sm_scale
qk
*=
sm_scale
...
@@ -593,38 +1268,49 @@ def _fwd_kernel_v2(
...
@@ -593,38 +1268,49 @@ def _fwd_kernel_v2(
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
if
SLIDING_WINDOW_SIZE
>
0
or
USE_CUSTOM_MASK
:
row_max
=
tl
.
max
(
qk
,
1
)
row_max
=
tl
.
max
(
qk
,
1
)
row_max_fixed
=
tl
.
where
(
row_max
==
float
(
"-inf"
),
-
1e20
,
row_max
)
row_max_fixed
=
tl
.
where
(
row_max
==
float
(
"-inf"
),
-
1e20
,
row_max
)
n_e_max
=
tl
.
maximum
(
row_max_fixed
,
e_max
)
n_e_max
=
tl
.
maximum
(
row_max_fixed
,
e_max
)
else
:
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
offs_v
=
(
if
USE_MLS
:
(
cur_seq_extend_start_idx
+
start_n
+
offs_n
[:,
None
])
*
stride_vbs
v
=
tl
.
matrix_load
(
+
cur_kv_head
*
stride_vh
V_Extend
+
cur_kv_head
*
stride_vh
,
+
offs_dv
[
None
,
:]
shape
=
(
kv_head_num
,
Lv
),
)
strides
=
(
stride_vbs
,
1
),
v
=
tl
.
load
(
block_shape
=
(
BLOCK_N
,
BLOCK_DV
),
V_Extend
+
offs_v
,
mask
=
mask_n
[:,
None
]
&
mask_dv
[
None
,
:],
other
=
0.0
offsets
=
((
cur_seq_extend_start_idx
+
start_n
).
to
(
tl
.
int32
),
0
),
)
)
if
not
(
ALL_MASK_N
&
ALL_MASK_DV
):
v
=
tl
.
where
((
mask_n
[:,
None
]
&
mask_dv
[
None
,
:]),
v
,
0.0
)
else
:
offs_v
=
(
(
cur_seq_extend_start_idx
+
start_n
+
offs_n
[:,
None
])
*
stride_vbs
+
cur_kv_head
*
stride_vh
+
offs_dv
[
None
,
:]
)
v
=
tl
.
load
(
V_Extend
+
offs_v
,
mask
=
mask_n
[:,
None
]
&
mask_dv
[
None
,
:],
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
p
=
p
.
to
(
v
.
dtype
)
acc
=
acc
*
re_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
acc
=
acc
*
re_scale
[:,
None
]
+
tl
.
dot
(
p
,
v
)
e_max
=
n_e_max
e_max
=
n_e_max
if
HAS_SINK
:
if
HAS_SINK
:
cur_sink
=
tl
.
load
(
sink_ptr
+
cur_head
)
cur_sink
=
tl
.
load
(
sink_ptr
+
cur_kv_head
*
kv_group_num
+
q_head_in_group
,
mask
=
mask_m
,
other
=
0.0
,
)
deno
+=
tl
.
exp
(
cur_sink
-
e_max
)
deno
+=
tl
.
exp
(
cur_sink
-
e_max
)
offs_o
=
(
offs_o
=
(
(
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
query_offset_0
[:,
None
]
*
stride_obs
*
stride_obs
+
query_offset_1
[:,
None
]
*
stride_oh
+
cur_head
*
stride_oh
+
offs_dv
[
None
,
:]
+
offs_dv
[
None
,
:]
)
)
if
STORE_TRANSPOSE
:
if
STORE_TRANSPOSE
:
...
@@ -700,13 +1386,118 @@ def _load_config_v2():
...
@@ -700,13 +1386,118 @@ def _load_config_v2():
raise
ValueError
(
raise
ValueError
(
f
"
{
dev
}
-EXTEND_ATTENTION-V2-FP16.json keys must be 7-tuples matching runtime "
f
"
{
dev
}
-EXTEND_ATTENTION-V2-FP16.json keys must be 7-tuples matching runtime "
f
"want7 (kv_group_num, Lq, Lv, USE_CUSTOM_MASK, IS_CAUSAL, HAS_SINK, "
f
"want7 (kv_group_num, Lq, Lv, USE_CUSTOM_MASK, IS_CAUSAL, HAS_SINK, "
f
"SLIDING_WINDOW
_SIZE
); got length
{
len
(
tup
)
}
for
{
k
!
r
}
"
f
"
USE_
SLIDING_WINDOW); got length
{
len
(
tup
)
}
for
{
k
!
r
}
"
)
)
res
[
"keys"
].
append
(
tup
)
res
[
"keys"
].
append
(
tup
)
return
res
return
res
def
_load_config_v2_decode
():
"""Autotuned block sizes for :func:`_fwd_kernel_v2_decode` (short extend path)."""
dev
=
arch_info
.
get_device
()
fpath
=
f
"
{
AITER_TRITON_CONFIGS_PATH
}
/
{
dev
}
-EXTEND_ATTENTION-V2-DECODE-FP16.json"
try
:
with
open
(
fpath
,
"r"
)
as
file
:
data
=
json
.
load
(
file
)
except
FileNotFoundError
:
return
{
"config"
:
{},
"path"
:
{},
"key"
:
[],
"keys"
:
[]}
res
=
{}
res
[
"config"
]
=
data
[
"config"
]
res
[
"path"
]
=
data
.
get
(
"path"
,
{})
res
[
"key"
]
=
list
(
data
[
"config"
].
keys
())
res
[
"keys"
]
=
[]
for
k
in
res
[
"key"
]:
tup
=
create_tuple
(
k
)
if
len
(
tup
)
!=
7
:
raise
ValueError
(
f
"
{
dev
}
-EXTEND_ATTENTION-V2-DECODE-FP16.json keys must be 7-tuples matching runtime "
f
"want7 (kv_group_num, Lq, Lv, USE_CUSTOM_MASK, IS_CAUSAL, HAS_SINK, "
f
"USE_SLIDING_WINDOW); got length
{
len
(
tup
)
}
for
{
k
!
r
}
"
)
res
[
"keys"
].
append
(
tup
)
return
res
TORCH_DTYPE_TO_DTYPE
=
{
torch
.
float32
:
"f32"
,
torch
.
float
:
"f32"
,
torch
.
float16
:
"f16"
,
torch
.
half
:
"f16"
,
torch
.
bfloat16
:
"bf16"
,
torch
.
float64
:
"f64"
,
torch
.
double
:
"f64"
,
torch
.
float8_e4m3fn
:
"f8_e4m3fn"
,
torch
.
float8_e5m2
:
"f8_e5m2"
,
torch
.
int8
:
"i8"
,
torch
.
int16
:
"i16"
,
torch
.
int32
:
"i32"
,
torch
.
int64
:
"i64"
,
torch
.
long
:
"i64"
,
torch
.
uint8
:
"u8"
,
torch
.
bool
:
"bool"
,
}
@
functools
.
lru_cache
def
get_gpu_label
():
target
=
triton
.
runtime
.
driver
.
active
.
get_current_target
()
device
=
torch
.
cuda
.
current_device
()
num_cu
=
torch
.
cuda
.
get_device_properties
(
device
).
multi_processor_count
return
f
"
{
target
.
arch
}
_cu
{
num_cu
}
"
def
_load_config_v3
():
"""Autotuned configs for :func:`_fwd_kernel_v2` (fp8 / sglang-style scale path).
Each ``config`` entry key must parse to a **7-tuple** via :func:`create_tuple`, matching
runtime ``want7``; 5-tuple keys are not accepted.
"""
fpath
=
f
"
{
AITER_TRITON_CONFIGS_PATH
}
/extend_attn/_fwd_kernel_v2-device=
{
get_gpu_label
()
}
.json"
try
:
with
open
(
fpath
,
"r"
)
as
file
:
data
=
json
.
load
(
file
)
except
FileNotFoundError
:
return
{
"key_name"
:
{},
"fpath"
:
{},
"config"
:
{},
"path"
:
{},
"key"
:
[],
"keys"
:
[]}
res
=
{}
res
[
"key_name"
]
=
data
[
"key"
]
res
[
'fpath'
]
=
fpath
res
[
"config"
]
=
data
[
"config"
]
res
[
"path"
]
=
data
.
get
(
"path"
,
{})
res
[
"key"
]
=
list
(
data
[
"config"
].
keys
())
res
[
"keys"
]
=
[]
for
k
in
res
[
"key"
]:
tup
=
create_tuple
(
k
)
res
[
"keys"
].
append
(
tup
)
return
res
def
_load_config_v3_decode
():
"""Autotuned block sizes for :func:`_fwd_kernel_v2_decode` (short extend path)."""
dev
=
arch_info
.
get_device
()
# fpath = f"{AITER_TRITON_CONFIGS_PATH}/extend_attn/_fwd_kernel_v2_decode-device={get_gpu_label()}-dtype=bf16.json"
fpath
=
f
"
{
AITER_TRITON_CONFIGS_PATH
}
/extend_attn/_fwd_kernel_v2_decode-device=
{
get_gpu_label
()
}
.json"
try
:
with
open
(
fpath
,
"r"
)
as
file
:
data
=
json
.
load
(
file
)
except
FileNotFoundError
:
return
{
"key_name"
:
{},
"fpath"
:
{},
"config"
:
{},
"path"
:
{},
"key"
:
[],
"keys"
:
[]}
res
=
{}
res
[
"key_name"
]
=
data
[
"key"
]
res
[
'fpath'
]
=
fpath
res
[
"config"
]
=
data
[
"config"
]
res
[
"path"
]
=
data
.
get
(
"path"
,
{})
res
[
"key"
]
=
list
(
data
[
"config"
].
keys
())
res
[
"keys"
]
=
[]
for
k
in
res
[
"key"
]:
tup
=
create_tuple
(
k
)
res
[
"keys"
].
append
(
tup
)
return
res
global_config_v2
=
_load_config_v2
()
global_config_v2
=
_load_config_v2
()
global_config_v2_decode
=
_load_config_v2_decode
()
global_config_v3
=
_load_config_v3
()
global_config_v3_decode
=
_load_config_v3_decode
()
default_config
=
{
default_config
=
{
"BLOCK_M"
:
32
,
"BLOCK_M"
:
32
,
...
@@ -715,7 +1506,8 @@ default_config = {
...
@@ -715,7 +1506,8 @@ default_config = {
"matrix_instr_nonkdim"
:
16
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
,
"kpack"
:
2
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
1
"num_stages"
:
2
,
"USE_MLS"
:
False
,
}
}
...
@@ -744,20 +1536,21 @@ def _get_config_v2(
...
@@ -744,20 +1536,21 @@ def _get_config_v2(
use_custom_mask
,
use_custom_mask
,
is_causal
,
is_causal
,
has_sink
:
bool
,
has_sink
:
bool
,
sliding_window
_size
:
int
,
use_
sliding_window
:
bool
,
):
):
"""
"""
Lookup order for ``_fwd_kernel_v2`` block sizes:
Lookup order for ``_fwd_kernel_v2`` block sizes:
1. ``want7 = (kv_group_num, Lq, Lv, use_custom_mask, is_causal, has_sink,
sliding_window_size
)``
1. ``want7 = (kv_group_num, Lq, Lv, use_custom_mask, is_causal, has_sink,
USE_SLIDING_WINDOW
)``
against ``{arch}-EXTEND_ATTENTION-V2-FP16.json``. JSON keys must be **7-tuple** strings,
against ``{arch}-EXTEND_ATTENTION-V2-FP16.json``. JSON keys must be **7-tuple** strings,
same shape as ``want7`` (see :func:`_load_config_v2`).
same shape as ``want7`` (see :func:`_load_config_v2`). The last element is a bool:
same tuning bucket for any ``sliding_window_size > 0``; use ``False`` when disabled (``<= 0``).
2. If no V2 entry matches, :data:`default_config` (no fallback to v1 JSON).
2. If no V2 entry matches, :data:`default_config` (no fallback to v1 JSON).
Log field mapping (typical): ``kv_group_num = q_extend.size(-2) // k_extend.size(-2)``,
Log field mapping (typical): ``kv_group_num = q_extend.size(-2) // k_extend.size(-2)``,
``Lq = q_extend.size(-1)``, ``Lv = v_extend.size(-1)``,
``Lq = q_extend.size(-1)``, ``Lv = v_extend.size(-1)``,
``use_custom_mask = custom_mask is not None``, ``is_causal`` as passed,
``use_custom_mask = custom_mask is not None``, ``is_causal`` as passed,
``has_sink = sinks is not None``, ``
sliding_window_size`` as passed (use ``-1`` if disabled)
.
``has_sink = sinks is not None``, ``
USE_SLIDING_WINDOW = (sliding_window_size > 0)``
.
"""
"""
want7
=
(
want7
=
(
kv_group_num
,
kv_group_num
,
...
@@ -766,7 +1559,7 @@ def _get_config_v2(
...
@@ -766,7 +1559,7 @@ def _get_config_v2(
use_custom_mask
,
use_custom_mask
,
is_causal
,
is_causal
,
has_sink
,
has_sink
,
sliding_window
_size
,
use_
sliding_window
,
)
)
for
i
,
keys
in
enumerate
(
global_config_v2
[
"keys"
]):
for
i
,
keys
in
enumerate
(
global_config_v2
[
"keys"
]):
if
keys
==
want7
:
if
keys
==
want7
:
...
@@ -777,10 +1570,137 @@ def _get_config_v2(
...
@@ -777,10 +1570,137 @@ def _get_config_v2(
return
default_config
,
None
return
default_config
,
None
@
functools
.
lru_cache
(
maxsize
=
1024
)
def
_get_config_v2_decode
(
kv_group_num
,
Lq
,
Lv
,
use_custom_mask
,
is_causal
,
has_sink
:
bool
,
use_sliding_window
:
bool
,
):
"""
Same ``want7`` as :func:`_get_config_v2`, but loads ``{arch}-EXTEND_ATTENTION-V2-DECODE-FP16.json``
for :func:`_fwd_kernel_v2_decode`.
"""
want7
=
(
kv_group_num
,
Lq
,
Lv
,
use_custom_mask
,
is_causal
,
has_sink
,
use_sliding_window
,
)
for
i
,
keys
in
enumerate
(
global_config_v2_decode
[
"keys"
]):
if
keys
==
want7
:
key
=
global_config_v2_decode
[
"key"
][
i
]
return
global_config_v2_decode
[
"config"
][
key
],
global_config_v2_decode
[
"path"
].
get
(
key
)
print
(
"WARNING: optimal V2 decode config not found, just use default config"
)
return
default_config
,
None
def
find_closest_index
(
target
,
lst
):
# lst: [(index, batch_size), ...]
def
key
(
item
):
val
=
item
[
1
]
diff
=
val
-
target
if
diff
>=
0
:
return
(
0
,
diff
)
else
:
return
(
1
,
-
diff
)
best
=
min
(
lst
,
key
=
key
)
return
best
[
0
]
@
functools
.
lru_cache
(
maxsize
=
1024
)
def
_get_config_v3
(
key
):
_key
=
str
(
key
)
_configs
=
global_config_v3
[
"config"
]
for
k
,
v
in
_configs
.
items
():
if
k
==
_key
:
return
v
# find the nearest batch size
# key = (batch_size, *other)
bs
=
[]
_keys
=
global_config_v3
[
"keys"
]
for
i
,
k
in
enumerate
(
_keys
):
if
k
[
1
:]
==
key
[
1
:]:
bs
.
append
((
i
,
k
[
0
]))
if
bs
:
__key
=
global_config_v3
[
"key"
][
find_closest_index
(
key
[
0
],
bs
)]
print
(
f
'WARNING: Not found key
{
_key
}
from
{
global_config_v3
[
"fpath"
]
}
, mapping to key
{
__key
}
'
)
return
_configs
[
__key
]
else
:
print
(
f
'WARNING: Not found optimal config from
{
global_config_v3
[
"fpath"
]
}
with key
{
str
(
key
)
}
, just use default config'
)
return
default_config
@
functools
.
lru_cache
(
maxsize
=
1024
)
def
_get_config_v3_decode
(
key
):
_key
=
str
(
key
)
_configs
=
global_config_v3_decode
[
"config"
]
for
k
,
v
in
_configs
.
items
():
if
k
==
_key
:
return
v
# find the nearest batch size
# key = (batch_size, *other)
bs
=
[]
_keys
=
global_config_v3_decode
[
"keys"
]
for
i
,
k
in
enumerate
(
_keys
):
if
k
[
1
:]
==
key
[
1
:]:
bs
.
append
((
i
,
k
[
0
]))
if
bs
:
__key
=
global_config_v3_decode
[
"key"
][
find_closest_index
(
key
[
0
],
bs
)]
print
(
f
'WARNING: Not found key
{
_key
}
from
{
global_config_v3
[
"fpath"
]
}
, mapping to key
{
__key
}
'
)
return
_configs
[
__key
]
else
:
print
(
f
'WARNING: Not found optimal config from
{
global_config_v3_decode
[
"fpath"
]
}
with key
{
str
(
key
)
}
, just use default config'
)
return
default_config
def
has_kernel_cache
(
path
):
def
has_kernel_cache
(
path
):
return
False
if
not
path
or
not
os
.
path
.
isdir
(
f
'
{
cache_knob
.
dir
}
/
{
path
}
'
)
else
True
return
False
if
not
path
or
not
os
.
path
.
isdir
(
f
'
{
cache_knob
.
dir
}
/
{
path
}
'
)
else
True
@
functools
.
lru_cache
(
maxsize
=
1024
)
def
get_v2_decode_final_grid
(
max_len_extend
:
int
,
kv_group_num
:
int
,
block_m_cfg
:
int
,
batch_size
:
int
,
kv_head_num
:
int
,
)
->
tuple
[
int
,
tuple
[
int
,
int
,
int
]]:
"""Decode path: power-of-2 ``block_m_decode``; grid-3 cdiv matches kernel q_seq stride."""
prod
=
max_len_extend
*
kv_group_num
npo2
=
triton
.
next_power_of_2
(
prod
)
kv_group_num_align
=
triton
.
next_power_of_2
(
kv_group_num
)
block_m_decode
=
block_m_cfg
if
prod
<
16
:
block_m_decode
=
16
elif
block_m_cfg
>
npo2
:
block_m_decode
=
npo2
else
:
block_m_decode
=
max
(
block_m_cfg
,
kv_group_num_align
)
block_count
=
batch_size
*
kv_head_num
*
triton
.
cdiv
(
max_len_extend
,
block_m_decode
//
kv_group_num
)
if
block_count
<=
32
:
block_m_decode
=
max
(
max
(
block_m_decode
//
2
,
16
),
kv_group_num_align
)
q_seq
=
block_m_decode
//
kv_group_num
grid
=
(
batch_size
,
kv_head_num
,
triton
.
cdiv
(
max_len_extend
,
q_seq
))
return
block_m_decode
,
grid
def
to_dtype
(
torch_dtype
):
def
to_dtype
(
torch_dtype
):
if
torch_dtype
==
torch
.
float32
:
if
torch_dtype
==
torch
.
float32
:
return
'fp32'
return
'fp32'
...
@@ -793,7 +1713,6 @@ def to_dtype(torch_dtype):
...
@@ -793,7 +1713,6 @@ def to_dtype(torch_dtype):
else
:
else
:
return
str
(
torch_dtype
)
return
str
(
torch_dtype
)
def
extend_attention_fwd
(
def
extend_attention_fwd
(
q_extend
,
q_extend
,
k_extend
,
k_extend
,
...
@@ -818,6 +1737,7 @@ def extend_attention_fwd(
...
@@ -818,6 +1737,7 @@ def extend_attention_fwd(
sinks
=
None
,
sinks
=
None
,
window_kv_offsets
=
None
,
window_kv_offsets
=
None
,
xai_temperature_len
=-
1
,
xai_temperature_len
=-
1
,
force_v2_prefill
:
bool
=
False
,
):
):
"""
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
q_extend, k_extend, v_extend, o_extend: contiguous tensors
...
@@ -828,7 +1748,11 @@ def extend_attention_fwd(
...
@@ -828,7 +1748,11 @@ def extend_attention_fwd(
extensions follow with defaults. ``k_scale`` / ``v_scale`` must both be
extensions follow with defaults. ``k_scale`` / ``v_scale`` must both be
``None`` or both set (``float`` / ``int`` like sglang, or 1-element
``None`` or both set (``float`` / ``int`` like sglang, or 1-element
``torch.Tensor`` on device); if both are set, :func:`_fwd_kernel_v2` is used.
``torch.Tensor`` on device); if both are set, :func:`_fwd_kernel_v2` is used.
If ``force_v2_prefill`` is True and v2 is active, always use :func:`_fwd_kernel_v2`
even when ``max_len_extend < 32`` (for tests / parity vs :func:`_fwd_kernel_v2_decode`).
"""
"""
# force_v2_prefill = True
Lq
,
Lv
=
(
Lq
,
Lv
=
(
q_extend
.
shape
[
-
1
],
q_extend
.
shape
[
-
1
],
v_extend
.
shape
[
-
1
],
v_extend
.
shape
[
-
1
],
...
@@ -853,30 +1777,87 @@ def extend_attention_fwd(
...
@@ -853,30 +1777,87 @@ def extend_attention_fwd(
sm_scale
=
sm_scale
or
1.0
/
(
Lq
**
0.5
)
sm_scale
=
sm_scale
or
1.0
/
(
Lq
**
0.5
)
batch_size
,
head_num
=
qo_indptr
.
shape
[
0
]
-
1
,
q_extend
.
shape
[
1
]
batch_size
,
head_num
=
qo_indptr
.
shape
[
0
]
-
1
,
q_extend
.
shape
[
1
]
kv_group_num
=
q_extend
.
shape
[
1
]
//
k_extend
.
shape
[
1
]
kv_head_num
=
k_extend
.
shape
[
1
]
kv_group_num
=
head_num
//
kv_head_num
USE_CUSTOM_MASK
=
custom_mask
is
not
None
USE_CUSTOM_MASK
=
custom_mask
is
not
None
# Skip custom mask for prefix part
# Skip custom mask for prefix part
SKIP_PREFIX_CUSTOM_MASK
=
skip_prefix_custom_mask
SKIP_PREFIX_CUSTOM_MASK
=
skip_prefix_custom_mask
use_v2
=
k_scale
is
not
None
or
v_scale
is
not
None
use_v2
=
k_scale
is
not
None
or
v_scale
is
not
None
use_v2_decode
=
(
use_v2
and
max_len_extend
<
32
and
not
force_v2_prefill
)
USE_SLIDING_WINDOW
=
sliding_window_size
>
0
if
not
USE_CUSTOM_MASK
:
if
not
USE_CUSTOM_MASK
:
custom_mask
=
torch
.
tensor
([
0
],
dtype
=
torch
.
bool
,
device
=
q_extend
.
device
)
# custom_mask = torch.tensor([0], dtype=torch.bool, device=q_extend.device)
mask_indptr
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
,
device
=
q_extend
.
device
)
# mask_indptr = torch.tensor([0], dtype=torch.int32, device=q_extend.device)
# set to None to avoid capture cudagraph err
custom_mask
=
None
mask_indptr
=
None
if
config
is
None
:
if
config
is
None
:
if
q_extend
.
dtype
==
torch
.
float16
or
q_extend
.
dtype
==
torch
.
bfloat16
:
if
q_extend
.
dtype
==
torch
.
float16
or
q_extend
.
dtype
==
torch
.
bfloat16
:
if
use_v2
:
if
use_v2
:
config
,
path
=
_get_config_v2
(
if
triton_minor_version
>=
5
:
# >= 3.5
kv_group_num
,
key
=
[
Lq
,
batch_size
,
Lv
,
kv_group_num
,
USE_CUSTOM_MASK
,
Lq
,
is_causal
,
Lv
,
sinks
is
not
None
,
USE_CUSTOM_MASK
,
sliding_window_size
,
is_causal
,
)
skip_prefix_custom_mask
,
sinks
is
not
None
,
sliding_window_size
,
xai_temperature_len
,
str
(
q_extend
.
dtype
),
str
(
k_extend
.
dtype
),
str
(
v_extend
.
dtype
),
str
(
o_extend
.
dtype
),
str
(
k_buffer
.
dtype
),
str
(
v_buffer
.
dtype
),
str
(
qo_indptr
.
dtype
),
str
(
kv_indptr
.
dtype
),
str
(
kv_indices
.
dtype
),
]
for
o
in
[
custom_mask
,
mask_indptr
,
sinks
,
window_kv_offsets
]:
if
o
is
not
None
:
if
hasattr
(
o
,
'dtype'
):
key
.
append
(
str
(
o
.
dtype
))
key
=
tuple
(
key
)
if
use_v2_decode
:
config
=
_get_config_v3_decode
(
key
)
else
:
config
=
_get_config_v3
(
key
)
else
:
if
use_v2_decode
:
config
,
path
=
_get_config_v2_decode
(
kv_group_num
,
Lq
,
Lv
,
USE_CUSTOM_MASK
,
is_causal
,
sinks
is
not
None
,
USE_SLIDING_WINDOW
,
)
else
:
config
,
path
=
_get_config_v2
(
kv_group_num
,
Lq
,
Lv
,
USE_CUSTOM_MASK
,
is_causal
,
sinks
is
not
None
,
USE_SLIDING_WINDOW
,
)
else
:
else
:
keys
=
[
kv_group_num
,
Lq
,
Lv
,
USE_CUSTOM_MASK
,
is_causal
]
keys
=
[
kv_group_num
,
Lq
,
Lv
,
USE_CUSTOM_MASK
,
is_causal
]
config
,
path
=
_get_config
(
*
keys
)
config
,
path
=
_get_config
(
*
keys
)
...
@@ -884,7 +1865,17 @@ def extend_attention_fwd(
...
@@ -884,7 +1865,17 @@ def extend_attention_fwd(
config
,
path
=
default_config
,
None
config
,
path
=
default_config
,
None
assert
config
is
not
None
,
"ERROR: optimal config not found"
assert
config
is
not
None
,
"ERROR: optimal config not found"
grid
=
(
batch_size
,
head_num
,
triton
.
cdiv
(
max_len_extend
,
config
[
"BLOCK_M"
]))
block_m_cfg
=
config
[
"BLOCK_M"
]
# Decode: block_m_decode is power of 2; Q_SEQ = block_m // G (floor), same as unified BLOCK_Q;
# grid-3 cdiv(max_len_extend, q_seq) matches kernel cur_block_m stride.
if
use_v2_decode
:
block_m_decode
,
grid
=
get_v2_decode_final_grid
(
max_len_extend
,
kv_group_num
,
block_m_cfg
,
batch_size
,
kv_head_num
)
# print(f"{max_len_extend=}, {use_v2_decode=}, {block_m_decode=}, {grid=}")
else
:
grid
=
(
batch_size
,
head_num
,
triton
.
cdiv
(
max_len_extend
,
block_m_cfg
))
# num_stages = 1
# num_stages = 1
# extra_kargs = {}
# extra_kargs = {}
...
@@ -922,32 +1913,69 @@ def extend_attention_fwd(
...
@@ -922,32 +1913,69 @@ def extend_attention_fwd(
HAS_SINK
=
sinks
is
not
None
HAS_SINK
=
sinks
is
not
None
assert
k_scale
is
not
None
and
v_scale
is
not
None
,
"k_scale and v_scale must both be set"
assert
k_scale
is
not
None
and
v_scale
is
not
None
,
"k_scale and v_scale must both be set"
# k_scale / v_scale kept in Python API; v2 kernel TEMP omits them for perf vs v1.
# k_scale / v_scale kept in Python API; v2 kernel TEMP omits them for perf vs v1.
_fwd_kernel_v2
[
grid
](
block_const_v2
=
{
q_extend
,
k_extend
,
v_extend
,
o_extend
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask
,
mask_indptr
,
sinks
,
window_kv_offsets
,
sm_scale
,
k_scale
,
v_scale
,
kv_group_num
,
*
stride_args
,
SLIDING_WINDOW_SIZE
=
sliding_window_size
,
logit_cap
=
logit_cap
,
xai_temperature_len
=
xai_temperature_len
,
HAS_SINK
=
HAS_SINK
,
**
block_const
,
**
block_const
,
**
config
,
**
config
,
)
}
if
use_v2_decode
:
block_const_v2
=
{
**
block_const_v2
,
"BLOCK_M"
:
block_m_decode
}
_fwd_kernel_v2_decode
[
grid
](
q_extend
,
k_extend
,
v_extend
,
o_extend
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask
,
mask_indptr
,
sinks
,
window_kv_offsets
,
sm_scale
,
k_scale
,
v_scale
,
*
stride_args
,
SLIDING_WINDOW_SIZE
=
sliding_window_size
,
logit_cap
=
logit_cap
,
xai_temperature_len
=
xai_temperature_len
,
HAS_SINK
=
HAS_SINK
,
kv_group_num
=
kv_group_num
,
num_query_heads
=
head_num
,
**
block_const_v2
,
batch_size
=
batch_size
,
# USE_MLS=True if os.getenv("TRITON_USE_MLS", "0") == "1" and kv_group_num == head_num else False,
)
else
:
_fwd_kernel_v2
[
grid
](
q_extend
,
k_extend
,
v_extend
,
o_extend
,
k_buffer
,
v_buffer
,
qo_indptr
,
kv_indptr
,
kv_indices
,
custom_mask
,
mask_indptr
,
sinks
,
window_kv_offsets
,
sm_scale
,
k_scale
,
v_scale
,
kv_group_num
,
*
stride_args
,
SLIDING_WINDOW_SIZE
=
sliding_window_size
,
logit_cap
=
logit_cap
,
xai_temperature_len
=
xai_temperature_len
,
HAS_SINK
=
HAS_SINK
,
**
block_const_v2
,
head_num
=
head_num
,
batch_size
=
batch_size
,
# USE_MLS=True if os.getenv("TRITON_USE_MLS", "0") == "1" and kv_group_num == head_num else False,
)
return
return
fn
=
(
fn
=
(
...
...
aiter/ops/triton/fla/fused_recurrent.py
0 → 100644
View file @
bb596f6e
# SPDX-License-Identifier: MIT
import
functools
import
json
import
os
from
typing
import
Tuple
import
torch
import
triton
import
triton.language
as
tl
import
aiter.ops.triton.utils.arch_info
as
arch_info
from
aiter
import
logger
from
aiter.ops.triton.utils.core
import
AITER_TRITON_CONFIGS_PATH
# HAS_DUMPED_PACKED_DECODE_KERNEL_METADATA = False
TRITON_CONFIG_CHECK
=
os
.
environ
.
get
(
"TRITON_CONFIG_CHECK"
,
"0"
)
==
"1"
_DEFAULT_FUSED_RECURRENT_PACKED_DECODE_CONFIG
=
{
"BV"
:
32
,
"num_warps"
:
1
,
"num_stages"
:
1
,
}
@
functools
.
lru_cache
(
maxsize
=
1
)
def
_load_fused_recurrent_packed_decode_configs
()
->
dict
:
device_name
=
arch_info
.
get_arch
()
path
=
os
.
path
.
join
(
AITER_TRITON_CONFIGS_PATH
,
"fused_recurrent_gated_delta_rule_packed_decode"
,
f
"fused_recurrent_gated_delta_rule_packed_decode-
{
device_name
}
.json"
,
)
if
not
os
.
path
.
exists
(
path
):
logger
.
warning
(
f
"fused_recurrent_gated_delta_rule_packed_decode config not found at
{
path
}
, "
f
"using default
{
_DEFAULT_FUSED_RECURRENT_PACKED_DECODE_CONFIG
}
."
)
return
{}
with
open
(
path
)
as
f
:
payload
=
json
.
load
(
f
)
return
payload
.
get
(
"config"
,
{})
if
isinstance
(
payload
,
dict
)
else
{}
@
functools
.
lru_cache
def
_get_fused_recurrent_packed_decode_config
(
B
:
int
,
H
:
int
,
HV
:
int
)
->
dict
:
cfgs
=
_load_fused_recurrent_packed_decode_configs
()
key
=
f
"B=
{
B
}
,H=
{
H
}
,HV=
{
HV
}
"
cfg
=
cfgs
.
get
(
key
)
if
cfg
is
None
:
candidates
=
[]
for
k
,
v
in
cfgs
.
items
():
if
k
==
"default"
:
continue
try
:
parts
=
{
x
.
split
(
"="
)[
0
]:
int
(
x
.
split
(
"="
)[
1
])
for
x
in
k
.
split
(
","
)}
except
Exception
:
continue
if
parts
.
get
(
"H"
)
==
H
and
parts
.
get
(
"HV"
)
==
HV
and
"B"
in
parts
:
candidates
.
append
((
abs
(
parts
[
"B"
]
-
B
),
parts
[
"B"
],
v
))
if
candidates
:
candidates
.
sort
(
key
=
lambda
x
:
x
[
0
])
_
,
nearest_b
,
cfg
=
candidates
[
0
]
if
TRITON_CONFIG_CHECK
:
logger
.
warning
(
f
"fused_recurrent_packed_decode config key '
{
key
}
' not found, "
f
"using nearest-B config with B=
{
nearest_b
}
:
{
cfg
}
."
)
if
cfg
is
None
:
default_cfg
=
cfgs
.
get
(
"default"
,
_DEFAULT_FUSED_RECURRENT_PACKED_DECODE_CONFIG
)
if
TRITON_CONFIG_CHECK
:
logger
.
warning
(
f
"fused_recurrent_packed_decode config key '
{
key
}
' not found, "
f
"using default config:
{
default_cfg
}
."
)
cfg
=
default_cfg
merged
=
dict
(
_DEFAULT_FUSED_RECURRENT_PACKED_DECODE_CONFIG
)
merged
.
update
(
cfg
)
return
merged
@
triton
.
jit
def
fused_recurrent_gated_delta_rule_packed_decode_kernel
(
mixed_qkv
,
a
,
b
,
A_log
,
dt_bias
,
o
,
h0
,
ht
,
ssm_state_indices
,
scale
,
stride_mixed_qkv_tok
:
tl
.
constexpr
,
stride_a_tok
:
tl
.
constexpr
,
stride_b_tok
:
tl
.
constexpr
,
stride_init_state_token
:
tl
.
constexpr
,
stride_final_state_token
:
tl
.
constexpr
,
stride_indices_seq
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
HV
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
SOFTPLUS_THRESHOLD
:
tl
.
constexpr
,
USE_QK_L2NORM_IN_KERNEL
:
tl
.
constexpr
,
):
i_v
,
i_nh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_n
,
i_hv
=
i_nh
//
HV
,
i_nh
%
HV
i_h
=
i_hv
//
(
HV
//
H
)
o_k
=
tl
.
arange
(
0
,
BK
)
o_v
=
i_v
*
BV
+
tl
.
arange
(
0
,
BV
)
mask_k
=
o_k
<
K
mask_v
=
o_v
<
V
state_idx
=
tl
.
load
(
ssm_state_indices
+
i_n
*
stride_indices_seq
).
to
(
tl
.
int64
)
p_o
=
o
+
(
i_n
*
HV
+
i_hv
)
*
V
+
o_v
if
state_idx
<
0
:
zero
=
tl
.
zeros
([
BV
],
dtype
=
tl
.
float32
).
to
(
p_o
.
dtype
.
element_ty
)
tl
.
store
(
p_o
,
zero
,
mask
=
mask_v
)
return
p_h0
=
h0
+
state_idx
*
stride_init_state_token
p_h0
=
p_h0
+
i_hv
*
V
*
K
+
o_v
[:,
None
]
*
K
+
o_k
[
None
,
:]
# [BV, BK]
b_h
=
tl
.
load
(
p_h0
,
mask
=
(
mask_v
[:,
None
]
&
mask_k
[
None
,
:]),
other
=
0
).
to
(
tl
.
float32
)
p_mixed
=
mixed_qkv
+
i_n
*
stride_mixed_qkv_tok
k_off
=
(
H
*
K
)
+
i_h
*
K
+
o_k
v_off
=
(
2
*
H
*
K
)
+
i_hv
*
V
+
o_v
b_k
=
tl
.
load
(
p_mixed
+
k_off
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_v
=
tl
.
load
(
p_mixed
+
v_off
,
mask
=
mask_v
,
other
=
0
).
to
(
tl
.
float32
)
if
USE_QK_L2NORM_IN_KERNEL
:
k_norm_inv
=
tl
.
rsqrt
(
tl
.
sum
(
b_k
*
b_k
)
+
1e-6
)
b_k
=
b_k
*
k_norm_inv
x
=
tl
.
load
(
a
+
i_n
*
stride_a_tok
+
i_hv
).
to
(
tl
.
float32
)
x
+=
tl
.
load
(
dt_bias
+
i_hv
).
to
(
tl
.
float32
)
softplus_x
=
tl
.
where
(
x
<=
SOFTPLUS_THRESHOLD
,
tl
.
log
(
1.0
+
tl
.
exp
(
x
)),
x
)
g_val
=
-
tl
.
exp
(
tl
.
load
(
A_log
+
i_hv
).
to
(
tl
.
float32
))
*
softplus_x
beta_val
=
tl
.
sigmoid
(
tl
.
load
(
b
+
i_n
*
stride_b_tok
+
i_hv
).
to
(
tl
.
float32
))
b_h
*=
tl
.
exp
(
g_val
)
b_v
-=
tl
.
sum
(
b_h
*
b_k
[
None
,
:],
1
)
b_v
*=
beta_val
b_h
+=
b_v
[:,
None
]
*
b_k
[
None
,
:]
q_off
=
i_h
*
K
+
o_k
b_q
=
tl
.
load
(
p_mixed
+
q_off
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
if
USE_QK_L2NORM_IN_KERNEL
:
q_norm_inv
=
tl
.
rsqrt
(
tl
.
sum
(
b_q
*
b_q
)
+
1e-6
)
b_q
=
b_q
*
q_norm_inv
b_o
=
tl
.
sum
(
b_h
*
b_q
[
None
,
:],
1
)
b_o
=
b_o
*
scale
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
mask
=
mask_v
)
p_ht
=
ht
+
state_idx
*
stride_final_state_token
p_ht
=
p_ht
+
i_hv
*
V
*
K
+
o_v
[:,
None
]
*
K
+
o_k
[
None
,
:]
tl
.
store
(
p_ht
,
b_h
.
to
(
p_ht
.
dtype
.
element_ty
),
mask
=
(
mask_v
[:,
None
]
&
mask_k
[
None
,
:]))
def
fused_recurrent_gated_delta_rule_packed_decode
(
mixed_qkv
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
A_log
:
torch
.
Tensor
,
dt_bias
:
torch
.
Tensor
,
scale
:
float
,
initial_state
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
ssm_state_indices
:
torch
.
Tensor
,
use_qk_l2norm_in_kernel
:
bool
=
False
,
kernel_cfg
:
dict
|
None
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
global
HAS_DUMPED_PACKED_DECODE_KERNEL_METADATA
if
mixed_qkv
.
ndim
!=
2
:
raise
ValueError
(
f
"`mixed_qkv` must be 2D, got ndim=
{
mixed_qkv
.
ndim
}
."
)
if
a
.
ndim
!=
2
or
b
.
ndim
!=
2
:
raise
ValueError
(
f
"`a` and `b` must be 2D, got a.ndim=
{
a
.
ndim
}
, b.ndim=
{
b
.
ndim
}
."
)
if
A_log
.
ndim
!=
1
or
dt_bias
.
ndim
!=
1
:
raise
ValueError
(
"`A_log` and `dt_bias` must be 1D."
)
if
ssm_state_indices
.
ndim
!=
1
:
raise
ValueError
(
"`ssm_state_indices` must be 1D."
)
if
initial_state
.
ndim
!=
4
:
raise
ValueError
(
f
"`initial_state` must be 4D, got ndim=
{
initial_state
.
ndim
}
."
)
dev
=
mixed_qkv
.
device
if
any
(
t
.
device
!=
dev
for
t
in
(
a
,
b
,
A_log
,
dt_bias
,
initial_state
,
out
,
ssm_state_indices
)):
raise
ValueError
(
"All tensors must be on the same device."
)
B
=
mixed_qkv
.
shape
[
0
]
if
a
.
shape
[
0
]
!=
B
or
b
.
shape
[
0
]
!=
B
or
ssm_state_indices
.
shape
[
0
]
!=
B
:
raise
ValueError
(
"Batch dimensions of mixed_qkv/a/b/ssm_state_indices must match."
)
HV
,
V
,
K
=
initial_state
.
shape
[
-
3
:]
if
a
.
shape
[
1
]
!=
HV
or
b
.
shape
[
1
]
!=
HV
:
raise
ValueError
(
"`a` and `b` second dim must match HV from initial_state."
)
if
A_log
.
numel
()
!=
HV
or
dt_bias
.
numel
()
!=
HV
:
raise
ValueError
(
"`A_log` and `dt_bias` numel must equal HV."
)
if
out
.
shape
!=
(
B
,
1
,
HV
,
V
):
raise
ValueError
(
f
"`out` must have shape
{
(
B
,
1
,
HV
,
V
)
}
, got
{
tuple
(
out
.
shape
)
}
."
)
qkv_dim
=
mixed_qkv
.
shape
[
1
]
qk_dim
=
qkv_dim
-
HV
*
V
if
qk_dim
<=
0
or
qk_dim
%
2
!=
0
:
raise
ValueError
(
"Invalid mixed_qkv layout for packed decode."
)
q_dim
=
qk_dim
//
2
if
q_dim
%
K
!=
0
:
raise
ValueError
(
"Inferred q_dim must be divisible by K."
)
H
=
q_dim
//
K
if
H
<=
0
or
HV
%
H
!=
0
:
raise
ValueError
(
f
"Invalid inferred heads: H=
{
H
}
, HV=
{
HV
}
."
)
BK
=
triton
.
next_power_of_2
(
K
)
cfg
=
kernel_cfg
if
kernel_cfg
is
not
None
else
_get_fused_recurrent_packed_decode_config
(
B
,
H
,
HV
)
BV
=
min
(
triton
.
next_power_of_2
(
V
),
int
(
cfg
[
"BV"
]))
stride_mixed_qkv_tok
=
mixed_qkv
.
stride
(
0
)
stride_a_tok
=
a
.
stride
(
0
)
stride_b_tok
=
b
.
stride
(
0
)
stride_init_state_token
=
initial_state
.
stride
(
0
)
stride_final_state_token
=
initial_state
.
stride
(
0
)
stride_indices_seq
=
ssm_state_indices
.
stride
(
0
)
NV
=
triton
.
cdiv
(
V
,
BV
)
grid
=
(
NV
,
B
*
HV
)
launch_kwargs
=
dict
(
mixed_qkv
=
mixed_qkv
,
a
=
a
,
b
=
b
,
A_log
=
A_log
,
dt_bias
=
dt_bias
,
o
=
out
,
h0
=
initial_state
,
ht
=
initial_state
,
ssm_state_indices
=
ssm_state_indices
,
scale
=
scale
,
stride_mixed_qkv_tok
=
stride_mixed_qkv_tok
,
stride_a_tok
=
stride_a_tok
,
stride_b_tok
=
stride_b_tok
,
stride_init_state_token
=
stride_init_state_token
,
stride_final_state_token
=
stride_final_state_token
,
stride_indices_seq
=
stride_indices_seq
,
H
=
H
,
HV
=
HV
,
K
=
K
,
V
=
V
,
BK
=
BK
,
BV
=
BV
,
SOFTPLUS_THRESHOLD
=
20.0
,
USE_QK_L2NORM_IN_KERNEL
=
use_qk_l2norm_in_kernel
,
num_warps
=
cfg
[
"num_warps"
],
num_stages
=
cfg
[
"num_stages"
],
)
compiled_kernel
=
fused_recurrent_gated_delta_rule_packed_decode_kernel
[
grid
](
**
launch_kwargs
)
'''
if not HAS_DUMPED_PACKED_DECODE_KERNEL_METADATA and compiled_kernel is not None:
print("packed decode kernel metadata")
print(f" grid: {grid}")
print(f" registers: {compiled_kernel.n_regs}")
print(f" spills: {compiled_kernel.n_spills}")
print(f" shared memory: {compiled_kernel.metadata.shared} bytes")
HAS_DUMPED_PACKED_DECODE_KERNEL_METADATA = True
'''
return
out
,
initial_state
aiter/ops/triton/fla/fused_sigmoid_gating.py
0 → 100644
View file @
bb596f6e
# SPDX-License-Identifier: MIT
from
typing
import
Tuple
import
functools
import
json
import
os
import
torch
import
triton
import
triton.language
as
tl
import
aiter.ops.triton.utils.arch_info
as
arch_info
from
aiter
import
logger
from
aiter.ops.triton.utils.core
import
AITER_TRITON_CONFIGS_PATH
# HAS_DUMPED_SIGMOID_GATING_KERNEL_METADATA = False
TRITON_CONFIG_CHECK
=
os
.
environ
.
get
(
"TRITON_CONFIG_CHECK"
,
"0"
)
==
"1"
_DEFAULT_FUSED_SIGMOID_GATING_CONFIG
=
{
"BV"
:
32
,
"num_warps"
:
1
,
}
@
functools
.
lru_cache
(
maxsize
=
1
)
def
_load_fused_sigmoid_gating_configs
()
->
dict
:
device_name
=
arch_info
.
get_arch
()
path
=
os
.
path
.
join
(
AITER_TRITON_CONFIGS_PATH
,
"fused_sigmoid_gating_delta_rule_update"
,
f
"fused_sigmoid_gating_delta_rule_update-
{
device_name
}
.json"
,
)
if
not
os
.
path
.
exists
(
path
):
logger
.
warning
(
f
"fused_sigmoid_gating_delta_rule_update config not found at
{
path
}
, "
f
"using default
{
_DEFAULT_FUSED_SIGMOID_GATING_CONFIG
}
."
)
return
{}
with
open
(
path
)
as
f
:
payload
=
json
.
load
(
f
)
return
payload
.
get
(
"config"
,
{})
if
isinstance
(
payload
,
dict
)
else
{}
@
functools
.
lru_cache
def
_get_fused_sigmoid_gating_config
(
T
:
int
,
H
:
int
,
HV
:
int
)
->
dict
:
cfgs
=
_load_fused_sigmoid_gating_configs
()
key
=
f
"T=
{
T
}
,H=
{
H
}
,HV=
{
HV
}
"
cfg
=
cfgs
.
get
(
key
)
if
cfg
is
None
:
candidates
=
[]
for
k
,
v
in
cfgs
.
items
():
if
k
==
"default"
:
continue
try
:
parts
=
{
x
.
split
(
"="
)[
0
]:
int
(
x
.
split
(
"="
)[
1
])
for
x
in
k
.
split
(
","
)}
except
Exception
:
continue
if
parts
.
get
(
"H"
)
==
H
and
parts
.
get
(
"HV"
)
==
HV
and
"T"
in
parts
:
candidates
.
append
((
abs
(
parts
[
"T"
]
-
T
),
parts
[
"T"
],
v
))
if
candidates
:
candidates
.
sort
(
key
=
lambda
x
:
x
[
0
])
_
,
nearest_t
,
cfg
=
candidates
[
0
]
if
TRITON_CONFIG_CHECK
:
logger
.
warning
(
f
"fused_sigmoid_gating config key '
{
key
}
' not found, "
f
"using nearest-T config with T=
{
nearest_t
}
:
{
cfg
}
."
)
if
cfg
is
None
:
default_cfg
=
cfgs
.
get
(
"default"
,
_DEFAULT_FUSED_SIGMOID_GATING_CONFIG
)
if
TRITON_CONFIG_CHECK
:
logger
.
warning
(
f
"fused_sigmoid_gating config key '
{
key
}
' not found, "
f
"using default config:
{
default_cfg
}
."
)
cfg
=
default_cfg
merged
=
dict
(
_DEFAULT_FUSED_SIGMOID_GATING_CONFIG
)
merged
.
update
(
cfg
)
return
merged
@
triton
.
heuristics
(
{
"USE_INITIAL_STATE"
:
lambda
args
:
args
[
"h0"
]
is
not
None
,
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
,
"IS_CONTINUOUS_BATCHING"
:
lambda
args
:
args
[
"ssm_state_indices"
]
is
not
None
,
"IS_SPEC_DECODING"
:
lambda
args
:
args
[
"num_accepted_tokens"
]
is
not
None
,
}
)
@
triton
.
jit
(
do_not_specialize
=
[
"N"
,
"T"
])
def
fused_sigmoid_gating_delta_rule_update_kernel
(
A_log
,
a
,
b
,
dt_bias
,
beta
,
threshold
,
q
,
k
,
v
,
o
,
h0
,
ht
,
cu_seqlens
,
ssm_state_indices
,
num_accepted_tokens
,
scale
,
N
:
tl
.
int64
,
T
:
tl
.
int64
,
B
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
HV
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
stride_init_state_token
:
tl
.
constexpr
,
stride_final_state_token
:
tl
.
constexpr
,
stride_indices_seq
:
tl
.
constexpr
,
stride_indices_tok
:
tl
.
constexpr
,
USE_INITIAL_STATE
:
tl
.
constexpr
,
INPLACE_FINAL_STATE
:
tl
.
constexpr
,
USE_QK_L2NORM_IN_KERNEL
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
IS_CONTINUOUS_BATCHING
:
tl
.
constexpr
,
IS_SPEC_DECODING
:
tl
.
constexpr
,
IS_KDA
:
tl
.
constexpr
,
):
i_k
,
i_v
,
i_nh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_n
,
i_hv
=
i_nh
//
HV
,
i_nh
%
HV
i_h
=
i_hv
//
(
HV
//
H
)
if
IS_VARLEN
:
bos
,
eos
=
(
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int64
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int64
),
)
all
=
T
T
=
eos
-
bos
else
:
bos
,
eos
=
i_n
*
T
,
i_n
*
T
+
T
all
=
B
*
T
if
T
==
0
:
return
o_k
=
i_k
*
BK
+
tl
.
arange
(
0
,
BK
)
o_v
=
i_v
*
BV
+
tl
.
arange
(
0
,
BV
)
p_q
=
q
+
(
bos
*
H
+
i_h
)
*
K
+
o_k
p_k
=
k
+
(
bos
*
H
+
i_h
)
*
K
+
o_k
p_v
=
v
+
(
bos
*
HV
+
i_hv
)
*
V
+
o_v
p_A_log
=
A_log
+
i_hv
if
not
IS_KDA
:
p_a
=
a
+
bos
*
HV
+
i_hv
p_dt_bias
=
dt_bias
+
i_hv
else
:
p_a
=
a
+
(
bos
*
HV
+
i_hv
)
*
K
+
o_k
p_dt_bias
=
dt_bias
+
i_hv
*
K
+
o_k
p_b
=
b
+
bos
*
HV
+
i_hv
p_o
=
o
+
((
i_k
*
all
+
bos
)
*
HV
+
i_hv
)
*
V
+
o_v
mask_k
=
o_k
<
K
mask_v
=
o_v
<
V
mask_h
=
mask_v
[:,
None
]
&
mask_k
[
None
,
:]
b_A_log
=
tl
.
exp
(
tl
.
load
(
p_A_log
).
to
(
tl
.
float32
))
if
not
IS_KDA
:
b_dt_bias
=
tl
.
load
(
p_dt_bias
).
to
(
tl
.
float32
)
b_h
=
tl
.
zeros
([
BV
,
BK
],
dtype
=
tl
.
float32
)
if
USE_INITIAL_STATE
:
if
IS_CONTINUOUS_BATCHING
:
if
IS_SPEC_DECODING
:
i_t
=
tl
.
load
(
num_accepted_tokens
+
i_n
).
to
(
tl
.
int64
)
-
1
else
:
i_t
=
0
state_idx
=
tl
.
load
(
ssm_state_indices
+
i_n
*
stride_indices_seq
+
i_t
).
to
(
tl
.
int64
)
if
state_idx
<
0
:
return
p_h0
=
h0
+
state_idx
*
stride_init_state_token
else
:
p_h0
=
h0
+
bos
*
HV
*
V
*
K
p_h0
=
p_h0
+
i_hv
*
V
*
K
+
o_v
[:,
None
]
*
K
+
o_k
[
None
,
:]
b_h
+=
tl
.
load
(
p_h0
,
mask
=
mask_h
,
other
=
0
).
to
(
tl
.
float32
)
for
i_t
in
range
(
0
,
T
):
b_q
=
tl
.
load
(
p_q
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_k
=
tl
.
load
(
p_k
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_b
=
tl
.
load
(
p_b
).
to
(
tl
.
float32
)
if
not
IS_KDA
:
x
=
tl
.
load
(
p_a
).
to
(
tl
.
float32
)
+
b_dt_bias
else
:
x
=
tl
.
load
(
p_a
).
to
(
tl
.
float32
)
+
tl
.
load
(
p_dt_bias
).
to
(
tl
.
float32
)
softplus_x
=
tl
.
where
(
beta
*
x
<=
threshold
,
(
1
/
beta
)
*
tl
.
log
(
1
+
tl
.
exp
(
beta
*
x
)),
x
)
b_g
=
-
b_A_log
*
softplus_x
b_beta
=
tl
.
sigmoid
(
b_b
)
b_v
=
tl
.
load
(
p_v
,
mask
=
mask_v
,
other
=
0
).
to
(
tl
.
float32
)
if
USE_QK_L2NORM_IN_KERNEL
:
b_q
=
b_q
*
(
tl
.
rsqrt
(
tl
.
sum
(
b_q
*
b_q
)
+
1e-6
))
b_k
=
b_k
*
(
tl
.
rsqrt
(
tl
.
sum
(
b_k
*
b_k
)
+
1e-6
))
b_q
=
b_q
*
scale
if
not
IS_KDA
:
b_h
*=
tl
.
exp
(
b_g
)
else
:
b_h
*=
tl
.
exp
(
b_g
[
None
,
:])
b_v
-=
tl
.
sum
(
b_h
*
b_k
[
None
,
:],
1
)
b_v
*=
b_beta
b_h
+=
b_v
[:,
None
]
*
b_k
[
None
,
:]
b_o
=
tl
.
sum
(
b_h
*
b_q
[
None
,
:],
1
)
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
mask
=
mask_v
)
if
INPLACE_FINAL_STATE
:
final_state_idx
=
tl
.
load
(
ssm_state_indices
+
i_n
*
stride_indices_seq
+
i_t
).
to
(
tl
.
int64
)
if
final_state_idx
>=
0
:
p_ht
=
ht
+
final_state_idx
*
stride_final_state_token
p_ht
=
p_ht
+
i_hv
*
V
*
K
+
o_v
[:,
None
]
*
K
+
o_k
[
None
,
:]
tl
.
store
(
p_ht
,
b_h
.
to
(
p_ht
.
dtype
.
element_ty
),
mask
=
mask_h
)
else
:
p_ht
=
ht
+
(
bos
+
i_t
)
*
stride_final_state_token
p_ht
=
p_ht
+
i_hv
*
V
*
K
+
o_v
[:,
None
]
*
K
+
o_k
[
None
,
:]
tl
.
store
(
p_ht
,
b_h
.
to
(
p_ht
.
dtype
.
element_ty
),
mask
=
mask_h
)
p_q
+=
H
*
K
p_k
+=
H
*
K
p_o
+=
HV
*
V
p_v
+=
HV
*
V
p_b
+=
HV
p_a
+=
HV
def
fused_sigmoid_gating_delta_rule_update
(
A_log
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
dt_bias
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
beta
:
float
=
1.0
,
threshold
:
float
=
20.0
,
scale
:
float
|
None
=
None
,
initial_state
:
torch
.
Tensor
|
None
=
None
,
inplace_final_state
:
bool
=
True
,
cu_seqlens
:
torch
.
Tensor
|
None
=
None
,
ssm_state_indices
:
torch
.
Tensor
|
None
=
None
,
num_accepted_tokens
:
torch
.
Tensor
|
None
=
None
,
use_qk_l2norm_in_kernel
:
bool
=
False
,
is_kda
:
bool
=
False
,
kernel_cfg
:
dict
|
None
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
global
HAS_DUMPED_SIGMOID_GATING_KERNEL_METADATA
B
,
T
,
H
,
K
,
V
=
*
k
.
shape
,
v
.
shape
[
-
1
]
HV
=
v
.
shape
[
2
]
N
=
B
if
cu_seqlens
is
None
else
len
(
cu_seqlens
)
-
1
BK
=
triton
.
next_power_of_2
(
K
)
cfg
=
kernel_cfg
if
kernel_cfg
is
not
None
else
_get_fused_sigmoid_gating_config
(
T
,
H
,
HV
)
BV
=
min
(
triton
.
next_power_of_2
(
V
),
int
(
cfg
[
"BV"
]))
NK
,
NV
=
triton
.
cdiv
(
K
,
BK
),
triton
.
cdiv
(
V
,
BV
)
num_warps
=
int
(
cfg
[
"num_warps"
])
if
NK
!=
1
:
raise
ValueError
(
f
"NK > 1 is not supported (K=
{
K
}
, BK=
{
BK
}
, NK=
{
NK
}
)."
)
if
scale
is
None
:
scale
=
K
**-
0.5
elif
scale
<=
0
:
raise
ValueError
(
"scale must be positive."
)
if
initial_state
is
None
:
raise
ValueError
(
"initial_state must not be None."
)
if
cu_seqlens
is
not
None
and
q
.
shape
[
0
]
!=
1
:
raise
ValueError
(
f
"q.shape[0] must be 1 when using cu_seqlens, got
{
q
.
shape
[
0
]
}
."
)
o
=
q
.
new_empty
(
NK
,
*
v
.
shape
)
final_state
=
initial_state
if
inplace_final_state
else
q
.
new_empty
(
T
,
HV
,
V
,
K
,
dtype
=
initial_state
.
dtype
)
stride_init_state_token
=
initial_state
.
stride
(
0
)
stride_final_state_token
=
final_state
.
stride
(
0
)
if
ssm_state_indices
is
None
:
stride_indices_seq
,
stride_indices_tok
=
1
,
1
elif
ssm_state_indices
.
ndim
==
1
:
stride_indices_seq
,
stride_indices_tok
=
ssm_state_indices
.
stride
(
0
),
1
elif
ssm_state_indices
.
ndim
==
2
:
stride_indices_seq
,
stride_indices_tok
=
ssm_state_indices
.
stride
()
else
:
raise
ValueError
(
f
"ssm_state_indices must be 1D/2D when provided, got ndim=
{
ssm_state_indices
.
ndim
}
."
)
grid
=
(
NK
,
NV
,
N
*
HV
)
compiled_kernel
=
fused_sigmoid_gating_delta_rule_update_kernel
[
grid
](
A_log
=
A_log
,
a
=
a
.
contiguous
(),
b
=
b
.
contiguous
(),
dt_bias
=
dt_bias
,
beta
=
beta
,
threshold
=
threshold
,
q
=
q
.
contiguous
(),
k
=
k
.
contiguous
(),
v
=
v
.
contiguous
(),
o
=
o
,
h0
=
initial_state
,
ht
=
final_state
,
cu_seqlens
=
cu_seqlens
,
ssm_state_indices
=
ssm_state_indices
,
num_accepted_tokens
=
num_accepted_tokens
,
scale
=
scale
,
N
=
N
,
T
=
T
,
B
=
B
,
H
=
H
,
HV
=
HV
,
K
=
K
,
V
=
V
,
BK
=
BK
,
BV
=
BV
,
stride_init_state_token
=
stride_init_state_token
,
stride_final_state_token
=
stride_final_state_token
,
stride_indices_seq
=
stride_indices_seq
,
stride_indices_tok
=
stride_indices_tok
,
INPLACE_FINAL_STATE
=
inplace_final_state
,
USE_QK_L2NORM_IN_KERNEL
=
use_qk_l2norm_in_kernel
,
IS_KDA
=
is_kda
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
'''
if not HAS_DUMPED_SIGMOID_GATING_KERNEL_METADATA and compiled_kernel is not None:
print("sigmoid gating kernel metadata")
print(f" grid: {grid}")
print(f" registers: {compiled_kernel.n_regs}")
print(f" spills: {compiled_kernel.n_spills}")
print(f" shared memory: {compiled_kernel.metadata.shared} bytes")
HAS_DUMPED_SIGMOID_GATING_KERNEL_METADATA = True
'''
return
o
.
squeeze
(
0
),
final_state
aiter/ops/triton/fla/fused_sigmoid_gating_recurrent.py
0 → 100644
View file @
bb596f6e
from
typing
import
Optional
import
functools
import
json
import
os
import
torch
import
triton
import
triton.language
as
tl
import
aiter.ops.triton.utils.arch_info
as
arch_info
from
aiter
import
logger
from
aiter.ops.triton.utils.core
import
AITER_TRITON_CONFIGS_PATH
# HAS_DUMPED_SIGMOID_GATING_REC_KERNEL_METADATA = False
TRITON_CONFIG_CHECK
=
os
.
environ
.
get
(
"TRITON_CONFIG_CHECK"
,
"0"
)
==
"1"
_DEFAULT_FUSED_SIGMOID_GATING_REC_CONFIG
=
{
"BV"
:
32
,
"num_warps"
:
1
,
}
@
functools
.
lru_cache
(
maxsize
=
1
)
def
_load_fused_sigmoid_gating_recurrent_configs
()
->
dict
:
device_name
=
arch_info
.
get_arch
()
path
=
os
.
path
.
join
(
AITER_TRITON_CONFIGS_PATH
,
"fused_sigmoid_gating_delta_rule_update_recurrent"
,
f
"fused_sigmoid_gating_delta_rule_update_recurrent-
{
device_name
}
.json"
,
)
if
not
os
.
path
.
exists
(
path
):
logger
.
warning
(
f
"fused_sigmoid_gating_delta_rule_update_recurrent config not found at
{
path
}
, "
f
"using default
{
_DEFAULT_FUSED_SIGMOID_GATING_REC_CONFIG
}
."
)
return
{}
with
open
(
path
)
as
f
:
payload
=
json
.
load
(
f
)
return
payload
.
get
(
"config"
,
{})
if
isinstance
(
payload
,
dict
)
else
{}
@
functools
.
lru_cache
def
_get_fused_sigmoid_gating_recurrent_config
(
T
:
int
,
H
:
int
,
HV
:
int
)
->
dict
:
cfgs
=
_load_fused_sigmoid_gating_recurrent_configs
()
key
=
f
"T=
{
T
}
,H=
{
H
}
,HV=
{
HV
}
"
cfg
=
cfgs
.
get
(
key
)
if
cfg
is
None
:
candidates
=
[]
for
k
,
v
in
cfgs
.
items
():
if
k
==
"default"
:
continue
try
:
parts
=
{
x
.
split
(
"="
)[
0
]:
int
(
x
.
split
(
"="
)[
1
])
for
x
in
k
.
split
(
","
)}
except
Exception
:
continue
if
parts
.
get
(
"H"
)
==
H
and
parts
.
get
(
"HV"
)
==
HV
and
"T"
in
parts
:
candidates
.
append
((
abs
(
parts
[
"T"
]
-
T
),
parts
[
"T"
],
v
))
if
candidates
:
candidates
.
sort
(
key
=
lambda
x
:
x
[
0
])
_
,
nearest_t
,
cfg
=
candidates
[
0
]
if
TRITON_CONFIG_CHECK
:
logger
.
warning
(
f
"fused_sigmoid_gating_recurrent config key '
{
key
}
' not found, "
f
"using nearest-T config with T=
{
nearest_t
}
:
{
cfg
}
."
)
if
cfg
is
None
:
default_cfg
=
cfgs
.
get
(
"default"
,
_DEFAULT_FUSED_SIGMOID_GATING_REC_CONFIG
)
if
TRITON_CONFIG_CHECK
:
logger
.
warning
(
f
"fused_sigmoid_gating_recurrent config key '
{
key
}
' not found, "
f
"using default config:
{
default_cfg
}
."
)
cfg
=
default_cfg
merged
=
dict
(
_DEFAULT_FUSED_SIGMOID_GATING_REC_CONFIG
)
merged
.
update
(
cfg
)
return
merged
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
fused_sigmoid_gating_delta_rule_update_kernel
(
A_log
,
a
,
dt_bias
,
softplus_beta
,
softplus_threshold
,
q
,
k
,
v
,
b
,
o
,
h0_source
,
h0_indices
,
cu_seqlens
,
# Parameters for target_verify support (unused for decode)
intermediate_states_buffer
,
intermediate_state_indices
,
cache_steps
,
retrieve_parent_token_ptr
,
stride_retrieve_parent_token_seq
:
tl
.
constexpr
,
stride_retrieve_parent_token_token
:
tl
.
constexpr
,
# ================================================
scale
,
T
,
stride_q
,
stride_k
,
stride_v
,
stride_b
,
NP2_T
:
tl
.
constexpr
,
B
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
HV
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_INITIAL_STATE
:
tl
.
constexpr
,
USE_QK_L2NORM_IN_KERNEL
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
IS_KDA
:
tl
.
constexpr
,
# Optional flags for target_verify support (default False for decode)
DISABLE_STATE_UPDATE
:
tl
.
constexpr
=
False
,
CACHE_INTERMEDIATE_STATES
:
tl
.
constexpr
=
False
,
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK
:
tl
.
constexpr
=
False
,
):
"""
Fused kernel that combines sigmoid gating computation with recurrent delta rule update.
"""
i_k
,
i_v
,
i_nh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_n
,
i_hv
=
i_nh
//
HV
,
i_nh
%
HV
i_h
=
i_hv
//
(
HV
//
H
)
if
IS_VARLEN
:
bos
,
eos
=
(
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int64
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int64
),
)
all
=
T
T
=
eos
-
bos
else
:
bos
,
eos
=
i_n
*
T
,
i_n
*
T
+
T
all
=
B
*
T
o_k
=
i_k
*
BK
+
tl
.
arange
(
0
,
BK
)
o_v
=
i_v
*
BV
+
tl
.
arange
(
0
,
BV
)
p_q
=
q
+
bos
*
stride_q
+
i_h
*
K
+
o_k
p_k
=
k
+
bos
*
stride_k
+
i_h
*
K
+
o_k
p_v
=
v
+
bos
*
stride_v
+
i_hv
*
V
+
o_v
p_b
=
b
+
bos
*
stride_b
+
i_hv
p_o
=
o
+
((
i_k
*
all
+
bos
)
*
HV
+
i_hv
)
*
V
+
o_v
# Gating computation pointers
p_A_log
=
A_log
+
i_hv
if
IS_KDA
:
p_a
=
a
+
(
bos
*
HV
+
i_hv
)
*
K
+
o_k
p_dt_bias
=
dt_bias
+
i_hv
*
K
+
o_k
else
:
p_a
=
a
+
bos
*
HV
+
i_hv
p_dt_bias
=
dt_bias
+
i_hv
mask_k
=
o_k
<
K
mask_v
=
o_v
<
V
mask_h
=
mask_k
[:,
None
]
&
mask_v
[
None
,
:]
b_h
=
tl
.
zeros
([
BK
,
BV
],
dtype
=
tl
.
float32
)
if
USE_INITIAL_STATE
:
idx
=
tl
.
load
(
h0_indices
+
i_n
)
if
idx
>=
0
:
p_h0
=
(
h0_source
+
idx
*
HV
*
K
*
V
+
i_hv
*
K
*
V
+
o_v
[
None
,
:]
*
K
+
o_k
[:,
None
]
)
b_h
+=
tl
.
load
(
p_h0
,
mask
=
mask_h
,
other
=
0
).
to
(
tl
.
float32
)
# Preload tree attention data if needed
if
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK
:
token_indices
=
tl
.
arange
(
0
,
NP2_T
)
mask_retrieve
=
token_indices
<
T
retrieve_parent_token_base
=
(
retrieve_parent_token_ptr
+
(
i_n
*
stride_retrieve_parent_token_seq
)
+
token_indices
*
stride_retrieve_parent_token_token
)
parent_idx_tokens
=
tl
.
load
(
retrieve_parent_token_base
,
mask
=
mask_retrieve
,
other
=
0
)
# Prepare intermediate state cache index if enabled
cache_idx
=
-
1
if
CACHE_INTERMEDIATE_STATES
:
cache_idx
=
tl
.
load
(
intermediate_state_indices
+
i_n
)
# Invariant across timesteps.
b_A
=
tl
.
exp
(
tl
.
load
(
p_A_log
).
to
(
tl
.
float32
))
if
not
IS_KDA
:
b_dt_bias
=
tl
.
load
(
p_dt_bias
).
to
(
tl
.
float32
)
step_idx
=
0
for
_
in
range
(
0
,
T
):
# Tree attention: load parent's cached state
if
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK
:
# step_idx == 0 uses b_h from USE_INITIAL_STATE
if
step_idx
!=
0
and
cache_idx
>=
0
:
parent_step_idx
=
tl
.
sum
(
tl
.
where
(
token_indices
==
step_idx
,
parent_idx_tokens
,
0
)
)
step_offset
=
parent_step_idx
*
HV
*
K
*
V
cache_ptr
=
(
intermediate_states_buffer
+
cache_idx
*
cache_steps
*
HV
*
K
*
V
+
step_offset
+
i_hv
*
K
*
V
+
o_v
[
None
,
:]
*
K
+
o_k
[:,
None
]
)
b_h
=
tl
.
load
(
cache_ptr
,
mask
=
mask_h
,
other
=
0
).
to
(
tl
.
float32
)
# Load k first; q is loaded later right before output to reduce register live range.
b_k
=
tl
.
load
(
p_k
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
# Compute sigmoid gating
# Load gating parameters
if
IS_KDA
:
b_a
=
tl
.
load
(
p_a
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_dt_bias
=
tl
.
load
(
p_dt_bias
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
else
:
b_a
=
tl
.
load
(
p_a
).
to
(
tl
.
float32
)
# Compute g with tighter live ranges for intermediates.
x
=
b_a
+
b_dt_bias
x_scaled
=
softplus_beta
*
x
x
=
tl
.
where
(
x_scaled
<=
softplus_threshold
,
(
1.0
/
softplus_beta
)
*
tl
.
log
(
1.0
+
tl
.
exp
(
x_scaled
)),
x
,
)
b_g
=
-
b_A
*
x
# Apply L2 normalization to k early; q normalization is deferred until q is loaded.
if
USE_QK_L2NORM_IN_KERNEL
:
b_k
=
b_k
*
tl
.
rsqrt
(
tl
.
sum
(
b_k
*
b_k
)
+
1e-6
)
# Apply gating to hidden state: h *= exp(g)
if
IS_KDA
:
b_h
*=
tl
.
exp
(
b_g
[:,
None
])
else
:
b_h
*=
tl
.
exp
(
b_g
)
b_v
=
tl
.
load
(
p_v
,
mask
=
mask_v
,
other
=
0
).
to
(
tl
.
float32
)
# Delta rule: v -= sum(h * k, dim=0)
b_v
-=
tl
.
sum
(
b_h
*
b_k
[:,
None
],
0
)
# Apply beta gating: v *= beta
b_v
*=
tl
.
sigmoid
(
tl
.
load
(
p_b
).
to
(
tl
.
float32
))
# Update hidden state: h += k[:, None] * v[None, :]
b_h
+=
b_k
[:,
None
]
*
b_v
[
None
,
:]
# Load q late to shorten q live range and lower peak register pressure.
b_q
=
tl
.
load
(
p_q
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
if
USE_QK_L2NORM_IN_KERNEL
:
b_q
=
b_q
*
tl
.
rsqrt
(
tl
.
sum
(
b_q
*
b_q
)
+
1e-6
)
b_q
=
b_q
*
scale
# Compute output: o = sum(h * q, dim=0)
b_o
=
tl
.
sum
(
b_h
*
b_q
[:,
None
],
0
)
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
mask
=
mask_v
)
# Cache intermediate states if enabled
if
CACHE_INTERMEDIATE_STATES
:
if
cache_idx
>=
0
:
step_offset
=
step_idx
*
HV
*
K
*
V
cache_ptr
=
(
intermediate_states_buffer
+
cache_idx
*
cache_steps
*
HV
*
K
*
V
+
step_offset
+
i_hv
*
K
*
V
+
o_v
[
None
,
:]
*
K
+
o_k
[:,
None
]
)
tl
.
store
(
cache_ptr
,
b_h
.
to
(
cache_ptr
.
dtype
.
element_ty
),
mask
=
mask_h
)
step_idx
+=
1
# Update pointers for next timestep
p_q
+=
stride_q
p_k
+=
stride_k
p_v
+=
stride_v
p_b
+=
stride_b
p_o
+=
HV
*
V
if
IS_KDA
:
p_a
+=
HV
*
K
else
:
p_a
+=
HV
# Store final state back to h0_source with bounds checking
if
not
DISABLE_STATE_UPDATE
:
if
USE_INITIAL_STATE
:
idx
=
tl
.
load
(
h0_indices
+
i_n
)
if
idx
>=
0
:
p_h0
=
(
h0_source
+
idx
*
HV
*
K
*
V
+
i_hv
*
K
*
V
+
o_v
[
None
,
:]
*
K
+
o_k
[:,
None
]
)
tl
.
store
(
p_h0
,
b_h
.
to
(
p_h0
.
dtype
.
element_ty
),
mask
=
mask_h
)
def
fused_sigmoid_gating_delta_rule_update
(
A_log
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
dt_bias
:
torch
.
Tensor
,
softplus_beta
:
float
,
softplus_threshold
:
float
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
initial_state_source
:
torch
.
Tensor
,
initial_state_indices
:
torch
.
Tensor
,
scale
:
Optional
[
float
]
=
None
,
use_qk_l2norm_in_kernel
:
bool
=
False
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
is_kda
:
bool
=
False
,
# Optional parameters for target_verify support
disable_state_update
:
bool
=
False
,
intermediate_states_buffer
:
Optional
[
torch
.
Tensor
]
=
None
,
intermediate_state_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_steps
:
Optional
[
int
]
=
None
,
retrieve_parent_token
:
Optional
[
torch
.
Tensor
]
=
None
,
kernel_cfg
:
dict
|
None
=
None
,
):
global
HAS_DUMPED_SIGMOID_GATING_REC_KERNEL_METADATA
"""
Fused triton implementation of sigmoid gating delta rule update.
This function uses a single fused kernel that combines both sigmoid gating computation
and the recurrent delta rule update for better performance.
Supports both decode and target_verify modes:
- decode: standard single-step update with state write-back
- target_verify: multi-step with intermediate state caching, optional tree attention,
and optional state update disable
"""
B
,
T
,
H
,
K
,
V
=
*
k
.
shape
,
v
.
shape
[
-
1
]
stride_q
=
q
.
stride
()[
1
]
stride_k
=
k
.
stride
()[
1
]
stride_v
=
v
.
stride
()[
1
]
stride_b
=
b
.
stride
()[
-
2
]
HV
=
v
.
shape
[
2
]
N
=
B
if
cu_seqlens
is
None
else
len
(
cu_seqlens
)
-
1
BK
=
triton
.
next_power_of_2
(
K
)
cfg
=
kernel_cfg
if
kernel_cfg
is
not
None
else
_get_fused_sigmoid_gating_recurrent_config
(
T
,
H
,
HV
)
BV
=
min
(
triton
.
next_power_of_2
(
V
),
int
(
cfg
[
"BV"
]))
NK
,
NV
=
triton
.
cdiv
(
K
,
BK
),
triton
.
cdiv
(
V
,
BV
)
assert
NK
==
1
,
"NK > 1 is not supported yet"
num_warps
=
int
(
cfg
[
"num_warps"
])
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**
-
0.5
else
:
assert
scale
>
0
,
"scale must be positive"
o
=
q
.
new_empty
(
NK
,
*
v
.
shape
)
# Prepare retrieve_parent_token strides
if
retrieve_parent_token
is
not
None
:
stride_retrieve_parent_token_seq
=
retrieve_parent_token
.
stride
(
0
)
stride_retrieve_parent_token_token
=
retrieve_parent_token
.
stride
(
1
)
else
:
stride_retrieve_parent_token_seq
=
0
stride_retrieve_parent_token_token
=
0
NP2_T
=
triton
.
next_power_of_2
(
T
)
grid
=
(
NK
,
NV
,
N
*
HV
)
compiled_kernel
=
fused_sigmoid_gating_delta_rule_update_kernel
[
grid
](
A_log
=
A_log
,
a
=
a
,
dt_bias
=
dt_bias
,
softplus_beta
=
softplus_beta
,
softplus_threshold
=
softplus_threshold
,
q
=
q
,
k
=
k
,
v
=
v
,
b
=
b
,
o
=
o
,
h0_source
=
initial_state_source
,
h0_indices
=
initial_state_indices
,
cu_seqlens
=
cu_seqlens
,
intermediate_states_buffer
=
intermediate_states_buffer
,
intermediate_state_indices
=
intermediate_state_indices
,
cache_steps
=
0
if
cache_steps
is
None
else
cache_steps
,
retrieve_parent_token_ptr
=
retrieve_parent_token
,
stride_retrieve_parent_token_seq
=
stride_retrieve_parent_token_seq
,
stride_retrieve_parent_token_token
=
stride_retrieve_parent_token_token
,
scale
=
scale
,
T
=
T
,
stride_q
=
stride_q
,
stride_k
=
stride_k
,
stride_v
=
stride_v
,
stride_b
=
stride_b
,
NP2_T
=
NP2_T
,
B
=
B
,
H
=
H
,
HV
=
HV
,
K
=
K
,
V
=
V
,
BK
=
BK
,
BV
=
BV
,
USE_INITIAL_STATE
=
initial_state_source
is
not
None
,
USE_QK_L2NORM_IN_KERNEL
=
use_qk_l2norm_in_kernel
,
IS_VARLEN
=
cu_seqlens
is
not
None
,
IS_KDA
=
is_kda
,
DISABLE_STATE_UPDATE
=
disable_state_update
,
CACHE_INTERMEDIATE_STATES
=
intermediate_states_buffer
is
not
None
,
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK
=
retrieve_parent_token
is
not
None
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
'''
if not HAS_DUMPED_SIGMOID_GATING_REC_KERNEL_METADATA and compiled_kernel is not None:
print("sigmoid gating recurrent kernel metadata")
print(f" grid: {grid}")
print(f" registers: {compiled_kernel.n_regs}")
print(f" spills: {compiled_kernel.n_spills}")
print(f" shared memory: {compiled_kernel.metadata.shared} bytes")
HAS_DUMPED_SIGMOID_GATING_REC_KERNEL_METADATA = True
'''
o
=
o
.
squeeze
(
0
)
return
o
aiter/ops/triton/fla/fused_sigmoid_gating_recurrent_ref.py
0 → 100644
View file @
bb596f6e
from
typing
import
Optional
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
fused_sigmoid_gating_delta_rule_update_kernel_ref
(
A_log
,
a
,
dt_bias
,
softplus_beta
,
softplus_threshold
,
q
,
k
,
v
,
b
,
o
,
h0_source
,
h0_indices
,
cu_seqlens
,
# Parameters for target_verify support (unused for decode)
intermediate_states_buffer
,
intermediate_state_indices
,
cache_steps
,
retrieve_parent_token_ptr
,
stride_retrieve_parent_token_seq
:
tl
.
constexpr
,
stride_retrieve_parent_token_token
:
tl
.
constexpr
,
# ================================================
scale
,
T
,
stride_q
,
stride_k
,
stride_v
,
stride_b
,
NP2_T
:
tl
.
constexpr
,
B
:
tl
.
constexpr
,
H
:
tl
.
constexpr
,
HV
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_INITIAL_STATE
:
tl
.
constexpr
,
USE_QK_L2NORM_IN_KERNEL
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
IS_KDA
:
tl
.
constexpr
,
# Optional flags for target_verify support (default False for decode)
DISABLE_STATE_UPDATE
:
tl
.
constexpr
=
False
,
CACHE_INTERMEDIATE_STATES
:
tl
.
constexpr
=
False
,
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK
:
tl
.
constexpr
=
False
,
):
"""
Fused kernel that combines sigmoid gating computation with recurrent delta rule update.
"""
i_k
,
i_v
,
i_nh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_n
,
i_hv
=
i_nh
//
HV
,
i_nh
%
HV
i_h
=
i_hv
//
(
HV
//
H
)
if
IS_VARLEN
:
bos
,
eos
=
(
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int64
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int64
),
)
all
=
T
T
=
eos
-
bos
else
:
bos
,
eos
=
i_n
*
T
,
i_n
*
T
+
T
all
=
B
*
T
o_k
=
i_k
*
BK
+
tl
.
arange
(
0
,
BK
)
o_v
=
i_v
*
BV
+
tl
.
arange
(
0
,
BV
)
p_q
=
q
+
bos
*
stride_q
+
i_h
*
K
+
o_k
p_k
=
k
+
bos
*
stride_k
+
i_h
*
K
+
o_k
p_v
=
v
+
bos
*
stride_v
+
i_hv
*
V
+
o_v
p_b
=
b
+
bos
*
stride_b
+
i_hv
p_o
=
o
+
((
i_k
*
all
+
bos
)
*
HV
+
i_hv
)
*
V
+
o_v
# Gating computation pointers
p_A_log
=
A_log
+
i_hv
if
IS_KDA
:
p_a
=
a
+
(
bos
*
HV
+
i_hv
)
*
K
+
o_k
p_dt_bias
=
dt_bias
+
i_hv
*
K
+
o_k
else
:
p_a
=
a
+
bos
*
HV
+
i_hv
p_dt_bias
=
dt_bias
+
i_hv
mask_k
=
o_k
<
K
mask_v
=
o_v
<
V
mask_h
=
mask_k
[:,
None
]
&
mask_v
[
None
,
:]
b_h
=
tl
.
zeros
([
BK
,
BV
],
dtype
=
tl
.
float32
)
if
USE_INITIAL_STATE
:
idx
=
tl
.
load
(
h0_indices
+
i_n
)
if
idx
>=
0
:
p_h0
=
(
h0_source
+
idx
*
HV
*
K
*
V
+
i_hv
*
K
*
V
+
o_v
[
None
,
:]
*
K
+
o_k
[:,
None
]
)
b_h
+=
tl
.
load
(
p_h0
,
mask
=
mask_h
,
other
=
0
).
to
(
tl
.
float32
)
# Preload tree attention data if needed
if
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK
:
token_indices
=
tl
.
arange
(
0
,
NP2_T
)
mask_retrieve
=
token_indices
<
T
retrieve_parent_token_base
=
(
retrieve_parent_token_ptr
+
(
i_n
*
stride_retrieve_parent_token_seq
)
+
token_indices
*
stride_retrieve_parent_token_token
)
parent_idx_tokens
=
tl
.
load
(
retrieve_parent_token_base
,
mask
=
mask_retrieve
,
other
=
0
)
# Prepare intermediate state cache index if enabled
cache_idx
=
-
1
if
CACHE_INTERMEDIATE_STATES
:
cache_idx
=
tl
.
load
(
intermediate_state_indices
+
i_n
)
step_idx
=
0
for
_
in
range
(
0
,
T
):
# Tree attention: load parent's cached state
if
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK
:
# step_idx == 0 uses b_h from USE_INITIAL_STATE
if
step_idx
!=
0
and
cache_idx
>=
0
:
parent_step_idx
=
tl
.
sum
(
tl
.
where
(
token_indices
==
step_idx
,
parent_idx_tokens
,
0
)
)
step_offset
=
parent_step_idx
*
HV
*
K
*
V
cache_ptr
=
(
intermediate_states_buffer
+
cache_idx
*
cache_steps
*
HV
*
K
*
V
+
step_offset
+
i_hv
*
K
*
V
+
o_v
[
None
,
:]
*
K
+
o_k
[:,
None
]
)
b_h
=
tl
.
load
(
cache_ptr
,
mask
=
mask_h
,
other
=
0
).
to
(
tl
.
float32
)
# Load inputs
b_q
=
tl
.
load
(
p_q
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_k
=
tl
.
load
(
p_k
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_v
=
tl
.
load
(
p_v
,
mask
=
mask_v
,
other
=
0
).
to
(
tl
.
float32
)
b_b
=
tl
.
load
(
p_b
).
to
(
tl
.
float32
)
# Compute sigmoid gating
# Load gating parameters
b_A_log
=
tl
.
load
(
p_A_log
).
to
(
tl
.
float32
)
if
IS_KDA
:
b_a
=
tl
.
load
(
p_a
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
b_dt_bias
=
tl
.
load
(
p_dt_bias
,
mask
=
mask_k
,
other
=
0
).
to
(
tl
.
float32
)
else
:
b_a
=
tl
.
load
(
p_a
).
to
(
tl
.
float32
)
b_dt_bias
=
tl
.
load
(
p_dt_bias
).
to
(
tl
.
float32
)
# Compute g = -exp(A_log) * softplus(a + dt_bias)
x
=
b_a
+
b_dt_bias
beta_x
=
softplus_beta
*
x
# Apply softplus with numerical stability
softplus_x
=
tl
.
where
(
beta_x
<=
softplus_threshold
,
(
1.0
/
softplus_beta
)
*
tl
.
log
(
1.0
+
tl
.
exp
(
beta_x
)),
x
,
)
b_g
=
-
tl
.
exp
(
b_A_log
)
*
softplus_x
# Compute beta = sigmoid(b)
b_beta
=
1.0
/
(
1.0
+
tl
.
exp
(
-
b_b
))
# Apply L2 normalization if enabled
if
USE_QK_L2NORM_IN_KERNEL
:
b_q
=
b_q
/
(
tl
.
sqrt
(
tl
.
sum
(
b_q
*
b_q
)
+
1e-6
))
b_k
=
b_k
/
(
tl
.
sqrt
(
tl
.
sum
(
b_k
*
b_k
)
+
1e-6
))
b_q
=
b_q
*
scale
# Apply gating to hidden state: h *= exp(g)
if
IS_KDA
:
b_h
*=
tl
.
exp
(
b_g
[:,
None
])
else
:
b_h
*=
tl
.
exp
(
b_g
)
# Delta rule: v -= sum(h * k, dim=0)
b_v
-=
tl
.
sum
(
b_h
*
b_k
[:,
None
],
0
)
# Apply beta gating: v *= beta
b_v
*=
b_beta
# Update hidden state: h += k[:, None] * v[None, :]
b_h
+=
b_k
[:,
None
]
*
b_v
[
None
,
:]
# Compute output: o = sum(h * q, dim=0)
b_o
=
tl
.
sum
(
b_h
*
b_q
[:,
None
],
0
)
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
mask
=
mask_v
)
# Cache intermediate states if enabled
if
CACHE_INTERMEDIATE_STATES
:
if
cache_idx
>=
0
:
step_offset
=
step_idx
*
HV
*
K
*
V
cache_ptr
=
(
intermediate_states_buffer
+
cache_idx
*
cache_steps
*
HV
*
K
*
V
+
step_offset
+
i_hv
*
K
*
V
+
o_v
[
None
,
:]
*
K
+
o_k
[:,
None
]
)
tl
.
store
(
cache_ptr
,
b_h
.
to
(
cache_ptr
.
dtype
.
element_ty
),
mask
=
mask_h
)
step_idx
+=
1
# Update pointers for next timestep
p_q
+=
stride_q
p_k
+=
stride_k
p_v
+=
stride_v
p_b
+=
stride_b
p_o
+=
HV
*
V
if
IS_KDA
:
p_a
+=
HV
*
K
else
:
p_a
+=
HV
# Store final state back to h0_source with bounds checking
if
not
DISABLE_STATE_UPDATE
:
if
USE_INITIAL_STATE
:
idx
=
tl
.
load
(
h0_indices
+
i_n
)
if
idx
>=
0
:
p_h0
=
(
h0_source
+
idx
*
HV
*
K
*
V
+
i_hv
*
K
*
V
+
o_v
[
None
,
:]
*
K
+
o_k
[:,
None
]
)
tl
.
store
(
p_h0
,
b_h
.
to
(
p_h0
.
dtype
.
element_ty
),
mask
=
mask_h
)
def
fused_sigmoid_gating_delta_rule_update
(
A_log
:
torch
.
Tensor
,
a
:
torch
.
Tensor
,
dt_bias
:
torch
.
Tensor
,
softplus_beta
:
float
,
softplus_threshold
:
float
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
initial_state_source
:
torch
.
Tensor
,
initial_state_indices
:
torch
.
Tensor
,
scale
:
Optional
[
float
]
=
None
,
use_qk_l2norm_in_kernel
:
bool
=
False
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
is_kda
:
bool
=
False
,
# Optional parameters for target_verify support
disable_state_update
:
bool
=
False
,
intermediate_states_buffer
:
Optional
[
torch
.
Tensor
]
=
None
,
intermediate_state_indices
:
Optional
[
torch
.
Tensor
]
=
None
,
cache_steps
:
Optional
[
int
]
=
None
,
retrieve_parent_token
:
Optional
[
torch
.
Tensor
]
=
None
,
):
"""
Fused triton implementation of sigmoid gating delta rule update.
This function uses a single fused kernel that combines both sigmoid gating computation
and the recurrent delta rule update for better performance.
Supports both decode and target_verify modes:
- decode: standard single-step update with state write-back
- target_verify: multi-step with intermediate state caching, optional tree attention,
and optional state update disable
"""
B
,
T
,
H
,
K
,
V
=
*
k
.
shape
,
v
.
shape
[
-
1
]
stride_q
=
q
.
stride
()[
1
]
stride_k
=
k
.
stride
()[
1
]
stride_v
=
v
.
stride
()[
1
]
stride_b
=
b
.
stride
()[
-
2
]
HV
=
v
.
shape
[
2
]
N
=
B
if
cu_seqlens
is
None
else
len
(
cu_seqlens
)
-
1
BK
,
BV
=
triton
.
next_power_of_2
(
K
),
min
(
triton
.
next_power_of_2
(
V
),
32
)
NK
,
NV
=
triton
.
cdiv
(
K
,
BK
),
triton
.
cdiv
(
V
,
BV
)
assert
NK
==
1
,
"NK > 1 is not supported yet"
num_stages
=
3
num_warps
=
1
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**
-
0.5
else
:
assert
scale
>
0
,
"scale must be positive"
o
=
q
.
new_empty
(
NK
,
*
v
.
shape
)
# Prepare retrieve_parent_token strides
if
retrieve_parent_token
is
not
None
:
stride_retrieve_parent_token_seq
=
retrieve_parent_token
.
stride
(
0
)
stride_retrieve_parent_token_token
=
retrieve_parent_token
.
stride
(
1
)
else
:
stride_retrieve_parent_token_seq
=
0
stride_retrieve_parent_token_token
=
0
NP2_T
=
triton
.
next_power_of_2
(
T
)
grid
=
(
NK
,
NV
,
N
*
HV
)
fused_sigmoid_gating_delta_rule_update_kernel_ref
[
grid
](
A_log
=
A_log
,
a
=
a
,
dt_bias
=
dt_bias
,
softplus_beta
=
softplus_beta
,
softplus_threshold
=
softplus_threshold
,
q
=
q
,
k
=
k
,
v
=
v
,
b
=
b
,
o
=
o
,
h0_source
=
initial_state_source
,
h0_indices
=
initial_state_indices
,
cu_seqlens
=
cu_seqlens
,
intermediate_states_buffer
=
intermediate_states_buffer
,
intermediate_state_indices
=
intermediate_state_indices
,
cache_steps
=
0
if
cache_steps
is
None
else
cache_steps
,
retrieve_parent_token_ptr
=
retrieve_parent_token
,
stride_retrieve_parent_token_seq
=
stride_retrieve_parent_token_seq
,
stride_retrieve_parent_token_token
=
stride_retrieve_parent_token_token
,
scale
=
scale
,
T
=
T
,
stride_q
=
stride_q
,
stride_k
=
stride_k
,
stride_v
=
stride_v
,
stride_b
=
stride_b
,
NP2_T
=
NP2_T
,
B
=
B
,
H
=
H
,
HV
=
HV
,
K
=
K
,
V
=
V
,
BK
=
BK
,
BV
=
BV
,
USE_INITIAL_STATE
=
initial_state_source
is
not
None
,
USE_QK_L2NORM_IN_KERNEL
=
use_qk_l2norm_in_kernel
,
IS_VARLEN
=
cu_seqlens
is
not
None
,
IS_KDA
=
is_kda
,
DISABLE_STATE_UPDATE
=
disable_state_update
,
CACHE_INTERMEDIATE_STATES
=
intermediate_states_buffer
is
not
None
,
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK
=
retrieve_parent_token
is
not
None
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
o
=
o
.
squeeze
(
0
)
return
o
aiter/ops/triton/fla/sglang/chunk_delta_h.py
0 → 100644
View file @
bb596f6e
# SPDX-License-Identifier: MIT
import
functools
import
json
import
os
import
torch
import
triton
import
triton.language
as
tl
import
aiter.ops.triton.utils.arch_info
as
arch_info
from
aiter
import
logger
from
aiter.ops.triton.utils.core
import
AITER_TRITON_CONFIGS_PATH
TRITON_CONFIG_CHECK
=
os
.
environ
.
get
(
"TRITON_CONFIG_CHECK"
,
"0"
)
==
"1"
HAS_DUMPED_CHUNK_DELTA_H_KERNEL_METADATA
=
False
@
triton
.
jit
def
safe_exp
(
x
):
return
exp
(
tl
.
where
(
x
<=
0
,
x
,
float
(
"-inf"
)))
@
triton
.
jit
def
exp
(
x
):
return
tl
.
exp
(
x
)
@
triton
.
jit
def
exp2
(
x
):
return
tl
.
math
.
exp2
(
x
)
def
prepare_chunk_indices
(
cu_seqlens
:
torch
.
LongTensor
,
chunk_size
:
int
)
->
torch
.
LongTensor
:
chunk_rows
=
[]
for
i
in
range
(
len
(
cu_seqlens
)
-
1
):
seqlen
=
int
((
cu_seqlens
[
i
+
1
]
-
cu_seqlens
[
i
]).
item
())
n_chunks
=
triton
.
cdiv
(
seqlen
,
chunk_size
)
for
chunk_idx
in
range
(
n_chunks
):
chunk_rows
.
append
([
i
,
chunk_idx
])
if
len
(
chunk_rows
)
==
0
:
return
torch
.
empty
((
0
,
2
),
dtype
=
torch
.
long
,
device
=
cu_seqlens
.
device
)
return
torch
.
tensor
(
chunk_rows
,
dtype
=
torch
.
long
,
device
=
cu_seqlens
.
device
)
def
prepare_chunk_offsets
(
cu_seqlens
:
torch
.
LongTensor
,
chunk_size
:
int
)
->
torch
.
LongTensor
:
seq_lens
=
cu_seqlens
[
1
:]
-
cu_seqlens
[:
-
1
]
chunk_counts
=
(
seq_lens
+
chunk_size
-
1
)
//
chunk_size
offsets
=
torch
.
zeros_like
(
chunk_counts
)
if
len
(
offsets
)
>
1
:
offsets
[
1
:]
=
torch
.
cumsum
(
chunk_counts
,
dim
=
0
)[:
-
1
]
return
offsets
_DEFAULT_CHUNK_DELTA_H_CONFIG
=
{
"BV"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
2
,
}
@
functools
.
lru_cache
(
maxsize
=
1
)
def
_load_chunk_delta_h_configs
()
->
dict
:
device_name
=
arch_info
.
get_arch
()
path
=
os
.
path
.
join
(
AITER_TRITON_CONFIGS_PATH
,
"chunk_gated_delta_rule_fwd_h"
,
f
"chunk_gated_delta_rule_fwd_h-
{
device_name
}
.json"
,
)
if
not
os
.
path
.
exists
(
path
):
logger
.
warning
(
f
"chunk_gated_delta_rule_fwd_h config not found at
{
path
}
, using default
{
_DEFAULT_CHUNK_DELTA_H_CONFIG
}
."
)
return
{}
with
open
(
path
)
as
f
:
payload
=
json
.
load
(
f
)
return
payload
.
get
(
"config"
,
{})
if
isinstance
(
payload
,
dict
)
else
{}
@
functools
.
lru_cache
def
_get_chunk_delta_h_config
(
K
:
int
,
V
:
int
,
BT
:
int
,
H
:
int
)
->
dict
:
cfgs
=
_load_chunk_delta_h_configs
()
key
=
f
"K=
{
K
}
,V=
{
V
}
,BT=
{
BT
}
,H=
{
H
}
"
cfg
=
cfgs
.
get
(
key
)
if
cfg
is
None
:
default_cfg
=
cfgs
.
get
(
"default"
,
_DEFAULT_CHUNK_DELTA_H_CONFIG
)
if
TRITON_CONFIG_CHECK
:
logger
.
warning
(
"chunk_gated_delta_rule_fwd_h config missing for "
f
"
{
key
}
, using default config
{
default_cfg
}
."
)
cfg
=
default_cfg
merged
=
dict
(
_DEFAULT_CHUNK_DELTA_H_CONFIG
)
merged
.
update
(
cfg
)
return
merged
def
launch_chunk_gated_delta_rule_fwd_kernel_h_blockdim64
(
*
,
k
:
torch
.
Tensor
,
u
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
v_new
:
torch
.
Tensor
|
None
,
g
:
torch
.
Tensor
|
None
,
gk
:
torch
.
Tensor
|
None
,
h
:
torch
.
Tensor
,
initial_state
:
torch
.
Tensor
|
None
,
initial_state_indices
:
torch
.
Tensor
|
None
,
# final_state: torch.Tensor | None,
cu_seqlens
:
torch
.
LongTensor
|
None
,
chunk_offsets
:
torch
.
LongTensor
|
None
,
N
:
int
,
T
:
int
,
H
:
int
,
Hg
:
int
,
K
:
int
,
V
:
int
,
BT
:
int
,
kernel_cfg
:
dict
|
None
,
):
global
HAS_DUMPED_CHUNK_DELTA_H_KERNEL_METADATA
def
grid
(
meta
):
return
(
triton
.
cdiv
(
V
,
meta
[
"BV"
]),
N
*
H
)
cfg
=
kernel_cfg
if
kernel_cfg
is
not
None
else
_get_chunk_delta_h_config
(
K
,
V
,
BT
,
H
)
launch_grid
=
(
triton
.
cdiv
(
V
,
cfg
[
"BV"
]),
N
*
H
)
compiled_kernel
=
chunk_gated_delta_rule_fwd_kernel_h_blockdim64
[
grid
](
k
=
k
,
v
=
u
,
w
=
w
,
v_new
=
v_new
,
g
=
g
,
gk
=
gk
,
h
=
h
,
initial_state
=
initial_state
,
initial_state_indices
=
initial_state_indices
,
cu_seqlens
=
cu_seqlens
,
chunk_offsets
=
chunk_offsets
,
T
=
T
,
H
=
H
,
Hg
=
Hg
,
K
=
K
,
V
=
V
,
BT
=
BT
,
BV
=
cfg
[
"BV"
],
INPLACE_UPDATE
=
True
,
num_warps
=
cfg
[
"num_warps"
],
num_stages
=
cfg
[
"num_stages"
],
)
if
(
TRITON_CONFIG_CHECK
and
not
HAS_DUMPED_CHUNK_DELTA_H_KERNEL_METADATA
and
compiled_kernel
is
not
None
):
print
(
"chunk_gated_delta_rule_fwd_kernel_h_blockdim64 metadata"
)
print
(
f
" grid:
{
launch_grid
}
"
)
print
(
f
" meta: BT=
{
BT
}
, BV=
{
cfg
[
'BV'
]
}
, K=
{
K
}
, V=
{
V
}
, H=
{
H
}
, Hg=
{
Hg
}
, N=
{
N
}
, T=
{
T
}
, "
f
"num_warps=
{
cfg
[
'num_warps'
]
}
, num_stages=
{
cfg
[
'num_stages'
]
}
"
)
print
(
f
" registers:
{
compiled_kernel
.
n_regs
}
"
)
print
(
f
" spills:
{
compiled_kernel
.
n_spills
}
"
)
print
(
f
" shared memory:
{
compiled_kernel
.
metadata
.
shared
}
bytes"
)
HAS_DUMPED_CHUNK_DELTA_H_KERNEL_METADATA
=
True
@
triton
.
heuristics
({
"USE_G"
:
lambda
args
:
args
[
"g"
]
is
not
None
,
"USE_GK"
:
lambda
args
:
args
[
"gk"
]
is
not
None
,
"USE_INITIAL_STATE"
:
lambda
args
:
args
[
"initial_state"
]
is
not
None
,
# "USE_INITIAL_STATE_INDICES": lambda args: args["initial_state_indices"] is not None,
# "STORE_FINAL_STATE": lambda args: args["ht"] is not None,
"SAVE_NEW_VALUE"
:
lambda
args
:
args
[
"v_new"
]
is
not
None
,
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
,
})
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
chunk_gated_delta_rule_fwd_kernel_h_blockdim64
(
k
,
v
,
w
,
v_new
,
g
,
gk
,
h
,
initial_state
,
initial_state_indices
,
cu_seqlens
,
chunk_offsets
,
T
,
H
:
tl
.
constexpr
,
Hg
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_G
:
tl
.
constexpr
,
USE_GK
:
tl
.
constexpr
,
USE_INITIAL_STATE
:
tl
.
constexpr
,
INPLACE_UPDATE
:
tl
.
constexpr
,
SAVE_NEW_VALUE
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_v
,
i_nh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
)
i_n
,
i_h
=
i_nh
//
H
,
i_nh
%
H
if
IS_VARLEN
:
bos
,
eos
=
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
)
T
=
eos
-
bos
NT
=
tl
.
cdiv
(
T
,
BT
)
boh
=
tl
.
load
(
chunk_offsets
+
i_n
).
to
(
tl
.
int32
)
else
:
bos
,
eos
=
i_n
*
T
,
i_n
*
T
+
T
NT
=
tl
.
cdiv
(
T
,
BT
)
boh
=
i_n
*
NT
# [BV, BK]
b_h1
=
tl
.
zeros
([
BV
,
64
],
dtype
=
tl
.
float32
)
if
K
>
64
:
b_h2
=
tl
.
zeros
([
BV
,
64
],
dtype
=
tl
.
float32
)
if
K
>
128
:
b_h3
=
tl
.
zeros
([
BV
,
64
],
dtype
=
tl
.
float32
)
if
K
>
192
:
b_h4
=
tl
.
zeros
([
BV
,
64
],
dtype
=
tl
.
float32
)
# calculate offset
h
+=
((
boh
*
H
+
i_h
)
*
V
*
K
).
to
(
tl
.
int64
)
v
+=
((
bos
*
H
+
i_h
)
*
V
).
to
(
tl
.
int64
)
k
+=
((
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
).
to
(
tl
.
int64
)
w
+=
((
bos
*
H
+
i_h
)
*
K
).
to
(
tl
.
int64
)
if
SAVE_NEW_VALUE
:
v_new
+=
((
bos
*
H
+
i_h
)
*
V
).
to
(
tl
.
int64
)
stride_v
=
H
*
V
stride_h
=
H
*
V
*
K
stride_k
=
Hg
*
K
stride_w
=
H
*
K
index
=
tl
.
load
(
initial_state_indices
+
i_n
).
to
(
tl
.
int32
)
h0
=
initial_state
+
index
*
stride_h
ht
=
initial_state
+
index
*
stride_h
if
USE_INITIAL_STATE
:
h0
=
h0
+
i_h
*
V
*
K
if
INPLACE_UPDATE
:
ht
=
ht
+
i_h
*
V
*
K
# load initial state
if
USE_INITIAL_STATE
:
p_h0_1
=
tl
.
make_block_ptr
(
h0
,
(
V
,
K
),
(
K
,
1
),
(
i_v
*
BV
,
0
),
(
BV
,
64
),
(
1
,
0
))
b_h1
+=
tl
.
load
(
p_h0_1
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
K
>
64
:
p_h0_2
=
tl
.
make_block_ptr
(
h0
,
(
V
,
K
),
(
K
,
1
),
(
i_v
*
BV
,
64
),
(
BV
,
64
),
(
1
,
0
)
)
b_h2
+=
tl
.
load
(
p_h0_2
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
K
>
128
:
p_h0_3
=
tl
.
make_block_ptr
(
h0
,
(
V
,
K
),
(
K
,
1
),
(
i_v
*
BV
,
128
),
(
BV
,
64
),
(
1
,
0
)
)
b_h3
+=
tl
.
load
(
p_h0_3
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
if
K
>
192
:
p_h0_4
=
tl
.
make_block_ptr
(
h0
,
(
V
,
K
),
(
K
,
1
),
(
i_v
*
BV
,
192
),
(
BV
,
64
),
(
1
,
0
)
)
b_h4
+=
tl
.
load
(
p_h0_4
,
boundary_check
=
(
0
,
1
)).
to
(
tl
.
float32
)
# main recurrence
for
i_t
in
range
(
NT
):
p_h1
=
tl
.
make_block_ptr
(
h
+
i_t
*
stride_h
,
(
V
,
K
),
(
K
,
1
),
(
i_v
*
BV
,
0
),
(
BV
,
64
),
(
1
,
0
)
)
tl
.
store
(
p_h1
,
b_h1
.
to
(
p_h1
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
64
:
p_h2
=
tl
.
make_block_ptr
(
h
+
i_t
*
stride_h
,
(
V
,
K
),
(
K
,
1
),
(
i_v
*
BV
,
64
),
(
BV
,
64
),
(
1
,
0
)
)
tl
.
store
(
p_h2
,
b_h2
.
to
(
p_h2
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
128
:
p_h3
=
tl
.
make_block_ptr
(
h
+
i_t
*
stride_h
,
(
V
,
K
),
(
K
,
1
),
(
i_v
*
BV
,
128
),
(
BV
,
64
),
(
1
,
0
)
)
tl
.
store
(
p_h3
,
b_h3
.
to
(
p_h3
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
192
:
p_h4
=
tl
.
make_block_ptr
(
h
+
i_t
*
stride_h
,
(
V
,
K
),
(
K
,
1
),
(
i_v
*
BV
,
192
),
(
BV
,
64
),
(
1
,
0
)
)
tl
.
store
(
p_h4
,
b_h4
.
to
(
p_h4
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
p_w1
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
0
),
(
BT
,
64
),
(
1
,
0
)
)
b_w1
=
tl
.
load
(
p_w1
,
boundary_check
=
(
0
,
1
))
if
K
>
64
:
p_w2
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
64
),
(
BT
,
64
),
(
1
,
0
)
)
b_w2
=
tl
.
load
(
p_w2
,
boundary_check
=
(
0
,
1
))
if
K
>
128
:
p_w3
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
128
),
(
BT
,
64
),
(
1
,
0
)
)
b_w3
=
tl
.
load
(
p_w3
,
boundary_check
=
(
0
,
1
))
if
K
>
192
:
p_w4
=
tl
.
make_block_ptr
(
w
,
(
T
,
K
),
(
stride_w
,
1
),
(
i_t
*
BT
,
192
),
(
BT
,
64
),
(
1
,
0
)
)
b_w4
=
tl
.
load
(
p_w4
,
boundary_check
=
(
0
,
1
))
b_v
=
tl
.
dot
(
b_w1
,
tl
.
trans
(
b_h1
).
to
(
b_w1
.
dtype
))
if
K
>
64
:
b_v
+=
tl
.
dot
(
b_w2
,
tl
.
trans
(
b_h2
).
to
(
b_w2
.
dtype
))
if
K
>
128
:
b_v
+=
tl
.
dot
(
b_w3
,
tl
.
trans
(
b_h3
).
to
(
b_w3
.
dtype
))
if
K
>
192
:
b_v
+=
tl
.
dot
(
b_w4
,
tl
.
trans
(
b_h4
).
to
(
b_w4
.
dtype
))
p_v
=
tl
.
make_block_ptr
(
v
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
)
b_v
=
tl
.
load
(
p_v
,
boundary_check
=
(
0
,
1
))
-
b_v
if
SAVE_NEW_VALUE
:
p_v
=
tl
.
make_block_ptr
(
v_new
,
(
T
,
V
),
(
stride_v
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
)
tl
.
store
(
p_v
,
b_v
.
to
(
p_v
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
last_idx
=
min
((
i_t
+
1
)
*
BT
,
T
)
-
1
if
USE_G
:
b_g_last
=
tl
.
load
(
g
+
bos
*
H
+
last_idx
*
H
+
i_h
)
p_g
=
tl
.
make_block_ptr
(
g
+
bos
*
H
+
i_h
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,)
)
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,))
b_v
=
b_v
*
safe_exp
(
b_g_last
-
b_g
)[:,
None
]
b_g_last
=
exp
(
b_g_last
)
b_h1
=
b_h1
*
b_g_last
if
K
>
64
:
b_h2
=
b_h2
*
b_g_last
if
K
>
128
:
b_h3
=
b_h3
*
b_g_last
if
K
>
192
:
b_h4
=
b_h4
*
b_g_last
if
USE_GK
:
o_k1
=
tl
.
arange
(
0
,
64
)
b_gk_last1
=
tl
.
load
(
gk
+
(
bos
+
last_idx
)
*
H
*
K
+
i_h
*
K
+
o_k1
,
mask
=
(
o_k1
<
K
),
other
=
0.0
,
)
b_h1
*=
exp
(
b_gk_last1
)[
None
,
:]
if
K
>
64
:
o_k2
=
64
+
o_k1
b_gk_last2
=
tl
.
load
(
gk
+
(
bos
+
last_idx
)
*
H
*
K
+
i_h
*
K
+
o_k2
,
mask
=
(
o_k2
<
K
),
other
=
0.0
,
)
b_h2
*=
exp
(
b_gk_last2
)[
None
,
:]
if
K
>
128
:
o_k3
=
128
+
o_k1
b_gk_last3
=
tl
.
load
(
gk
+
(
bos
+
last_idx
)
*
H
*
K
+
i_h
*
K
+
o_k3
,
mask
=
(
o_k3
<
K
),
other
=
0.0
,
)
b_h3
*=
exp
(
b_gk_last3
)[
None
,
:]
if
K
>
192
:
o_k4
=
192
+
o_k1
b_gk_last4
=
tl
.
load
(
gk
+
(
bos
+
last_idx
)
*
H
*
K
+
i_h
*
K
+
o_k4
,
mask
=
(
o_k4
<
K
),
other
=
0.0
,
)
b_h4
*=
exp
(
b_gk_last4
)[
None
,
:]
b_v
=
b_v
.
to
(
k
.
dtype
.
element_ty
)
p_k1
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
0
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
)
b_k1
=
tl
.
load
(
p_k1
,
boundary_check
=
(
0
,
1
))
if
K
>
64
:
p_k2
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
64
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
)
b_k2
=
tl
.
load
(
p_k2
,
boundary_check
=
(
0
,
1
))
if
K
>
128
:
p_k3
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
128
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
)
b_k3
=
tl
.
load
(
p_k3
,
boundary_check
=
(
0
,
1
))
if
K
>
192
:
p_k4
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
stride_k
),
(
192
,
i_t
*
BT
),
(
64
,
BT
),
(
0
,
1
)
)
b_k4
=
tl
.
load
(
p_k4
,
boundary_check
=
(
0
,
1
))
b_h1
+=
tl
.
trans
(
tl
.
dot
(
b_k1
,
b_v
))
if
K
>
64
:
b_h2
+=
tl
.
trans
(
tl
.
dot
(
b_k2
,
b_v
))
if
K
>
128
:
b_h3
+=
tl
.
trans
(
tl
.
dot
(
b_k3
,
b_v
))
if
K
>
192
:
b_h4
+=
tl
.
trans
(
tl
.
dot
(
b_k4
,
b_v
))
# epilogue
if
INPLACE_UPDATE
:
p_ht
=
tl
.
make_block_ptr
(
ht
,
(
V
,
K
),
(
K
,
1
),
(
i_v
*
BV
,
0
),
(
BV
,
64
),
(
1
,
0
))
tl
.
store
(
p_ht
,
b_h1
.
to
(
p_ht
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
64
:
p_ht
=
tl
.
make_block_ptr
(
ht
,
(
V
,
K
),
(
K
,
1
),
(
i_v
*
BV
,
64
),
(
BV
,
64
),
(
1
,
0
)
)
tl
.
store
(
p_ht
,
b_h2
.
to
(
p_ht
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
128
:
p_ht
=
tl
.
make_block_ptr
(
ht
,
(
V
,
K
),
(
K
,
1
),
(
i_v
*
BV
,
128
),
(
BV
,
64
),
(
1
,
0
)
)
tl
.
store
(
p_ht
,
b_h3
.
to
(
p_ht
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
if
K
>
192
:
p_ht
=
tl
.
make_block_ptr
(
ht
,
(
V
,
K
),
(
K
,
1
),
(
i_v
*
BV
,
192
),
(
BV
,
64
),
(
1
,
0
)
)
tl
.
store
(
p_ht
,
b_h4
.
to
(
p_ht
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
chunk_gated_delta_rule_fwd_h
(
k
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
u
:
torch
.
Tensor
,
g
:
torch
.
Tensor
|
None
=
None
,
gk
:
torch
.
Tensor
|
None
=
None
,
initial_state
:
torch
.
Tensor
|
None
=
None
,
initial_state_indices
:
torch
.
Tensor
|
None
=
None
,
output_final_state
:
bool
=
True
,
chunk_size
:
int
=
64
,
save_new_value
:
bool
=
True
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
chunk_indices
:
torch
.
LongTensor
|
None
=
None
,
use_exp2
:
bool
=
False
,
transpose_state_layout
:
bool
=
True
,
kernel_cfg
:
dict
|
None
=
None
,
):
B
,
T
,
Hg
,
K
,
V
=
*
k
.
shape
,
u
.
shape
[
-
1
]
H
=
u
.
shape
[
-
2
]
BT
=
chunk_size
chunk_indices
=
(
prepare_chunk_indices
(
cu_seqlens
,
chunk_size
)
if
cu_seqlens
is
not
None
else
None
)
# N: the actual number of sequences in the batch with either equal or variable lengths
if
cu_seqlens
is
None
:
N
,
NT
,
chunk_offsets
=
B
,
triton
.
cdiv
(
T
,
BT
),
None
else
:
N
,
NT
,
chunk_offsets
=
(
len
(
cu_seqlens
)
-
1
,
len
(
chunk_indices
),
prepare_chunk_offsets
(
cu_seqlens
,
BT
),
)
assert
K
<=
256
,
"current kernel does not support head dimension larger than 256."
h
=
k
.
new_empty
(
B
,
NT
,
H
,
V
,
K
)
v_new
=
torch
.
empty_like
(
u
)
if
save_new_value
else
None
launch_chunk_gated_delta_rule_fwd_kernel_h_blockdim64
(
k
=
k
,
u
=
u
,
w
=
w
,
v_new
=
v_new
,
g
=
g
,
gk
=
gk
,
h
=
h
,
initial_state
=
initial_state
,
initial_state_indices
=
initial_state_indices
,
cu_seqlens
=
cu_seqlens
,
chunk_offsets
=
chunk_offsets
,
N
=
N
,
T
=
T
,
H
=
H
,
Hg
=
Hg
,
K
=
K
,
V
=
V
,
BT
=
BT
,
# use_exp2=use_exp2,
# transpose_state_layout=transpose_state_layout,
kernel_cfg
=
kernel_cfg
,
)
return
h
,
v_new
aiter/ops/triton/fla/sglang/chunk_o.py
0 → 100644
View file @
bb596f6e
# SPDX-License-Identifier: MIT
import
functools
import
json
import
os
import
torch
import
triton
import
triton.language
as
tl
import
aiter.ops.triton.utils.arch_info
as
arch_info
from
aiter
import
logger
from
aiter.ops.triton.utils.core
import
AITER_TRITON_CONFIGS_PATH
TRITON_CONFIG_CHECK
=
os
.
environ
.
get
(
"TRITON_CONFIG_CHECK"
,
"0"
)
==
"1"
HAS_DUMPED_CHUNK_FWD_O_KERNEL_METADATA
=
False
@
triton
.
jit
def
safe_exp
(
x
):
return
exp
(
tl
.
where
(
x
<=
0
,
x
,
float
(
"-inf"
)))
@
triton
.
jit
def
exp
(
x
):
return
tl
.
exp
(
x
)
@
triton
.
jit
def
exp2
(
x
):
return
tl
.
math
.
exp2
(
x
)
def
prepare_chunk_indices
(
cu_seqlens
:
torch
.
LongTensor
,
chunk_size
:
int
)
->
torch
.
LongTensor
:
chunk_rows
=
[]
for
i
in
range
(
len
(
cu_seqlens
)
-
1
):
seqlen
=
int
((
cu_seqlens
[
i
+
1
]
-
cu_seqlens
[
i
]).
item
())
n_chunks
=
triton
.
cdiv
(
seqlen
,
chunk_size
)
for
chunk_idx
in
range
(
n_chunks
):
chunk_rows
.
append
([
i
,
chunk_idx
])
if
len
(
chunk_rows
)
==
0
:
return
torch
.
empty
((
0
,
2
),
dtype
=
torch
.
long
,
device
=
cu_seqlens
.
device
)
return
torch
.
tensor
(
chunk_rows
,
dtype
=
torch
.
long
,
device
=
cu_seqlens
.
device
)
_DEFAULT_CHUNK_O_CONFIG
=
{
"BK"
:
128
,
"BV"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
2
,
}
@
functools
.
lru_cache
(
maxsize
=
1
)
def
_load_chunk_o_configs
()
->
dict
:
device_name
=
arch_info
.
get_arch
()
path
=
os
.
path
.
join
(
AITER_TRITON_CONFIGS_PATH
,
"chunk_fwd_o"
,
f
"chunk_fwd_o-
{
device_name
}
.json"
,
)
if
not
os
.
path
.
exists
(
path
):
logger
.
warning
(
f
"chunk_fwd_o config not found at
{
path
}
, using default
{
_DEFAULT_CHUNK_O_CONFIG
}
."
)
return
{}
with
open
(
path
)
as
f
:
payload
=
json
.
load
(
f
)
return
payload
.
get
(
"config"
,
{})
if
isinstance
(
payload
,
dict
)
else
{}
@
functools
.
lru_cache
def
_get_chunk_o_config
(
K
:
int
,
V
:
int
,
BT
:
int
)
->
dict
:
cfgs
=
_load_chunk_o_configs
()
key
=
f
"K=
{
K
}
,V=
{
V
}
,BT=
{
BT
}
"
cfg
=
cfgs
.
get
(
key
)
if
cfg
is
None
:
default_cfg
=
cfgs
.
get
(
"default"
,
_DEFAULT_CHUNK_O_CONFIG
)
if
TRITON_CONFIG_CHECK
:
logger
.
warning
(
"chunk_fwd_o config missing for "
f
"
{
key
}
, using default config
{
default_cfg
}
."
)
cfg
=
default_cfg
merged
=
dict
(
_DEFAULT_CHUNK_O_CONFIG
)
merged
.
update
(
cfg
)
return
merged
def
launch_chunk_fwd_kernel_o
(
*
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
h
:
torch
.
Tensor
,
g
:
torch
.
Tensor
|
None
,
o
:
torch
.
Tensor
,
cu_seqlens
:
torch
.
LongTensor
|
None
,
chunk_indices
:
torch
.
LongTensor
|
None
,
scale
:
float
,
T
:
int
,
H
:
int
,
Hg
:
int
,
K
:
int
,
V
:
int
,
BT
:
int
,
NT
:
int
,
B
:
int
,
kernel_cfg
:
dict
|
None
,
):
global
HAS_DUMPED_CHUNK_FWD_O_KERNEL_METADATA
def
grid
(
meta
):
return
(
triton
.
cdiv
(
V
,
meta
[
"BV"
]),
NT
,
B
*
H
)
cfg
=
kernel_cfg
if
kernel_cfg
is
not
None
else
_get_chunk_o_config
(
K
,
V
,
BT
)
launch_grid
=
(
triton
.
cdiv
(
V
,
cfg
[
"BV"
]),
NT
,
B
*
H
)
compiled_kernel
=
chunk_fwd_kernel_o
[
grid
](
q
=
q
,
k
=
k
,
v
=
v
,
h
=
h
,
g
=
g
,
o
=
o
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
scale
=
scale
,
T
=
T
,
H
=
H
,
Hg
=
Hg
,
K
=
K
,
V
=
V
,
BT
=
BT
,
BK
=
cfg
[
"BK"
],
BV
=
cfg
[
"BV"
],
num_warps
=
cfg
[
"num_warps"
],
num_stages
=
cfg
[
"num_stages"
],
)
if
(
TRITON_CONFIG_CHECK
and
not
HAS_DUMPED_CHUNK_FWD_O_KERNEL_METADATA
and
compiled_kernel
is
not
None
):
print
(
"chunk_fwd_kernel_o metadata"
)
print
(
f
" grid:
{
launch_grid
}
"
)
print
(
f
" meta: BT=
{
BT
}
, BK=
{
cfg
[
'BK'
]
}
, BV=
{
cfg
[
'BV'
]
}
, K=
{
K
}
, V=
{
V
}
, H=
{
H
}
, Hg=
{
Hg
}
, "
f
"NT=
{
NT
}
, B=
{
B
}
, T=
{
T
}
, num_warps=
{
cfg
[
'num_warps'
]
}
, num_stages=
{
cfg
[
'num_stages'
]
}
"
)
print
(
f
" registers:
{
compiled_kernel
.
n_regs
}
"
)
print
(
f
" spills:
{
compiled_kernel
.
n_spills
}
"
)
print
(
f
" shared memory:
{
compiled_kernel
.
metadata
.
shared
}
bytes"
)
HAS_DUMPED_CHUNK_FWD_O_KERNEL_METADATA
=
True
@
triton
.
heuristics
({
"USE_G"
:
lambda
args
:
args
[
"g"
]
is
not
None
,
"IS_VARLEN"
:
lambda
args
:
args
[
"cu_seqlens"
]
is
not
None
,
})
@
triton
.
jit
(
do_not_specialize
=
[
"T"
])
def
chunk_fwd_kernel_o
(
q
,
k
,
v
,
h
,
g
,
o
,
cu_seqlens
,
chunk_indices
,
scale
,
T
,
H
:
tl
.
constexpr
,
Hg
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
V
:
tl
.
constexpr
,
BT
:
tl
.
constexpr
,
BK
:
tl
.
constexpr
,
BV
:
tl
.
constexpr
,
USE_G
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
):
i_v
,
i_t
,
i_bh
=
tl
.
program_id
(
0
),
tl
.
program_id
(
1
),
tl
.
program_id
(
2
)
i_b
,
i_h
=
i_bh
//
H
,
i_bh
%
H
if
IS_VARLEN
:
i_tg
=
i_t
i_n
,
i_t
=
(
tl
.
load
(
chunk_indices
+
i_t
*
2
).
to
(
tl
.
int32
),
tl
.
load
(
chunk_indices
+
i_t
*
2
+
1
).
to
(
tl
.
int32
),
)
bos
,
eos
=
(
tl
.
load
(
cu_seqlens
+
i_n
).
to
(
tl
.
int32
),
tl
.
load
(
cu_seqlens
+
i_n
+
1
).
to
(
tl
.
int32
),
)
T
=
eos
-
bos
else
:
NT
=
tl
.
cdiv
(
T
,
BT
)
i_tg
=
i_b
*
NT
+
i_t
bos
,
eos
=
i_b
*
T
,
i_b
*
T
+
T
q
+=
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
k
+=
(
bos
*
Hg
+
i_h
//
(
H
//
Hg
))
*
K
v
+=
(
bos
*
H
+
i_h
)
*
V
o
+=
(
bos
*
H
+
i_h
)
*
V
h
+=
(
i_tg
*
H
+
i_h
).
to
(
tl
.
int64
)
*
V
*
K
b_o
=
tl
.
zeros
([
BT
,
BV
],
dtype
=
tl
.
float32
)
b_A
=
tl
.
zeros
([
BT
,
BT
],
dtype
=
tl
.
float32
)
for
i_k
in
range
(
tl
.
cdiv
(
K
,
BK
)):
p_q
=
tl
.
make_block_ptr
(
q
,
(
T
,
K
),
(
Hg
*
K
,
1
),
(
i_t
*
BT
,
i_k
*
BK
),
(
BT
,
BK
),
(
1
,
0
)
)
p_k
=
tl
.
make_block_ptr
(
k
,
(
K
,
T
),
(
1
,
Hg
*
K
),
(
i_k
*
BK
,
i_t
*
BT
),
(
BK
,
BT
),
(
0
,
1
)
)
p_h
=
tl
.
make_block_ptr
(
h
,
(
V
,
K
),
(
K
,
1
),
(
i_v
*
BV
,
i_k
*
BK
),
(
BV
,
BK
),
(
1
,
0
)
)
b_q
=
tl
.
load
(
p_q
,
boundary_check
=
(
0
,
1
))
b_k
=
tl
.
load
(
p_k
,
boundary_check
=
(
0
,
1
))
b_h
=
tl
.
load
(
p_h
,
boundary_check
=
(
0
,
1
))
b_o
+=
tl
.
dot
(
b_q
,
tl
.
trans
(
b_h
))
b_A
+=
tl
.
dot
(
b_q
,
b_k
)
if
USE_G
:
g
+=
bos
*
H
+
i_h
p_g
=
tl
.
make_block_ptr
(
g
,
(
T
,),
(
H
,),
(
i_t
*
BT
,),
(
BT
,),
(
0
,))
b_g
=
tl
.
load
(
p_g
,
boundary_check
=
(
0
,))
b_o
=
b_o
*
exp
(
b_g
)[:,
None
]
b_A
=
b_A
*
safe_exp
(
b_g
[:,
None
]
-
b_g
[
None
,
:])
o_i
=
tl
.
arange
(
0
,
BT
)
m_A
=
o_i
[:,
None
]
>=
o_i
[
None
,
:]
b_A
=
tl
.
where
(
m_A
,
b_A
,
0
)
p_v
=
tl
.
make_block_ptr
(
v
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
)
p_o
=
tl
.
make_block_ptr
(
o
,
(
T
,
V
),
(
H
*
V
,
1
),
(
i_t
*
BT
,
i_v
*
BV
),
(
BT
,
BV
),
(
1
,
0
)
)
b_v
=
tl
.
load
(
p_v
,
boundary_check
=
(
0
,
1
))
b_o
=
b_o
*
scale
+
tl
.
dot
(
b_A
.
to
(
b_v
.
dtype
),
b_v
)
*
scale
tl
.
store
(
p_o
,
b_o
.
to
(
p_o
.
dtype
.
element_ty
),
boundary_check
=
(
0
,
1
))
def
chunk_fwd_o
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
h
:
torch
.
Tensor
,
g
:
torch
.
Tensor
|
None
=
None
,
g_gamma
:
torch
.
Tensor
|
None
=
None
,
scale
:
float
|
None
=
None
,
cu_seqlens
:
torch
.
LongTensor
|
None
=
None
,
chunk_size
:
int
=
64
,
chunk_indices
:
torch
.
LongTensor
|
None
=
None
,
use_exp2
:
bool
=
False
,
transpose_state_layout
:
bool
=
False
,
kernel_cfg
:
dict
|
None
=
None
,
)
->
torch
.
Tensor
:
B
,
T
,
Hg
,
K
,
V
=
*
q
.
shape
,
v
.
shape
[
-
1
]
H
=
v
.
shape
[
-
2
]
BT
=
chunk_size
if
chunk_indices
is
None
and
cu_seqlens
is
not
None
:
chunk_indices
=
prepare_chunk_indices
(
cu_seqlens
,
BT
)
NT
=
triton
.
cdiv
(
T
,
BT
)
if
cu_seqlens
is
None
else
len
(
chunk_indices
)
if
scale
is
None
:
scale
=
k
.
shape
[
-
1
]
**
-
0.5
o
=
torch
.
empty_like
(
v
)
launch_chunk_fwd_kernel_o
(
q
=
q
,
k
=
k
,
v
=
v
,
h
=
h
,
g
=
g
,
o
=
o
,
cu_seqlens
=
cu_seqlens
,
chunk_indices
=
chunk_indices
,
scale
=
scale
,
T
=
T
,
H
=
H
,
Hg
=
Hg
,
K
=
K
,
V
=
V
,
BT
=
BT
,
NT
=
NT
,
B
=
B
,
kernel_cfg
=
kernel_cfg
,
)
return
o
Prev
1
…
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment