Commit 38d80967 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc2' into v0.10.2rc2-ori

parents 33650733 880c741b
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
}
}
......@@ -57,13 +57,14 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
if not _valid_deep_gemm_shape(M, N, K):
logger.debug_once(
"DeepGemm disabled due to unaligned problem size. "
"M: %s, N: %s, K: %s. M should >= align size "
"and N and K must be multiples of %s."
"M: %s, N: %s, K: %s. M should >= %s "
"and N and K must be multiples of %s. "
"This is not an error and we will fall back to triton.",
M,
N,
K,
align,
align,
)
return False
elif N <= 512:
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Callable, Optional, Union
import deep_ep
import torch
......@@ -25,6 +25,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.num_dispatchers_ = num_dispatchers
self.dp_size = dp_size
self.rank_expert_offset = rank_expert_offset
self.async_prepare = True
# The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the
# combine function.
......@@ -47,19 +49,25 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return torch.int64
def _get_dispatch_config(self) -> Optional[deep_ep.Config]:
if self.dp_size not in self.available_rank_configs:
if self.num_dispatchers_ not in self.available_rank_configs:
return None
return deep_ep.Buffer.get_dispatch_config(self.dp_size)
return deep_ep.Buffer.get_dispatch_config(self.num_dispatchers_)
def _get_combine_config(self) -> Optional[deep_ep.Config]:
if self.dp_size not in self.available_rank_configs:
if self.num_dispatchers_ not in self.available_rank_configs:
return None
return deep_ep.Buffer.get_combine_config(self.dp_size)
return deep_ep.Buffer.get_combine_config(self.num_dispatchers_)
def _do_dispatch(self, tokens: torch.Tensor,
token_scales: Optional[torch.Tensor],
rank_topk_ids: torch.Tensor,
rank_topk_weights: torch.Tensor, num_experts: int):
def _do_dispatch(
self,
tokens: torch.Tensor,
token_scales: Optional[torch.Tensor],
rank_topk_ids: torch.Tensor,
rank_topk_weights: torch.Tensor,
num_experts: int,
a1_scale: Optional[torch.Tensor],
quant_config: FusedMoEQuantConfig,
) -> Callable:
has_scales = token_scales is not None
......@@ -93,9 +101,36 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_alignment=1,
config=self._get_dispatch_config(),
previous_event=None,
async_finish=False,
async_finish=self.async_prepare,
allocate_on_comm_stream=False)
return lambda: self._receiver(
event,
has_scales,
token_data,
expert_topk_ids,
num_experts,
expert_num_tokens_per_expert_list,
expert_topk_weights,
a1_scale,
quant_config,
)
def _receiver(
self,
event: deep_ep.EventOverlap,
has_scales: bool,
token_data: Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor],
expert_topk_ids: Optional[torch.Tensor],
num_experts: int,
expert_num_tokens_per_expert_list: list[int],
expert_topk_weights: Optional[torch.Tensor],
a1_scale: Optional[torch.Tensor],
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
if self.async_prepare:
event.current_stream_wait()
if has_scales:
expert_x, expert_x_scale = token_data
else:
......@@ -112,6 +147,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# DeepEP's topk_ids output refers to the local experts directly. Offset
# the topk_ids to move it back to the global experts space so it aligns
# with existing vLLM interfaces.
assert expert_topk_ids is not None
expert_topk_ids = torch.where(
expert_topk_ids == -1,
num_experts - 1 if self.rank_expert_offset == 0 else 0,
......@@ -123,10 +159,28 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list(
expert_num_tokens_per_expert_list, device=expert_x.device)
# Dispatch and Quant
# DeepEP kernels only support dispatching block-quantized
# activation scales.
# Dispatch in bfloat16 and quantize afterwards
if not quant_config.is_block_quantized:
# Quantize after dispatch.
expert_x_scale = None
if expert_x.numel() != 0:
expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x,
a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=False,
block_shape=quant_config.block_shape)
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights)
def prepare(
def supports_async(self) -> bool:
return True
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
......@@ -137,9 +191,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
) -> Callable:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
......@@ -159,37 +211,37 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
)
if a1q_scale is not None and a1q_scale.numel() == 1:
a1q_scale = a1q_scale.view(1, 1)
(expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) = self._do_dispatch(
tokens=a1q,
token_scales=a1q_scale,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts)
a1_post_scale = None
else:
# Dispatch and Quant
# DeepEP kernels only support dispatching block-quantized
# activation scales.
# Dispatch in bfloat16
(expert_x, _, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) = self._do_dispatch(
tokens=a1,
token_scales=None,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts)
# Quantize after dispatch.
expert_x_scale = None
if expert_x.numel() != 0:
expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x,
a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=False,
block_shape=quant_config.block_shape)
a1q = a1
a1q_scale = None
a1_post_scale = a1_scale
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights)
return self._do_dispatch(tokens=a1q,
token_scales=a1q_scale,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts,
a1_scale=a1_post_scale,
quant_config=quant_config)
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights,
topk_ids, num_experts, expert_map,
apply_router_weight_on_input,
quant_config)
return receiver()
def finalize(
self,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union
from typing import Callable, Optional, Union
import deep_ep
import torch
......@@ -75,7 +75,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self,
x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
a1_dtype: torch.dtype,
quant_dtype: Union[torch.dtype, str, None],
per_act_token_quant: bool,
......@@ -110,7 +109,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return x, x_scales
def prepare(
def supports_async(self) -> bool:
return True
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
......@@ -121,9 +123,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
) -> mk.ReceiverType:
hidden_size = a1.size(1)
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
......@@ -155,16 +155,48 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts,
use_fp8=self.use_fp8_dispatch,
async_finish=False,
return_recv_hook=False)
return_recv_hook=True)
return lambda: self._receiver(hook, expert_x, expert_num_tokens,
a1_scale, a1.dtype, quant_config)
def _receiver(
self,
hook: Callable,
expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
expert_num_tokens: torch.Tensor,
a1_scale,
a1_dtype,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
hook()
expert_x, expert_x_scale = self._do_quant(
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype,
expert_x, a1_scale, a1_dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape)
expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)
return (expert_x, expert_x_scale, expert_tokens_meta, None, None)
return expert_x, expert_x_scale, expert_tokens_meta, None, None
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights,
topk_ids, num_experts, expert_map,
apply_router_weight_on_input,
quant_config)
return receiver()
def finalize(
self,
......
......@@ -56,9 +56,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool,
# TODO(bnell): use quant_config + scales instead of ctor args
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
) -> mk.PrepareResultType:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
......
......@@ -506,9 +506,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
) -> mk.PrepareResultType:
assert a1.dim() == 2
assert topk_ids.dim() == 2
assert topk_ids.size(0) == a1.size(0)
......
......@@ -549,7 +549,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
EM = sorted_token_ids.size(0)
if A.size(0) < config["BLOCK_SIZE_M"]:
# optimize for small batch_size.
# We assume that top_ids of each token is unique, so
# We assume that top_ids of each token is unique,
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
# and we can skip some invalid blocks.
EM = min(sorted_token_ids.size(0),
......
......@@ -6,7 +6,7 @@ import os
from abc import abstractmethod
from collections.abc import Iterable
from enum import Enum
from typing import Callable, Literal, Optional, overload
from typing import Callable, Literal, Optional, Union, overload
import torch
import torch.nn.functional as F
......@@ -37,7 +37,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx,
from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx,
round_up)
if current_platform.is_cuda_alike():
......@@ -217,6 +217,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
layer.shared_experts,
)
def select_gemm_impl(
......@@ -254,7 +255,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
raise NotImplementedError
......@@ -428,7 +429,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
......@@ -482,7 +483,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
......@@ -570,7 +571,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
):
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb is not False or expert_load_view is not None or \
logical_to_physical_map is not None or \
logical_replica_count is not None:
......@@ -617,7 +618,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
):
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb is not False or expert_load_view is not None or \
logical_to_physical_map is not None or \
logical_replica_count is not None:
......@@ -657,7 +658,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
......@@ -733,7 +734,7 @@ def determine_expert_map(
# Create a tensor of size num_experts filled with -1
expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32)
# Create a expert map for the local experts
# Create an expert map for the local experts
start_idx = ep_rank * base_experts + min(ep_rank, remainder)
expert_map[start_idx:start_idx + local_num_experts] = torch.arange(
0, local_num_experts, dtype=torch.int32)
......@@ -778,7 +779,7 @@ class FusedMoE(CustomOp):
intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters.
reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel
renormalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
enable_eplb: Whether to enable expert parallelism load balancer.
"""
......@@ -809,6 +810,7 @@ class FusedMoE(CustomOp):
enable_eplb: bool = False,
num_redundant_experts: int = 0,
has_bias: bool = False,
is_sequence_parallel=False,
):
super().__init__()
if params_dtype is None:
......@@ -820,6 +822,10 @@ class FusedMoE(CustomOp):
dp_size_ = (dp_size
if dp_size is not None else get_dp_group().world_size)
self.is_sequence_parallel = is_sequence_parallel
if self.is_sequence_parallel:
self.sp_size = tp_size_
vllm_config = get_current_vllm_config()
self.moe_parallel_config: FusedMoEParallelConfig = (
FusedMoEParallelConfig.make(
......@@ -829,11 +835,18 @@ class FusedMoE(CustomOp):
self.global_num_experts = num_experts + num_redundant_experts
# we padding globally so EP buffer allocation works
# we are padding globally so EP buffer allocation works
if quant_config and quant_config.get_name() == "mxfp4":
from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501
should_use_flashinfer_mxfp4)
if current_platform.is_rocm() or should_use_flashinfer_mxfp4():
from vllm.model_executor.layers.quantization.mxfp4 import (
Mxfp4Backend, get_mxfp4_backend)
current_mxfp4_backend = get_mxfp4_backend()
if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
or current_mxfp4_backend
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS):
hidden_size = round_up(hidden_size, 128)
elif (current_platform.is_rocm() or current_mxfp4_backend
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or
current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
hidden_size = round_up(hidden_size, 256)
# For smuggling this layer into the fused moe custom op
......@@ -980,6 +993,10 @@ class FusedMoE(CustomOp):
dtype=moe.in_dtype,
device=torch.cuda.current_device())
@property
def shared_experts(self) -> Optional[torch.nn.Module]:
return None
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
......@@ -1444,6 +1461,7 @@ class FusedMoE(CustomOp):
return [
weight.view(self.local_num_experts, -1) for name, weight in weights
if name not in NON_EXPERT_WEIGHTS
and not name.startswith("_shared_experts.")
]
def set_eplb_state(
......@@ -1626,25 +1644,52 @@ class FusedMoE(CustomOp):
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
def forward_native(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
og_hidden_states = hidden_states.shape[-1]
if self.hidden_size != og_hidden_states:
hidden_states = F.pad(hidden_states,
(0, self.hidden_size - og_hidden_states),
mode='constant',
value=0.0)
# TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op.
if current_platform.is_tpu():
return self.forward_impl(hidden_states, router_logits)
if self.shared_experts is None:
if current_platform.is_tpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
fused_output = self.forward_impl(hidden_states, router_logits)
assert not isinstance(fused_output, tuple)
else:
fused_output = torch.ops.vllm.moe_forward(
hidden_states, router_logits, self.layer_name)
return fused_output[..., :og_hidden_states]
else:
return torch.ops.vllm.moe_forward(
hidden_states, router_logits,
self.layer_name)[..., :og_hidden_states]
if current_platform.is_tpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
shared_output, fused_output = self.forward_impl(
hidden_states, router_logits)
else:
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
hidden_states, router_logits, self.layer_name)
return (shared_output[..., :og_hidden_states],
fused_output[..., :og_hidden_states])
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor):
def forward_cuda(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
return self.forward_native(hidden_states, router_logits)
def forward_impl_chunked(
self,
full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None
assert self.batched_hidden_states.dtype == full_hidden_states.dtype
......@@ -1655,7 +1700,10 @@ class FusedMoE(CustomOp):
assert (
self.batched_router_logits.size(-1) == full_router_logits.size(-1))
full_final_hidden_states = torch.empty_like(full_hidden_states)
full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
if self.shared_experts is not None:
full_shared_final_hidden_states = torch.empty_like(
full_hidden_states)
def process_chunk(chunk_start, chunk_end, skip_result_store=False):
chunk_size = chunk_end - chunk_start
......@@ -1696,20 +1744,40 @@ class FusedMoE(CustomOp):
logical_replica_count=self.logical_replica_count,
)
assert self.shared_experts is None or isinstance(
final_hidden_states, tuple)
if not skip_result_store:
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
final_hidden_states, non_blocking=True)
if self.shared_experts is None:
full_fused_final_hidden_states[
chunk_start:chunk_end, :].copy_(final_hidden_states,
non_blocking=True)
else:
full_shared_final_hidden_states[
chunk_start:chunk_end, :].copy_(final_hidden_states[0],
non_blocking=True)
full_fused_final_hidden_states[
chunk_start:chunk_end, :].copy_(final_hidden_states[1],
non_blocking=True)
ctx = get_forward_context()
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu
max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
# If the input to the MoE is sequence parallel then divide by sp_size
# to find the maximum number of tokens for any individual dispatcher.
if self.is_sequence_parallel:
max_tokens_across_dispatchers = cdiv(max_tokens_across_dispatchers,
self.sp_size)
num_tokens = full_hidden_states.size(0)
for chunk_idx, chunk_start_ in enumerate(
range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)):
range(0, max_tokens_across_dispatchers,
moe_dp_chunk_size_per_rank)):
chunk_start = chunk_start_
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
max_tokens_across_dp)
max_tokens_across_dispatchers)
# clamp start and end
chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens)
......@@ -1719,10 +1787,17 @@ class FusedMoE(CustomOp):
chunk_end,
skip_result_store=chunk_start_ >= num_tokens)
return full_final_hidden_states
if self.shared_experts is None:
return full_fused_final_hidden_states
else:
return (full_shared_final_hidden_states,
full_fused_final_hidden_states)
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
def forward_impl(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.quant_method is not None
# Route to the chunked forward path using the FlashInfer Cutlass kernel
# only when data parallelism (DP) is enabled.
......@@ -1738,6 +1813,16 @@ class FusedMoE(CustomOp):
self.dp_size > 1
and not self.moe_parallel_config.use_deepep_ht_kernels
and not self.moe_config.use_flashinfer_cutlass_kernels)
# If there are shared experts but we are not using a modular kernel, the
# shared experts must be called here
if (not isinstance(self.quant_method.fused_experts,
FusedMoEModularKernel)
and self.shared_experts is not None):
shared_output = self.shared_experts(hidden_states)
else:
shared_output = None
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits)
......@@ -1767,14 +1852,32 @@ class FusedMoE(CustomOp):
use_nn_moe=self.use_nn_moe,
)
if do_naive_dispatch_combine:
final_hidden_states = get_ep_group().combine(final_hidden_states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
# Default set to False. (May have to add shared expert outputs.
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel(
final_hidden_states)
if shared_output is not None:
assert not isinstance(final_hidden_states, tuple)
assert self.shared_experts is not None
final_hidden_states = (
shared_output,
final_hidden_states,
)
def reduce_output(states: torch.Tensor,
do_combine: bool = True) -> torch.Tensor:
if do_naive_dispatch_combine and do_combine:
states = get_ep_group().combine(states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
states = self.maybe_all_reduce_tensor_model_parallel(states)
return final_hidden_states
return states
if self.shared_experts is None:
assert not isinstance(final_hidden_states, tuple)
return reduce_output(final_hidden_states)
else:
return (
reduce_output(final_hidden_states[0], do_combine=False),
reduce_output(final_hidden_states[1]),
)
@classmethod
def make_expert_params_mapping(
......@@ -1829,17 +1932,22 @@ class FusedMoE(CustomOp):
return s
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
def moe_forward(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None
assert self.shared_experts is None
return self.forward_impl(hidden_states, router_logits)
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
def moe_forward_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
......@@ -1852,6 +1960,37 @@ direct_register_custom_op(
tags=(torch.Tag.needs_fixed_stride_order, ),
)
def moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.shared_experts is not None
return self.forward_impl(hidden_states, router_logits)
def moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states)
fused_out = torch.empty_like(hidden_states)
return shared_out, fused_out
direct_register_custom_op(
op_name="moe_forward_shared",
op_func=moe_forward_shared,
mutates_args=["hidden_states"],
fake_impl=moe_forward_shared_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
# Mark the FusedMoE weight_loader as supporting MoE-specific parameters
# to avoid expensive runtime reflection in model loading code
FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]
......@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from math import prod
from typing import Optional, final
from typing import Callable, Optional, Union, final
import torch
......@@ -141,6 +141,29 @@ class TopKWeightAndReduce(ABC):
raise NotImplementedError
#
# PrepareResultType is a tuple of:
# - quantized + dispatched a.
# - quantized + dispatched a1_scales.
# - Optional ExpertTokensMetadata containing gpu/cpu tensors
# as big as the number of local experts with the information about the
# number of tokens assigned to each local expert.
# - Optional dispatched expert topk IDs
# - Optional dispatched expert topk weight
#
# See `prepare` method below.
#
PrepareResultType = tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[ExpertTokensMetadata],
Optional[torch.Tensor],
Optional[torch.Tensor],
]
ReceiverType = Callable[[], PrepareResultType]
# TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalize(ABC):
"""
......@@ -160,16 +183,9 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[ExpertTokensMetadata],
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
) -> PrepareResultType:
"""
Perform any quantization (and/or) dispatching needed
for this kernel.
Perform any quantization (and/or) dispatching needed for this kernel.
- a1: The (unquantized) input to the MoE layer.
- a1_scale: Optional scales for a1
- a2_scale: Optional scales for the second MoE gemm. Required to make
......@@ -193,6 +209,51 @@ class FusedMoEPrepareAndFinalize(ABC):
"""
raise NotImplementedError
def supports_async(self) -> bool:
"""
Indicates whether or not this class implements prepare_async.
"""
return False
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> ReceiverType:
"""
Perform any quantization (and/or) dispatching needed for this kernel
but do not wait for results from other workers.
- a1: The (unquantized) input to the MoE layer.
- a1_scale: Optional scales for a1
- a2_scale: Optional scales for the second MoE gemm. Required to make
sure the quantization is consistent for both gemms.
- topk_ids: The topk ids.
- topk_weights: The topk weights.
- num_experts: The total number of experts in the global expert space.
- expert_map: A tensor mapping expert indices from the global expert
space to the local expert space of the expert parallel shard.
- apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching.
Returns a callback that when invoked waits for results from other
workers and has the same return signature as `prepare`, e.g.
receiver = obj.prepare_async(...)
a, a_scales, expert_meta, topk_ids, topk_weights = receiver()
is equivalent to:
a, a_scales, expert_meta, topk_ids, topk_weights = obj.prepare(...)
"""
raise NotImplementedError
@abstractmethod
def finalize(
self,
......@@ -241,7 +302,7 @@ class FusedMoEPrepareAndFinalize(ABC):
def max_num_tokens_per_rank(self) -> Optional[int]:
"""
Some PrepareFinalize All2All implementations are batched. Meaning,
they can processes only as set of tokens at a time. This
they can process only as set of tokens at a time. This
function returns the batch size i.e the maximum number of tokens
the implementation can process at a time.
Return None if there are no such restrictions.
......@@ -453,10 +514,12 @@ class FusedMoEModularKernel(torch.nn.Module):
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: FusedMoEPermuteExpertsUnpermute,
shared_experts: Optional[torch.nn.Module] = None,
):
super().__init__()
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
self.shared_experts = shared_experts
assert prepare_finalize.activation_format == \
fused_experts.activation_formats[0], (
f"{prepare_finalize.__class__.__name__}."
......@@ -692,7 +755,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
of weights, w1 and w2, and top-k gating mechanism.
......@@ -736,18 +799,46 @@ class FusedMoEModularKernel(torch.nn.Module):
if global_num_experts == -1:
global_num_experts = local_num_experts
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
)
shared_output: torch.Tensor
if (not self.prepare_finalize.supports_async()
or self.shared_experts is None):
# Run shared experts serially with dispatch.
if self.shared_experts is not None:
shared_output = self.shared_experts(a1)
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = self.prepare_finalize.prepare(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
)
else:
# Overlap shared expert compute with all2all dispatch.
receiver = self.prepare_finalize.prepare_async(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
)
assert self.shared_experts is not None
shared_output = self.shared_experts(a1)
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = receiver()
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
......@@ -795,4 +886,7 @@ class FusedMoEModularKernel(torch.nn.Module):
self.fused_experts.finalize_weight_and_reduce_impl(),
)
return output
if self.shared_experts is None:
return output
else:
return shared_output, output
......@@ -7,7 +7,7 @@ import torch.nn.functional as F
def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
"""
Compute the histogram of a int32 tensor. The bin edges are defined by the
Compute the histogram of an int32 tensor. The bin edges are defined by the
min and max values, with step = 1.
"""
assert input.dtype == torch.int32, "input must be of torch.int32 dtype."
......
......@@ -84,12 +84,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return self.max_num_tokens
def topk_indices_dtype(self) -> Optional[torch.dtype]:
return torch.int32
return torch.uint32
def num_dispatchers(self) -> int:
return self.num_dispatchers_
def prepare(
def supports_async(self) -> bool:
return True
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
......@@ -100,9 +103,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
) -> mk.ReceiverType:
num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K
......@@ -138,6 +139,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
_validate_scale_shape(a1q, a1q_scale, quant_config.per_act_token_quant,
quant_config.block_shape)
orig_a_scale_block_shape: Optional[int] = None
if a1q_scale is not None:
scalar_scales = a1q_scale.numel() == 1
......@@ -205,8 +208,45 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=topk_ids.view(dtype=torch.uint32),
indices=topk_ids,
bound_m=bound_m,
do_send=True,
do_recv=False,
)
return lambda: self._receiver(
expert_num_tokens,
expert_x,
expert_x_scale,
a1q,
a1q_scale,
topk_ids,
bound_m,
orig_a_scale_block_shape,
)
def _receiver(
self,
expert_num_tokens: torch.Tensor,
expert_x: torch.Tensor,
expert_x_scale: Optional[torch.Tensor],
a1q: torch.Tensor,
a1q_scale: Optional[torch.Tensor],
topk_ids: torch.Tensor,
bound_m: Optional[torch.Tensor],
orig_a_scale_block_shape: Optional[int],
) -> mk.PrepareResultType:
self.a2a.dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=topk_ids,
bound_m=bound_m,
do_send=False,
do_recv=True,
)
if expert_x_scale is not None:
......@@ -218,6 +258,31 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return expert_x, expert_x_scale, expert_tokens_meta, None, None
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
receiver = self.prepare_async(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
num_experts,
expert_map,
apply_router_weight_on_input,
quant_config,
)
return receiver()
def finalize(
self,
output: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment