Commit 38d80967 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc2' into v0.10.2rc2-ori

parents 33650733 880c741b
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
}
}
...@@ -57,13 +57,14 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, ...@@ -57,13 +57,14 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor,
if not _valid_deep_gemm_shape(M, N, K): if not _valid_deep_gemm_shape(M, N, K):
logger.debug_once( logger.debug_once(
"DeepGemm disabled due to unaligned problem size. " "DeepGemm disabled due to unaligned problem size. "
"M: %s, N: %s, K: %s. M should >= align size " "M: %s, N: %s, K: %s. M should >= %s "
"and N and K must be multiples of %s." "and N and K must be multiples of %s. "
"This is not an error and we will fall back to triton.", "This is not an error and we will fall back to triton.",
M, M,
N, N,
K, K,
align, align,
align,
) )
return False return False
elif N <= 512: elif N <= 512:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional from typing import Callable, Optional, Union
import deep_ep import deep_ep
import torch import torch
...@@ -25,6 +25,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -25,6 +25,8 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.num_dispatchers_ = num_dispatchers self.num_dispatchers_ = num_dispatchers
self.dp_size = dp_size self.dp_size = dp_size
self.rank_expert_offset = rank_expert_offset self.rank_expert_offset = rank_expert_offset
self.async_prepare = True
# The dispatch function returns a handle that the combine function # The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the # requires. We store the handle here so it is available to the
# combine function. # combine function.
...@@ -47,19 +49,25 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -47,19 +49,25 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return torch.int64 return torch.int64
def _get_dispatch_config(self) -> Optional[deep_ep.Config]: def _get_dispatch_config(self) -> Optional[deep_ep.Config]:
if self.dp_size not in self.available_rank_configs: if self.num_dispatchers_ not in self.available_rank_configs:
return None return None
return deep_ep.Buffer.get_dispatch_config(self.dp_size) return deep_ep.Buffer.get_dispatch_config(self.num_dispatchers_)
def _get_combine_config(self) -> Optional[deep_ep.Config]: def _get_combine_config(self) -> Optional[deep_ep.Config]:
if self.dp_size not in self.available_rank_configs: if self.num_dispatchers_ not in self.available_rank_configs:
return None return None
return deep_ep.Buffer.get_combine_config(self.dp_size) return deep_ep.Buffer.get_combine_config(self.num_dispatchers_)
def _do_dispatch(self, tokens: torch.Tensor, def _do_dispatch(
token_scales: Optional[torch.Tensor], self,
rank_topk_ids: torch.Tensor, tokens: torch.Tensor,
rank_topk_weights: torch.Tensor, num_experts: int): token_scales: Optional[torch.Tensor],
rank_topk_ids: torch.Tensor,
rank_topk_weights: torch.Tensor,
num_experts: int,
a1_scale: Optional[torch.Tensor],
quant_config: FusedMoEQuantConfig,
) -> Callable:
has_scales = token_scales is not None has_scales = token_scales is not None
...@@ -93,9 +101,36 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -93,9 +101,36 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_alignment=1, expert_alignment=1,
config=self._get_dispatch_config(), config=self._get_dispatch_config(),
previous_event=None, previous_event=None,
async_finish=False, async_finish=self.async_prepare,
allocate_on_comm_stream=False) allocate_on_comm_stream=False)
return lambda: self._receiver(
event,
has_scales,
token_data,
expert_topk_ids,
num_experts,
expert_num_tokens_per_expert_list,
expert_topk_weights,
a1_scale,
quant_config,
)
def _receiver(
self,
event: deep_ep.EventOverlap,
has_scales: bool,
token_data: Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor],
expert_topk_ids: Optional[torch.Tensor],
num_experts: int,
expert_num_tokens_per_expert_list: list[int],
expert_topk_weights: Optional[torch.Tensor],
a1_scale: Optional[torch.Tensor],
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
if self.async_prepare:
event.current_stream_wait()
if has_scales: if has_scales:
expert_x, expert_x_scale = token_data expert_x, expert_x_scale = token_data
else: else:
...@@ -112,6 +147,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -112,6 +147,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# DeepEP's topk_ids output refers to the local experts directly. Offset # DeepEP's topk_ids output refers to the local experts directly. Offset
# the topk_ids to move it back to the global experts space so it aligns # the topk_ids to move it back to the global experts space so it aligns
# with existing vLLM interfaces. # with existing vLLM interfaces.
assert expert_topk_ids is not None
expert_topk_ids = torch.where( expert_topk_ids = torch.where(
expert_topk_ids == -1, expert_topk_ids == -1,
num_experts - 1 if self.rank_expert_offset == 0 else 0, num_experts - 1 if self.rank_expert_offset == 0 else 0,
...@@ -123,10 +159,28 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -123,10 +159,28 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list( expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list(
expert_num_tokens_per_expert_list, device=expert_x.device) expert_num_tokens_per_expert_list, device=expert_x.device)
# Dispatch and Quant
# DeepEP kernels only support dispatching block-quantized
# activation scales.
# Dispatch in bfloat16 and quantize afterwards
if not quant_config.is_block_quantized:
# Quantize after dispatch.
expert_x_scale = None
if expert_x.numel() != 0:
expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x,
a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=False,
block_shape=quant_config.block_shape)
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) expert_topk_weights)
def prepare( def supports_async(self) -> bool:
return True
def prepare_async(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor], a1_scale: Optional[torch.Tensor],
...@@ -137,9 +191,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -137,9 +191,7 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> Callable:
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)
...@@ -159,37 +211,37 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -159,37 +211,37 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
) )
if a1q_scale is not None and a1q_scale.numel() == 1: if a1q_scale is not None and a1q_scale.numel() == 1:
a1q_scale = a1q_scale.view(1, 1) a1q_scale = a1q_scale.view(1, 1)
(expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, a1_post_scale = None
expert_topk_weights) = self._do_dispatch(
tokens=a1q,
token_scales=a1q_scale,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts)
else: else:
# Dispatch and Quant a1q = a1
# DeepEP kernels only support dispatching block-quantized a1q_scale = None
# activation scales. a1_post_scale = a1_scale
# Dispatch in bfloat16
(expert_x, _, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) = self._do_dispatch(
tokens=a1,
token_scales=None,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts)
# Quantize after dispatch.
expert_x_scale = None
if expert_x.numel() != 0:
expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x,
a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=False,
block_shape=quant_config.block_shape)
return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids, return self._do_dispatch(tokens=a1q,
expert_topk_weights) token_scales=a1q_scale,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts,
a1_scale=a1_post_scale,
quant_config=quant_config)
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights,
topk_ids, num_experts, expert_map,
apply_router_weight_on_input,
quant_config)
return receiver()
def finalize( def finalize(
self, self,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional, Union from typing import Callable, Optional, Union
import deep_ep import deep_ep
import torch import torch
...@@ -75,7 +75,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -75,7 +75,6 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self, self,
x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
a1_scale: Optional[torch.Tensor], a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
a1_dtype: torch.dtype, a1_dtype: torch.dtype,
quant_dtype: Union[torch.dtype, str, None], quant_dtype: Union[torch.dtype, str, None],
per_act_token_quant: bool, per_act_token_quant: bool,
...@@ -110,7 +109,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -110,7 +109,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return x, x_scales return x, x_scales
def prepare( def supports_async(self) -> bool:
return True
def prepare_async(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor], a1_scale: Optional[torch.Tensor],
...@@ -121,9 +123,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -121,9 +123,7 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> mk.ReceiverType:
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
hidden_size = a1.size(1) hidden_size = a1.size(1)
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \ assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
...@@ -155,16 +155,48 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -155,16 +155,48 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
num_experts, num_experts,
use_fp8=self.use_fp8_dispatch, use_fp8=self.use_fp8_dispatch,
async_finish=False, async_finish=False,
return_recv_hook=False) return_recv_hook=True)
return lambda: self._receiver(hook, expert_x, expert_num_tokens,
a1_scale, a1.dtype, quant_config)
def _receiver(
self,
hook: Callable,
expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
expert_num_tokens: torch.Tensor,
a1_scale,
a1_dtype,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
hook()
expert_x, expert_x_scale = self._do_quant( expert_x, expert_x_scale = self._do_quant(
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype, expert_x, a1_scale, a1_dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape) quant_config.per_act_token_quant, quant_config.block_shape)
expert_tokens_meta = mk.ExpertTokensMetadata( expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None) expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)
return (expert_x, expert_x_scale, expert_tokens_meta, None, None) return expert_x, expert_x_scale, expert_tokens_meta, None, None
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights,
topk_ids, num_experts, expert_map,
apply_router_weight_on_input,
quant_config)
return receiver()
def finalize( def finalize(
self, self,
......
...@@ -56,9 +56,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -56,9 +56,7 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
# TODO(bnell): use quant_config + scales instead of ctor args # TODO(bnell): use quant_config + scales instead of ctor args
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> mk.PrepareResultType:
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)
......
...@@ -506,9 +506,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -506,9 +506,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> mk.PrepareResultType:
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
assert a1.dim() == 2 assert a1.dim() == 2
assert topk_ids.dim() == 2 assert topk_ids.dim() == 2
assert topk_ids.size(0) == a1.size(0) assert topk_ids.size(0) == a1.size(0)
......
...@@ -549,7 +549,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -549,7 +549,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
EM = sorted_token_ids.size(0) EM = sorted_token_ids.size(0)
if A.size(0) < config["BLOCK_SIZE_M"]: if A.size(0) < config["BLOCK_SIZE_M"]:
# optimize for small batch_size. # optimize for small batch_size.
# We assume that top_ids of each token is unique, so # We assume that top_ids of each token is unique,
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M, # so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
# and we can skip some invalid blocks. # and we can skip some invalid blocks.
EM = min(sorted_token_ids.size(0), EM = min(sorted_token_ids.size(0),
......
...@@ -6,7 +6,7 @@ import os ...@@ -6,7 +6,7 @@ import os
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from enum import Enum from enum import Enum
from typing import Callable, Literal, Optional, overload from typing import Callable, Literal, Optional, Union, overload
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -37,7 +37,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -37,7 +37,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum from vllm.platforms.interface import CpuArchEnum
from vllm.utils import (direct_register_custom_op, has_deep_ep, has_pplx, from vllm.utils import (cdiv, direct_register_custom_op, has_deep_ep, has_pplx,
round_up) round_up)
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
...@@ -217,6 +217,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -217,6 +217,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
self.fused_experts = FusedMoEModularKernel( self.fused_experts = FusedMoEModularKernel(
prepare_finalize, prepare_finalize,
experts, experts,
layer.shared_experts,
) )
def select_gemm_impl( def select_gemm_impl(
...@@ -254,7 +255,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -254,7 +255,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
raise NotImplementedError raise NotImplementedError
...@@ -428,7 +429,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -428,7 +429,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb: if enable_eplb:
assert expert_load_view is not None assert expert_load_view is not None
assert logical_to_physical_map is not None assert logical_to_physical_map is not None
...@@ -482,7 +483,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -482,7 +483,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
...@@ -570,7 +571,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -570,7 +571,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
): ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb is not False or expert_load_view is not None or \ if enable_eplb is not False or expert_load_view is not None or \
logical_to_physical_map is not None or \ logical_to_physical_map is not None or \
logical_replica_count is not None: logical_replica_count is not None:
...@@ -617,7 +618,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -617,7 +618,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
): ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb is not False or expert_load_view is not None or \ if enable_eplb is not False or expert_load_view is not None or \
logical_to_physical_map is not None or \ logical_to_physical_map is not None or \
logical_replica_count is not None: logical_replica_count is not None:
...@@ -657,7 +658,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -657,7 +658,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert not use_grouped_topk assert not use_grouped_topk
assert num_expert_group is None assert num_expert_group is None
assert topk_group is None assert topk_group is None
...@@ -733,7 +734,7 @@ def determine_expert_map( ...@@ -733,7 +734,7 @@ def determine_expert_map(
# Create a tensor of size num_experts filled with -1 # Create a tensor of size num_experts filled with -1
expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32) expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32)
# Create a expert map for the local experts # Create an expert map for the local experts
start_idx = ep_rank * base_experts + min(ep_rank, remainder) start_idx = ep_rank * base_experts + min(ep_rank, remainder)
expert_map[start_idx:start_idx + local_num_experts] = torch.arange( expert_map[start_idx:start_idx + local_num_experts] = torch.arange(
0, local_num_experts, dtype=torch.int32) 0, local_num_experts, dtype=torch.int32)
...@@ -778,7 +779,7 @@ class FusedMoE(CustomOp): ...@@ -778,7 +779,7 @@ class FusedMoE(CustomOp):
intermediate_size: Intermediate size of the experts intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
reduce_results: Whether to all all_reduce on the output of the layer reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel renormalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure. quant_config: Quantization configure.
enable_eplb: Whether to enable expert parallelism load balancer. enable_eplb: Whether to enable expert parallelism load balancer.
""" """
...@@ -809,6 +810,7 @@ class FusedMoE(CustomOp): ...@@ -809,6 +810,7 @@ class FusedMoE(CustomOp):
enable_eplb: bool = False, enable_eplb: bool = False,
num_redundant_experts: int = 0, num_redundant_experts: int = 0,
has_bias: bool = False, has_bias: bool = False,
is_sequence_parallel=False,
): ):
super().__init__() super().__init__()
if params_dtype is None: if params_dtype is None:
...@@ -820,6 +822,10 @@ class FusedMoE(CustomOp): ...@@ -820,6 +822,10 @@ class FusedMoE(CustomOp):
dp_size_ = (dp_size dp_size_ = (dp_size
if dp_size is not None else get_dp_group().world_size) if dp_size is not None else get_dp_group().world_size)
self.is_sequence_parallel = is_sequence_parallel
if self.is_sequence_parallel:
self.sp_size = tp_size_
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.moe_parallel_config: FusedMoEParallelConfig = ( self.moe_parallel_config: FusedMoEParallelConfig = (
FusedMoEParallelConfig.make( FusedMoEParallelConfig.make(
...@@ -829,11 +835,18 @@ class FusedMoE(CustomOp): ...@@ -829,11 +835,18 @@ class FusedMoE(CustomOp):
self.global_num_experts = num_experts + num_redundant_experts self.global_num_experts = num_experts + num_redundant_experts
# we padding globally so EP buffer allocation works # we are padding globally so EP buffer allocation works
if quant_config and quant_config.get_name() == "mxfp4": if quant_config and quant_config.get_name() == "mxfp4":
from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501 from vllm.model_executor.layers.quantization.mxfp4 import (
should_use_flashinfer_mxfp4) Mxfp4Backend, get_mxfp4_backend)
if current_platform.is_rocm() or should_use_flashinfer_mxfp4(): current_mxfp4_backend = get_mxfp4_backend()
if (current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
or current_mxfp4_backend
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS):
hidden_size = round_up(hidden_size, 128)
elif (current_platform.is_rocm() or current_mxfp4_backend
== Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM or
current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_BF16):
hidden_size = round_up(hidden_size, 256) hidden_size = round_up(hidden_size, 256)
# For smuggling this layer into the fused moe custom op # For smuggling this layer into the fused moe custom op
...@@ -980,6 +993,10 @@ class FusedMoE(CustomOp): ...@@ -980,6 +993,10 @@ class FusedMoE(CustomOp):
dtype=moe.in_dtype, dtype=moe.in_dtype,
device=torch.cuda.current_device()) device=torch.cuda.current_device())
@property
def shared_experts(self) -> Optional[torch.nn.Module]:
return None
@property @property
def tp_size(self): def tp_size(self):
return self.moe_parallel_config.tp_size return self.moe_parallel_config.tp_size
...@@ -1444,6 +1461,7 @@ class FusedMoE(CustomOp): ...@@ -1444,6 +1461,7 @@ class FusedMoE(CustomOp):
return [ return [
weight.view(self.local_num_experts, -1) for name, weight in weights weight.view(self.local_num_experts, -1) for name, weight in weights
if name not in NON_EXPERT_WEIGHTS if name not in NON_EXPERT_WEIGHTS
and not name.startswith("_shared_experts.")
] ]
def set_eplb_state( def set_eplb_state(
...@@ -1626,25 +1644,52 @@ class FusedMoE(CustomOp): ...@@ -1626,25 +1644,52 @@ class FusedMoE(CustomOp):
else: else:
return tensor_model_parallel_all_reduce(final_hidden_states) return tensor_model_parallel_all_reduce(final_hidden_states)
def forward(self, hidden_states: torch.Tensor, def forward_native(
router_logits: torch.Tensor): self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
og_hidden_states = hidden_states.shape[-1] og_hidden_states = hidden_states.shape[-1]
if self.hidden_size != og_hidden_states: if self.hidden_size != og_hidden_states:
hidden_states = F.pad(hidden_states, hidden_states = F.pad(hidden_states,
(0, self.hidden_size - og_hidden_states), (0, self.hidden_size - og_hidden_states),
mode='constant', mode='constant',
value=0.0) value=0.0)
# TODO: Once the OOM issue for the TPU backend is resolved, we will
# switch to using the moe_forward custom op. if self.shared_experts is None:
if current_platform.is_tpu(): if current_platform.is_tpu():
return self.forward_impl(hidden_states, router_logits) # TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op.
fused_output = self.forward_impl(hidden_states, router_logits)
assert not isinstance(fused_output, tuple)
else:
fused_output = torch.ops.vllm.moe_forward(
hidden_states, router_logits, self.layer_name)
return fused_output[..., :og_hidden_states]
else: else:
return torch.ops.vllm.moe_forward( if current_platform.is_tpu():
hidden_states, router_logits, # TODO: Once the OOM issue for the TPU backend is resolved, we
self.layer_name)[..., :og_hidden_states] # will switch to using the moe_forward custom op.
shared_output, fused_output = self.forward_impl(
hidden_states, router_logits)
else:
shared_output, fused_output = torch.ops.vllm.moe_forward_shared(
hidden_states, router_logits, self.layer_name)
return (shared_output[..., :og_hidden_states],
fused_output[..., :og_hidden_states])
def forward_impl_chunked(self, full_hidden_states: torch.Tensor, def forward_cuda(
full_router_logits: torch.Tensor): self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
return self.forward_native(hidden_states, router_logits)
def forward_impl_chunked(
self,
full_hidden_states: torch.Tensor,
full_router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.batched_hidden_states is not None assert self.batched_hidden_states is not None
assert self.batched_router_logits is not None assert self.batched_router_logits is not None
assert self.batched_hidden_states.dtype == full_hidden_states.dtype assert self.batched_hidden_states.dtype == full_hidden_states.dtype
...@@ -1655,7 +1700,10 @@ class FusedMoE(CustomOp): ...@@ -1655,7 +1700,10 @@ class FusedMoE(CustomOp):
assert ( assert (
self.batched_router_logits.size(-1) == full_router_logits.size(-1)) self.batched_router_logits.size(-1) == full_router_logits.size(-1))
full_final_hidden_states = torch.empty_like(full_hidden_states) full_fused_final_hidden_states = torch.empty_like(full_hidden_states)
if self.shared_experts is not None:
full_shared_final_hidden_states = torch.empty_like(
full_hidden_states)
def process_chunk(chunk_start, chunk_end, skip_result_store=False): def process_chunk(chunk_start, chunk_end, skip_result_store=False):
chunk_size = chunk_end - chunk_start chunk_size = chunk_end - chunk_start
...@@ -1696,20 +1744,40 @@ class FusedMoE(CustomOp): ...@@ -1696,20 +1744,40 @@ class FusedMoE(CustomOp):
logical_replica_count=self.logical_replica_count, logical_replica_count=self.logical_replica_count,
) )
assert self.shared_experts is None or isinstance(
final_hidden_states, tuple)
if not skip_result_store: if not skip_result_store:
full_final_hidden_states[chunk_start:chunk_end, :].copy_( if self.shared_experts is None:
final_hidden_states, non_blocking=True) full_fused_final_hidden_states[
chunk_start:chunk_end, :].copy_(final_hidden_states,
non_blocking=True)
else:
full_shared_final_hidden_states[
chunk_start:chunk_end, :].copy_(final_hidden_states[0],
non_blocking=True)
full_fused_final_hidden_states[
chunk_start:chunk_end, :].copy_(final_hidden_states[1],
non_blocking=True)
ctx = get_forward_context() ctx = get_forward_context()
# flashinfer_cutlass_kernels can handle: optional DP + TP/EP # flashinfer_cutlass_kernels can handle: optional DP + TP/EP
max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu max_tokens_across_dispatchers = ctx.dp_metadata.max_tokens_across_dp_cpu
moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens
# If the input to the MoE is sequence parallel then divide by sp_size
# to find the maximum number of tokens for any individual dispatcher.
if self.is_sequence_parallel:
max_tokens_across_dispatchers = cdiv(max_tokens_across_dispatchers,
self.sp_size)
num_tokens = full_hidden_states.size(0) num_tokens = full_hidden_states.size(0)
for chunk_idx, chunk_start_ in enumerate( for chunk_idx, chunk_start_ in enumerate(
range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank)): range(0, max_tokens_across_dispatchers,
moe_dp_chunk_size_per_rank)):
chunk_start = chunk_start_ chunk_start = chunk_start_
chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank, chunk_end = min(chunk_start + moe_dp_chunk_size_per_rank,
max_tokens_across_dp) max_tokens_across_dispatchers)
# clamp start and end # clamp start and end
chunk_start = min(chunk_start, num_tokens - 1) chunk_start = min(chunk_start, num_tokens - 1)
chunk_end = min(chunk_end, num_tokens) chunk_end = min(chunk_end, num_tokens)
...@@ -1719,10 +1787,17 @@ class FusedMoE(CustomOp): ...@@ -1719,10 +1787,17 @@ class FusedMoE(CustomOp):
chunk_end, chunk_end,
skip_result_store=chunk_start_ >= num_tokens) skip_result_store=chunk_start_ >= num_tokens)
return full_final_hidden_states if self.shared_experts is None:
return full_fused_final_hidden_states
else:
return (full_shared_final_hidden_states,
full_fused_final_hidden_states)
def forward_impl(self, hidden_states: torch.Tensor, def forward_impl(
router_logits: torch.Tensor): self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert self.quant_method is not None assert self.quant_method is not None
# Route to the chunked forward path using the FlashInfer Cutlass kernel # Route to the chunked forward path using the FlashInfer Cutlass kernel
# only when data parallelism (DP) is enabled. # only when data parallelism (DP) is enabled.
...@@ -1738,6 +1813,16 @@ class FusedMoE(CustomOp): ...@@ -1738,6 +1813,16 @@ class FusedMoE(CustomOp):
self.dp_size > 1 self.dp_size > 1
and not self.moe_parallel_config.use_deepep_ht_kernels and not self.moe_parallel_config.use_deepep_ht_kernels
and not self.moe_config.use_flashinfer_cutlass_kernels) and not self.moe_config.use_flashinfer_cutlass_kernels)
# If there are shared experts but we are not using a modular kernel, the
# shared experts must be called here
if (not isinstance(self.quant_method.fused_experts,
FusedMoEModularKernel)
and self.shared_experts is not None):
shared_output = self.shared_experts(hidden_states)
else:
shared_output = None
if do_naive_dispatch_combine: if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits) hidden_states, router_logits)
...@@ -1767,14 +1852,32 @@ class FusedMoE(CustomOp): ...@@ -1767,14 +1852,32 @@ class FusedMoE(CustomOp):
use_nn_moe=self.use_nn_moe, use_nn_moe=self.use_nn_moe,
) )
if do_naive_dispatch_combine: if shared_output is not None:
final_hidden_states = get_ep_group().combine(final_hidden_states) assert not isinstance(final_hidden_states, tuple)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): assert self.shared_experts is not None
# Default set to False. (May have to add shared expert outputs. final_hidden_states = (
final_hidden_states = self.maybe_all_reduce_tensor_model_parallel( shared_output,
final_hidden_states) final_hidden_states,
)
def reduce_output(states: torch.Tensor,
do_combine: bool = True) -> torch.Tensor:
if do_naive_dispatch_combine and do_combine:
states = get_ep_group().combine(states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
states = self.maybe_all_reduce_tensor_model_parallel(states)
return final_hidden_states return states
if self.shared_experts is None:
assert not isinstance(final_hidden_states, tuple)
return reduce_output(final_hidden_states)
else:
return (
reduce_output(final_hidden_states[0], do_combine=False),
reduce_output(final_hidden_states[1]),
)
@classmethod @classmethod
def make_expert_params_mapping( def make_expert_params_mapping(
...@@ -1829,17 +1932,22 @@ class FusedMoE(CustomOp): ...@@ -1829,17 +1932,22 @@ class FusedMoE(CustomOp):
return s return s
def moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor, def moe_forward(
layer_name: str) -> torch.Tensor: hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name] self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None assert self.shared_experts is None
return self.forward_impl(hidden_states, router_logits) return self.forward_impl(hidden_states, router_logits)
def moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor, def moe_forward_fake(
layer_name: str) -> torch.Tensor: hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -1852,6 +1960,37 @@ direct_register_custom_op( ...@@ -1852,6 +1960,37 @@ direct_register_custom_op(
tags=(torch.Tag.needs_fixed_stride_order, ), tags=(torch.Tag.needs_fixed_stride_order, ),
) )
def moe_forward_shared(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.shared_experts is not None
return self.forward_impl(hidden_states, router_logits)
def moe_forward_shared_fake(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
layer_name: str,
) -> tuple[torch.Tensor, torch.Tensor]:
shared_out = torch.empty_like(hidden_states)
fused_out = torch.empty_like(hidden_states)
return shared_out, fused_out
direct_register_custom_op(
op_name="moe_forward_shared",
op_func=moe_forward_shared,
mutates_args=["hidden_states"],
fake_impl=moe_forward_shared_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
)
# Mark the FusedMoE weight_loader as supporting MoE-specific parameters # Mark the FusedMoE weight_loader as supporting MoE-specific parameters
# to avoid expensive runtime reflection in model loading code # to avoid expensive runtime reflection in model loading code
FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined] FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]
...@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod ...@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from math import prod from math import prod
from typing import Optional, final from typing import Callable, Optional, Union, final
import torch import torch
...@@ -141,6 +141,29 @@ class TopKWeightAndReduce(ABC): ...@@ -141,6 +141,29 @@ class TopKWeightAndReduce(ABC):
raise NotImplementedError raise NotImplementedError
#
# PrepareResultType is a tuple of:
# - quantized + dispatched a.
# - quantized + dispatched a1_scales.
# - Optional ExpertTokensMetadata containing gpu/cpu tensors
# as big as the number of local experts with the information about the
# number of tokens assigned to each local expert.
# - Optional dispatched expert topk IDs
# - Optional dispatched expert topk weight
#
# See `prepare` method below.
#
PrepareResultType = tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[ExpertTokensMetadata],
Optional[torch.Tensor],
Optional[torch.Tensor],
]
ReceiverType = Callable[[], PrepareResultType]
# TODO: pass FusedMoEParallelConfig in as ctor parameter? # TODO: pass FusedMoEParallelConfig in as ctor parameter?
class FusedMoEPrepareAndFinalize(ABC): class FusedMoEPrepareAndFinalize(ABC):
""" """
...@@ -160,16 +183,9 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -160,16 +183,9 @@ class FusedMoEPrepareAndFinalize(ABC):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[ ) -> PrepareResultType:
torch.Tensor,
Optional[torch.Tensor],
Optional[ExpertTokensMetadata],
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
""" """
Perform any quantization (and/or) dispatching needed Perform any quantization (and/or) dispatching needed for this kernel.
for this kernel.
- a1: The (unquantized) input to the MoE layer. - a1: The (unquantized) input to the MoE layer.
- a1_scale: Optional scales for a1 - a1_scale: Optional scales for a1
- a2_scale: Optional scales for the second MoE gemm. Required to make - a2_scale: Optional scales for the second MoE gemm. Required to make
...@@ -193,6 +209,51 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -193,6 +209,51 @@ class FusedMoEPrepareAndFinalize(ABC):
""" """
raise NotImplementedError raise NotImplementedError
def supports_async(self) -> bool:
"""
Indicates whether or not this class implements prepare_async.
"""
return False
def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> ReceiverType:
"""
Perform any quantization (and/or) dispatching needed for this kernel
but do not wait for results from other workers.
- a1: The (unquantized) input to the MoE layer.
- a1_scale: Optional scales for a1
- a2_scale: Optional scales for the second MoE gemm. Required to make
sure the quantization is consistent for both gemms.
- topk_ids: The topk ids.
- topk_weights: The topk weights.
- num_experts: The total number of experts in the global expert space.
- expert_map: A tensor mapping expert indices from the global expert
space to the local expert space of the expert parallel shard.
- apply_router_weight_on_input: When True, apply the weights to the
activations, before quantization + dispatching.
Returns a callback that when invoked waits for results from other
workers and has the same return signature as `prepare`, e.g.
receiver = obj.prepare_async(...)
a, a_scales, expert_meta, topk_ids, topk_weights = receiver()
is equivalent to:
a, a_scales, expert_meta, topk_ids, topk_weights = obj.prepare(...)
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def finalize( def finalize(
self, self,
...@@ -241,7 +302,7 @@ class FusedMoEPrepareAndFinalize(ABC): ...@@ -241,7 +302,7 @@ class FusedMoEPrepareAndFinalize(ABC):
def max_num_tokens_per_rank(self) -> Optional[int]: def max_num_tokens_per_rank(self) -> Optional[int]:
""" """
Some PrepareFinalize All2All implementations are batched. Meaning, Some PrepareFinalize All2All implementations are batched. Meaning,
they can processes only as set of tokens at a time. This they can process only as set of tokens at a time. This
function returns the batch size i.e the maximum number of tokens function returns the batch size i.e the maximum number of tokens
the implementation can process at a time. the implementation can process at a time.
Return None if there are no such restrictions. Return None if there are no such restrictions.
...@@ -453,10 +514,12 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -453,10 +514,12 @@ class FusedMoEModularKernel(torch.nn.Module):
self, self,
prepare_finalize: FusedMoEPrepareAndFinalize, prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: FusedMoEPermuteExpertsUnpermute, fused_experts: FusedMoEPermuteExpertsUnpermute,
shared_experts: Optional[torch.nn.Module] = None,
): ):
super().__init__() super().__init__()
self.prepare_finalize = prepare_finalize self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts self.fused_experts = fused_experts
self.shared_experts = shared_experts
assert prepare_finalize.activation_format == \ assert prepare_finalize.activation_format == \
fused_experts.activation_formats[0], ( fused_experts.activation_formats[0], (
f"{prepare_finalize.__class__.__name__}." f"{prepare_finalize.__class__.__name__}."
...@@ -692,7 +755,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -692,7 +755,7 @@ class FusedMoEModularKernel(torch.nn.Module):
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
) -> torch.Tensor: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets This function computes a Mixture of Experts (MoE) layer using two sets
of weights, w1 and w2, and top-k gating mechanism. of weights, w1 and w2, and top-k gating mechanism.
...@@ -736,18 +799,46 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -736,18 +799,46 @@ class FusedMoEModularKernel(torch.nn.Module):
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = local_num_experts global_num_experts = local_num_experts
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids, shared_output: torch.Tensor
_expert_topk_weights) = self.prepare_finalize.prepare(
a1, if (not self.prepare_finalize.supports_async()
a1_scale, or self.shared_experts is None):
a2_scale,
topk_weights, # Run shared experts serially with dispatch.
topk_ids, if self.shared_experts is not None:
global_num_experts, shared_output = self.shared_experts(a1)
expert_map,
apply_router_weight_on_input, (a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
self.fused_experts.quant_config, _expert_topk_weights) = self.prepare_finalize.prepare(
) a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
)
else:
# Overlap shared expert compute with all2all dispatch.
receiver = self.prepare_finalize.prepare_async(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
global_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
)
assert self.shared_experts is not None
shared_output = self.shared_experts(a1)
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = receiver()
# Maybe prepare gathered topk_ids and topk_weights from other EP ranks. # Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
...@@ -795,4 +886,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -795,4 +886,7 @@ class FusedMoEModularKernel(torch.nn.Module):
self.fused_experts.finalize_weight_and_reduce_impl(), self.fused_experts.finalize_weight_and_reduce_impl(),
) )
return output if self.shared_experts is None:
return output
else:
return shared_output, output
...@@ -7,7 +7,7 @@ import torch.nn.functional as F ...@@ -7,7 +7,7 @@ import torch.nn.functional as F
def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor: def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
""" """
Compute the histogram of a int32 tensor. The bin edges are defined by the Compute the histogram of an int32 tensor. The bin edges are defined by the
min and max values, with step = 1. min and max values, with step = 1.
""" """
assert input.dtype == torch.int32, "input must be of torch.int32 dtype." assert input.dtype == torch.int32, "input must be of torch.int32 dtype."
......
...@@ -84,12 +84,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -84,12 +84,15 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return self.max_num_tokens return self.max_num_tokens
def topk_indices_dtype(self) -> Optional[torch.dtype]: def topk_indices_dtype(self) -> Optional[torch.dtype]:
return torch.int32 return torch.uint32
def num_dispatchers(self) -> int: def num_dispatchers(self) -> int:
return self.num_dispatchers_ return self.num_dispatchers_
def prepare( def supports_async(self) -> bool:
return True
def prepare_async(
self, self,
a1: torch.Tensor, a1: torch.Tensor,
a1_scale: Optional[torch.Tensor], a1_scale: Optional[torch.Tensor],
...@@ -100,9 +103,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -100,9 +103,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> mk.ReceiverType:
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
num_tokens = a1.size(0) # M num_tokens = a1.size(0) # M
hidden_dim = a1.size(-1) # K hidden_dim = a1.size(-1) # K
...@@ -138,6 +139,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -138,6 +139,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
_validate_scale_shape(a1q, a1q_scale, quant_config.per_act_token_quant, _validate_scale_shape(a1q, a1q_scale, quant_config.per_act_token_quant,
quant_config.block_shape) quant_config.block_shape)
orig_a_scale_block_shape: Optional[int] = None
if a1q_scale is not None: if a1q_scale is not None:
scalar_scales = a1q_scale.numel() == 1 scalar_scales = a1q_scale.numel() == 1
...@@ -205,8 +208,45 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -205,8 +208,45 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
out_expert_x_scale=expert_x_scale, out_expert_x_scale=expert_x_scale,
dp_x=a1q, dp_x=a1q,
dp_x_scale=a1q_scale, dp_x_scale=a1q_scale,
indices=topk_ids.view(dtype=torch.uint32), indices=topk_ids,
bound_m=bound_m,
do_send=True,
do_recv=False,
)
return lambda: self._receiver(
expert_num_tokens,
expert_x,
expert_x_scale,
a1q,
a1q_scale,
topk_ids,
bound_m,
orig_a_scale_block_shape,
)
def _receiver(
self,
expert_num_tokens: torch.Tensor,
expert_x: torch.Tensor,
expert_x_scale: Optional[torch.Tensor],
a1q: torch.Tensor,
a1q_scale: Optional[torch.Tensor],
topk_ids: torch.Tensor,
bound_m: Optional[torch.Tensor],
orig_a_scale_block_shape: Optional[int],
) -> mk.PrepareResultType:
self.a2a.dispatch(
out_expert_num_tokens=expert_num_tokens,
out_expert_x=expert_x,
out_expert_x_scale=expert_x_scale,
dp_x=a1q,
dp_x_scale=a1q_scale,
indices=topk_ids,
bound_m=bound_m, bound_m=bound_m,
do_send=False,
do_recv=True,
) )
if expert_x_scale is not None: if expert_x_scale is not None:
...@@ -218,6 +258,31 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): ...@@ -218,6 +258,31 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
return expert_x, expert_x_scale, expert_tokens_meta, None, None return expert_x, expert_x_scale, expert_tokens_meta, None, None
def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
receiver = self.prepare_async(
a1,
a1_scale,
a2_scale,
topk_weights,
topk_ids,
num_experts,
expert_map,
apply_router_weight_on_input,
quant_config,
)
return receiver()
def finalize( def finalize(
self, self,
output: torch.Tensor, output: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment