Commit ec5e299c authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.7.3' into v0.7.3-dev

parents 47bd229c ed6e9075
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 32,
"kpack": 2
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"8": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"4096": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
}
......@@ -128,7 +128,7 @@
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 32,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"512": {
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"16": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"1536": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"4096": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 32,
"kpack": 2
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"4096": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"16": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
}
......@@ -609,7 +609,7 @@ def moe_align_block_size(
dtype=torch.int32,
device=topk_ids.device)
if num_experts >= 224:
if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON:
if envs.VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON or num_experts != 256:
moe_align_block_size_triton(
topk_ids,
num_experts,
......@@ -619,6 +619,7 @@ def moe_align_block_size(
num_tokens_post_pad,
)
else:
# Currently requires num_experts=256
ops.sgl_moe_align_block_size(
topk_ids,
num_experts,
......@@ -1023,15 +1024,17 @@ def grouped_topk(hidden_states: torch.Tensor,
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
num_token = scores.shape[0]
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
num_token = scores.shape[0]
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group]
group_scores = (scores.view(num_token, num_expert_group,
-1).topk(2, dim=-1)[0].sum(dim=-1))
else:
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1,
sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
......@@ -1039,7 +1042,8 @@ def grouped_topk(hidden_states: torch.Tensor,
score_mask = group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(),
float("-inf")) # [n, e]
if e_score_correction_bias is not None:
topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)[1]
......
......@@ -360,8 +360,8 @@ class FusedMoE(torch.nn.Module):
"use_nn_moe": self.use_nn_moe,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__ ==
"CompressedTensorsWNA16MoEMethod"):
if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
......
......@@ -309,29 +309,30 @@ class ColumnParallelLinear(LinearBase):
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[list[int]] = None,
prefix: str = ""):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
assert self.quant_method is not None
self.output_size_per_partition = divide(self.output_size, tp_size)
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, tp_size)
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)
self.gather_output = gather_output
if output_sizes is None:
output_sizes = [output_size]
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
......@@ -354,6 +355,12 @@ class ColumnParallelLinear(LinearBase):
tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
......@@ -362,13 +369,12 @@ class ColumnParallelLinear(LinearBase):
# Materialize GGUF UninitializedParameter
if is_gguf_weight and isinstance(param, UninitializedParameter):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
final_shape = list(loaded_weight.shape)
if output_dim is not None:
tp_size = get_tensor_model_parallel_world_size()
assert final_shape[output_dim] % tp_size == 0
final_shape[output_dim] = final_shape[output_dim] // tp_size
param.materialize(final_shape, dtype=loaded_weight.dtype)
param_data = param.data
if output_dim is not None and not is_sharded_weight:
......@@ -1058,22 +1064,24 @@ class RowParallelLinear(LinearBase):
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
# Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
# Divide the weight matrix along the last dimension.
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=[self.output_size],
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
......
......@@ -51,7 +51,6 @@ class LogitsProcessor(nn.Module):
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.
parallel_config = get_current_vllm_config().parallel_config
self.use_all_gather = current_platform.is_tpu() \
or envs.VLLM_USE_V1 \
......@@ -88,17 +87,8 @@ class LogitsProcessor(nn.Module):
return logits
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = lm_head.linear_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
def _gather_logits(self, logits: torch.Tensor) -> torch.Tensor:
"""gather/all-gather the logits tensor across model parallel group."""
if self.use_all_gather:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
......@@ -109,6 +99,22 @@ class LogitsProcessor(nn.Module):
else:
# None may be returned for rank > 0
logits = tensor_model_parallel_gather(logits)
return logits
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = lm_head.quant_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
# Gather logits for TP
logits = self._gather_logits(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[..., :self.org_vocab_size]
......
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple, Union
import torch
from torch import nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata)
from vllm.attention.backends.xformers import XFormersMetadata
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_state_update)
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction, composed_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.utils import set_weight_attrs
# Added by the IBM Team, 2024
# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated
@CustomOp.register("mixer2_gated_rms_norm")
class Mixer2RMSNormGated(CustomOp):
def __init__(self, full_hidden_size, full_n_groups, eps=1e-6):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.full_hidden_size = full_hidden_size
self.group_size = full_hidden_size // full_n_groups
self.per_rank_hidden_size = full_hidden_size // self.tp_size
self.n_groups = full_hidden_size // self.group_size
self.variance_epsilon = eps
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
set_weight_attrs(self.weight,
{"weight_loader": sharded_weight_loader(0)})
assert self.full_hidden_size % self.tp_size== 0,\
"Tensor parallel world size must divide hidden size."
def forward_native(
self,
x: torch.Tensor,
gate: torch.Tensor,
):
# Three tensor-parallel cases:
# 1. n_groups is 1
# In this case we parallelize along the reduction dim.
# Each rank computes a local sum of squares followed by AllReduce
# 2. tp_size divides n_groups
# Each rank only reduces within its local group(s).
# No collective ops necessary.
# 3. The general case can be pretty complicated so we AllGather
# the input and then redundantly compute the RMSNorm.
input_dtype = x.dtype
x = x * nn.functional.silu(gate.to(torch.float32))
if self.n_groups == 1:
if self.tp_size > 1:
# Compute local sum and then reduce to obtain global sum
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
global_sums = tensor_model_parallel_all_reduce(local_sums)
# Calculate the variance
count = self.tp_size * x.shape[-1]
variance = (global_sums / count)
else:
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
else:
redundant_tp: bool = self.n_groups % self.tp_size != 0
if redundant_tp:
# To handle the general case, redundantly apply the variance
x = tensor_model_parallel_all_gather(x, -1)
*prefix_dims, hidden_dim = x.shape
group_count = hidden_dim // self.group_size
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
variance = x_grouped.pow(2).mean(-1, keepdim=True)
x_grouped = x_grouped * torch.rsqrt(variance +
self.variance_epsilon)
x = x_grouped.view(*prefix_dims, hidden_dim)
if redundant_tp:
start = self.per_rank_hidden_size * self.tp_rank
end = start + self.per_rank_hidden_size
x = x[..., start:end]
return self.weight * x.to(input_dtype)
def forward_cuda(
self,
x: torch.Tensor,
gate: torch.Tensor,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if self.tp_size > 1 or self.n_groups != 1:
return self.forward_native(x, gate)
from vllm import _custom_ops as ops
# cast x and gate to float32 before silu
out = torch.empty_like(x)
y = x * nn.functional.silu(gate.to(torch.float32))
ops.rms_norm(
out,
y.to(x.dtype),
self.weight.data,
self.variance_epsilon,
)
return out
def extra_groups_for_head_shards(ngroups: int, tp_size: int):
"""Compute the increase in group numbers to account for
replication in order to accompany the head shards."""
# in the case ngoups % tp_size == 0, this will be zero
if ngroups % tp_size == 0:
return 0
return tp_size - ngroups % tp_size
def mamba_v2_sharded_weight_loader(
shard_spec: List[Tuple[int, int, float]],
tp_size: int,
tp_rank: int,
) -> LoaderFunction:
"""Create a weight loader for mamba v2. This ensures that the projections
are correctly sharded so that they can be split into x, B, C. It also
ensures the the all the groups corresponding to a head shard is placed
together with it.
"""
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
# - track boundary of (sharded) param, and loaded_weight, respectively
boundary, loaded_boundary = 0, 0
# - iterate over the shard specs
for full_dim, extra, ratio in shard_spec:
# - full dim is the model dim (before TP).
# - extra > 0, means there is expected overall increase
# of dimensions. This is so because of replication.
# - ratio is used map the tp_rank to the actual shard
# rank. This is useful when there is replication of
# groups to accompany head shards.
# - size of the loaded shard
shard_size = full_dim // tp_size
# - compute the rank into the loaded shard.
# - if there is replication, different TP shards will
# take from the same rank.
rank = tp_rank // ratio
# - leftmost boundary index into loaded weight.
loaded_skip = rank * shard_size
loaded_start_idx = loaded_boundary + loaded_skip
# - take these many dims from the loaded weight.
take = min(shard_size, full_dim - extra - loaded_skip)
# - always shard on dim 0
# - the ignore is for a mundane mypy error as it does not
# seem to handle slices well.
# https://github.com/python/mypy/issues/2410
param.data[
boundary:(boundary + take), # type: ignore[misc]
...] = loaded_weight[loaded_start_idx:( # type: ignore[misc]
loaded_start_idx + take)] # type: ignore[misc]
# move indexing boundaries
boundary += shard_size
loaded_boundary += (full_dim - extra)
return loader
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
@CustomOp.register("mamba_mixer2")
class MambaMixer2(CustomOp):
"""
Compute ∆, A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
for why A isn't selective) ∆, B, C are input-dependent
(this is a key difference between Mamba and the linear time
invariant S4, and is why Mamba is called
**selective** state spaces)
"""
def __init__(self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation="silu",
chunk_size: int = 256,
quant_config: Optional[QuantizationConfig] = None):
super().__init__()
# For TP, the sharding plan is as follows:
# - for the conv modules, since
# conv_dim = intermediate_size * 2 * n_groups * ssm_state_size,
# we shard intermediate_size and n_groups
# - since intermediate_size = n_heads * head_dim, sharding on
# intermediate_size is achieved by sharding on n_heads.
# - IF, world_size divides groups, then sharding
# (n_groups / world_size, n_heads / world_size)
# also maintains the invariant n_heads % n_groups == 0
# - HOWEVER IF, world_size DOES NOT divide groups, then we need
# to allocate extra space in the shard, such that groups
# may be replicated to follow the head shard.
self.tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
assert num_heads % self.tp_size == 0, \
"Tensor parallel world size must divide num heads."
self.ssm_state_size = ssm_state_size
self.activation = activation
self.chunk_size = chunk_size
self.intermediate_size = intermediate_size
self.head_dim = head_dim
self.num_heads = num_heads
self.n_groups = n_groups
if n_groups % self.tp_size != 0:
# - for TP we shard conv_dim by sharding on n_groups,
# - but if n_groups cannot divide tp_size, we need to
# extend some extra groups
self.n_groups = n_groups + extra_groups_for_head_shards(
n_groups, self.tp_size)
self.conv_dim = (intermediate_size +
2 * self.n_groups * ssm_state_size)
self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size,
output_size=self.conv_dim,
bias=use_conv_bias,
quant_config=None,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
self.in_proj = ColumnParallelLinear(input_size=hidden_size,
output_size=intermediate_size +
self.conv_dim + self.num_heads,
bias=use_bias,
quant_config=quant_config)
# - because in_proj is a concatenation of 3 weights, we
# need to interleave them before sharding
# - use the custom weight loader mamba_v2_sharded_weight_loader
# for conv1d.bias, covn1d.weight and in_proj.weight
# - need to set these settings, to assign the groups to the head shards
group_shard_settings = (
self.n_groups * self.ssm_state_size, # expected model size
(self.n_groups - n_groups) *
self.ssm_state_size, # extra dims assigned
self.num_heads //
n_groups, # ratio for mapping back to original group
)
intermediate_settings = (intermediate_size, 0, 1)
head_setings = (self.num_heads, 0, 1)
# - the weight already has a "weight_loader" attribute
# which set_weight_attrs will raise if we do not
# delete before trying to override it
# - ditto for the otther two weights below
delattr(self.conv1d.bias, "weight_loader")
set_weight_attrs(
self.conv1d.bias, {
"weight_loader":
mamba_v2_sharded_weight_loader(
[
intermediate_settings,
group_shard_settings,
group_shard_settings,
],
self.tp_size,
tp_rank,
)
})
delattr(self.conv1d.weight, "weight_loader")
set_weight_attrs(
self.conv1d.weight, {
"weight_loader":
mamba_v2_sharded_weight_loader([
intermediate_settings,
group_shard_settings,
group_shard_settings,
], self.tp_size, tp_rank)
})
delattr(self.in_proj.weight, "weight_loader")
set_weight_attrs(
self.in_proj.weight,
{
"weight_loader":
mamba_v2_sharded_weight_loader(
[
intermediate_settings, # for gate
intermediate_settings,
group_shard_settings,
group_shard_settings,
head_setings, # for dt
],
self.tp_size,
tp_rank)
})
# - these are TPed by heads to reduce the size of the
# temporal shape
self.A = nn.Parameter(
torch.empty(
divide(num_heads, self.tp_size),
dtype=torch.float32,
))
self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
a_weight_loader = composed_weight_loader(
sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
set_weight_attrs(self.dt_bias,
{"weight_loader": sharded_weight_loader(0)})
self.out_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=use_bias,
input_is_parallel=True,
quant_config=quant_config)
self.norm = Mixer2RMSNormGated(intermediate_size,
n_groups,
eps=rms_norm_eps)
def forward_native(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
conv_state: torch.Tensor, ssm_state: torch.Tensor):
pass
def forward_cuda(
self,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams,
sequence_idx: Optional[torch.Tensor] = None,
):
seq_len, _ = hidden_states.shape
groups_time_state_size = self.n_groups * self.ssm_state_size
# detect if there are prefills
has_prefill = attn_metadata.num_prefills > 0
# - also need flags to indicate if there are initial states
# - currently we really only support the FlashAttention backend
has_initial_states = None
if (isinstance(attn_metadata,
(FlashAttentionMetadata, XFormersMetadata,
PlaceholderAttentionMetadata))
and attn_metadata.context_lens_tensor is not None):
has_initial_states = attn_metadata.context_lens_tensor > 0
# 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states)
gate, hidden_states_B_C, dt = torch.split(
projected_states,
[
self.intermediate_size // self.tp_size,
self.conv_dim // self.tp_size,
self.num_heads // self.tp_size,
],
dim=-1,
)
# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))
if has_prefill:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "mamba_cache_params.state_indices_tensor"
hidden_states_B_C = causal_conv1d_fn(
hidden_states_B_C.transpose(0, 1),
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=has_initial_states,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc).transpose(
0, 1)[:seq_len]
# TODO: Why is this needed?
hidden_states_B_C = hidden_states_B_C.contiguous()
else:
hidden_states_B_C = causal_conv1d_update(
hidden_states_B_C,
mamba_cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=mamba_cache_params.state_indices_tensor)
# - get hidden_states, B and C after depthwise convolution.
hidden_states, B, C = torch.split(
hidden_states_B_C,
[
self.intermediate_size // self.tp_size,
groups_time_state_size // self.tp_size,
groups_time_state_size // self.tp_size,
],
dim=-1,
)
# 3. State Space Model sequence transformation
if has_prefill:
initial_states = None
if has_initial_states is not None and any(has_initial_states):
for idx in mamba_cache_params.state_indices_tensor[
~has_initial_states]:
mamba_cache_params.ssm_state[idx].zero_()
initial_states = mamba_cache_params.ssm_state[
mamba_cache_params.state_indices_tensor]
scan_output, varlen_state = mamba_chunk_scan_combined(
hidden_states.view(1, seq_len, self.num_heads // self.tp_size,
self.head_dim),
dt.unsqueeze(0),
self.A,
B.view(1, seq_len, self.n_groups // self.tp_size, -1),
C.view(1, seq_len, self.n_groups // self.tp_size, -1),
chunk_size=self.chunk_size,
D=self.D,
z=None,
dt_bias=self.dt_bias,
seq_idx=sequence_idx,
cu_seqlens=attn_metadata.query_start_loc,
initial_states=initial_states,
return_varlen_states=True,
return_final_states=False,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
)
# update ssm states
# - varlen state is a (batch, nheads, headdim, dstate) tensor
for i, idx in enumerate(mamba_cache_params.state_indices_tensor):
mamba_cache_params.ssm_state[idx].copy_(varlen_state[i])
# - reshape
hidden_states = scan_output.view(seq_len, -1)
else:
n_groups = self.n_groups // self.tp_size
A = self.A[:, None, ...][:, :, None].expand(
-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
D = self.D[:, None, ...].expand(-1, self.head_dim)
B = B.view(-1, n_groups, B.shape[1] // n_groups)
C = C.view(-1, n_groups, C.shape[1] // n_groups)
hidden_states_reshaped = hidden_states.view(
-1, self.num_heads // self.tp_size, self.head_dim)
# - the hidden is reshaped into number of current batches
# - in this case there is no more prefill, so the batches gen
# 1 token at a time
# - thus hidden will be (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected
# using "mamba_cache_params.state_indices_tensor", just as
# above in the prefill case
hidden_states = selective_state_update(
mamba_cache_params.ssm_state,
hidden_states_reshaped,
dt,
A,
B,
C,
D,
z=None,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=mamba_cache_params.state_indices_tensor,
)
hidden_states = hidden_states.view(
-1, (self.num_heads // self.tp_size) * self.head_dim)
# # 4. gated MLP
hidden_states = self.norm(hidden_states, gate)
# # 5. Final linear projection
out, _ = self.out_proj(hidden_states)
return out
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
import torch
import triton
......
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py
# ruff: noqa: E501,SIM102
import math
import torch
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 64
},
num_stages=3,
num_warps=8),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32
},
num_stages=5,
num_warps=2),
triton.Config(
{
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
num_stages=5,
num_warps=2),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=2),
],
key=['chunk_size', 'K', 'IS_CAUSAL'],
)
@triton.jit
def _bmm_chunk_fwd_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
out_ptr,
seq_idx_ptr,
# Matrix dimensions
seqlen,
chunk_size,
K,
ngroups,
stride_a_batch,
stride_a_seqlen,
stride_a_head,
stride_ak,
stride_b_batch,
stride_b_seqlen,
stride_b_head,
stride_bk,
stride_out_batch,
stride_out_chunk,
stride_out_head,
stride_outm,
stride_outn,
stride_seq_idx_batch,
stride_seq_idx_seqlen,
# Meta-parameters
IS_CAUSAL: tl.constexpr,
dot_dtype: tl.constexpr,
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
pid_b = tl.program_id(axis=1)
pid_ch = tl.program_id(axis=2).to(tl.int64)
pid_c = pid_ch // ngroups
pid_h = pid_ch - pid_c * ngroups
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
if IS_CAUSAL:
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
return
a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen +
offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk +
offs_n[None, :] * stride_b_seqlen)
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs,
mask=(offs_m[:, None] < chunk_size_limit) &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0).to(dot_dtype)
b = tl.load(b_ptrs,
mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) &
(offs_n[None, :] < chunk_size_limit),
other=0.0).to(dot_dtype)
acc += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if HAS_SEQ_IDX:
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
mask=offs_m < chunk_size_limit,
other=-1)
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen,
mask=offs_n < chunk_size_limit,
other=-2)
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
out = acc.to(out_ptr.dtype.element_ty)
out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] +
offs_n[None, :] * stride_outn)
tl.store(out_ptrs,
out,
mask=(offs_m[:, None] < chunk_size) &
(offs_n[None, :] < chunk_size))
def _bmm_chunk_fwd(a,
b,
chunk_size,
seq_idx=None,
causal=False,
output_dtype=None):
"""
Argument:
a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
b: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
guaranteed to be correct.
Return:
out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
"""
# Check constraints.
has_groups = a.dim() == 4
if not has_groups:
batch, seqlen, k = a.shape
else:
batch, seqlen, ngroups, k = a.shape
assert b.shape == a.shape
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
if a.stride(-1) != 1 and a.stride(1) != 1:
a = a.contiguous()
if b.stride(-1) != 1 and b.stride(1) != 1:
b = b.contiguous()
nchunks = math.ceil(seqlen / chunk_size)
# Allocates output.
out_dtype = a.dtype if output_dtype is None else output_dtype
out = torch.empty(
(batch, nchunks, chunk_size, chunk_size) if not has_groups else
(batch, nchunks, ngroups, chunk_size, chunk_size),
device=a.device,
dtype=out_dtype)
dot_dtype = (tl.bfloat16
if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
(tl.float16 if a.dtype == torch.float16
or b.dtype == torch.float16 else tl.float32))
grid = lambda META: (triton.cdiv(
chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
chunk_size, META['BLOCK_SIZE_N']), batch, nchunks
if not has_groups else nchunks * ngroups)
with torch.cuda.device(a.device.index):
_bmm_chunk_fwd_kernel[grid](
a,
b,
out,
seq_idx,
seqlen,
chunk_size,
k,
ngroups if has_groups else 1,
a.stride(0),
a.stride(1),
0 if not has_groups else a.stride(2),
a.stride(-1),
b.stride(0),
b.stride(1),
0 if not has_groups else b.stride(2),
b.stride(-1),
out.stride(0),
out.stride(1),
0 if not has_groups else out.stride(2),
out.stride(-2),
out.stride(-1),
*((seq_idx.stride(0),
seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
causal,
dot_dtype,
HAS_SEQ_IDX=seq_idx is not None,
)
return out
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py
# ruff: noqa: E501,SIM102
import math
import torch
import triton
import triton.language as tl
from packaging import version
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
@triton.autotune(
configs=[
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 64
},
num_stages=3,
num_warps=8),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 64
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 64
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32
},
num_stages=5,
num_warps=2),
triton.Config(
{
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
num_stages=5,
num_warps=2),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=2),
],
key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'],
)
@triton.jit
def _chunk_scan_fwd_kernel(
# Pointers to matrices
cb_ptr,
x_ptr,
z_ptr,
out_ptr,
out_x_ptr,
dt_ptr,
dA_cumsum_ptr,
seq_idx_ptr,
C_ptr,
states_ptr,
D_ptr,
initstates_ptr,
chunk_indices_ptr,
chunk_offsets_ptr,
chunk_meta_num,
# Matrix dimensions
chunk_size,
hdim,
dstate,
batch,
seqlen,
nheads_ngroups_ratio,
# Strides
stride_cb_batch,
stride_cb_chunk,
stride_cb_head,
stride_cb_csize_m,
stride_cb_csize_k,
stride_x_batch,
stride_x_seqlen,
stride_x_head,
stride_x_hdim,
stride_z_batch,
stride_z_seqlen,
stride_z_head,
stride_z_hdim,
stride_out_batch,
stride_out_seqlen,
stride_out_head,
stride_out_hdim,
stride_dt_batch,
stride_dt_chunk,
stride_dt_head,
stride_dt_csize,
stride_dA_cs_batch,
stride_dA_cs_chunk,
stride_dA_cs_head,
stride_dA_cs_csize,
stride_seq_idx_batch,
stride_seq_idx_seqlen,
stride_C_batch,
stride_C_seqlen,
stride_C_head,
stride_C_dstate,
stride_states_batch,
stride_states_chunk,
stride_states_head,
stride_states_hdim,
stride_states_dstate,
stride_init_states_batch,
stride_init_states_head,
stride_init_states_hdim,
stride_init_states_dstate,
stride_D_head,
# Meta-parameters
IS_CAUSAL: tl.constexpr,
HAS_D: tl.constexpr,
D_HAS_HDIM: tl.constexpr,
HAS_Z: tl.constexpr,
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
IS_TRITON_22: tl.constexpr,
HAS_INITSTATES: tl.constexpr,
):
pid_bc = tl.program_id(axis=1).to(tl.int64)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
if not HAS_INITSTATES:
c_idx = pid_c
c_off = 0
else:
c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0)
c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0)
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
cb_ptr += pid_b * stride_cb_batch + c_idx * stride_cb_chunk + (
pid_h // nheads_ngroups_ratio) * stride_cb_head
x_ptr += pid_b * stride_x_batch + c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
C_ptr += pid_b * stride_C_batch + c_idx * chunk_size * stride_C_seqlen + (
pid_h // nheads_ngroups_ratio) * stride_C_head
# M-block offsets and prev states
# - logic in next block may override these if there is an active offset
offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
prev_states_ptr = states_ptr + pid_b * stride_states_batch + c_idx * stride_states_chunk + pid_h * stride_states_head
prev_states_hdim = stride_states_hdim
prev_states_dstate = stride_states_dstate
chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size)
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen
# - we only need seq_idx_prev to be aligned to chunk boundary
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen,
mask=c_idx >= 1,
other=0)
if HAS_INITSTATES:
# if there are init states, we only need seq_idx_m to point
# what is the current seq_idx
# get current seq idx
if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit:
seq_idx_m = tl.load(
seq_idx_ptr +
(pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, )
# - recall that in ssd_state_passing, for the case c_off == 0
# i.e., the very first sequence, we made states_ptr hold its initial state
# so this edge case is taken care of
if ((c_off == 0) and
(seq_idx_prev != seq_idx_m
) # if a seq is changed exactly on boundary
or (c_off > 0) # implies a new example (pseudo chunk)
):
# - replace prev_states_ptr with init_states
prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head
prev_states_hdim = stride_init_states_hdim # override strides
prev_states_dstate = stride_init_states_dstate
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
mask=offs_m < chunk_size,
other=0.0).to(tl.float32)
# - handle chunk state limit
if HAS_INITSTATES:
# have to split this if otherwise compilation will have problems
dA_cs_m_boundary = 0.0
# get the c_idx for the next (logica) chunk
c_idx_n = tl.load(
chunk_indices_ptr + (pid_c + 1),
mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
other=-1 # to trigger different chunk
)
# - there are things to consider
# A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct
# contribution of past states
# B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to
# encroach into the next sequence, where c_off_n is the offset of the next
# (logical) chunk.
# An equivalent check for B is c_idx == c_idx_n, where there is repetition in
# (logical) chunk indices.
if (c_idx == c_idx_n) or c_off > 0:
# get the next offset
c_off_n = tl.load(chunk_offsets_ptr + (pid_c + 1),
mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
other=chunk_size)
# in this case, adjust down the chunk_size_limit
if c_idx == c_idx_n:
chunk_size_limit = min(c_off_n, chunk_size_limit)
# get the cs at the offset boundary
# - c_off == 0 is a passthrough
dA_cs_m_boundary = tl.load(
dA_cumsum_ptr +
(pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize,
mask=(((pid_m * BLOCK_SIZE_M + c_off - 1) > -1)
and ((pid_m * BLOCK_SIZE_M + c_off) < chunk_size)),
other=0.0).to(tl.float32)
if HAS_SEQ_IDX:
# - handle seq idx when HAS_INITSTATES==False
if not HAS_INITSTATES:
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
mask=offs_m < chunk_size_limit,
other=-1)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# Without the if (pid_c > -1), with Triton 2.1.0, I get
# Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed.
# With Triton 2.2.0, this works
if IS_TRITON_22 or c_idx > -1:
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
offs_k_dstate = tl.arange(
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen +
offs_k_dstate[None, :] * stride_C_dstate)
prev_states_ptrs = prev_states_ptr + (
offs_n[None, :] * prev_states_hdim +
offs_k_dstate[:, None] * prev_states_dstate)
if HAS_SEQ_IDX:
if not HAS_INITSTATES:
# - this is for continuous batching where there is no init states
scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m),
0.0)
else:
# - if there is initstates, we will rely on prev_states, no zeroing
# required.
scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary)
else:
scale_m = tl.exp(dA_cs_m)
if BLOCK_SIZE_DSTATE <= 128:
C = tl.load(C_ptrs,
mask=(offs_m[:, None] < chunk_size_limit) &
(offs_k_dstate[None, :] < dstate),
other=0.0)
prev_states = tl.load(prev_states_ptrs,
mask=(offs_k_dstate[:, None] < dstate) &
(offs_n[None, :] < hdim),
other=0.0)
prev_states = prev_states.to(C_ptr.dtype.element_ty)
acc = tl.dot(C, prev_states) * scale_m[:, None]
else:
for k in range(0, dstate, BLOCK_SIZE_K):
C = tl.load(C_ptrs,
mask=(offs_m[:, None] < chunk_size_limit) &
(offs_k_dstate[None, :] < dstate - k),
other=0.0)
# C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty)
prev_states = tl.load(
prev_states_ptrs,
mask=(offs_k_dstate[:, None] < dstate - k) &
(offs_n[None, :] < hdim),
other=0.0)
prev_states = prev_states.to(C_ptr.dtype.element_ty)
acc += tl.dot(C, prev_states)
C_ptrs += BLOCK_SIZE_K
prev_states_ptrs += BLOCK_SIZE_K
acc *= scale_m[:, None]
offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off
cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m +
offs_k[None, :] * stride_cb_csize_k)
x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen +
offs_n[None, :] * stride_x_hdim)
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
K_MAX = chunk_size_limit if not IS_CAUSAL else min(
(pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
for k in range(0, K_MAX, BLOCK_SIZE_K):
cb = tl.load(cb_ptrs,
mask=(offs_m[:, None] < chunk_size) &
(offs_k[None, :] < chunk_size - k),
other=0.0).to(tl.float32)
dA_cs_k = tl.load(dA_cumsum_ptrs,
mask=offs_k < chunk_size - k,
other=0.0).to(tl.float32)
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
# So we don't need masking wrt seq_idx here.
cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :])
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k,
other=0.0).to(tl.float32)
cb *= dt_k
if IS_CAUSAL:
mask = offs_m[:, None] >= k + offs_k[None, :]
cb = tl.where(mask, cb, 0.0)
cb = cb.to(x_ptr.dtype.element_ty)
x = tl.load(x_ptrs,
mask=(offs_k[:, None] < chunk_size_limit - k) &
(offs_n[None, :] < hdim),
other=0.0)
acc += tl.dot(cb, x)
cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if HAS_D:
if D_HAS_HDIM:
D = tl.load(D_ptr + pid_h * stride_D_head + offs_n,
mask=offs_n < hdim,
other=0.0).to(tl.float32)
else:
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen +
offs_n[None, :] * stride_x_hdim),
mask=(offs_m[:, None] < chunk_size_limit) &
(offs_n[None, :] < hdim),
other=0.0).to(tl.float32)
acc += x_residual * D
if HAS_Z:
out_x_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head
out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] +
offs_out_n[None, :])
tl.store(out_x_ptrs,
acc,
mask=(offs_out_m[:, None] < chunk_size_limit) &
(offs_out_n[None, :] < hdim))
z_ptr += pid_b * stride_z_batch + c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head
z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] +
stride_z_hdim * offs_out_n[None, :])
z = tl.load(z_ptrs,
mask=(offs_out_m[:, None] < chunk_size_limit) &
(offs_out_n[None, :] < hdim),
other=0.0).to(tl.float32)
acc *= z * tl.sigmoid(z)
out_ptr += pid_b * stride_out_batch + c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head
out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] +
offs_out_n[None, :] * stride_out_hdim)
tl.store(out_ptrs,
acc,
mask=(offs_out_m[:, None] < chunk_size_limit) &
(offs_out_n[None, :] < hdim))
def _seq_idx_to_chunk_indices_offsets(seq_idx, chunk_size: int):
# convert seq_idx to chunk indices and offsets
# - derive the cu_seqlens
_, cu_seqlens = torch.where(seq_idx.diff())
cu_seqlens += 1
# outputs will have length expansion of chunks that do not divide
# chunk_size
N = math.ceil(seq_idx.shape[-1] / chunk_size) + (cu_seqlens % chunk_size
> 0).sum()
chunk_indices = torch.arange(N, dtype=torch.int, device=seq_idx.device)
chunk_offsets = torch.zeros((N, ), dtype=torch.int, device=seq_idx.device)
cu_seqlens = cu_seqlens.tolist() + [seq_idx.shape[-1]]
p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
# if does not divide chunk_size, then there is one chunk insertion
p += (s % chunk_size > 0)
# get the dimensions
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size
> 0)
# adjust inidces and offsets
chunk_indices[_s:_e] -= p
chunk_offsets[_s] = s % chunk_size
return chunk_indices, chunk_offsets
def _chunk_scan_fwd(
cb,
x,
dt,
dA_cumsum,
C,
states,
D=None,
z=None,
seq_idx=None,
initial_states=None,
):
batch, seqlen, nheads, headdim = x.shape
_, _, nchunks, chunk_size = dt.shape
_, _, ngroups, dstate = C.shape
assert nheads % ngroups == 0
assert C.shape == (batch, seqlen, ngroups, dstate)
assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
if z is not None:
assert z.shape == x.shape
if D is not None:
assert D.shape == (nheads, headdim) or D.shape == (nheads, )
assert dt.shape == (batch, nheads, nchunks, chunk_size)
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
chunk_indices, chunk_offsets = None, None
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
if initial_states is not None:
# with initial states, we need to take care of how
# seq_idx crosses the boundaries
assert batch == 1, "chunk scan only supports initial states with batch 1"
assert initial_states.shape == (seq_idx[0].max() + 1, nheads,
headdim, dstate)
if initial_states.shape[0] == 1:
# no in this case no point to use initial states
initial_states = None
else:
chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets(
seq_idx, chunk_size)
# Allocates output.
out = torch.empty(batch,
seqlen,
nheads,
headdim,
device=x.device,
dtype=x.dtype)
if z is not None:
out_x = torch.empty(batch,
seqlen,
nheads,
headdim,
device=x.device,
dtype=x.dtype)
assert out_x.stride() == out.stride()
else:
out_x = None
grid = lambda META: (
triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
headdim, META['BLOCK_SIZE_N']), batch * nchunks
if chunk_offsets is None else len(chunk_offsets), nheads)
z_strides = ((z.stride(0), z.stride(1), z.stride(2),
z.stride(3)) if z is not None else (0, 0, 0, 0))
_chunk_scan_fwd_kernel[grid](
cb,
x,
z,
out,
out_x,
dt,
dA_cumsum,
seq_idx,
C,
states,
D,
initial_states,
chunk_indices,
chunk_offsets,
len(chunk_indices) if chunk_indices is not None else 0,
chunk_size,
headdim,
dstate,
batch,
seqlen,
nheads // ngroups,
cb.stride(0),
cb.stride(1),
cb.stride(2),
cb.stride(3),
cb.stride(4),
x.stride(0),
x.stride(1),
x.stride(2),
x.stride(3),
z_strides[0],
z_strides[1],
z_strides[2],
z_strides[3],
out.stride(0),
out.stride(1),
out.stride(2),
out.stride(3),
dt.stride(0),
dt.stride(2),
dt.stride(1),
dt.stride(3),
dA_cumsum.stride(0),
dA_cumsum.stride(2),
dA_cumsum.stride(1),
dA_cumsum.stride(3),
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else
(0, 0)),
C.stride(0),
C.stride(1),
C.stride(2),
C.stride(3),
states.stride(0),
states.stride(1),
states.stride(2),
states.stride(3),
states.stride(4),
*((initial_states.stride(0), initial_states.stride(1),
initial_states.stride(2),
initial_states.stride(3)) if initial_states is not None else
(0, 0, 0, 0)),
D.stride(0) if D is not None else 0,
True,
D is not None,
D.dim() == 2 if D is not None else True,
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
HAS_Z=z is not None,
HAS_SEQ_IDX=seq_idx is not None,
IS_TRITON_22=TRITON_22,
HAS_INITSTATES=initial_states is not None,
)
return out, out_x
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_state.py
# ruff: noqa: E501
import math
import torch
import triton
import triton.language as tl
from .mamba_ssm import softplus
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_H': 1}),
triton.Config({'BLOCK_SIZE_H': 2}),
triton.Config({'BLOCK_SIZE_H': 4}),
triton.Config({'BLOCK_SIZE_H': 8}),
triton.Config({'BLOCK_SIZE_H': 16}),
triton.Config({'BLOCK_SIZE_H': 32}),
triton.Config({'BLOCK_SIZE_H': 64}),
],
key=['chunk_size', 'nheads'],
)
@triton.jit
def _chunk_cumsum_fwd_kernel(
# Pointers to matrices
dt_ptr,
A_ptr,
dt_bias_ptr,
dt_out_ptr,
dA_cumsum_ptr,
# Matrix dimension
batch,
seqlen,
nheads,
chunk_size,
dt_min,
dt_max,
# Strides
stride_dt_batch,
stride_dt_seqlen,
stride_dt_head,
stride_A_head,
stride_dt_bias_head,
stride_dt_out_batch,
stride_dt_out_chunk,
stride_dt_out_head,
stride_dt_out_csize,
stride_dA_cs_batch,
stride_dA_cs_chunk,
stride_dA_cs_head,
stride_dA_cs_csize,
# Meta-parameters
DT_SOFTPLUS: tl.constexpr,
HAS_DT_BIAS: tl.constexpr,
BLOCK_SIZE_H: tl.constexpr,
BLOCK_SIZE_CHUNK: tl.constexpr,
):
pid_b = tl.program_id(axis=0)
# if dt is long, may cause problems, so use 64 bit
# https://github.com/triton-lang/triton/issues/1058
pid_c = tl.program_id(axis=1).to(tl.int64)
pid_h = tl.program_id(axis=2)
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head +
offs_c[None, :] * stride_dt_seqlen)
A_ptrs = A_ptr + offs_h * stride_A_head
dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head +
offs_c[None, :] * stride_dt_out_csize)
dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head +
offs_c[None, :] * stride_dA_cs_csize)
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
dt = tl.load(dt_ptrs,
mask=(offs_h[:, None] < nheads) &
(offs_c[None, :] < chunk_size_limit),
other=0.0).to(tl.float32)
if HAS_DT_BIAS:
dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head,
mask=offs_h < nheads,
other=0.0).to(tl.float32)
dt += dt_bias[:, None]
if DT_SOFTPLUS:
dt = tl.where(dt <= 20.0, softplus(dt), dt)
# As of Triton 2.2.0, tl.clamp is not available yet
# dt = tl.clamp(dt, dt_min, dt_max)
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
dt = tl.where(
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt,
0.0)
tl.store(dt_out_ptrs,
dt,
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
dA = dt * A[:, None]
dA_cs = tl.cumsum(dA, axis=1)
tl.store(dA_cs_ptrs,
dA_cs,
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
@triton.autotune(
configs=[
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 64
},
num_stages=3,
num_warps=8),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32
},
num_stages=5,
num_warps=2),
triton.Config(
{
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
num_stages=5,
num_warps=2),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=2),
],
key=['hdim', 'dstate', 'chunk_size'],
)
@triton.jit
def _chunk_state_fwd_kernel(
# Pointers to matrices
x_ptr,
b_ptr,
states_ptr,
dt_ptr,
dA_cumsum_ptr,
seq_idx_ptr,
# Matrix dimensions
hdim,
dstate,
chunk_size,
batch,
seqlen,
nheads_ngroups_ratio,
# Strides
stride_x_batch,
stride_x_seqlen,
stride_x_head,
stride_x_hdim,
stride_b_batch,
stride_b_seqlen,
stride_b_head,
stride_b_dstate,
stride_states_batch,
stride_states_chunk,
stride_states_head,
stride_states_hdim,
stride_states_dstate,
stride_dt_batch,
stride_dt_chunk,
stride_dt_head,
stride_dt_csize,
stride_dA_cs_batch,
stride_dA_cs_chunk,
stride_dA_cs_head,
stride_dA_cs_csize,
stride_seq_idx_batch,
stride_seq_idx_seqlen,
# Meta-parameters
HAS_SEQ_IDX: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
pid_bc = tl.program_id(axis=1).to(tl.int64)
pid_c = pid_bc // batch
pid_b = pid_bc - pid_c * batch
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + (
pid_h // nheads_ngroups_ratio) * stride_b_head
x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim +
offs_k[None, :] * stride_x_seqlen)
b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate +
offs_k[:, None] * stride_b_seqlen)
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
dA_cs_last = tl.load(dA_cumsum_ptr +
(chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
if HAS_SEQ_IDX:
seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
if HAS_SEQ_IDX:
seq_idx_last = tl.load(seq_idx_ptr +
(chunk_size_limit - 1) * stride_seq_idx_seqlen)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
x = tl.load(x_ptrs,
mask=(offs_m[:, None] < hdim) &
(offs_k[None, :] < chunk_size_limit - k),
other=0.0)
b = tl.load(b_ptrs,
mask=(offs_k[:, None] < chunk_size_limit - k) &
(offs_n[None, :] < dstate),
other=0.0).to(tl.float32)
dA_cs_k = tl.load(dA_cumsum_ptrs,
mask=offs_k < chunk_size_limit - k,
other=0.0).to(tl.float32)
if HAS_SEQ_IDX:
seq_idx_k = tl.load(seq_idx_ptrs,
mask=offs_k < chunk_size_limit - k,
other=-1)
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k,
other=0.0).to(tl.float32)
if not HAS_SEQ_IDX:
scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k
else:
scale = tl.where(seq_idx_k == seq_idx_last,
tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0)
b *= scale[:, None]
b = b.to(x_ptr.dtype.element_ty)
acc += tl.dot(x, b)
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
if HAS_SEQ_IDX:
seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
states = acc.to(states_ptr.dtype.element_ty)
states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim +
offs_n[None, :] * stride_states_dstate)
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
tl.store(states_ptrs, states, mask=c_mask)
@triton.autotune(
configs=[
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 64
},
num_stages=3,
num_warps=8),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=4),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 32
},
num_stages=5,
num_warps=2),
triton.Config(
{
'BLOCK_SIZE_M': 32,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
num_stages=5,
num_warps=2),
triton.Config(
{
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32
},
num_stages=4,
num_warps=2),
],
key=['hdim', 'dstate', 'chunk_size'],
)
@triton.jit
def _chunk_state_varlen_kernel(
# Pointers to matrices
x_ptr,
b_ptr,
dt_ptr,
dA_cumsum_ptr,
chunk_states_ptr,
cu_seqlens_ptr,
states_ptr,
initstates_ptr,
# Matrix dimensions
hdim,
dstate,
chunk_size,
seqlen,
nheads_ngroups_ratio,
# Strides
stride_x_seqlen,
stride_x_head,
stride_x_hdim,
stride_b_seqlen,
stride_b_head,
stride_b_dstate,
stride_dt_chunk,
stride_dt_head,
stride_dt_csize,
stride_dA_cs_chunk,
stride_dA_cs_head,
stride_dA_cs_csize,
stride_chunk_states_chunk,
stride_chunk_states_head,
stride_chunk_states_hdim,
stride_chunk_states_dstate,
stride_states_batch,
stride_states_head,
stride_states_hdim,
stride_states_dstate,
stride_init_states_batch,
stride_init_states_head,
stride_init_states_hdim,
stride_init_states_dstate,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
HAS_INITSTATES: tl.constexpr,
):
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
pid_c = (end_idx - 1) // chunk_size
b_ptr += pid_c * chunk_size * stride_b_seqlen + (
pid_h // nheads_ngroups_ratio) * stride_b_head
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
if HAS_INITSTATES:
# if there are init states provided, we differentiate between states (which
# are boundary conditions at a chunk boundary) and initstates (which are boundary
# conditions when a new example in a cont batch starts)
initstates_ptr += pid_h * stride_init_states_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim +
offs_k[None, :] * stride_x_seqlen)
b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate +
offs_k[:, None] * stride_b_seqlen)
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) *
stride_dA_cs_csize).to(tl.float32)
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
chunk_size_limit = end_idx - pid_c * chunk_size
start_idx = tl.load(cu_seqlens_ptr + pid_b)
start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
x = tl.load(x_ptrs,
mask=(offs_m[:, None] < hdim) &
(offs_k[None, :] < chunk_size_limit - k) &
(offs_k[None, :] >= start_idx_cur - k),
other=0.0)
b = tl.load(b_ptrs,
mask=(offs_k[:, None] < chunk_size_limit - k) &
(offs_n[None, :] < dstate) &
(offs_k[:, None] >= start_idx_cur - k),
other=0.0).to(tl.float32)
dA_cs_k = tl.load(dA_cumsum_ptrs,
mask=offs_k < chunk_size_limit - k,
other=0.0).to(tl.float32)
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k,
other=0.0).to(tl.float32)
scale = tl.where(
(offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0)
b *= scale[:, None]
b = b.to(x_ptr.dtype.element_ty)
acc += tl.dot(x, b)
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
# If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
# If HAS_INITSTATES==True need to consider two possiblties
# - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs
# - if state_idx >= pid * chunk_size, then we need to insert initstates
if ((start_idx < pid_c * chunk_size) # first chunk
or (HAS_INITSTATES)):
dA_cs_boundary = 0.0 # default
if not HAS_INITSTATES:
past_states_ptrs = chunk_states_ptr + (
offs_m[:, None] * stride_chunk_states_hdim +
offs_n[None, :] * stride_chunk_states_dstate)
else:
# - this seems repetitve, buts its to help the compiler
if start_idx < pid_c * chunk_size:
past_states_ptrs = chunk_states_ptr + (
offs_m[:, None] * stride_chunk_states_hdim +
offs_n[None, :] * stride_chunk_states_dstate)
else:
past_states_ptrs = initstates_ptr + (
pid_b * stride_init_states_batch +
offs_m[:, None] * stride_init_states_hdim +
offs_n[None, :] * stride_init_states_dstate)
# need to adjust the boundary
if start_idx > pid_c * chunk_size:
dA_cs_boundary = tl.load(dA_cumsum_ptr +
(start_idx - pid_c * chunk_size -
1) * stride_dA_cs_csize).to(
tl.float32)
past_states = tl.load(past_states_ptrs,
mask=(offs_m[:, None] < hdim) &
(offs_n[None, :] < dstate),
other=0.0).to(tl.float32)
scale = tl.exp(dA_cs_last - dA_cs_boundary)
acc += past_states * scale
states = acc.to(states_ptr.dtype.element_ty)
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim +
offs_n[None, :] * stride_states_dstate)
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
tl.store(states_ptrs, states, mask=c_mask)
def _chunk_cumsum_fwd(dt,
A,
chunk_size,
dt_bias=None,
dt_softplus=False,
dt_limit=(0.0, float("inf"))):
batch, seqlen, nheads = dt.shape
assert A.shape == (nheads, )
if dt_bias is not None:
assert dt_bias.shape == (nheads, )
nchunks = math.ceil(seqlen / chunk_size)
dt_out = torch.empty(batch,
nheads,
nchunks,
chunk_size,
device=dt.device,
dtype=torch.float32)
dA_cumsum = torch.empty(batch,
nheads,
nchunks,
chunk_size,
device=dt.device,
dtype=torch.float32)
grid_chunk_cs = lambda META: (batch, nchunks,
triton.cdiv(nheads, META['BLOCK_SIZE_H']))
with torch.cuda.device(dt.device.index):
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
dt,
A,
dt_bias,
dt_out,
dA_cumsum,
batch,
seqlen,
nheads,
chunk_size,
dt_limit[0],
dt_limit[1],
dt.stride(0),
dt.stride(1),
dt.stride(2),
A.stride(0),
dt_bias.stride(0) if dt_bias is not None else 0,
dt_out.stride(0),
dt_out.stride(2),
dt_out.stride(1),
dt_out.stride(3),
dA_cumsum.stride(0),
dA_cumsum.stride(2),
dA_cumsum.stride(1),
dA_cumsum.stride(3),
dt_softplus,
HAS_DT_BIAS=dt_bias is not None,
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
)
return dA_cumsum, dt_out
def _chunk_state_fwd(B,
x,
dt,
dA_cumsum,
seq_idx=None,
states=None,
states_in_fp32=True):
batch, seqlen, nheads, headdim = x.shape
_, _, nchunks, chunk_size = dt.shape
_, _, ngroups, dstate = B.shape
assert nheads % ngroups == 0
assert B.shape == (batch, seqlen, ngroups, dstate)
assert dt.shape == (batch, nheads, nchunks, chunk_size)
assert dA_cumsum.shape == dt.shape
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
if states is not None:
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
else:
states_dtype = torch.float32 if states_in_fp32 else B.dtype
states = torch.empty((batch, nchunks, nheads, headdim, dstate),
device=x.device,
dtype=states_dtype)
grid = lambda META: (
triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(
dstate, META['BLOCK_SIZE_N']), batch * nchunks, nheads)
with torch.cuda.device(x.device.index):
_chunk_state_fwd_kernel[grid](
x,
B,
states,
dt,
dA_cumsum,
seq_idx,
headdim,
dstate,
chunk_size,
batch,
seqlen,
nheads // ngroups,
x.stride(0),
x.stride(1),
x.stride(2),
x.stride(3),
B.stride(0),
B.stride(1),
B.stride(2),
B.stride(-1),
states.stride(0),
states.stride(1),
states.stride(2),
states.stride(3),
states.stride(4),
dt.stride(0),
dt.stride(2),
dt.stride(1),
dt.stride(3),
dA_cumsum.stride(0),
dA_cumsum.stride(2),
dA_cumsum.stride(1),
dA_cumsum.stride(3),
*((seq_idx.stride(0),
seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
HAS_SEQ_IDX=seq_idx is not None,
)
return states
def chunk_state_varlen(B,
x,
dt,
dA_cumsum,
cu_seqlens,
chunk_states,
initial_states=None):
total_seqlen, nheads, headdim = x.shape
_, nchunks, chunk_size = dt.shape
_, ngroups, dstate = B.shape
batch = cu_seqlens.shape[0] - 1
cu_seqlens = cu_seqlens.contiguous()
assert nheads % ngroups == 0
assert B.shape == (total_seqlen, ngroups, dstate)
assert dt.shape == (nheads, nchunks, chunk_size)
assert dA_cumsum.shape == dt.shape
assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
if initial_states is not None:
assert initial_states.shape == (batch, nheads, headdim, dstate)
states = torch.empty(batch,
nheads,
headdim,
dstate,
dtype=chunk_states.dtype,
device=chunk_states.device)
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.
cdiv(dstate, META['BLOCK_SIZE_N']), batch, nheads)
with torch.cuda.device(x.device.index):
_chunk_state_varlen_kernel[grid](
x,
B,
dt,
dA_cumsum,
chunk_states,
cu_seqlens,
states,
initial_states,
headdim,
dstate,
chunk_size,
total_seqlen,
nheads // ngroups,
x.stride(0),
x.stride(1),
x.stride(2),
B.stride(0),
B.stride(1),
B.stride(2),
dt.stride(1),
dt.stride(0),
dt.stride(2),
dA_cumsum.stride(1),
dA_cumsum.stride(0),
dA_cumsum.stride(2),
chunk_states.stride(0),
chunk_states.stride(1),
chunk_states.stride(2),
chunk_states.stride(3),
states.stride(0),
states.stride(1),
states.stride(2),
states.stride(3),
*((initial_states.stride(0), initial_states.stride(1),
initial_states.stride(2),
initial_states.stride(3)) if initial_states is not None else
(0, 0, 0, 0)),
HAS_INITSTATES=initial_states is not None)
return states
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py
# ruff: noqa: E501
import torch
import triton
from einops import rearrange
from packaging import version
from .ssd_bmm import _bmm_chunk_fwd
from .ssd_chunk_scan import _chunk_scan_fwd
from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd,
chunk_state_varlen)
from .ssd_state_passing import _state_passing_fwd
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
def _mamba_chunk_scan_combined_fwd(x,
dt,
A,
B,
C,
chunk_size,
D=None,
z=None,
dt_bias=None,
initial_states=None,
seq_idx=None,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf"))):
batch, seqlen, nheads, headdim = x.shape
_, _, ngroups, dstate = B.shape
assert nheads % ngroups == 0
assert B.shape == (batch, seqlen, ngroups, dstate)
assert x.shape == (batch, seqlen, nheads, headdim)
assert dt.shape == (batch, seqlen, nheads)
assert A.shape == (nheads, )
assert C.shape == B.shape
if z is not None:
assert z.shape == x.shape
if D is not None:
assert D.shape == (nheads, headdim) or D.shape == (nheads, )
if seq_idx is not None:
assert seq_idx.shape == (batch, seqlen)
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if x.stride(-1) != 1 and x.stride(
1) != 1: # Either M or K dimension should be contiguous
x = x.contiguous()
if z is not None and z.stride(-1) != 1 and z.stride(
1) != 1: # Either M or K dimension should be contiguous
z = z.contiguous()
if D is not None and D.stride(-1) != 1:
D = D.contiguous()
if initial_states is not None:
if cu_seqlens is None:
assert initial_states.shape == (batch, nheads, headdim, dstate)
else:
assert initial_states.shape == (len(cu_seqlens) - 1, nheads,
headdim, dstate)
# This function executes 5 sub-functions for computing mamba
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
# which has a minimal implementation to understand the below operations
# - as explained by the blog, mamba is a special case of causal attention
# - the idea is to chunk the attention matrix and compute each
# submatrix separately using different optimizations.
# - see the blog and paper for a visualization of the submatrices
# which we refer to in the comments below
# 1. Compute chunked cumsum of A * dt
# - here dt may go through a softplus activation
dA_cumsum, dt = _chunk_cumsum_fwd(dt,
A,
chunk_size,
dt_bias=dt_bias,
dt_softplus=dt_softplus,
dt_limit=dt_limit)
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
states = _chunk_state_fwd(B,
x,
dt,
dA_cumsum,
seq_idx=seq_idx,
states_in_fp32=True)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
# - for handling chunked prefill, this requires i) initial_states
# ii) seq_idx and iii) has_cu_seqlens to be all specified.
# - When a new seq_idx is detected, we will stop passing the prev_state
# and switch accordingly to the init_state corresponding to the new seq_idx.
# - this will ensure that states will be updated with the rightmost flushed seq_idx
# of the previous chunk. This implies that the first chunk of states is either 0
# or equal to init_states of the first example.
states, final_states = _state_passing_fwd(
rearrange(states, "... p n -> ... (p n)"),
dA_cumsum[:, :, :, -1],
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
if initial_states is not None else None,
seq_idx=seq_idx,
chunk_size=chunk_size,
out_dtype=C.dtype,
is_cont_batched=cu_seqlens is not None)
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
for t in [states, final_states])
# 4. Compute batched matrix multiply for C_j^T B_i terms
CB = _bmm_chunk_fwd(C,
B,
chunk_size,
seq_idx=seq_idx,
output_dtype=torch.float32)
# 5. Scan and compute the diagonal blocks, taking into
# account past causal states.
# - if initial states are provided, then states information will be
# augmented with initial_states.
# - to do this properly, we need to account for example changes in
# the continuous batch, therefore we introduce pseudo chunks, which is
# a chunk that is split up each time an example changes.
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
# a seq_idx change, in which case we take states information from
# init_states.
out, out_x = _chunk_scan_fwd(
CB,
x,
dt,
dA_cumsum,
C,
states,
D=D,
z=z,
seq_idx=seq_idx,
initial_states=initial_states,
)
if cu_seqlens is None:
return out, out_x, dt, dA_cumsum, states, final_states
else:
assert batch == 1, "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
varlen_states = chunk_state_varlen(
B.squeeze(0),
x.squeeze(0),
dt.squeeze(0),
dA_cumsum.squeeze(0),
cu_seqlens,
states.squeeze(0),
initial_states=initial_states,
)
return out, out_x, dt, dA_cumsum, states, final_states, varlen_states
def mamba_chunk_scan_combined(x,
dt,
A,
B,
C,
chunk_size,
D=None,
z=None,
dt_bias=None,
initial_states=None,
seq_idx=None,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
return_final_states=False,
return_varlen_states=False):
"""
Argument:
x: (batch, seqlen, nheads, headdim)
dt: (batch, seqlen, nheads)
A: (nheads)
B: (batch, seqlen, ngroups, dstate)
C: (batch, seqlen, ngroups, dstate)
chunk_size: int
D: (nheads, headdim) or (nheads,)
z: (batch, seqlen, nheads, headdim)
dt_bias: (nheads,)
initial_states: (batch, nheads, headdim, dstate)
seq_idx: (batch, seqlen)
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
dt_softplus: Whether to apply softplus to dt
Return:
out: (batch, seqlen, nheads, headdim)
"""
if not return_varlen_states:
cu_seqlens = None
else:
assert cu_seqlens is not None, "cu_seqlens must be provided if return_varlen_states is True"
out, out_x, dt_out, dA_cumsum, states, final_states, *rest = _mamba_chunk_scan_combined_fwd(
x,
dt,
A,
B,
C,
chunk_size,
D=D,
z=z,
dt_bias=dt_bias,
initial_states=initial_states,
seq_idx=seq_idx,
cu_seqlens=cu_seqlens,
dt_softplus=dt_softplus,
dt_limit=dt_limit)
if not return_varlen_states:
return out if not return_final_states else (out, final_states)
else:
varlen_states = rest[0]
return (out,
varlen_states) if not return_final_states else (out,
final_states,
varlen_states)
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py
# ruff: noqa: E501
import torch
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE': 64}),
triton.Config({'BLOCK_SIZE': 128}),
triton.Config({'BLOCK_SIZE': 256}),
triton.Config({'BLOCK_SIZE': 512}),
triton.Config({'BLOCK_SIZE': 1024}),
triton.Config({'BLOCK_SIZE': 2048}),
],
key=['dim'],
)
@triton.jit
def _state_passing_fwd_kernel(
# Pointers to matrices
states_ptr,
out_ptr,
final_states_ptr,
dA_cs_ptr,
initstates_ptr,
seq_idx_ptr,
# Matrix dimensions
dim,
nchunks,
seqlen,
chunk_size,
# Strides
stride_states_batch,
stride_states_chunk,
stride_states_head,
stride_states_dim,
stride_out_batch,
stride_out_chunk,
stride_out_head,
stride_out_dim,
stride_final_states_batch,
stride_final_states_head,
stride_final_states_dim,
stride_dA_cs_batch,
stride_dA_cs_chunk,
stride_dA_cs_head,
stride_initstates_batch,
stride_initstates_head,
stride_initstates_dim,
stride_seq_idx_batch,
stride_seq_idx_seqlen,
# Meta-parameters
HAS_INITSTATES: tl.constexpr,
HAS_SEQ_IDX: tl.constexpr,
IS_CONT_BATCHED: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
pid_m = tl.program_id(axis=0)
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
final_states_ptr += pid_b * stride_final_states_batch + pid_h * stride_final_states_head
if HAS_INITSTATES:
initstates_ptr += pid_h * stride_initstates_head
if not IS_CONT_BATCHED:
initstates_ptr += pid_b * stride_initstates_batch
if HAS_SEQ_IDX:
seq_idx_ptr += pid_b * stride_seq_idx_batch
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
states_ptrs = states_ptr + offs_m * stride_states_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim
# - states will be the past state of the sequence that continues on the current check
if not HAS_INITSTATES:
states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
else:
initstates_ptr += offs_m * stride_initstates_dim
initstates_ptrs = initstates_ptr
# - for cont batches, for the first chunk mean it will be the first batch's
# init state
states = tl.load(initstates_ptrs, mask=offs_m < dim,
other=0.0).to(tl.float32)
tl.store(out_ptrs, states, mask=offs_m < dim)
out_ptrs += stride_out_chunk
seq_idx = 0
for c in range(nchunks):
new_states = tl.load(states_ptrs, mask=offs_m < dim,
other=0.0).to(tl.float32)
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
scale = tl.exp(dA_cs)
if HAS_SEQ_IDX:
# - the seq to pass forward is the one that is flushed to the right
# boundary.
# - that is given by seq_idx_new below.
seq_idx_new = tl.load(seq_idx_ptr +
(min((c + 1) * chunk_size, seqlen) - 1) *
stride_seq_idx_seqlen)
if HAS_INITSTATES:
if IS_CONT_BATCHED and seq_idx != seq_idx_new:
# this means in the current chunk the rightmost flushed seq
# has changed.
# - so we do not propagate the state from previous chunk
# - but rather we load that sequence's init state
initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch
# - update state with seq_idx_new's init state
states = tl.load(initstates_ptrs,
mask=offs_m < dim,
other=0.0).to(tl.float32)
else:
scale = tl.where(seq_idx_new == seq_idx, scale, 0.0)
seq_idx = seq_idx_new
states = scale * states + new_states
if c < nchunks - 1:
tl.store(out_ptrs, states, mask=offs_m < dim)
else:
tl.store(final_states_ptrs, states, mask=offs_m < dim)
states_ptrs += stride_states_chunk
dA_cs_ptr += stride_dA_cs_chunk
out_ptrs += stride_out_chunk
def _state_passing_fwd(
states,
dA_chunk_cumsum,
initial_states=None,
seq_idx=None,
chunk_size=None,
out_dtype=None,
is_cont_batched=False,
):
batch, nchunks, nheads, dim = states.shape
assert dA_chunk_cumsum.shape == (batch, nheads, nchunks)
if initial_states is not None:
if is_cont_batched:
# - if cu_seqlens is provided, then the initial states
# are used for continuous batching. In which case we
# require seq_idx to be provided
assert seq_idx is not None, ""
assert initial_states.shape == (seq_idx.max().item() + 1, nheads,
dim)
else:
# - this is the regular batching case, where initial
# states are used are for each example of the batch.
assert initial_states.shape == (batch, nheads, dim)
if seq_idx is not None:
assert chunk_size is not None
seqlen = seq_idx.shape[-1]
assert seq_idx.shape == (batch, seqlen)
out_dtype = states.dtype if out_dtype is None else out_dtype
out = torch.empty((batch, nchunks, nheads, dim),
device=states.device,
dtype=out_dtype)
final_states = torch.empty((batch, nheads, dim),
device=states.device,
dtype=torch.float32)
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), batch, nheads)
with torch.cuda.device(states.device.index):
_state_passing_fwd_kernel[grid](
states,
out,
final_states,
dA_chunk_cumsum,
initial_states,
seq_idx,
dim,
nchunks,
seqlen if seq_idx is not None else 0,
chunk_size if seq_idx is not None else 0,
states.stride(0),
states.stride(1),
states.stride(2),
states.stride(3),
out.stride(0),
out.stride(1),
out.stride(2),
out.stride(3),
final_states.stride(0),
final_states.stride(1),
final_states.stride(2),
dA_chunk_cumsum.stride(0),
dA_chunk_cumsum.stride(2),
dA_chunk_cumsum.stride(1),
*((initial_states.stride(0), initial_states.stride(1),
initial_states.stride(2)) if initial_states is not None else
(0, 0, 0)),
*((seq_idx.stride(0),
seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
HAS_INITSTATES=initial_states is not None,
HAS_SEQ_IDX=seq_idx is not None,
IS_CONT_BATCHED=is_cont_batched,
)
return out, final_states
......@@ -11,6 +11,7 @@ QUANTIZATION_METHODS: List[str] = [
"deepspeedfp",
"tpu_int8",
"fp8",
"ptpc_fp8",
"fbgemm_fp8",
"modelopt",
# The order of gptq methods is important for config.py iteration over
......@@ -99,6 +100,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from .modelopt import ModelOptFp8Config
from .moe_wna16 import MoeWNA16Config
from .neuron_quant import NeuronQuantConfig
from .ptpc_fp8 import PTPCFp8Config
from .qqq import QQQConfig
from .tpu_int8 import Int8TpuConfig
......@@ -120,6 +122,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"gptq": GPTQConfig,
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"ptpc_fp8": PTPCFp8Config,
"qqq": QQQConfig,
"hqq": HQQMarlinConfig,
"experts_int8": ExpertsInt8Config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment