Commit bb596f6e authored by xiaowei.zhang's avatar xiaowei.zhang
Browse files

1. Update MOE; 2. Update sglang mHC; 3. Update test scripts; 4 Add new

   ops.
parent d9ebb683
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 2,
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_N": 512,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
......@@ -35,8 +44,11 @@
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"num_warps": 8,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"16": {
......@@ -45,89 +57,116 @@
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"num_warps": 8,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 8,
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 8,
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 8,
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 2,
"num_stages": 1
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 64,
......@@ -135,28 +174,50 @@
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"8192": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"16384": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
"num_stages": 2
},
"32768": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"num_warps": 2,
"num_stages": 2
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 4,
"num_stages": 2
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 4,
"num_stages": 2
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
"num_stages": 1
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
"num_stages": 1
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
"num_stages": 1
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
"num_stages": 1
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
"num_stages": 1
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
"num_stages": 1
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
"num_stages": 1
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 16,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
"num_stages": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 16,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
"num_stages": 1
},
"8192": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 16,
"GROUP_SIZE_M": 4,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"16384": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"32768": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
"num_stages": 1
}
}
\ No newline at end of file
......@@ -2,151 +2,222 @@
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
"num_stages": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 8,
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 8,
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"num_warps": 8,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 8,
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 8,
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 16,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 8,
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"1024": {
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"num_warps": 8,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"2048": {
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 16,
"num_stages": 2
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"4096": {
"8192": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 16,
"num_stages": 2
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"8192": {
"16384": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 16,
"num_stages": 2
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"32768": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": true,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"COMBINE_SCALE_LOAD": true,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": true,
"instruction_sched_variant": "none",
"num_warps": 8,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": true,
"instruction_sched_variant": "none",
"num_warps": 8,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": true,
"instruction_sched_variant": "none",
"num_warps": 8,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": true,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 8,
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": true,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 8,
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": true,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 8,
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": true,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 8,
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": true,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 8,
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": true,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 8,
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": true,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 8,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": true,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 8,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": true,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 8,
"num_stages": 2
},
......@@ -143,20 +185,39 @@
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": true,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 8,
"num_stages": 2
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"16384": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": true,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 2
"num_stages": 1
},
"32768": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 8,
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 4,
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 4,
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"8": {
......@@ -35,78 +44,102 @@
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"num_warps": 8,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 8,
"num_stages": 2
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 8,
"num_stages": 2
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"num_warps": 2,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"num_warps": 8,
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 4,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 8,
"num_stages": 2
},
"1024": {
......@@ -115,18 +148,24 @@
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 8,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"num_warps": 2,
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"4096": {
......@@ -135,7 +174,10 @@
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
......@@ -145,18 +187,37 @@
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
},
"16384": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"num_warps": 2,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"32768": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 512,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 512,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 512,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"8192": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"16384": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"32768": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 2
},
"8192": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"16384": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"32768": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 512,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 512,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"8192": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"16384": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"32768": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"8192": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"16384": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"32768": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": true,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 1,
"num_warps": 16,
"num_stages": 2
},
"8192": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 16,
"num_stages": 2
},
"16384": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 1,
"num_warps": 16,
"num_stages": 2
},
"32768": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 1,
"num_warps": 16,
"num_stages": 2
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 1
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 1,
"num_warps": 16,
"num_stages": 2
},
"8192": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 16,
"num_stages": 2
},
"16384": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 16,
"num_stages": 2
},
"32768": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "mmac5-ds10",
"kpack": 1,
"num_warps": 16,
"num_stages": 2
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 2,
"num_warps": 8,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 2,
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 1,
"num_warps": 4,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 2,
"num_warps": 4,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 2,
"num_warps": 4,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 2,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 2,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 2,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 2,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 2,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 2,
"num_warps": 2,
"num_stages": 1
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 2,
"num_warps": 8,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 2,
"num_warps": 4,
"num_stages": 1
},
"8192": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 1,
"num_warps": 8,
"num_stages": 1
},
"16384": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 2,
"num_warps": 8,
"num_stages": 1
},
"32768": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 2,
"num_warps": 8,
"num_stages": 1
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 2,
"num_warps": 2,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 2,
"num_warps": 2,
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 2,
"num_warps": 2,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 2,
"num_warps": 2,
"num_stages": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 2,
"num_warps": 4,
"num_stages": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 2,
"num_warps": 4,
"num_stages": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 2,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 2,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 2,
"num_warps": 2,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 2,
"num_warps": 2,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "local-prefetch",
"sched_latency": "none",
"kpack": 2,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 2,
"num_warps": 2,
"num_stages": 1
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 2,
"num_warps": 4,
"num_stages": 1
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 2,
"num_warps": 4,
"num_stages": 1
},
"8192": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 2,
"num_warps": 8,
"num_stages": 1
},
"16384": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 2,
"num_warps": 4,
"num_stages": 1
},
"32768": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"COMBINE_SCALE_LOAD": false,
"USE_MLS_LOAD": false,
"instruction_sched_variant": "none",
"sched_latency": "none",
"kpack": 2,
"num_warps": 4,
"num_stages": 1
}
}
\ No newline at end of file
......@@ -38,6 +38,8 @@ from aiter.ops.triton.activation import _tanh
import aiter.ops.triton.utils.arch_info as arch_info
from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH
from triton import __version__ as triton_version
triton_minor_version = int(triton_version.split(".")[1])
@triton.jit
def _fwd_kernel(
......@@ -348,6 +350,10 @@ def _fwd_kernel_v2(
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
STORE_TRANSPOSE: tl.constexpr,
HAS_SINK: tl.constexpr,
head_num: tl.constexpr,
USE_MLS: tl.constexpr,
batch_size: tl.constexpr,
# max_len_extend: tl.constexpr,
):
cur_seq = tl.program_id(0)
cur_head = tl.program_id(1)
......@@ -357,6 +363,24 @@ def _fwd_kernel_v2(
tl.assume(K_Extend.to(tl.int64) >= 0)
tl.assume(V_Extend.to(tl.int64) >= 0)
tl.assume(kv_group_num >= 0)
tl.assume(stride_qbs >= 0)
tl.assume(stride_qh >= 0)
tl.assume(stride_kbs >= 0)
tl.assume(stride_kh >= 0)
tl.assume(stride_vbs >= 0)
tl.assume(stride_vh >= 0)
tl.assume(stride_obs >= 0)
tl.assume(stride_oh >= 0)
tl.assume(stride_buf_kbs >= 0)
tl.assume(stride_buf_kh >= 0)
tl.assume(stride_buf_vbs >= 0)
tl.assume(stride_buf_vh >= 0)
tl.assume(head_num >= 0)
tl.assume(batch_size >= 0)
# tl.assume(max_len_extend >= 0)
kv_head_num = head_num // kv_group_num
cur_kv_head = cur_head // kv_group_num
cur_seq_extend_start_idx = tl.load(qo_indptr + cur_seq)
......@@ -380,6 +404,10 @@ def _fwd_kernel_v2(
mask_d = offs_d < Lq
mask_dv = offs_dv < Lv
ALL_MASK_M = tl.min(mask_m.to(tl.int32), axis=0) == 1
ALL_MASK_D = tl.min(mask_d.to(tl.int32), axis=0) == 1
ALL_MASK_DV = tl.min(mask_dv.to(tl.int32), axis=0) == 1
if xai_temperature_len > 0:
offs_qidx = cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
......@@ -389,10 +417,534 @@ def _fwd_kernel_v2(
1.0,
)
offs_q = (
if USE_MLS:
q = tl.matrix_load(
Q_Extend + cur_head * stride_qh,
shape=(head_num, Lq),
strides=(stride_qbs, 1),
block_shape=(BLOCK_M, BLOCK_DMODEL),
offsets=((cur_seq_extend_start_idx + cur_block_m * BLOCK_M).to(tl.int32), 0),
)
if not (ALL_MASK_M & ALL_MASK_D):
q = tl.where((mask_m[:, None]) & (mask_d[None, :]), q, 0.0)
else:
offs_q = (
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_qbs
+ cur_head * stride_qh
+ offs_d[None, :]
)
q = tl.load(
Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0
)
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
if USE_MLS:
qpe = tl.matrix_load(Q_Extend + cur_head * stride_qh,
shape=(head_num, Lq),
strides=(stride_qbs, 1),
block_shape=(BLOCK_M, BLOCK_DPE),
offsets=((cur_seq_extend_start_idx + cur_block_m * BLOCK_M).to(tl.int32),
BLOCK_DMODEL),
)
if not ALL_MASK_M:
qpe = tl.where(mask_m[:, None], qpe, 0.0)
else:
offs_qpe = (
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_qbs
+ cur_head * stride_qh
+ offs_dpe[None, :]
)
qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
offs_n = tl.arange(0, BLOCK_N)
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
deno = tl.zeros([BLOCK_M], dtype=tl.float32)
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask_n = (start_n + offs_n) < cur_seq_len_prefix
ALL_MASK_N = tl.min(mask_n.to(tl.int32), axis=0) == 1
final_mask = mask_m[:, None] & mask_n[None, :]
if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK:
if USE_MLS:
custom_mask = tl.matrix_load(
mask_ptr + cur_seq_mask_start_idx,
shape=(cur_seq_len_extend, cur_seq_len + window_kv_offset),
strides=(cur_seq_len + window_kv_offset, 1),
block_shape=(BLOCK_M, BLOCK_N),
offsets=((cur_block_m * BLOCK_M).to(tl.int32),
(window_kv_offset + start_n).to(tl.int32)),
)
if not (ALL_MASK_M & ALL_MASK_N):
custom_mask = tl.where((mask_m[:, None]) & (mask_n[None, :]), custom_mask, 0)
else:
custom_mask = tl.load(
mask_ptr
+ cur_seq_mask_start_idx
+ (cur_block_m * BLOCK_M + offs_m[:, None])
* (cur_seq_len + window_kv_offset)
+ window_kv_offset
+ start_n
+ offs_n[None, :],
mask=(mask_m[:, None] & mask_n[None, :]),
other=0,
)
final_mask &= custom_mask
if SLIDING_WINDOW_SIZE > 0:
window_mask = (
cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m[:, None]
) <= (start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE)
final_mask &= window_mask
SKIP_TILE = False
if (USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK) or SLIDING_WINDOW_SIZE > 0:
SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
if not SKIP_TILE:
offs_kv_loc = tl.load(
kv_indices + cur_seq_kv_start_idx + start_n + offs_n,
mask=mask_n,
other=0,
)
# offs_kv_next = offs_kv_loc[1:] # [1, 2, ..., N-1]
# offs_kv_curr = offs_kv_loc[:-1] # [0, 1, ..., N-2]
# diff = offs_kv_next - offs_kv_curr # 长度 N-1
# is_continuous = tl.all(diff == 1)
# if USE_MLS and is_continuous:
# offs_kv_start_idx = tl.load(
# kv_indices + cur_seq_kv_start_idx + start_n,
# )
# k = tl.matrix_load(
# K_Buffer + cur_kv_head * stride_buf_kh,
# shape=(Lq, cur_seq_len_prefix.to(tl.int32)),
# strides=(1, stride_buf_kbs),
# block_shape=(BLOCK_DMODEL, BLOCK_N),
# offsets=(0, offs_kv_start_idx.to(tl.int32)),
# # mask=(mask_m[:, None] & mask_n[None, :]),
# # boundary_check=(0, 1),
# )
# if not (ALL_MASK_N & ALL_MASK_D):
# k = tl.where((mask_n[None, :]) & (mask_d[:, None]), k, 0.0)
# else:
# offs_buf_k = (
# offs_kv_loc[None, :] * stride_buf_kbs
# + cur_kv_head * stride_buf_kh
# + offs_d[:, None]
# )
# k = tl.load(
# K_Buffer + offs_buf_k,
# mask=(mask_n[None, :]) & (mask_d[:, None]),
# other=0.0,
# )
offs_buf_k = (
offs_kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_d[:, None]
)
k = tl.load(
K_Buffer + offs_buf_k,
mask=(mask_n[None, :]) & (mask_d[:, None]),
other=0.0,
)
qk = tl.dot(q.to(k.dtype), k)
if BLOCK_DPE > 0:
# if USE_MLS:
# kpe = tl.matrix_load(
# K_Buffer + cur_kv_head * stride_buf_kh,
# shape=(Lq, cur_seq_len_prefix.to(tl.int32)),
# strides=(1, stride_buf_kbs),
# block_shape=(BLOCK_DPE, BLOCK_N),
# offsets=(BLOCK_DMODEL, offs_kv_start_idx.to(tl.int32)),
# # mask=(mask_m[:, None] & mask_n[None, :]),
# # boundary_check=(0, 1),
# )
# if not ALL_MASK_N:
# kpe = tl.where((mask_n[None, :]), kpe, 0.0)
# else:
# offs_kpe = (
# offs_kv_loc[None, :] * stride_buf_kbs
# + cur_kv_head * stride_buf_kh
# + offs_dpe[:, None]
# )
# kpe = tl.load(
# K_Buffer + offs_kpe,
# mask=mask_n[None, :],
# other=0.0,
# )
offs_kpe = (
offs_kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Buffer + offs_kpe,
mask=mask_n[None, :],
other=0.0,
)
qk += tl.dot(qpe.to(kpe.dtype), kpe)
qk *= sm_scale * k_scale
if logit_cap > 0:
qk = logit_cap * _tanh(qk / logit_cap)
if xai_temperature_len > 0:
qk *= xai_temperature_reg[:, None]
qk = tl.where(final_mask, qk, float("-inf"))
row_max = tl.max(qk, 1)
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
n_e_max = tl.maximum(row_max_fixed, e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
deno = deno * re_scale + tl.sum(p, 1)
# if USE_MLS:
# v = tl.matrix_load(
# V_Buffer + cur_kv_head * stride_buf_vh,
# shape=(cur_seq_len_prefix.to(tl.int32), Lv),
# strides=(stride_buf_kbs, 1),
# block_shape=(BLOCK_N, BLOCK_DV),
# offsets=(offs_kv_start_idx.to(tl.int32), 0),
# # mask=(mask_m[:, None] & mask_n[None, :]),
# # boundary_check=(0, 1),
# )
# if not (ALL_MASK_N & ALL_MASK_DV):
# v = tl.where((mask_n[:, None] & mask_dv[None, :]), v, 0.0)
# else:
# offs_buf_v = (
# offs_kv_loc[:, None] * stride_buf_vbs
# + cur_kv_head * stride_buf_vh
# + offs_dv[None, :]
# )
# v = tl.load(
# V_Buffer + offs_buf_v,
# mask=mask_n[:, None] & mask_dv[None, :],
# other=0.0,
# )
offs_buf_v = (
offs_kv_loc[:, None] * stride_buf_vbs
+ cur_kv_head * stride_buf_vh
+ offs_dv[None, :]
)
v = tl.load(
V_Buffer + offs_buf_v,
mask=mask_n[:, None] & mask_dv[None, :],
other=0.0,
)
p = p.to(v.dtype)
acc = acc * re_scale[:, None] + tl.dot(p, v) * v_scale
e_max = n_e_max
cur_block_m_end = (
cur_seq_len_extend
if not IS_CAUSAL
else tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
)
for start_n in range(0, cur_block_m_end, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask_n = (start_n + offs_n) < cur_block_m_end
ALL_MASK_N = tl.min(mask_n.to(tl.int32), axis=0) == 1
final_mask = mask_m[:, None] & mask_n[None, :]
if USE_CUSTOM_MASK:
if USE_MLS:
custom_mask = tl.matrix_load(
mask_ptr + cur_seq_mask_start_idx,
shape=(cur_block_m_end, cur_seq_len + window_kv_offset),
strides=(cur_seq_len + window_kv_offset, 1),
block_shape=(BLOCK_M, BLOCK_N),
offsets=((cur_block_m * BLOCK_M).to(tl.int32),
(window_kv_offset + cur_seq_len_prefix + start_n).to(tl.int32)),
)
if not (ALL_MASK_M & ALL_MASK_N):
custom_mask = tl.where((mask_m[:, None]) & (mask_n[None, :]), custom_mask, 0)
else:
custom_mask = tl.load(
mask_ptr
+ cur_seq_mask_start_idx
+ (cur_block_m * BLOCK_M + offs_m[:, None])
* (cur_seq_len + window_kv_offset)
+ window_kv_offset
+ cur_seq_len_prefix
+ start_n
+ offs_n[None, :],
mask=(mask_m[:, None] & mask_n[None, :]),
other=0,
)
custom_mask &= mask_m[:, None] & mask_n[None, :]
final_mask &= custom_mask
elif IS_CAUSAL:
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
start_n + offs_n[None, :]
)
mask_causual &= mask_m[:, None] & mask_n[None, :]
final_mask &= mask_causual
else:
mask_non_causal = mask_m[:, None] & mask_n[None, :]
final_mask &= mask_non_causal
if SLIDING_WINDOW_SIZE > 0:
window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= (
start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE
)
final_mask &= window_mask
SKIP_TILE = False
if USE_CUSTOM_MASK or SLIDING_WINDOW_SIZE > 0:
SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
if not SKIP_TILE:
if USE_MLS:
k = tl.matrix_load(
K_Extend + cur_kv_head * stride_kh,
shape=(kv_head_num, Lq),
strides=(1, stride_kbs),
block_shape=(BLOCK_DMODEL, BLOCK_N),
offsets=(0, (cur_seq_extend_start_idx + start_n).to(tl.int32)),
)
if not (ALL_MASK_N & ALL_MASK_D):
k = tl.where((mask_d[:, None]) & (mask_n[None, :]), k, 0.0)
else:
offs_k = (
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
+ cur_kv_head * stride_kh
+ offs_d[:, None]
)
k = tl.load(
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
)
qk = tl.dot(q.to(k.dtype), k, out_dtype=tl.float32)
if BLOCK_DPE > 0:
if USE_MLS:
kpe = tl.matrix_load(
K_Extend + cur_kv_head * stride_kh,
shape=(kv_head_num, Lq),
strides=(1, stride_kbs),
block_shape=(BLOCK_DPE, BLOCK_N),
offsets=(BLOCK_DMODEL, (cur_seq_extend_start_idx + start_n).to(tl.int32)),
)
if not ALL_MASK_N:
kpe = tl.where(mask_n[None, :], kpe, 0.0)
else:
offs_kpe = (
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
+ cur_kv_head * stride_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Extend + offs_kpe,
mask=mask_n[None, :],
other=0.0,
)
qk += tl.dot(qpe.to(kpe.dtype), kpe)
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * _tanh(qk / logit_cap)
if xai_temperature_len > 0:
qk *= xai_temperature_reg[:, None]
qk = tl.where(final_mask, qk, float("-inf"))
row_max = tl.max(qk, 1)
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
n_e_max = tl.maximum(row_max_fixed, e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
deno = deno * re_scale + tl.sum(p, 1)
if USE_MLS:
v = tl.matrix_load(
V_Extend + cur_kv_head * stride_vh,
shape=(kv_head_num, Lv),
strides=(stride_vbs, 1),
block_shape=(BLOCK_N, BLOCK_DV),
offsets=((cur_seq_extend_start_idx + start_n).to(tl.int32), 0),
)
if not (ALL_MASK_N & ALL_MASK_DV):
v = tl.where((mask_n[:, None]) & (mask_dv[None, :]), v, 0.0)
else:
offs_v = (
(cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
+ cur_kv_head * stride_vh
+ offs_dv[None, :]
)
v = tl.load(
V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
)
p = p.to(v.dtype)
acc = acc * re_scale[:, None] + tl.dot(p, v)
e_max = n_e_max
if HAS_SINK:
cur_sink = tl.load(sink_ptr + cur_head)
deno += tl.exp(cur_sink - e_max)
offs_o = (
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_qbs
+ cur_head * stride_qh
* stride_obs
+ cur_head * stride_oh
+ offs_dv[None, :]
)
if STORE_TRANSPOSE:
tl.store(
O_Extend + offs_o.T,
(acc / deno[:, None]).T,
mask=(mask_m[:, None] & mask_dv[None, :]).T,
)
else:
tl.store(
O_Extend + offs_o,
acc / deno[:, None],
mask=mask_m[:, None] & mask_dv[None, :],
)
@triton.jit
def _fwd_kernel_v2_decode(
Q_Extend,
K_Extend,
V_Extend,
O_Extend,
K_Buffer,
V_Buffer,
qo_indptr,
kv_indptr,
kv_indices,
mask_ptr,
mask_indptr,
sink_ptr,
window_kv_offset_ptr,
sm_scale,
k_scale,
v_scale,
stride_qbs,
stride_qh,
stride_kbs,
stride_kh,
stride_vbs,
stride_vh,
stride_obs,
stride_oh,
stride_buf_kbs,
stride_buf_kh,
stride_buf_vbs,
stride_buf_vh,
SLIDING_WINDOW_SIZE: tl.constexpr,
logit_cap: tl.constexpr,
xai_temperature_len: tl.constexpr,
Lq: tl.constexpr,
Lv: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DPE: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
USE_CUSTOM_MASK: tl.constexpr,
IS_CAUSAL: tl.constexpr,
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
STORE_TRANSPOSE: tl.constexpr,
HAS_SINK: tl.constexpr,
kv_group_num: tl.constexpr,
num_query_heads: tl.constexpr,
USE_MLS: tl.constexpr,
batch_size: tl.constexpr,
# max_len_extend: tl.constexpr,
):
"""
v2 decode: grid (batch, num_kv_heads, cdiv(max_len_extend, Q_SEQ)) with
Q_SEQ = BLOCK_M // kv_group_num (same as unified ``BLOCK_Q``; floor, not ceil). If BLOCK_M is not
a multiple of G, adjacent ``cur_block_m`` may overlap in query_pos like unified, but `mask_m`
and sequence bounds keep correctness. BLOCK_M is a power of 2 (host). Require BLOCK_M // G >= 1 to launch.
"""
# Per unified: BLOCK_Q = BLOCK_M // G; stride in query token index for each +1 of cur_block_m.
Q_SEQ: tl.constexpr = BLOCK_M // kv_group_num
cur_seq = tl.program_id(0)
cur_kv_head = tl.program_id(1)
cur_block_m = tl.program_id(2)
tl.assume(Q_Extend.to(tl.int64) >= 0)
tl.assume(K_Extend.to(tl.int64) >= 0)
tl.assume(V_Extend.to(tl.int64) >= 0)
tl.assume(stride_qbs >= 0)
tl.assume(stride_qh >= 0)
tl.assume(stride_kbs >= 0)
tl.assume(stride_kh >= 0)
tl.assume(stride_vbs >= 0)
tl.assume(stride_vh >= 0)
tl.assume(stride_obs >= 0)
tl.assume(stride_oh >= 0)
tl.assume(stride_buf_kbs >= 0)
tl.assume(stride_buf_kh >= 0)
tl.assume(stride_buf_vbs >= 0)
tl.assume(stride_buf_vh >= 0)
tl.assume(batch_size >= 0)
# tl.assume(max_len_extend >= 0)
kv_head_num = num_query_heads // kv_group_num
cur_seq_extend_start_idx = tl.load(qo_indptr + cur_seq)
cur_seq_len_extend = tl.load(qo_indptr + cur_seq + 1) - cur_seq_extend_start_idx
cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq)
cur_seq_len_prefix = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx
cur_seq_len = cur_seq_len_prefix + cur_seq_len_extend
if cur_block_m * Q_SEQ >= cur_seq_len_extend:
return
if USE_CUSTOM_MASK:
cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)
window_kv_offset = 0
if USE_CUSTOM_MASK and SLIDING_WINDOW_SIZE > 0:
window_kv_offset = tl.load(window_kv_offset_ptr + cur_seq)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
offs_m = tl.arange(0, BLOCK_M)
# unified_attention-style: per offs_m, row = offs_m // G, h = offs_m % G (G = kv_group_num)
query_pos = cur_block_m * Q_SEQ + (offs_m // kv_group_num)
q_head_in_group = offs_m % kv_group_num
query_offset_0 = cur_seq_extend_start_idx + query_pos
query_offset_1 = cur_kv_head * kv_group_num + q_head_in_group
mask_m = (query_pos < cur_seq_len_extend) & (query_offset_1 < num_query_heads)
mask_d = offs_d < Lq
mask_dv = offs_dv < Lv
ALL_MASK_M = tl.min(mask_m.to(tl.int32), axis=0) == 1
ALL_MASK_D = tl.min(mask_d.to(tl.int32), axis=0) == 1
ALL_MASK_DV = tl.min(mask_dv.to(tl.int32), axis=0) == 1
if xai_temperature_len > 0:
offs_qidx = cur_seq_len_prefix + query_pos
xai_temperature_scale = 1.0 / tl.log2(float(xai_temperature_len))
xai_temperature_reg = tl.where(
offs_qidx > xai_temperature_len,
tl.log2(offs_qidx.to(tl.float32)) * xai_temperature_scale,
1.0,
)
offs_q = (
query_offset_0[:, None] * stride_qbs
+ query_offset_1[:, None] * stride_qh
+ offs_d[None, :]
)
q = tl.load(
......@@ -402,9 +954,8 @@ def _fwd_kernel_v2(
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
offs_qpe = (
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_qbs
+ cur_head * stride_qh
query_offset_0[:, None] * stride_qbs
+ query_offset_1[:, None] * stride_qh
+ offs_dpe[None, :]
)
qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
......@@ -418,14 +969,38 @@ def _fwd_kernel_v2(
for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask_n = (start_n + offs_n) < cur_seq_len_prefix
ALL_MASK_N = tl.min(mask_n.to(tl.int32), axis=0) == 1
final_mask = mask_m[:, None] & mask_n[None, :]
if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK:
# if USE_MLS:
# group_id = offs_m // kv_group_num
# custom_mask_group = tl.matrix_load(
# mask_ptr + cur_seq_mask_start_idx,
# shape=(cur_seq_len_prefix, cur_seq_len + window_kv_offset),
# strides=(cur_seq_len + window_kv_offset, 1),
# block_shape=(BLOCK_M // kv_group_num, BLOCK_N),
# offsets=((cur_block_m * Q_SEQ).to(tl.int32),
# (window_kv_offset + start_n).to(tl.int32)),
# )
# custom_mask = custom_mask_group[group_id[:, None], offs_n[None, :]]
# if not (ALL_MASK_M & ALL_MASK_N):
# custom_mask = tl.where((mask_m[:, None] & mask_n[None, :]), custom_mask, 0)
# else:
# custom_mask = tl.load(
# mask_ptr
# + cur_seq_mask_start_idx
# + (query_pos[:, None]) * (cur_seq_len + window_kv_offset)
# + window_kv_offset
# + start_n
# + offs_n[None, :],
# mask=(mask_m[:, None] & mask_n[None, :]),
# other=0,
# )
custom_mask = tl.load(
mask_ptr
+ cur_seq_mask_start_idx
+ (cur_block_m * BLOCK_M + offs_m[:, None])
* (cur_seq_len + window_kv_offset)
+ (query_pos[:, None]) * (cur_seq_len + window_kv_offset)
+ window_kv_offset
+ start_n
+ offs_n[None, :],
......@@ -435,7 +1010,7 @@ def _fwd_kernel_v2(
final_mask &= custom_mask
if SLIDING_WINDOW_SIZE > 0:
window_mask = (
cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m[:, None]
cur_seq_len_prefix + query_pos[:, None]
) <= (start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE)
final_mask &= window_mask
......@@ -450,6 +1025,31 @@ def _fwd_kernel_v2(
other=0,
)
# if USE_MLS:
# offs_kv_start_idx = tl.load(
# kv_indices + cur_seq_kv_start_idx + start_n,
# )
# k = tl.matrix_load(
# K_Buffer + cur_kv_head * stride_buf_kh,
# shape=(Lq, cur_seq_len_prefix.to(tl.int32)),
# strides=(1, stride_buf_kbs),
# block_shape=(BLOCK_DMODEL, BLOCK_N),
# offsets=(0, offs_kv_start_idx.to(tl.int32)),
# )
# if not (ALL_MASK_N & ALL_MASK_D):
# k = tl.where((mask_n[None, :]) & (mask_d[:, None]), k, 0.0)
# else:
# offs_buf_k = (
# offs_kv_loc[None, :] * stride_buf_kbs
# + cur_kv_head * stride_buf_kh
# + offs_d[:, None]
# )
# k = tl.load(
# K_Buffer + offs_buf_k,
# mask=(mask_n[None, :]) & (mask_d[:, None]),
# other=0.0,
# )
offs_buf_k = (
offs_kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
......@@ -460,9 +1060,29 @@ def _fwd_kernel_v2(
mask=(mask_n[None, :]) & (mask_d[:, None]),
other=0.0,
)
qk = tl.dot(q.to(k.dtype), k)
if BLOCK_DPE > 0:
# if USE_MLS:
# kpe = tl.matrix_load(
# K_Buffer + cur_kv_head * stride_buf_kh,
# shape=(Lq, cur_seq_len_prefix.to(tl.int32)),
# strides=(1, stride_buf_kbs),
# block_shape=(BLOCK_DPE, BLOCK_N),
# offsets=(BLOCK_DMODEL, offs_kv_start_idx.to(tl.int32)),
# )
# if not ALL_MASK_N:
# kpe = tl.where((mask_n[None, :]), kpe, 0.0)
# else:
# offs_kpe = (
# offs_kv_loc[None, :] * stride_buf_kbs
# + cur_kv_head * stride_buf_kh
# + offs_dpe[:, None]
# )
# kpe = tl.load(
# K_Buffer + offs_kpe,
# mask=mask_n[None, :],
# other=0.0,
# )
offs_kpe = (
offs_kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
......@@ -484,30 +1104,40 @@ def _fwd_kernel_v2(
qk = tl.where(final_mask, qk, float("-inf"))
# row_max_fixed avoids exp(-inf - (-inf)) when a row is all -inf in this tile;
# only needed under sliding window or custom mask (plain causal matches v1).
if SLIDING_WINDOW_SIZE > 0 or (
USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK
):
row_max = tl.max(qk, 1)
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
n_e_max = tl.maximum(row_max_fixed, e_max)
else:
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
row_max = tl.max(qk, 1)
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
n_e_max = tl.maximum(row_max_fixed, e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
deno = deno * re_scale + tl.sum(p, 1)
# if USE_MLS:
# v = tl.matrix_load(
# V_Buffer + cur_kv_head * stride_buf_vh,
# shape=(cur_seq_len_prefix.to(tl.int32), Lv),
# strides=(stride_buf_kbs, 1),
# block_shape=(BLOCK_N, BLOCK_DV),
# offsets=(offs_kv_start_idx.to(tl.int32), 0),
# )
# if not (ALL_MASK_N & ALL_MASK_DV):
# v = tl.where((mask_n[:, None] & mask_dv[None, :]), v, 0.0)
# else:
# offs_buf_v = (
# offs_kv_loc[:, None] * stride_buf_vbs
# + cur_kv_head * stride_buf_vh
# + offs_dv[None, :]
# )
# v = tl.load(
# V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
# )
offs_buf_v = (
offs_kv_loc[:, None] * stride_buf_vbs
+ cur_kv_head * stride_buf_vh
+ offs_dv[None, :]
)
v = tl.load(
V_Buffer + offs_buf_v,
mask=mask_n[:, None] & mask_dv[None, :],
other=0.0,
V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
)
p = p.to(v.dtype)
acc = acc * re_scale[:, None] + tl.dot(p, v) * v_scale
......@@ -517,19 +1147,44 @@ def _fwd_kernel_v2(
cur_block_m_end = (
cur_seq_len_extend
if not IS_CAUSAL
else tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
else tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * Q_SEQ)
)
for start_n in range(0, cur_block_m_end, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask_n = (start_n + offs_n) < cur_block_m_end
ALL_MASK_N = tl.min(mask_n.to(tl.int32), axis=0) == 1
final_mask = mask_m[:, None] & mask_n[None, :]
if USE_CUSTOM_MASK:
# if USE_MLS:
# group_id = offs_m // kv_group_num
# custom_mask_group = tl.matrix_load(
# mask_ptr + cur_seq_mask_start_idx,
# shape=(cur_block_m_end, cur_seq_len + window_kv_offset),
# strides=(cur_seq_len + window_kv_offset, 1),
# block_shape=(BLOCK_M // kv_group_num, BLOCK_N),
# offsets=((cur_block_m * Q_SEQ).to(tl.int32),
# (window_kv_offset + cur_seq_len_prefix + start_n).to(tl.int32)),
# )
# custom_mask = custom_mask_group[group_id[:, None], offs_n[None, :]]
# if not (ALL_MASK_M & ALL_MASK_N):
# custom_mask = tl.where((mask_m[:, None] & mask_n[None, :]), custom_mask, 0)
# else:
# custom_mask = tl.load(
# mask_ptr
# + cur_seq_mask_start_idx
# + (query_pos[:, None]) * (cur_seq_len + window_kv_offset)
# + window_kv_offset
# + cur_seq_len_prefix
# + start_n
# + offs_n[None, :],
# mask=(mask_m[:, None] & mask_n[None, :]),
# other=0,
# )
custom_mask = tl.load(
mask_ptr
+ cur_seq_mask_start_idx
+ (cur_block_m * BLOCK_M + offs_m[:, None])
* (cur_seq_len + window_kv_offset)
+ (query_pos[:, None]) * (cur_seq_len + window_kv_offset)
+ window_kv_offset
+ cur_seq_len_prefix
+ start_n
......@@ -540,9 +1195,7 @@ def _fwd_kernel_v2(
custom_mask &= mask_m[:, None] & mask_n[None, :]
final_mask &= custom_mask
elif IS_CAUSAL:
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
start_n + offs_n[None, :]
)
mask_causual = query_pos[:, None] >= (start_n + offs_n[None, :])
mask_causual &= mask_m[:, None] & mask_n[None, :]
final_mask &= mask_causual
else:
......@@ -550,7 +1203,7 @@ def _fwd_kernel_v2(
final_mask &= mask_non_causal
if SLIDING_WINDOW_SIZE > 0:
window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= (
window_mask = query_pos[:, None] <= (
start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE
)
final_mask &= window_mask
......@@ -560,27 +1213,49 @@ def _fwd_kernel_v2(
SKIP_TILE = tl.max(tl.max(final_mask.to(tl.int32), axis=1), axis=0) == 0
if not SKIP_TILE:
offs_k = (
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
+ cur_kv_head * stride_kh
+ offs_d[:, None]
)
k = tl.load(
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
)
qk = tl.dot(q.to(k.dtype), k, out_dtype=tl.float32)
if BLOCK_DPE > 0:
offs_kpe = (
if USE_MLS:
k = tl.matrix_load(
K_Extend + cur_kv_head * stride_kh,
shape=(kv_head_num, Lq),
strides=(1, stride_kbs),
block_shape=(BLOCK_DMODEL, BLOCK_N),
offsets=(0, (cur_seq_extend_start_idx + start_n).to(tl.int32)),
)
if not (ALL_MASK_D & ALL_MASK_N):
k = tl.where((mask_d[:, None] & (mask_n[None, :])), k, 0.0)
else:
offs_k = (
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
+ cur_kv_head * stride_kh
+ offs_dpe[:, None]
+ offs_d[:, None]
)
kpe = tl.load(
K_Extend + offs_kpe,
mask=mask_n[None, :],
other=0.0,
k = tl.load(
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
)
qk = tl.dot(q.to(k.dtype), k, out_dtype=tl.float32)
if BLOCK_DPE > 0:
if USE_MLS:
kpe = tl.matrix_load(
K_Extend + cur_kv_head * stride_kh,
shape=(kv_head_num, Lq),
strides=(1, stride_kbs),
block_shape=(BLOCK_DPE, BLOCK_N),
offsets=(BLOCK_DMODEL, (cur_seq_extend_start_idx + start_n).to(tl.int32)),
)
if not ALL_MASK_N:
kpe = tl.where(mask_n[None, :], kpe, 0.0)
else:
offs_kpe = (
(cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs
+ cur_kv_head * stride_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Extend + offs_kpe,
mask=mask_n[None, :],
other=0.0,
)
qk += tl.dot(qpe.to(kpe.dtype), kpe)
qk *= sm_scale
......@@ -593,38 +1268,49 @@ def _fwd_kernel_v2(
qk = tl.where(final_mask, qk, float("-inf"))
if SLIDING_WINDOW_SIZE > 0 or USE_CUSTOM_MASK:
row_max = tl.max(qk, 1)
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
n_e_max = tl.maximum(row_max_fixed, e_max)
else:
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
row_max = tl.max(qk, 1)
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
n_e_max = tl.maximum(row_max_fixed, e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
deno = deno * re_scale + tl.sum(p, 1)
offs_v = (
(cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
+ cur_kv_head * stride_vh
+ offs_dv[None, :]
)
v = tl.load(
V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
)
if USE_MLS:
v = tl.matrix_load(
V_Extend + cur_kv_head * stride_vh,
shape=(kv_head_num, Lv),
strides=(stride_vbs, 1),
block_shape=(BLOCK_N, BLOCK_DV),
offsets=((cur_seq_extend_start_idx + start_n).to(tl.int32), 0),
)
if not (ALL_MASK_N & ALL_MASK_DV):
v = tl.where((mask_n[:, None] & mask_dv[None, :]), v, 0.0)
else:
offs_v = (
(cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs
+ cur_kv_head * stride_vh
+ offs_dv[None, :]
)
v = tl.load(
V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
)
p = p.to(v.dtype)
acc = acc * re_scale[:, None] + tl.dot(p, v)
e_max = n_e_max
if HAS_SINK:
cur_sink = tl.load(sink_ptr + cur_head)
cur_sink = tl.load(
sink_ptr + cur_kv_head * kv_group_num + q_head_in_group,
mask=mask_m,
other=0.0,
)
deno += tl.exp(cur_sink - e_max)
offs_o = (
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_obs
+ cur_head * stride_oh
query_offset_0[:, None] * stride_obs
+ query_offset_1[:, None] * stride_oh
+ offs_dv[None, :]
)
if STORE_TRANSPOSE:
......@@ -700,13 +1386,118 @@ def _load_config_v2():
raise ValueError(
f"{dev}-EXTEND_ATTENTION-V2-FP16.json keys must be 7-tuples matching runtime "
f"want7 (kv_group_num, Lq, Lv, USE_CUSTOM_MASK, IS_CAUSAL, HAS_SINK, "
f"SLIDING_WINDOW_SIZE); got length {len(tup)} for {k!r}"
f"USE_SLIDING_WINDOW); got length {len(tup)} for {k!r}"
)
res["keys"].append(tup)
return res
def _load_config_v2_decode():
"""Autotuned block sizes for :func:`_fwd_kernel_v2_decode` (short extend path)."""
dev = arch_info.get_device()
fpath = f"{AITER_TRITON_CONFIGS_PATH}/{dev}-EXTEND_ATTENTION-V2-DECODE-FP16.json"
try:
with open(fpath, "r") as file:
data = json.load(file)
except FileNotFoundError:
return {"config": {}, "path": {}, "key": [], "keys": []}
res = {}
res["config"] = data["config"]
res["path"] = data.get("path", {})
res["key"] = list(data["config"].keys())
res["keys"] = []
for k in res["key"]:
tup = create_tuple(k)
if len(tup) != 7:
raise ValueError(
f"{dev}-EXTEND_ATTENTION-V2-DECODE-FP16.json keys must be 7-tuples matching runtime "
f"want7 (kv_group_num, Lq, Lv, USE_CUSTOM_MASK, IS_CAUSAL, HAS_SINK, "
f"USE_SLIDING_WINDOW); got length {len(tup)} for {k!r}"
)
res["keys"].append(tup)
return res
TORCH_DTYPE_TO_DTYPE = {
torch.float32: "f32",
torch.float: "f32",
torch.float16: "f16",
torch.half: "f16",
torch.bfloat16: "bf16",
torch.float64: "f64",
torch.double: "f64",
torch.float8_e4m3fn: "f8_e4m3fn",
torch.float8_e5m2: "f8_e5m2",
torch.int8: "i8",
torch.int16: "i16",
torch.int32: "i32",
torch.int64: "i64",
torch.long: "i64",
torch.uint8: "u8",
torch.bool: "bool",
}
@functools.lru_cache
def get_gpu_label():
target = triton.runtime.driver.active.get_current_target()
device = torch.cuda.current_device()
num_cu = torch.cuda.get_device_properties(device).multi_processor_count
return f"{target.arch}_cu{num_cu}"
def _load_config_v3():
"""Autotuned configs for :func:`_fwd_kernel_v2` (fp8 / sglang-style scale path).
Each ``config`` entry key must parse to a **7-tuple** via :func:`create_tuple`, matching
runtime ``want7``; 5-tuple keys are not accepted.
"""
fpath = f"{AITER_TRITON_CONFIGS_PATH}/extend_attn/_fwd_kernel_v2-device={get_gpu_label()}.json"
try:
with open(fpath, "r") as file:
data = json.load(file)
except FileNotFoundError:
return {"key_name": {}, "fpath": {}, "config": {}, "path": {}, "key": [], "keys": []}
res = {}
res["key_name"] = data["key"]
res['fpath'] = fpath
res["config"] = data["config"]
res["path"] = data.get("path", {})
res["key"] = list(data["config"].keys())
res["keys"] = []
for k in res["key"]:
tup = create_tuple(k)
res["keys"].append(tup)
return res
def _load_config_v3_decode():
"""Autotuned block sizes for :func:`_fwd_kernel_v2_decode` (short extend path)."""
dev = arch_info.get_device()
# fpath = f"{AITER_TRITON_CONFIGS_PATH}/extend_attn/_fwd_kernel_v2_decode-device={get_gpu_label()}-dtype=bf16.json"
fpath = f"{AITER_TRITON_CONFIGS_PATH}/extend_attn/_fwd_kernel_v2_decode-device={get_gpu_label()}.json"
try:
with open(fpath, "r") as file:
data = json.load(file)
except FileNotFoundError:
return {"key_name": {}, "fpath": {}, "config": {}, "path": {}, "key": [], "keys": []}
res = {}
res["key_name"] = data["key"]
res['fpath'] = fpath
res["config"] = data["config"]
res["path"] = data.get("path", {})
res["key"] = list(data["config"].keys())
res["keys"] = []
for k in res["key"]:
tup = create_tuple(k)
res["keys"].append(tup)
return res
global_config_v2 = _load_config_v2()
global_config_v2_decode = _load_config_v2_decode()
global_config_v3 = _load_config_v3()
global_config_v3_decode = _load_config_v3_decode()
default_config = {
"BLOCK_M": 32,
......@@ -715,7 +1506,8 @@ default_config = {
"matrix_instr_nonkdim": 16,
"kpack": 2,
"num_warps": 4,
"num_stages": 1
"num_stages": 2,
"USE_MLS": False,
}
......@@ -744,20 +1536,21 @@ def _get_config_v2(
use_custom_mask,
is_causal,
has_sink: bool,
sliding_window_size: int,
use_sliding_window: bool,
):
"""
Lookup order for ``_fwd_kernel_v2`` block sizes:
1. ``want7 = (kv_group_num, Lq, Lv, use_custom_mask, is_causal, has_sink, sliding_window_size)``
1. ``want7 = (kv_group_num, Lq, Lv, use_custom_mask, is_causal, has_sink, USE_SLIDING_WINDOW)``
against ``{arch}-EXTEND_ATTENTION-V2-FP16.json``. JSON keys must be **7-tuple** strings,
same shape as ``want7`` (see :func:`_load_config_v2`).
same shape as ``want7`` (see :func:`_load_config_v2`). The last element is a bool:
same tuning bucket for any ``sliding_window_size > 0``; use ``False`` when disabled (``<= 0``).
2. If no V2 entry matches, :data:`default_config` (no fallback to v1 JSON).
Log field mapping (typical): ``kv_group_num = q_extend.size(-2) // k_extend.size(-2)``,
``Lq = q_extend.size(-1)``, ``Lv = v_extend.size(-1)``,
``use_custom_mask = custom_mask is not None``, ``is_causal`` as passed,
``has_sink = sinks is not None``, ``sliding_window_size`` as passed (use ``-1`` if disabled).
``has_sink = sinks is not None``, ``USE_SLIDING_WINDOW = (sliding_window_size > 0)``.
"""
want7 = (
kv_group_num,
......@@ -766,7 +1559,7 @@ def _get_config_v2(
use_custom_mask,
is_causal,
has_sink,
sliding_window_size,
use_sliding_window,
)
for i, keys in enumerate(global_config_v2["keys"]):
if keys == want7:
......@@ -777,10 +1570,137 @@ def _get_config_v2(
return default_config, None
@functools.lru_cache(maxsize=1024)
def _get_config_v2_decode(
kv_group_num,
Lq,
Lv,
use_custom_mask,
is_causal,
has_sink: bool,
use_sliding_window: bool,
):
"""
Same ``want7`` as :func:`_get_config_v2`, but loads ``{arch}-EXTEND_ATTENTION-V2-DECODE-FP16.json``
for :func:`_fwd_kernel_v2_decode`.
"""
want7 = (
kv_group_num,
Lq,
Lv,
use_custom_mask,
is_causal,
has_sink,
use_sliding_window,
)
for i, keys in enumerate(global_config_v2_decode["keys"]):
if keys == want7:
key = global_config_v2_decode["key"][i]
return global_config_v2_decode["config"][key], global_config_v2_decode[
"path"
].get(key)
print("WARNING: optimal V2 decode config not found, just use default config")
return default_config, None
def find_closest_index(target, lst):
# lst: [(index, batch_size), ...]
def key(item):
val = item[1]
diff = val - target
if diff >= 0:
return (0, diff)
else:
return (1, -diff)
best = min(lst, key=key)
return best[0]
@functools.lru_cache(maxsize=1024)
def _get_config_v3(key):
_key = str(key)
_configs = global_config_v3["config"]
for k, v in _configs.items():
if k == _key:
return v
# find the nearest batch size
# key = (batch_size, *other)
bs = []
_keys = global_config_v3["keys"]
for i, k in enumerate(_keys):
if k[1:] == key[1:]:
bs.append((i, k[0]))
if bs:
__key = global_config_v3["key"][find_closest_index(key[0], bs)]
print(f'WARNING: Not found key {_key} from {global_config_v3["fpath"]}, mapping to key {__key}')
return _configs[__key]
else:
print(f'WARNING: Not found optimal config from {global_config_v3["fpath"]} with key {str(key)}, just use default config')
return default_config
@functools.lru_cache(maxsize=1024)
def _get_config_v3_decode(key):
_key = str(key)
_configs = global_config_v3_decode["config"]
for k, v in _configs.items():
if k == _key:
return v
# find the nearest batch size
# key = (batch_size, *other)
bs = []
_keys = global_config_v3_decode["keys"]
for i, k in enumerate(_keys):
if k[1:] == key[1:]:
bs.append((i, k[0]))
if bs:
__key = global_config_v3_decode["key"][find_closest_index(key[0], bs)]
print(f'WARNING: Not found key {_key} from {global_config_v3["fpath"]}, mapping to key {__key}')
return _configs[__key]
else:
print(f'WARNING: Not found optimal config from {global_config_v3_decode["fpath"]} with key {str(key)}, just use default config')
return default_config
def has_kernel_cache(path):
return False if not path or not os.path.isdir(f'{cache_knob.dir}/{path}') else True
@functools.lru_cache(maxsize=1024)
def get_v2_decode_final_grid(
max_len_extend: int,
kv_group_num: int,
block_m_cfg: int,
batch_size: int,
kv_head_num: int,
) -> tuple[int, tuple[int, int, int]]:
"""Decode path: power-of-2 ``block_m_decode``; grid-3 cdiv matches kernel q_seq stride."""
prod = max_len_extend * kv_group_num
npo2 = triton.next_power_of_2(prod)
kv_group_num_align = triton.next_power_of_2(kv_group_num)
block_m_decode = block_m_cfg
if prod < 16:
block_m_decode = 16
elif block_m_cfg > npo2:
block_m_decode = npo2
else:
block_m_decode = max(block_m_cfg, kv_group_num_align)
block_count = batch_size * kv_head_num * triton.cdiv(
max_len_extend, block_m_decode // kv_group_num
)
if block_count <= 32:
block_m_decode = max(max(block_m_decode // 2, 16), kv_group_num_align)
q_seq = block_m_decode // kv_group_num
grid = (batch_size, kv_head_num, triton.cdiv(max_len_extend, q_seq))
return block_m_decode, grid
def to_dtype(torch_dtype):
if torch_dtype == torch.float32:
return 'fp32'
......@@ -793,7 +1713,6 @@ def to_dtype(torch_dtype):
else:
return str(torch_dtype)
def extend_attention_fwd(
q_extend,
k_extend,
......@@ -818,6 +1737,7 @@ def extend_attention_fwd(
sinks=None,
window_kv_offsets=None,
xai_temperature_len=-1,
force_v2_prefill: bool = False,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
......@@ -828,7 +1748,11 @@ def extend_attention_fwd(
extensions follow with defaults. ``k_scale`` / ``v_scale`` must both be
``None`` or both set (``float`` / ``int`` like sglang, or 1-element
``torch.Tensor`` on device); if both are set, :func:`_fwd_kernel_v2` is used.
If ``force_v2_prefill`` is True and v2 is active, always use :func:`_fwd_kernel_v2`
even when ``max_len_extend < 32`` (for tests / parity vs :func:`_fwd_kernel_v2_decode`).
"""
# force_v2_prefill = True
Lq, Lv = (
q_extend.shape[-1],
v_extend.shape[-1],
......@@ -853,30 +1777,87 @@ def extend_attention_fwd(
sm_scale = sm_scale or 1.0 / (Lq**0.5)
batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1]
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
kv_head_num = k_extend.shape[1]
kv_group_num = head_num // kv_head_num
USE_CUSTOM_MASK = custom_mask is not None
# Skip custom mask for prefix part
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
use_v2 = k_scale is not None or v_scale is not None
use_v2_decode = (
use_v2 and max_len_extend < 32 and not force_v2_prefill
)
USE_SLIDING_WINDOW = sliding_window_size > 0
if not USE_CUSTOM_MASK:
custom_mask = torch.tensor([0], dtype=torch.bool, device=q_extend.device)
mask_indptr = torch.tensor([0], dtype=torch.int32, device=q_extend.device)
# custom_mask = torch.tensor([0], dtype=torch.bool, device=q_extend.device)
# mask_indptr = torch.tensor([0], dtype=torch.int32, device=q_extend.device)
# set to None to avoid capture cudagraph err
custom_mask = None
mask_indptr = None
if config is None:
if q_extend.dtype == torch.float16 or q_extend.dtype == torch.bfloat16:
if use_v2:
config, path = _get_config_v2(
kv_group_num,
Lq,
Lv,
USE_CUSTOM_MASK,
is_causal,
sinks is not None,
sliding_window_size,
)
if triton_minor_version >= 5: # >= 3.5
key = [
batch_size,
kv_group_num,
Lq,
Lv,
USE_CUSTOM_MASK,
is_causal,
skip_prefix_custom_mask,
sinks is not None,
sliding_window_size,
xai_temperature_len,
str(q_extend.dtype),
str(k_extend.dtype),
str(v_extend.dtype),
str(o_extend.dtype),
str(k_buffer.dtype),
str(v_buffer.dtype),
str(qo_indptr.dtype),
str(kv_indptr.dtype),
str(kv_indices.dtype),
]
for o in [
custom_mask,
mask_indptr,
sinks,
window_kv_offsets
]:
if o is not None:
if hasattr(o, 'dtype'):
key.append(str(o.dtype))
key = tuple(key)
if use_v2_decode:
config = _get_config_v3_decode(key)
else:
config = _get_config_v3(key)
else:
if use_v2_decode:
config, path = _get_config_v2_decode(
kv_group_num,
Lq,
Lv,
USE_CUSTOM_MASK,
is_causal,
sinks is not None,
USE_SLIDING_WINDOW,
)
else:
config, path = _get_config_v2(
kv_group_num,
Lq,
Lv,
USE_CUSTOM_MASK,
is_causal,
sinks is not None,
USE_SLIDING_WINDOW,
)
else:
keys = [kv_group_num, Lq, Lv, USE_CUSTOM_MASK, is_causal]
config, path = _get_config(*keys)
......@@ -884,7 +1865,17 @@ def extend_attention_fwd(
config, path = default_config, None
assert config is not None, "ERROR: optimal config not found"
grid = (batch_size, head_num, triton.cdiv(max_len_extend, config["BLOCK_M"]))
block_m_cfg = config["BLOCK_M"]
# Decode: block_m_decode is power of 2; Q_SEQ = block_m // G (floor), same as unified BLOCK_Q;
# grid-3 cdiv(max_len_extend, q_seq) matches kernel cur_block_m stride.
if use_v2_decode:
block_m_decode, grid = get_v2_decode_final_grid(
max_len_extend, kv_group_num, block_m_cfg, batch_size, kv_head_num
)
# print(f"{max_len_extend=}, {use_v2_decode=}, {block_m_decode=}, {grid=}")
else:
grid = (batch_size, head_num, triton.cdiv(max_len_extend, block_m_cfg))
# num_stages = 1
# extra_kargs = {}
......@@ -922,32 +1913,69 @@ def extend_attention_fwd(
HAS_SINK = sinks is not None
assert k_scale is not None and v_scale is not None, "k_scale and v_scale must both be set"
# k_scale / v_scale kept in Python API; v2 kernel TEMP omits them for perf vs v1.
_fwd_kernel_v2[grid](
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask,
mask_indptr,
sinks,
window_kv_offsets,
sm_scale,
k_scale,
v_scale,
kv_group_num,
*stride_args,
SLIDING_WINDOW_SIZE=sliding_window_size,
logit_cap=logit_cap,
xai_temperature_len=xai_temperature_len,
HAS_SINK=HAS_SINK,
block_const_v2 = {
**block_const,
**config,
)
}
if use_v2_decode:
block_const_v2 = {**block_const_v2, "BLOCK_M": block_m_decode}
_fwd_kernel_v2_decode[grid](
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask,
mask_indptr,
sinks,
window_kv_offsets,
sm_scale,
k_scale,
v_scale,
*stride_args,
SLIDING_WINDOW_SIZE=sliding_window_size,
logit_cap=logit_cap,
xai_temperature_len=xai_temperature_len,
HAS_SINK=HAS_SINK,
kv_group_num=kv_group_num,
num_query_heads=head_num,
**block_const_v2,
batch_size=batch_size,
# USE_MLS=True if os.getenv("TRITON_USE_MLS", "0") == "1" and kv_group_num == head_num else False,
)
else:
_fwd_kernel_v2[grid](
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
qo_indptr,
kv_indptr,
kv_indices,
custom_mask,
mask_indptr,
sinks,
window_kv_offsets,
sm_scale,
k_scale,
v_scale,
kv_group_num,
*stride_args,
SLIDING_WINDOW_SIZE=sliding_window_size,
logit_cap=logit_cap,
xai_temperature_len=xai_temperature_len,
HAS_SINK=HAS_SINK,
**block_const_v2,
head_num=head_num,
batch_size=batch_size,
# USE_MLS=True if os.getenv("TRITON_USE_MLS", "0") == "1" and kv_group_num == head_num else False,
)
return
fn = (
......
# SPDX-License-Identifier: MIT
import functools
import json
import os
from typing import Tuple
import torch
import triton
import triton.language as tl
import aiter.ops.triton.utils.arch_info as arch_info
from aiter import logger
from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH
# HAS_DUMPED_PACKED_DECODE_KERNEL_METADATA = False
TRITON_CONFIG_CHECK = os.environ.get("TRITON_CONFIG_CHECK", "0") == "1"
_DEFAULT_FUSED_RECURRENT_PACKED_DECODE_CONFIG = {
"BV": 32,
"num_warps": 1,
"num_stages": 1,
}
@functools.lru_cache(maxsize=1)
def _load_fused_recurrent_packed_decode_configs() -> dict:
device_name = arch_info.get_arch()
path = os.path.join(
AITER_TRITON_CONFIGS_PATH,
"fused_recurrent_gated_delta_rule_packed_decode",
f"fused_recurrent_gated_delta_rule_packed_decode-{device_name}.json",
)
if not os.path.exists(path):
logger.warning(
f"fused_recurrent_gated_delta_rule_packed_decode config not found at {path}, "
f"using default {_DEFAULT_FUSED_RECURRENT_PACKED_DECODE_CONFIG}."
)
return {}
with open(path) as f:
payload = json.load(f)
return payload.get("config", {}) if isinstance(payload, dict) else {}
@functools.lru_cache
def _get_fused_recurrent_packed_decode_config(B: int, H: int, HV: int) -> dict:
cfgs = _load_fused_recurrent_packed_decode_configs()
key = f"B={B},H={H},HV={HV}"
cfg = cfgs.get(key)
if cfg is None:
candidates = []
for k, v in cfgs.items():
if k == "default":
continue
try:
parts = {x.split("=")[0]: int(x.split("=")[1]) for x in k.split(",")}
except Exception:
continue
if parts.get("H") == H and parts.get("HV") == HV and "B" in parts:
candidates.append((abs(parts["B"] - B), parts["B"], v))
if candidates:
candidates.sort(key=lambda x: x[0])
_, nearest_b, cfg = candidates[0]
if TRITON_CONFIG_CHECK:
logger.warning(
f"fused_recurrent_packed_decode config key '{key}' not found, "
f"using nearest-B config with B={nearest_b}: {cfg}."
)
if cfg is None:
default_cfg = cfgs.get("default", _DEFAULT_FUSED_RECURRENT_PACKED_DECODE_CONFIG)
if TRITON_CONFIG_CHECK:
logger.warning(
f"fused_recurrent_packed_decode config key '{key}' not found, "
f"using default config: {default_cfg}."
)
cfg = default_cfg
merged = dict(_DEFAULT_FUSED_RECURRENT_PACKED_DECODE_CONFIG)
merged.update(cfg)
return merged
@triton.jit
def fused_recurrent_gated_delta_rule_packed_decode_kernel(
mixed_qkv,
a,
b,
A_log,
dt_bias,
o,
h0,
ht,
ssm_state_indices,
scale,
stride_mixed_qkv_tok: tl.constexpr,
stride_a_tok: tl.constexpr,
stride_b_tok: tl.constexpr,
stride_init_state_token: tl.constexpr,
stride_final_state_token: tl.constexpr,
stride_indices_seq: tl.constexpr,
H: tl.constexpr,
HV: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
SOFTPLUS_THRESHOLD: tl.constexpr,
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
):
i_v, i_nh = tl.program_id(0), tl.program_id(1)
i_n, i_hv = i_nh // HV, i_nh % HV
i_h = i_hv // (HV // H)
o_k = tl.arange(0, BK)
o_v = i_v * BV + tl.arange(0, BV)
mask_k = o_k < K
mask_v = o_v < V
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq).to(tl.int64)
p_o = o + (i_n * HV + i_hv) * V + o_v
if state_idx < 0:
zero = tl.zeros([BV], dtype=tl.float32).to(p_o.dtype.element_ty)
tl.store(p_o, zero, mask=mask_v)
return
p_h0 = h0 + state_idx * stride_init_state_token
p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
# [BV, BK]
b_h = tl.load(p_h0, mask=(mask_v[:, None] & mask_k[None, :]), other=0).to(tl.float32)
p_mixed = mixed_qkv + i_n * stride_mixed_qkv_tok
k_off = (H * K) + i_h * K + o_k
v_off = (2 * H * K) + i_hv * V + o_v
b_k = tl.load(p_mixed + k_off, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_mixed + v_off, mask=mask_v, other=0).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL:
k_norm_inv = tl.rsqrt(tl.sum(b_k * b_k) + 1e-6)
b_k = b_k * k_norm_inv
x = tl.load(a + i_n * stride_a_tok + i_hv).to(tl.float32)
x += tl.load(dt_bias + i_hv).to(tl.float32)
softplus_x = tl.where(x <= SOFTPLUS_THRESHOLD, tl.log(1.0 + tl.exp(x)), x)
g_val = -tl.exp(tl.load(A_log + i_hv).to(tl.float32)) * softplus_x
beta_val = tl.sigmoid(tl.load(b + i_n * stride_b_tok + i_hv).to(tl.float32))
b_h *= tl.exp(g_val)
b_v -= tl.sum(b_h * b_k[None, :], 1)
b_v *= beta_val
b_h += b_v[:, None] * b_k[None, :]
q_off = i_h * K + o_k
b_q = tl.load(p_mixed + q_off, mask=mask_k, other=0).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL:
q_norm_inv = tl.rsqrt(tl.sum(b_q * b_q) + 1e-6)
b_q = b_q * q_norm_inv
b_o = tl.sum(b_h * b_q[None, :], 1)
b_o = b_o * scale
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
p_ht = ht + state_idx * stride_final_state_token
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=(mask_v[:, None] & mask_k[None, :]))
def fused_recurrent_gated_delta_rule_packed_decode(
mixed_qkv: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
A_log: torch.Tensor,
dt_bias: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
out: torch.Tensor,
ssm_state_indices: torch.Tensor,
use_qk_l2norm_in_kernel: bool = False,
kernel_cfg: dict | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
global HAS_DUMPED_PACKED_DECODE_KERNEL_METADATA
if mixed_qkv.ndim != 2:
raise ValueError(f"`mixed_qkv` must be 2D, got ndim={mixed_qkv.ndim}.")
if a.ndim != 2 or b.ndim != 2:
raise ValueError(f"`a` and `b` must be 2D, got a.ndim={a.ndim}, b.ndim={b.ndim}.")
if A_log.ndim != 1 or dt_bias.ndim != 1:
raise ValueError("`A_log` and `dt_bias` must be 1D.")
if ssm_state_indices.ndim != 1:
raise ValueError("`ssm_state_indices` must be 1D.")
if initial_state.ndim != 4:
raise ValueError(f"`initial_state` must be 4D, got ndim={initial_state.ndim}.")
dev = mixed_qkv.device
if any(t.device != dev for t in (a, b, A_log, dt_bias, initial_state, out, ssm_state_indices)):
raise ValueError("All tensors must be on the same device.")
B = mixed_qkv.shape[0]
if a.shape[0] != B or b.shape[0] != B or ssm_state_indices.shape[0] != B:
raise ValueError("Batch dimensions of mixed_qkv/a/b/ssm_state_indices must match.")
HV, V, K = initial_state.shape[-3:]
if a.shape[1] != HV or b.shape[1] != HV:
raise ValueError("`a` and `b` second dim must match HV from initial_state.")
if A_log.numel() != HV or dt_bias.numel() != HV:
raise ValueError("`A_log` and `dt_bias` numel must equal HV.")
if out.shape != (B, 1, HV, V):
raise ValueError(f"`out` must have shape {(B, 1, HV, V)}, got {tuple(out.shape)}.")
qkv_dim = mixed_qkv.shape[1]
qk_dim = qkv_dim - HV * V
if qk_dim <= 0 or qk_dim % 2 != 0:
raise ValueError("Invalid mixed_qkv layout for packed decode.")
q_dim = qk_dim // 2
if q_dim % K != 0:
raise ValueError("Inferred q_dim must be divisible by K.")
H = q_dim // K
if H <= 0 or HV % H != 0:
raise ValueError(f"Invalid inferred heads: H={H}, HV={HV}.")
BK = triton.next_power_of_2(K)
cfg = kernel_cfg if kernel_cfg is not None else _get_fused_recurrent_packed_decode_config(B, H, HV)
BV = min(triton.next_power_of_2(V), int(cfg["BV"]))
stride_mixed_qkv_tok = mixed_qkv.stride(0)
stride_a_tok = a.stride(0)
stride_b_tok = b.stride(0)
stride_init_state_token = initial_state.stride(0)
stride_final_state_token = initial_state.stride(0)
stride_indices_seq = ssm_state_indices.stride(0)
NV = triton.cdiv(V, BV)
grid = (NV, B * HV)
launch_kwargs = dict(
mixed_qkv=mixed_qkv,
a=a,
b=b,
A_log=A_log,
dt_bias=dt_bias,
o=out,
h0=initial_state,
ht=initial_state,
ssm_state_indices=ssm_state_indices,
scale=scale,
stride_mixed_qkv_tok=stride_mixed_qkv_tok,
stride_a_tok=stride_a_tok,
stride_b_tok=stride_b_tok,
stride_init_state_token=stride_init_state_token,
stride_final_state_token=stride_final_state_token,
stride_indices_seq=stride_indices_seq,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
SOFTPLUS_THRESHOLD=20.0,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
num_warps=cfg["num_warps"],
num_stages=cfg["num_stages"],
)
compiled_kernel = fused_recurrent_gated_delta_rule_packed_decode_kernel[grid](**launch_kwargs)
'''
if not HAS_DUMPED_PACKED_DECODE_KERNEL_METADATA and compiled_kernel is not None:
print("packed decode kernel metadata")
print(f" grid: {grid}")
print(f" registers: {compiled_kernel.n_regs}")
print(f" spills: {compiled_kernel.n_spills}")
print(f" shared memory: {compiled_kernel.metadata.shared} bytes")
HAS_DUMPED_PACKED_DECODE_KERNEL_METADATA = True
'''
return out, initial_state
# SPDX-License-Identifier: MIT
from typing import Tuple
import functools
import json
import os
import torch
import triton
import triton.language as tl
import aiter.ops.triton.utils.arch_info as arch_info
from aiter import logger
from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH
# HAS_DUMPED_SIGMOID_GATING_KERNEL_METADATA = False
TRITON_CONFIG_CHECK = os.environ.get("TRITON_CONFIG_CHECK", "0") == "1"
_DEFAULT_FUSED_SIGMOID_GATING_CONFIG = {
"BV": 32,
"num_warps": 1,
}
@functools.lru_cache(maxsize=1)
def _load_fused_sigmoid_gating_configs() -> dict:
device_name = arch_info.get_arch()
path = os.path.join(
AITER_TRITON_CONFIGS_PATH,
"fused_sigmoid_gating_delta_rule_update",
f"fused_sigmoid_gating_delta_rule_update-{device_name}.json",
)
if not os.path.exists(path):
logger.warning(
f"fused_sigmoid_gating_delta_rule_update config not found at {path}, "
f"using default {_DEFAULT_FUSED_SIGMOID_GATING_CONFIG}."
)
return {}
with open(path) as f:
payload = json.load(f)
return payload.get("config", {}) if isinstance(payload, dict) else {}
@functools.lru_cache
def _get_fused_sigmoid_gating_config(T: int, H: int, HV: int) -> dict:
cfgs = _load_fused_sigmoid_gating_configs()
key = f"T={T},H={H},HV={HV}"
cfg = cfgs.get(key)
if cfg is None:
candidates = []
for k, v in cfgs.items():
if k == "default":
continue
try:
parts = {x.split("=")[0]: int(x.split("=")[1]) for x in k.split(",")}
except Exception:
continue
if parts.get("H") == H and parts.get("HV") == HV and "T" in parts:
candidates.append((abs(parts["T"] - T), parts["T"], v))
if candidates:
candidates.sort(key=lambda x: x[0])
_, nearest_t, cfg = candidates[0]
if TRITON_CONFIG_CHECK:
logger.warning(
f"fused_sigmoid_gating config key '{key}' not found, "
f"using nearest-T config with T={nearest_t}: {cfg}."
)
if cfg is None:
default_cfg = cfgs.get("default", _DEFAULT_FUSED_SIGMOID_GATING_CONFIG)
if TRITON_CONFIG_CHECK:
logger.warning(
f"fused_sigmoid_gating config key '{key}' not found, "
f"using default config: {default_cfg}."
)
cfg = default_cfg
merged = dict(_DEFAULT_FUSED_SIGMOID_GATING_CONFIG)
merged.update(cfg)
return merged
@triton.heuristics(
{
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
"IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None,
"IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
}
)
@triton.jit(do_not_specialize=["N", "T"])
def fused_sigmoid_gating_delta_rule_update_kernel(
A_log,
a,
b,
dt_bias,
beta,
threshold,
q,
k,
v,
o,
h0,
ht,
cu_seqlens,
ssm_state_indices,
num_accepted_tokens,
scale,
N: tl.int64,
T: tl.int64,
B: tl.constexpr,
H: tl.constexpr,
HV: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
stride_init_state_token: tl.constexpr,
stride_final_state_token: tl.constexpr,
stride_indices_seq: tl.constexpr,
stride_indices_tok: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
INPLACE_FINAL_STATE: tl.constexpr,
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
IS_KDA: tl.constexpr,
):
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_hv = i_nh // HV, i_nh % HV
i_h = i_hv // (HV // H)
if IS_VARLEN:
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int64),
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
)
all = T
T = eos - bos
else:
bos, eos = i_n * T, i_n * T + T
all = B * T
if T == 0:
return
o_k = i_k * BK + tl.arange(0, BK)
o_v = i_v * BV + tl.arange(0, BV)
p_q = q + (bos * H + i_h) * K + o_k
p_k = k + (bos * H + i_h) * K + o_k
p_v = v + (bos * HV + i_hv) * V + o_v
p_A_log = A_log + i_hv
if not IS_KDA:
p_a = a + bos * HV + i_hv
p_dt_bias = dt_bias + i_hv
else:
p_a = a + (bos * HV + i_hv) * K + o_k
p_dt_bias = dt_bias + i_hv * K + o_k
p_b = b + bos * HV + i_hv
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
mask_k = o_k < K
mask_v = o_v < V
mask_h = mask_v[:, None] & mask_k[None, :]
b_A_log = tl.exp(tl.load(p_A_log).to(tl.float32))
if not IS_KDA:
b_dt_bias = tl.load(p_dt_bias).to(tl.float32)
b_h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
if IS_CONTINUOUS_BATCHING:
if IS_SPEC_DECODING:
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
else:
i_t = 0
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
tl.int64
)
if state_idx < 0:
return
p_h0 = h0 + state_idx * stride_init_state_token
else:
p_h0 = h0 + bos * HV * V * K
p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
for i_t in range(0, T):
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
b_b = tl.load(p_b).to(tl.float32)
if not IS_KDA:
x = tl.load(p_a).to(tl.float32) + b_dt_bias
else:
x = tl.load(p_a).to(tl.float32) + tl.load(p_dt_bias).to(tl.float32)
softplus_x = tl.where(
beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x
)
b_g = -b_A_log * softplus_x
b_beta = tl.sigmoid(b_b)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q * (tl.rsqrt(tl.sum(b_q * b_q) + 1e-6))
b_k = b_k * (tl.rsqrt(tl.sum(b_k * b_k) + 1e-6))
b_q = b_q * scale
if not IS_KDA:
b_h *= tl.exp(b_g)
else:
b_h *= tl.exp(b_g[None, :])
b_v -= tl.sum(b_h * b_k[None, :], 1)
b_v *= b_beta
b_h += b_v[:, None] * b_k[None, :]
b_o = tl.sum(b_h * b_q[None, :], 1)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
if INPLACE_FINAL_STATE:
final_state_idx = tl.load(
ssm_state_indices + i_n * stride_indices_seq + i_t
).to(tl.int64)
if final_state_idx >= 0:
p_ht = ht + final_state_idx * stride_final_state_token
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
else:
p_ht = ht + (bos + i_t) * stride_final_state_token
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
p_q += H * K
p_k += H * K
p_o += HV * V
p_v += HV * V
p_b += HV
p_a += HV
def fused_sigmoid_gating_delta_rule_update(
A_log: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
dt_bias: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
beta: float = 1.0,
threshold: float = 20.0,
scale: float | None = None,
initial_state: torch.Tensor | None = None,
inplace_final_state: bool = True,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
is_kda: bool = False,
kernel_cfg: dict | None = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
global HAS_DUMPED_SIGMOID_GATING_KERNEL_METADATA
B, T, H, K, V = *k.shape, v.shape[-1]
HV = v.shape[2]
N = B if cu_seqlens is None else len(cu_seqlens) - 1
BK = triton.next_power_of_2(K)
cfg = kernel_cfg if kernel_cfg is not None else _get_fused_sigmoid_gating_config(T, H, HV)
BV = min(triton.next_power_of_2(V), int(cfg["BV"]))
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
num_warps = int(cfg["num_warps"])
if NK != 1:
raise ValueError(f"NK > 1 is not supported (K={K}, BK={BK}, NK={NK}).")
if scale is None:
scale = K**-0.5
elif scale <= 0:
raise ValueError("scale must be positive.")
if initial_state is None:
raise ValueError("initial_state must not be None.")
if cu_seqlens is not None and q.shape[0] != 1:
raise ValueError(
f"q.shape[0] must be 1 when using cu_seqlens, got {q.shape[0]}."
)
o = q.new_empty(NK, *v.shape)
final_state = initial_state if inplace_final_state else q.new_empty(T, HV, V, K, dtype=initial_state.dtype)
stride_init_state_token = initial_state.stride(0)
stride_final_state_token = final_state.stride(0)
if ssm_state_indices is None:
stride_indices_seq, stride_indices_tok = 1, 1
elif ssm_state_indices.ndim == 1:
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
elif ssm_state_indices.ndim == 2:
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
else:
raise ValueError(
f"ssm_state_indices must be 1D/2D when provided, got ndim={ssm_state_indices.ndim}."
)
grid = (NK, NV, N * HV)
compiled_kernel = fused_sigmoid_gating_delta_rule_update_kernel[grid](
A_log=A_log,
a=a.contiguous(),
b=b.contiguous(),
dt_bias=dt_bias,
beta=beta,
threshold=threshold,
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
o=o,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
ssm_state_indices=ssm_state_indices,
num_accepted_tokens=num_accepted_tokens,
scale=scale,
N=N,
T=T,
B=B,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
stride_init_state_token=stride_init_state_token,
stride_final_state_token=stride_final_state_token,
stride_indices_seq=stride_indices_seq,
stride_indices_tok=stride_indices_tok,
INPLACE_FINAL_STATE=inplace_final_state,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
IS_KDA=is_kda,
num_warps=num_warps,
num_stages=1,
)
'''
if not HAS_DUMPED_SIGMOID_GATING_KERNEL_METADATA and compiled_kernel is not None:
print("sigmoid gating kernel metadata")
print(f" grid: {grid}")
print(f" registers: {compiled_kernel.n_regs}")
print(f" spills: {compiled_kernel.n_spills}")
print(f" shared memory: {compiled_kernel.metadata.shared} bytes")
HAS_DUMPED_SIGMOID_GATING_KERNEL_METADATA = True
'''
return o.squeeze(0), final_state
from typing import Optional
import functools
import json
import os
import torch
import triton
import triton.language as tl
import aiter.ops.triton.utils.arch_info as arch_info
from aiter import logger
from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH
# HAS_DUMPED_SIGMOID_GATING_REC_KERNEL_METADATA = False
TRITON_CONFIG_CHECK = os.environ.get("TRITON_CONFIG_CHECK", "0") == "1"
_DEFAULT_FUSED_SIGMOID_GATING_REC_CONFIG = {
"BV": 32,
"num_warps": 1,
}
@functools.lru_cache(maxsize=1)
def _load_fused_sigmoid_gating_recurrent_configs() -> dict:
device_name = arch_info.get_arch()
path = os.path.join(
AITER_TRITON_CONFIGS_PATH,
"fused_sigmoid_gating_delta_rule_update_recurrent",
f"fused_sigmoid_gating_delta_rule_update_recurrent-{device_name}.json",
)
if not os.path.exists(path):
logger.warning(
f"fused_sigmoid_gating_delta_rule_update_recurrent config not found at {path}, "
f"using default {_DEFAULT_FUSED_SIGMOID_GATING_REC_CONFIG}."
)
return {}
with open(path) as f:
payload = json.load(f)
return payload.get("config", {}) if isinstance(payload, dict) else {}
@functools.lru_cache
def _get_fused_sigmoid_gating_recurrent_config(T: int, H: int, HV: int) -> dict:
cfgs = _load_fused_sigmoid_gating_recurrent_configs()
key = f"T={T},H={H},HV={HV}"
cfg = cfgs.get(key)
if cfg is None:
candidates = []
for k, v in cfgs.items():
if k == "default":
continue
try:
parts = {x.split("=")[0]: int(x.split("=")[1]) for x in k.split(",")}
except Exception:
continue
if parts.get("H") == H and parts.get("HV") == HV and "T" in parts:
candidates.append((abs(parts["T"] - T), parts["T"], v))
if candidates:
candidates.sort(key=lambda x: x[0])
_, nearest_t, cfg = candidates[0]
if TRITON_CONFIG_CHECK:
logger.warning(
f"fused_sigmoid_gating_recurrent config key '{key}' not found, "
f"using nearest-T config with T={nearest_t}: {cfg}."
)
if cfg is None:
default_cfg = cfgs.get("default", _DEFAULT_FUSED_SIGMOID_GATING_REC_CONFIG)
if TRITON_CONFIG_CHECK:
logger.warning(
f"fused_sigmoid_gating_recurrent config key '{key}' not found, "
f"using default config: {default_cfg}."
)
cfg = default_cfg
merged = dict(_DEFAULT_FUSED_SIGMOID_GATING_REC_CONFIG)
merged.update(cfg)
return merged
@triton.jit(do_not_specialize=["T"])
def fused_sigmoid_gating_delta_rule_update_kernel(
A_log,
a,
dt_bias,
softplus_beta,
softplus_threshold,
q,
k,
v,
b,
o,
h0_source,
h0_indices,
cu_seqlens,
# Parameters for target_verify support (unused for decode)
intermediate_states_buffer,
intermediate_state_indices,
cache_steps,
retrieve_parent_token_ptr,
stride_retrieve_parent_token_seq: tl.constexpr,
stride_retrieve_parent_token_token: tl.constexpr,
# ================================================
scale,
T,
stride_q,
stride_k,
stride_v,
stride_b,
NP2_T: tl.constexpr,
B: tl.constexpr,
H: tl.constexpr,
HV: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_KDA: tl.constexpr,
# Optional flags for target_verify support (default False for decode)
DISABLE_STATE_UPDATE: tl.constexpr = False,
CACHE_INTERMEDIATE_STATES: tl.constexpr = False,
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: tl.constexpr = False,
):
"""
Fused kernel that combines sigmoid gating computation with recurrent delta rule update.
"""
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_hv = i_nh // HV, i_nh % HV
i_h = i_hv // (HV // H)
if IS_VARLEN:
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int64),
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
)
all = T
T = eos - bos
else:
bos, eos = i_n * T, i_n * T + T
all = B * T
o_k = i_k * BK + tl.arange(0, BK)
o_v = i_v * BV + tl.arange(0, BV)
p_q = q + bos * stride_q + i_h * K + o_k
p_k = k + bos * stride_k + i_h * K + o_k
p_v = v + bos * stride_v + i_hv * V + o_v
p_b = b + bos * stride_b + i_hv
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
# Gating computation pointers
p_A_log = A_log + i_hv
if IS_KDA:
p_a = a + (bos * HV + i_hv) * K + o_k
p_dt_bias = dt_bias + i_hv * K + o_k
else:
p_a = a + bos * HV + i_hv
p_dt_bias = dt_bias + i_hv
mask_k = o_k < K
mask_v = o_v < V
mask_h = mask_k[:, None] & mask_v[None, :]
b_h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
idx = tl.load(h0_indices + i_n)
if idx >= 0:
p_h0 = (
h0_source
+ idx * HV * K * V
+ i_hv * K * V
+ o_v[None, :] * K
+ o_k[:, None]
)
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
# Preload tree attention data if needed
if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:
token_indices = tl.arange(0, NP2_T)
mask_retrieve = token_indices < T
retrieve_parent_token_base = (
retrieve_parent_token_ptr
+ (i_n * stride_retrieve_parent_token_seq)
+ token_indices * stride_retrieve_parent_token_token
)
parent_idx_tokens = tl.load(
retrieve_parent_token_base, mask=mask_retrieve, other=0
)
# Prepare intermediate state cache index if enabled
cache_idx = -1
if CACHE_INTERMEDIATE_STATES:
cache_idx = tl.load(intermediate_state_indices + i_n)
# Invariant across timesteps.
b_A = tl.exp(tl.load(p_A_log).to(tl.float32))
if not IS_KDA:
b_dt_bias = tl.load(p_dt_bias).to(tl.float32)
step_idx = 0
for _ in range(0, T):
# Tree attention: load parent's cached state
if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:
# step_idx == 0 uses b_h from USE_INITIAL_STATE
if step_idx != 0 and cache_idx >= 0:
parent_step_idx = tl.sum(
tl.where(token_indices == step_idx, parent_idx_tokens, 0)
)
step_offset = parent_step_idx * HV * K * V
cache_ptr = (
intermediate_states_buffer
+ cache_idx * cache_steps * HV * K * V
+ step_offset
+ i_hv * K * V
+ o_v[None, :] * K
+ o_k[:, None]
)
b_h = tl.load(cache_ptr, mask=mask_h, other=0).to(tl.float32)
# Load k first; q is loaded later right before output to reduce register live range.
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
# Compute sigmoid gating
# Load gating parameters
if IS_KDA:
b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
b_dt_bias = tl.load(p_dt_bias, mask=mask_k, other=0).to(tl.float32)
else:
b_a = tl.load(p_a).to(tl.float32)
# Compute g with tighter live ranges for intermediates.
x = b_a + b_dt_bias
x_scaled = softplus_beta * x
x = tl.where(
x_scaled <= softplus_threshold,
(1.0 / softplus_beta) * tl.log(1.0 + tl.exp(x_scaled)),
x,
)
b_g = -b_A * x
# Apply L2 normalization to k early; q normalization is deferred until q is loaded.
if USE_QK_L2NORM_IN_KERNEL:
b_k = b_k * tl.rsqrt(tl.sum(b_k * b_k) + 1e-6)
# Apply gating to hidden state: h *= exp(g)
if IS_KDA:
b_h *= tl.exp(b_g[:, None])
else:
b_h *= tl.exp(b_g)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
# Delta rule: v -= sum(h * k, dim=0)
b_v -= tl.sum(b_h * b_k[:, None], 0)
# Apply beta gating: v *= beta
b_v *= tl.sigmoid(tl.load(p_b).to(tl.float32))
# Update hidden state: h += k[:, None] * v[None, :]
b_h += b_k[:, None] * b_v[None, :]
# Load q late to shorten q live range and lower peak register pressure.
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q * tl.rsqrt(tl.sum(b_q * b_q) + 1e-6)
b_q = b_q * scale
# Compute output: o = sum(h * q, dim=0)
b_o = tl.sum(b_h * b_q[:, None], 0)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
# Cache intermediate states if enabled
if CACHE_INTERMEDIATE_STATES:
if cache_idx >= 0:
step_offset = step_idx * HV * K * V
cache_ptr = (
intermediate_states_buffer
+ cache_idx * cache_steps * HV * K * V
+ step_offset
+ i_hv * K * V
+ o_v[None, :] * K
+ o_k[:, None]
)
tl.store(cache_ptr, b_h.to(cache_ptr.dtype.element_ty), mask=mask_h)
step_idx += 1
# Update pointers for next timestep
p_q += stride_q
p_k += stride_k
p_v += stride_v
p_b += stride_b
p_o += HV * V
if IS_KDA:
p_a += HV * K
else:
p_a += HV
# Store final state back to h0_source with bounds checking
if not DISABLE_STATE_UPDATE:
if USE_INITIAL_STATE:
idx = tl.load(h0_indices + i_n)
if idx >= 0:
p_h0 = (
h0_source
+ idx * HV * K * V
+ i_hv * K * V
+ o_v[None, :] * K
+ o_k[:, None]
)
tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h)
def fused_sigmoid_gating_delta_rule_update(
A_log: torch.Tensor,
a: torch.Tensor,
dt_bias: torch.Tensor,
softplus_beta: float,
softplus_threshold: float,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
b: torch.Tensor,
initial_state_source: torch.Tensor,
initial_state_indices: torch.Tensor,
scale: Optional[float] = None,
use_qk_l2norm_in_kernel: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
is_kda: bool = False,
# Optional parameters for target_verify support
disable_state_update: bool = False,
intermediate_states_buffer: Optional[torch.Tensor] = None,
intermediate_state_indices: Optional[torch.Tensor] = None,
cache_steps: Optional[int] = None,
retrieve_parent_token: Optional[torch.Tensor] = None,
kernel_cfg: dict | None = None,
):
global HAS_DUMPED_SIGMOID_GATING_REC_KERNEL_METADATA
"""
Fused triton implementation of sigmoid gating delta rule update.
This function uses a single fused kernel that combines both sigmoid gating computation
and the recurrent delta rule update for better performance.
Supports both decode and target_verify modes:
- decode: standard single-step update with state write-back
- target_verify: multi-step with intermediate state caching, optional tree attention,
and optional state update disable
"""
B, T, H, K, V = *k.shape, v.shape[-1]
stride_q = q.stride()[1]
stride_k = k.stride()[1]
stride_v = v.stride()[1]
stride_b = b.stride()[-2]
HV = v.shape[2]
N = B if cu_seqlens is None else len(cu_seqlens) - 1
BK = triton.next_power_of_2(K)
cfg = kernel_cfg if kernel_cfg is not None else _get_fused_sigmoid_gating_recurrent_config(T, H, HV)
BV = min(triton.next_power_of_2(V), int(cfg["BV"]))
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
assert NK == 1, "NK > 1 is not supported yet"
num_warps = int(cfg["num_warps"])
if scale is None:
scale = k.shape[-1] ** -0.5
else:
assert scale > 0, "scale must be positive"
o = q.new_empty(NK, *v.shape)
# Prepare retrieve_parent_token strides
if retrieve_parent_token is not None:
stride_retrieve_parent_token_seq = retrieve_parent_token.stride(0)
stride_retrieve_parent_token_token = retrieve_parent_token.stride(1)
else:
stride_retrieve_parent_token_seq = 0
stride_retrieve_parent_token_token = 0
NP2_T = triton.next_power_of_2(T)
grid = (NK, NV, N * HV)
compiled_kernel = fused_sigmoid_gating_delta_rule_update_kernel[grid](
A_log=A_log,
a=a,
dt_bias=dt_bias,
softplus_beta=softplus_beta,
softplus_threshold=softplus_threshold,
q=q,
k=k,
v=v,
b=b,
o=o,
h0_source=initial_state_source,
h0_indices=initial_state_indices,
cu_seqlens=cu_seqlens,
intermediate_states_buffer=intermediate_states_buffer,
intermediate_state_indices=intermediate_state_indices,
cache_steps=0 if cache_steps is None else cache_steps,
retrieve_parent_token_ptr=retrieve_parent_token,
stride_retrieve_parent_token_seq=stride_retrieve_parent_token_seq,
stride_retrieve_parent_token_token=stride_retrieve_parent_token_token,
scale=scale,
T=T,
stride_q=stride_q,
stride_k=stride_k,
stride_v=stride_v,
stride_b=stride_b,
NP2_T=NP2_T,
B=B,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
USE_INITIAL_STATE=initial_state_source is not None,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
IS_VARLEN=cu_seqlens is not None,
IS_KDA=is_kda,
DISABLE_STATE_UPDATE=disable_state_update,
CACHE_INTERMEDIATE_STATES=intermediate_states_buffer is not None,
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK=retrieve_parent_token is not None,
num_warps=num_warps,
num_stages=1,
)
'''
if not HAS_DUMPED_SIGMOID_GATING_REC_KERNEL_METADATA and compiled_kernel is not None:
print("sigmoid gating recurrent kernel metadata")
print(f" grid: {grid}")
print(f" registers: {compiled_kernel.n_regs}")
print(f" spills: {compiled_kernel.n_spills}")
print(f" shared memory: {compiled_kernel.metadata.shared} bytes")
HAS_DUMPED_SIGMOID_GATING_REC_KERNEL_METADATA = True
'''
o = o.squeeze(0)
return o
from typing import Optional
import torch
import triton
import triton.language as tl
@triton.jit(do_not_specialize=["T"])
def fused_sigmoid_gating_delta_rule_update_kernel_ref(
A_log,
a,
dt_bias,
softplus_beta,
softplus_threshold,
q,
k,
v,
b,
o,
h0_source,
h0_indices,
cu_seqlens,
# Parameters for target_verify support (unused for decode)
intermediate_states_buffer,
intermediate_state_indices,
cache_steps,
retrieve_parent_token_ptr,
stride_retrieve_parent_token_seq: tl.constexpr,
stride_retrieve_parent_token_token: tl.constexpr,
# ================================================
scale,
T,
stride_q,
stride_k,
stride_v,
stride_b,
NP2_T: tl.constexpr,
B: tl.constexpr,
H: tl.constexpr,
HV: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_KDA: tl.constexpr,
# Optional flags for target_verify support (default False for decode)
DISABLE_STATE_UPDATE: tl.constexpr = False,
CACHE_INTERMEDIATE_STATES: tl.constexpr = False,
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK: tl.constexpr = False,
):
"""
Fused kernel that combines sigmoid gating computation with recurrent delta rule update.
"""
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_hv = i_nh // HV, i_nh % HV
i_h = i_hv // (HV // H)
if IS_VARLEN:
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int64),
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
)
all = T
T = eos - bos
else:
bos, eos = i_n * T, i_n * T + T
all = B * T
o_k = i_k * BK + tl.arange(0, BK)
o_v = i_v * BV + tl.arange(0, BV)
p_q = q + bos * stride_q + i_h * K + o_k
p_k = k + bos * stride_k + i_h * K + o_k
p_v = v + bos * stride_v + i_hv * V + o_v
p_b = b + bos * stride_b + i_hv
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
# Gating computation pointers
p_A_log = A_log + i_hv
if IS_KDA:
p_a = a + (bos * HV + i_hv) * K + o_k
p_dt_bias = dt_bias + i_hv * K + o_k
else:
p_a = a + bos * HV + i_hv
p_dt_bias = dt_bias + i_hv
mask_k = o_k < K
mask_v = o_v < V
mask_h = mask_k[:, None] & mask_v[None, :]
b_h = tl.zeros([BK, BV], dtype=tl.float32)
if USE_INITIAL_STATE:
idx = tl.load(h0_indices + i_n)
if idx >= 0:
p_h0 = (
h0_source
+ idx * HV * K * V
+ i_hv * K * V
+ o_v[None, :] * K
+ o_k[:, None]
)
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
# Preload tree attention data if needed
if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:
token_indices = tl.arange(0, NP2_T)
mask_retrieve = token_indices < T
retrieve_parent_token_base = (
retrieve_parent_token_ptr
+ (i_n * stride_retrieve_parent_token_seq)
+ token_indices * stride_retrieve_parent_token_token
)
parent_idx_tokens = tl.load(
retrieve_parent_token_base, mask=mask_retrieve, other=0
)
# Prepare intermediate state cache index if enabled
cache_idx = -1
if CACHE_INTERMEDIATE_STATES:
cache_idx = tl.load(intermediate_state_indices + i_n)
step_idx = 0
for _ in range(0, T):
# Tree attention: load parent's cached state
if HAS_EAGLE_TREE_CUSTOM_ATTN_MASK:
# step_idx == 0 uses b_h from USE_INITIAL_STATE
if step_idx != 0 and cache_idx >= 0:
parent_step_idx = tl.sum(
tl.where(token_indices == step_idx, parent_idx_tokens, 0)
)
step_offset = parent_step_idx * HV * K * V
cache_ptr = (
intermediate_states_buffer
+ cache_idx * cache_steps * HV * K * V
+ step_offset
+ i_hv * K * V
+ o_v[None, :] * K
+ o_k[:, None]
)
b_h = tl.load(cache_ptr, mask=mask_h, other=0).to(tl.float32)
# Load inputs
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
b_b = tl.load(p_b).to(tl.float32)
# Compute sigmoid gating
# Load gating parameters
b_A_log = tl.load(p_A_log).to(tl.float32)
if IS_KDA:
b_a = tl.load(p_a, mask=mask_k, other=0).to(tl.float32)
b_dt_bias = tl.load(p_dt_bias, mask=mask_k, other=0).to(tl.float32)
else:
b_a = tl.load(p_a).to(tl.float32)
b_dt_bias = tl.load(p_dt_bias).to(tl.float32)
# Compute g = -exp(A_log) * softplus(a + dt_bias)
x = b_a + b_dt_bias
beta_x = softplus_beta * x
# Apply softplus with numerical stability
softplus_x = tl.where(
beta_x <= softplus_threshold,
(1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)),
x,
)
b_g = -tl.exp(b_A_log) * softplus_x
# Compute beta = sigmoid(b)
b_beta = 1.0 / (1.0 + tl.exp(-b_b))
# Apply L2 normalization if enabled
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q) + 1e-6))
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k) + 1e-6))
b_q = b_q * scale
# Apply gating to hidden state: h *= exp(g)
if IS_KDA:
b_h *= tl.exp(b_g[:, None])
else:
b_h *= tl.exp(b_g)
# Delta rule: v -= sum(h * k, dim=0)
b_v -= tl.sum(b_h * b_k[:, None], 0)
# Apply beta gating: v *= beta
b_v *= b_beta
# Update hidden state: h += k[:, None] * v[None, :]
b_h += b_k[:, None] * b_v[None, :]
# Compute output: o = sum(h * q, dim=0)
b_o = tl.sum(b_h * b_q[:, None], 0)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
# Cache intermediate states if enabled
if CACHE_INTERMEDIATE_STATES:
if cache_idx >= 0:
step_offset = step_idx * HV * K * V
cache_ptr = (
intermediate_states_buffer
+ cache_idx * cache_steps * HV * K * V
+ step_offset
+ i_hv * K * V
+ o_v[None, :] * K
+ o_k[:, None]
)
tl.store(cache_ptr, b_h.to(cache_ptr.dtype.element_ty), mask=mask_h)
step_idx += 1
# Update pointers for next timestep
p_q += stride_q
p_k += stride_k
p_v += stride_v
p_b += stride_b
p_o += HV * V
if IS_KDA:
p_a += HV * K
else:
p_a += HV
# Store final state back to h0_source with bounds checking
if not DISABLE_STATE_UPDATE:
if USE_INITIAL_STATE:
idx = tl.load(h0_indices + i_n)
if idx >= 0:
p_h0 = (
h0_source
+ idx * HV * K * V
+ i_hv * K * V
+ o_v[None, :] * K
+ o_k[:, None]
)
tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h)
def fused_sigmoid_gating_delta_rule_update(
A_log: torch.Tensor,
a: torch.Tensor,
dt_bias: torch.Tensor,
softplus_beta: float,
softplus_threshold: float,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
b: torch.Tensor,
initial_state_source: torch.Tensor,
initial_state_indices: torch.Tensor,
scale: Optional[float] = None,
use_qk_l2norm_in_kernel: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
is_kda: bool = False,
# Optional parameters for target_verify support
disable_state_update: bool = False,
intermediate_states_buffer: Optional[torch.Tensor] = None,
intermediate_state_indices: Optional[torch.Tensor] = None,
cache_steps: Optional[int] = None,
retrieve_parent_token: Optional[torch.Tensor] = None,
):
"""
Fused triton implementation of sigmoid gating delta rule update.
This function uses a single fused kernel that combines both sigmoid gating computation
and the recurrent delta rule update for better performance.
Supports both decode and target_verify modes:
- decode: standard single-step update with state write-back
- target_verify: multi-step with intermediate state caching, optional tree attention,
and optional state update disable
"""
B, T, H, K, V = *k.shape, v.shape[-1]
stride_q = q.stride()[1]
stride_k = k.stride()[1]
stride_v = v.stride()[1]
stride_b = b.stride()[-2]
HV = v.shape[2]
N = B if cu_seqlens is None else len(cu_seqlens) - 1
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
assert NK == 1, "NK > 1 is not supported yet"
num_stages = 3
num_warps = 1
if scale is None:
scale = k.shape[-1] ** -0.5
else:
assert scale > 0, "scale must be positive"
o = q.new_empty(NK, *v.shape)
# Prepare retrieve_parent_token strides
if retrieve_parent_token is not None:
stride_retrieve_parent_token_seq = retrieve_parent_token.stride(0)
stride_retrieve_parent_token_token = retrieve_parent_token.stride(1)
else:
stride_retrieve_parent_token_seq = 0
stride_retrieve_parent_token_token = 0
NP2_T = triton.next_power_of_2(T)
grid = (NK, NV, N * HV)
fused_sigmoid_gating_delta_rule_update_kernel_ref[grid](
A_log=A_log,
a=a,
dt_bias=dt_bias,
softplus_beta=softplus_beta,
softplus_threshold=softplus_threshold,
q=q,
k=k,
v=v,
b=b,
o=o,
h0_source=initial_state_source,
h0_indices=initial_state_indices,
cu_seqlens=cu_seqlens,
intermediate_states_buffer=intermediate_states_buffer,
intermediate_state_indices=intermediate_state_indices,
cache_steps=0 if cache_steps is None else cache_steps,
retrieve_parent_token_ptr=retrieve_parent_token,
stride_retrieve_parent_token_seq=stride_retrieve_parent_token_seq,
stride_retrieve_parent_token_token=stride_retrieve_parent_token_token,
scale=scale,
T=T,
stride_q=stride_q,
stride_k=stride_k,
stride_v=stride_v,
stride_b=stride_b,
NP2_T=NP2_T,
B=B,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
USE_INITIAL_STATE=initial_state_source is not None,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
IS_VARLEN=cu_seqlens is not None,
IS_KDA=is_kda,
DISABLE_STATE_UPDATE=disable_state_update,
CACHE_INTERMEDIATE_STATES=intermediate_states_buffer is not None,
HAS_EAGLE_TREE_CUSTOM_ATTN_MASK=retrieve_parent_token is not None,
num_warps=num_warps,
num_stages=num_stages,
)
o = o.squeeze(0)
return o
# SPDX-License-Identifier: MIT
import functools
import json
import os
import torch
import triton
import triton.language as tl
import aiter.ops.triton.utils.arch_info as arch_info
from aiter import logger
from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH
TRITON_CONFIG_CHECK = os.environ.get("TRITON_CONFIG_CHECK", "0") == "1"
HAS_DUMPED_CHUNK_DELTA_H_KERNEL_METADATA = False
@triton.jit
def safe_exp(x):
return exp(tl.where(x <= 0, x, float("-inf")))
@triton.jit
def exp(x):
return tl.exp(x)
@triton.jit
def exp2(x):
return tl.math.exp2(x)
def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
chunk_rows = []
for i in range(len(cu_seqlens) - 1):
seqlen = int((cu_seqlens[i + 1] - cu_seqlens[i]).item())
n_chunks = triton.cdiv(seqlen, chunk_size)
for chunk_idx in range(n_chunks):
chunk_rows.append([i, chunk_idx])
if len(chunk_rows) == 0:
return torch.empty((0, 2), dtype=torch.long, device=cu_seqlens.device)
return torch.tensor(chunk_rows, dtype=torch.long, device=cu_seqlens.device)
def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
chunk_counts = (seq_lens + chunk_size - 1) // chunk_size
offsets = torch.zeros_like(chunk_counts)
if len(offsets) > 1:
offsets[1:] = torch.cumsum(chunk_counts, dim=0)[:-1]
return offsets
_DEFAULT_CHUNK_DELTA_H_CONFIG = {
"BV": 32,
"num_warps": 8,
"num_stages": 2,
}
@functools.lru_cache(maxsize=1)
def _load_chunk_delta_h_configs() -> dict:
device_name = arch_info.get_arch()
path = os.path.join(
AITER_TRITON_CONFIGS_PATH,
"chunk_gated_delta_rule_fwd_h",
f"chunk_gated_delta_rule_fwd_h-{device_name}.json",
)
if not os.path.exists(path):
logger.warning(
f"chunk_gated_delta_rule_fwd_h config not found at {path}, using default {_DEFAULT_CHUNK_DELTA_H_CONFIG}."
)
return {}
with open(path) as f:
payload = json.load(f)
return payload.get("config", {}) if isinstance(payload, dict) else {}
@functools.lru_cache
def _get_chunk_delta_h_config(K: int, V: int, BT: int, H: int) -> dict:
cfgs = _load_chunk_delta_h_configs()
key = f"K={K},V={V},BT={BT},H={H}"
cfg = cfgs.get(key)
if cfg is None:
default_cfg = cfgs.get("default", _DEFAULT_CHUNK_DELTA_H_CONFIG)
if TRITON_CONFIG_CHECK:
logger.warning(
"chunk_gated_delta_rule_fwd_h config missing for "
f"{key}, using default config {default_cfg}."
)
cfg = default_cfg
merged = dict(_DEFAULT_CHUNK_DELTA_H_CONFIG)
merged.update(cfg)
return merged
def launch_chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
*,
k: torch.Tensor,
u: torch.Tensor,
w: torch.Tensor,
v_new: torch.Tensor | None,
g: torch.Tensor | None,
gk: torch.Tensor | None,
h: torch.Tensor,
initial_state: torch.Tensor | None,
initial_state_indices: torch.Tensor | None,
# final_state: torch.Tensor | None,
cu_seqlens: torch.LongTensor | None,
chunk_offsets: torch.LongTensor | None,
N: int,
T: int,
H: int,
Hg: int,
K: int,
V: int,
BT: int,
kernel_cfg: dict | None,
):
global HAS_DUMPED_CHUNK_DELTA_H_KERNEL_METADATA
def grid(meta):
return (triton.cdiv(V, meta["BV"]), N * H)
cfg = kernel_cfg if kernel_cfg is not None else _get_chunk_delta_h_config(K, V, BT, H)
launch_grid = (triton.cdiv(V, cfg["BV"]), N * H)
compiled_kernel = chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid](
k=k,
v=u,
w=w,
v_new=v_new,
g=g,
gk=gk,
h=h,
initial_state=initial_state,
initial_state_indices=initial_state_indices,
cu_seqlens=cu_seqlens,
chunk_offsets=chunk_offsets,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BV=cfg["BV"],
INPLACE_UPDATE=True,
num_warps=cfg["num_warps"],
num_stages=cfg["num_stages"],
)
if (
TRITON_CONFIG_CHECK
and not HAS_DUMPED_CHUNK_DELTA_H_KERNEL_METADATA
and compiled_kernel is not None
):
print("chunk_gated_delta_rule_fwd_kernel_h_blockdim64 metadata")
print(f" grid: {launch_grid}")
print(
f" meta: BT={BT}, BV={cfg['BV']}, K={K}, V={V}, H={H}, Hg={Hg}, N={N}, T={T}, "
f"num_warps={cfg['num_warps']}, num_stages={cfg['num_stages']}"
)
print(f" registers: {compiled_kernel.n_regs}")
print(f" spills: {compiled_kernel.n_spills}")
print(f" shared memory: {compiled_kernel.metadata.shared} bytes")
HAS_DUMPED_CHUNK_DELTA_H_KERNEL_METADATA = True
@triton.heuristics({
"USE_G": lambda args: args["g"] is not None,
"USE_GK": lambda args: args["gk"] is not None,
"USE_INITIAL_STATE": lambda args: args["initial_state"] is not None,
# "USE_INITIAL_STATE_INDICES": lambda args: args["initial_state_indices"] is not None,
# "STORE_FINAL_STATE": lambda args: args["ht"] is not None,
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
})
@triton.jit(do_not_specialize=["T"])
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
k,
v,
w,
v_new,
g,
gk,
h,
initial_state,
initial_state_indices,
cu_seqlens,
chunk_offsets,
T,
H: tl.constexpr,
Hg: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_GK: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
INPLACE_UPDATE: tl.constexpr,
SAVE_NEW_VALUE: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_v, i_nh = tl.program_id(0), tl.program_id(1)
i_n, i_h = i_nh // H, i_nh % H
if IS_VARLEN:
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
cu_seqlens + i_n + 1
).to(tl.int32)
T = eos - bos
NT = tl.cdiv(T, BT)
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
else:
bos, eos = i_n * T, i_n * T + T
NT = tl.cdiv(T, BT)
boh = i_n * NT
# [BV, BK]
b_h1 = tl.zeros([BV, 64], dtype=tl.float32)
if K > 64:
b_h2 = tl.zeros([BV, 64], dtype=tl.float32)
if K > 128:
b_h3 = tl.zeros([BV, 64], dtype=tl.float32)
if K > 192:
b_h4 = tl.zeros([BV, 64], dtype=tl.float32)
# calculate offset
h += ((boh * H + i_h) * V * K).to(tl.int64)
v += ((bos * H + i_h) * V).to(tl.int64)
k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64)
w += ((bos * H + i_h) * K).to(tl.int64)
if SAVE_NEW_VALUE:
v_new += ((bos * H + i_h) * V).to(tl.int64)
stride_v = H * V
stride_h = H * V * K
stride_k = Hg * K
stride_w = H * K
index = tl.load(initial_state_indices + i_n).to(tl.int32)
h0 = initial_state + index * stride_h
ht = initial_state + index * stride_h
if USE_INITIAL_STATE:
h0 = h0 + i_h * V * K
if INPLACE_UPDATE:
ht = ht + i_h * V * K
# load initial state
if USE_INITIAL_STATE:
p_h0_1 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0))
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
if K > 64:
p_h0_2 = tl.make_block_ptr(
h0, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
)
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
if K > 128:
p_h0_3 = tl.make_block_ptr(
h0, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
)
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
if K > 192:
p_h0_4 = tl.make_block_ptr(
h0, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
)
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
# main recurrence
for i_t in range(NT):
p_h1 = tl.make_block_ptr(
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)
)
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_h2 = tl.make_block_ptr(
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
)
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_h3 = tl.make_block_ptr(
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
)
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_h4 = tl.make_block_ptr(
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
)
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
p_w1 = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)
)
b_w1 = tl.load(p_w1, boundary_check=(0, 1))
if K > 64:
p_w2 = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)
)
b_w2 = tl.load(p_w2, boundary_check=(0, 1))
if K > 128:
p_w3 = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)
)
b_w3 = tl.load(p_w3, boundary_check=(0, 1))
if K > 192:
p_w4 = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)
)
b_w4 = tl.load(p_w4, boundary_check=(0, 1))
b_v = tl.dot(b_w1, tl.trans(b_h1).to(b_w1.dtype))
if K > 64:
b_v += tl.dot(b_w2, tl.trans(b_h2).to(b_w2.dtype))
if K > 128:
b_v += tl.dot(b_w3, tl.trans(b_h3).to(b_w3.dtype))
if K > 192:
b_v += tl.dot(b_w4, tl.trans(b_h4).to(b_w4.dtype))
p_v = tl.make_block_ptr(
v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v
if SAVE_NEW_VALUE:
p_v = tl.make_block_ptr(
v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))
last_idx = min((i_t + 1) * BT, T) - 1
if USE_G:
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
p_g = tl.make_block_ptr(
g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
)
b_g = tl.load(p_g, boundary_check=(0,))
b_v = b_v * safe_exp(b_g_last - b_g)[:, None]
b_g_last = exp(b_g_last)
b_h1 = b_h1 * b_g_last
if K > 64:
b_h2 = b_h2 * b_g_last
if K > 128:
b_h3 = b_h3 * b_g_last
if K > 192:
b_h4 = b_h4 * b_g_last
if USE_GK:
o_k1 = tl.arange(0, 64)
b_gk_last1 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k1,
mask=(o_k1 < K),
other=0.0,
)
b_h1 *= exp(b_gk_last1)[None, :]
if K > 64:
o_k2 = 64 + o_k1
b_gk_last2 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k2,
mask=(o_k2 < K),
other=0.0,
)
b_h2 *= exp(b_gk_last2)[None, :]
if K > 128:
o_k3 = 128 + o_k1
b_gk_last3 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k3,
mask=(o_k3 < K),
other=0.0,
)
b_h3 *= exp(b_gk_last3)[None, :]
if K > 192:
o_k4 = 192 + o_k1
b_gk_last4 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k4,
mask=(o_k4 < K),
other=0.0,
)
b_h4 *= exp(b_gk_last4)[None, :]
b_v = b_v.to(k.dtype.element_ty)
p_k1 = tl.make_block_ptr(
k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)
)
b_k1 = tl.load(p_k1, boundary_check=(0, 1))
if K > 64:
p_k2 = tl.make_block_ptr(
k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)
)
b_k2 = tl.load(p_k2, boundary_check=(0, 1))
if K > 128:
p_k3 = tl.make_block_ptr(
k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)
)
b_k3 = tl.load(p_k3, boundary_check=(0, 1))
if K > 192:
p_k4 = tl.make_block_ptr(
k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)
)
b_k4 = tl.load(p_k4, boundary_check=(0, 1))
b_h1 += tl.trans(tl.dot(b_k1, b_v))
if K > 64:
b_h2 += tl.trans(tl.dot(b_k2, b_v))
if K > 128:
b_h3 += tl.trans(tl.dot(b_k3, b_v))
if K > 192:
b_h4 += tl.trans(tl.dot(b_k4, b_v))
# epilogue
if INPLACE_UPDATE:
p_ht = tl.make_block_ptr(ht, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0))
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_ht = tl.make_block_ptr(
ht, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
)
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_ht = tl.make_block_ptr(
ht, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
)
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_ht = tl.make_block_ptr(
ht, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
)
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
def chunk_gated_delta_rule_fwd_h(
k: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
g: torch.Tensor | None = None,
gk: torch.Tensor | None = None,
initial_state: torch.Tensor | None = None,
initial_state_indices: torch.Tensor | None = None,
output_final_state: bool = True,
chunk_size: int = 64,
save_new_value: bool = True,
cu_seqlens: torch.LongTensor | None = None,
chunk_indices: torch.LongTensor | None = None,
use_exp2: bool = False,
transpose_state_layout: bool = True,
kernel_cfg: dict | None = None,
):
B, T, Hg, K, V = *k.shape, u.shape[-1]
H = u.shape[-2]
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is not None
else None
)
# N: the actual number of sequences in the batch with either equal or variable lengths
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
else:
N, NT, chunk_offsets = (
len(cu_seqlens) - 1,
len(chunk_indices),
prepare_chunk_offsets(cu_seqlens, BT),
)
assert K <= 256, "current kernel does not support head dimension larger than 256."
h = k.new_empty(B, NT, H, V, K)
v_new = torch.empty_like(u) if save_new_value else None
launch_chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
k=k,
u=u,
w=w,
v_new=v_new,
g=g,
gk=gk,
h=h,
initial_state=initial_state,
initial_state_indices=initial_state_indices,
cu_seqlens=cu_seqlens,
chunk_offsets=chunk_offsets,
N=N,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
# use_exp2=use_exp2,
# transpose_state_layout=transpose_state_layout,
kernel_cfg=kernel_cfg,
)
return h, v_new
# SPDX-License-Identifier: MIT
import functools
import json
import os
import torch
import triton
import triton.language as tl
import aiter.ops.triton.utils.arch_info as arch_info
from aiter import logger
from aiter.ops.triton.utils.core import AITER_TRITON_CONFIGS_PATH
TRITON_CONFIG_CHECK = os.environ.get("TRITON_CONFIG_CHECK", "0") == "1"
HAS_DUMPED_CHUNK_FWD_O_KERNEL_METADATA = False
@triton.jit
def safe_exp(x):
return exp(tl.where(x <= 0, x, float("-inf")))
@triton.jit
def exp(x):
return tl.exp(x)
@triton.jit
def exp2(x):
return tl.math.exp2(x)
def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
chunk_rows = []
for i in range(len(cu_seqlens) - 1):
seqlen = int((cu_seqlens[i + 1] - cu_seqlens[i]).item())
n_chunks = triton.cdiv(seqlen, chunk_size)
for chunk_idx in range(n_chunks):
chunk_rows.append([i, chunk_idx])
if len(chunk_rows) == 0:
return torch.empty((0, 2), dtype=torch.long, device=cu_seqlens.device)
return torch.tensor(chunk_rows, dtype=torch.long, device=cu_seqlens.device)
_DEFAULT_CHUNK_O_CONFIG = {
"BK": 128,
"BV": 64,
"num_warps": 4,
"num_stages": 2,
}
@functools.lru_cache(maxsize=1)
def _load_chunk_o_configs() -> dict:
device_name = arch_info.get_arch()
path = os.path.join(
AITER_TRITON_CONFIGS_PATH,
"chunk_fwd_o",
f"chunk_fwd_o-{device_name}.json",
)
if not os.path.exists(path):
logger.warning(
f"chunk_fwd_o config not found at {path}, using default {_DEFAULT_CHUNK_O_CONFIG}."
)
return {}
with open(path) as f:
payload = json.load(f)
return payload.get("config", {}) if isinstance(payload, dict) else {}
@functools.lru_cache
def _get_chunk_o_config(K: int, V: int, BT: int) -> dict:
cfgs = _load_chunk_o_configs()
key = f"K={K},V={V},BT={BT}"
cfg = cfgs.get(key)
if cfg is None:
default_cfg = cfgs.get("default", _DEFAULT_CHUNK_O_CONFIG)
if TRITON_CONFIG_CHECK:
logger.warning(
"chunk_fwd_o config missing for "
f"{key}, using default config {default_cfg}."
)
cfg = default_cfg
merged = dict(_DEFAULT_CHUNK_O_CONFIG)
merged.update(cfg)
return merged
def launch_chunk_fwd_kernel_o(
*,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
h: torch.Tensor,
g: torch.Tensor | None,
o: torch.Tensor,
cu_seqlens: torch.LongTensor | None,
chunk_indices: torch.LongTensor | None,
scale: float,
T: int,
H: int,
Hg: int,
K: int,
V: int,
BT: int,
NT: int,
B: int,
kernel_cfg: dict | None,
):
global HAS_DUMPED_CHUNK_FWD_O_KERNEL_METADATA
def grid(meta):
return (triton.cdiv(V, meta["BV"]), NT, B * H)
cfg = kernel_cfg if kernel_cfg is not None else _get_chunk_o_config(K, V, BT)
launch_grid = (triton.cdiv(V, cfg["BV"]), NT, B * H)
compiled_kernel = chunk_fwd_kernel_o[grid](
q=q,
k=k,
v=v,
h=h,
g=g,
o=o,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
scale=scale,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=cfg["BK"],
BV=cfg["BV"],
num_warps=cfg["num_warps"],
num_stages=cfg["num_stages"],
)
if (
TRITON_CONFIG_CHECK
and not HAS_DUMPED_CHUNK_FWD_O_KERNEL_METADATA
and compiled_kernel is not None
):
print("chunk_fwd_kernel_o metadata")
print(f" grid: {launch_grid}")
print(
f" meta: BT={BT}, BK={cfg['BK']}, BV={cfg['BV']}, K={K}, V={V}, H={H}, Hg={Hg}, "
f"NT={NT}, B={B}, T={T}, num_warps={cfg['num_warps']}, num_stages={cfg['num_stages']}"
)
print(f" registers: {compiled_kernel.n_regs}")
print(f" spills: {compiled_kernel.n_spills}")
print(f" shared memory: {compiled_kernel.metadata.shared} bytes")
HAS_DUMPED_CHUNK_FWD_O_KERNEL_METADATA = True
@triton.heuristics({
"USE_G": lambda args: args["g"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
})
@triton.jit(do_not_specialize=["T"])
def chunk_fwd_kernel_o(
q,
k,
v,
h,
g,
o,
cu_seqlens,
chunk_indices,
scale,
T,
H: tl.constexpr,
Hg: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_tg = i_t
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
NT = tl.cdiv(T, BT)
i_tg = i_b * NT + i_t
bos, eos = i_b * T, i_b * T + T
q += (bos * Hg + i_h // (H // Hg)) * K
k += (bos * Hg + i_h // (H // Hg)) * K
v += (bos * H + i_h) * V
o += (bos * H + i_h) * V
h += (i_tg * H + i_h).to(tl.int64) * V * K
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(
q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_k = tl.make_block_ptr(
k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
)
p_h = tl.make_block_ptr(
h, (V, K), (K, 1), (i_v * BV, i_k * BK), (BV, BK), (1, 0)
)
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h = tl.load(p_h, boundary_check=(0, 1))
b_o += tl.dot(b_q, tl.trans(b_h))
b_A += tl.dot(b_q, b_k)
if USE_G:
g += bos * H + i_h
p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_o = b_o * exp(b_g)[:, None]
b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :])
o_i = tl.arange(0, BT)
m_A = o_i[:, None] >= o_i[None, :]
b_A = tl.where(m_A, b_A, 0)
p_v = tl.make_block_ptr(
v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_o = tl.make_block_ptr(
o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_v, boundary_check=(0, 1))
b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
def chunk_fwd_o(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
h: torch.Tensor,
g: torch.Tensor | None = None,
g_gamma: torch.Tensor | None = None,
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
chunk_indices: torch.LongTensor | None = None,
use_exp2: bool = False,
transpose_state_layout: bool = False,
kernel_cfg: dict | None = None,
) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]
H = v.shape[-2]
BT = chunk_size
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
if scale is None:
scale = k.shape[-1] ** -0.5
o = torch.empty_like(v)
launch_chunk_fwd_kernel_o(
q=q,
k=k,
v=v,
h=h,
g=g,
o=o,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
scale=scale,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
NT=NT,
B=B,
kernel_cfg=kernel_cfg,
)
return o
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment