Commit 89d1dd57 authored by zhuwenwen's avatar zhuwenwen
Browse files

[Models]support blas and moe nn layout of deepseek-v3

parent 53076d70
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
"num_ldmatrixes": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 4,
"num_ldmatrixes": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 3,
"num_ldmatrixes": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
"num_ldmatrixes": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
"num_ldmatrixes": 1
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"32": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"48": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"96": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"128": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"512": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"1024": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"1536": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"2048": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"3072": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"4096": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
}
}
{
"1": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
"num_ldmatrixes": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 4,
"num_ldmatrixes": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 3,
"num_ldmatrixes": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
"num_ldmatrixes": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4,
"num_ldmatrixes": 1
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"32": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"48": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"96": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"128": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"512": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"1024": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"1536": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"2048": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"3072": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"4096": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"16": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2
}
}
{
"1": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"4": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2,
"num_ldmatrixes": 1
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"num_ldmatrixes": 1
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"16": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2
}
}
{
"1": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
}
}
{
"1": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
}
}
...@@ -109,14 +109,28 @@ def fused_moe_kernel_gptq_awq( ...@@ -109,14 +109,28 @@ def fused_moe_kernel_gptq_awq(
# Map program ids `pid` to the block of C it should compute. # Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse. # This is done in a grouped ordering to promote L2 data reuse.
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) # num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) # num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n # num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group # group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M # first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) # group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) # pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m # pid_n = (pid % num_pid_in_group) // group_size_m
if GROUP_SIZE_M ==1:
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
else:
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ---------------------------------------------------------- # ----------------------------------------------------------
# Create pointers for the first blocks of A and B. # Create pointers for the first blocks of A and B.
...@@ -681,7 +695,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -681,7 +695,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
block_shape: Optional[List[int]] = None) -> None: block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool]=False) -> None:
assert topk_weights.stride(1) == 1 assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
...@@ -712,7 +727,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -712,7 +727,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
EM = min(sorted_token_ids.shape[0], EM = min(sorted_token_ids.shape[0],
A.shape[0] * top_k * config['BLOCK_SIZE_M']) A.shape[0] * top_k * config['BLOCK_SIZE_M'])
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
B.shape[1], META['BLOCK_SIZE_N']), ) B.shape[1] if not use_nn_moe else B.shape[2], META['BLOCK_SIZE_N']), )
if (use_int8_w8a16 or use_int4_w4a16) and \ if (use_int8_w8a16 or use_int4_w4a16) and \
block_shape is not None and block_shape[1] > 0: block_shape is not None and block_shape[1] > 0:
...@@ -799,15 +814,15 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -799,15 +814,15 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
B.shape[1], B.shape[1] if not use_nn_moe else B.shape[2],
A.shape[1], A.shape[1],
EM, EM,
topk_ids.numel(), topk_ids.numel(),
A.stride(0), A.stride(0),
A.stride(1), A.stride(1),
B.stride(0), B.stride(0),
B.stride(2), B.stride(2) if not use_nn_moe else B.stride(1),
B.stride(1), B.stride(1) if not use_nn_moe else B.stride(2),
C.stride(1), C.stride(1),
C.stride(2), C.stride(2),
A_scale.stride(0) A_scale.stride(0)
...@@ -836,13 +851,15 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -836,13 +851,15 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
def get_config_file_name(E: int, def get_config_file_name(E: int,
N: int, N: int,
dtype: Optional[str], dtype: Optional[str],
block_shape: Optional[List[int]] = None) -> str: block_shape: Optional[List[int]] = None, use_nn_moe: Optional[bool] = False) -> str:
device_name = current_platform.get_device_name().replace(" ", "_") device_name = current_platform.get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}" dtype_selector = "" if not dtype else f",dtype={dtype}"
block_shape_selector = ("" if not block_shape or not all(block_shape) else block_shape_selector = ("" if not block_shape or not all(block_shape) else
f",block_shape={block_shape}").replace(" ", "") f",block_shape={block_shape}").replace(" ", "")
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 if not use_nn_moe:
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501
else:
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}_nn.json"
# Adapted from: https://github.com/sgl-project/sglang/pull/2628 # Adapted from: https://github.com/sgl-project/sglang/pull/2628
@functools.lru_cache @functools.lru_cache
...@@ -852,6 +869,7 @@ def get_moe_configs( ...@@ -852,6 +869,7 @@ def get_moe_configs(
dtype: Optional[str], dtype: Optional[str],
block_n: Optional[int] = None, block_n: Optional[int] = None,
block_k: Optional[int] = None, block_k: Optional[int] = None,
use_nn_moe: Optional[bool] = False,
) -> Optional[Dict[int, Any]]: ) -> Optional[Dict[int, Any]]:
""" """
Return optimized configurations for the fused MoE kernel. Return optimized configurations for the fused MoE kernel.
...@@ -865,7 +883,7 @@ def get_moe_configs( ...@@ -865,7 +883,7 @@ def get_moe_configs(
# First look up if an optimized configuration is available in the configs # First look up if an optimized configuration is available in the configs
# directory # directory
block_shape = [block_n, block_k] if block_n and block_k else None block_shape = [block_n, block_k] if block_n and block_k else None
json_file_name = get_config_file_name(E, N, dtype, block_shape) json_file_name = get_config_file_name(E, N, dtype, block_shape, use_nn_moe=use_nn_moe)
config_file_path = os.path.join( config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
...@@ -957,6 +975,7 @@ def get_default_config( ...@@ -957,6 +975,7 @@ def get_default_config(
dtype: Optional[str], dtype: Optional[str],
is_marlin: bool, is_marlin: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool]=False,
) -> Dict[str, int]: ) -> Dict[str, int]:
if dtype == "fp8_w8a8" and block_shape is not None: if dtype == "fp8_w8a8" and block_shape is not None:
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
...@@ -999,6 +1018,8 @@ def get_default_config( ...@@ -999,6 +1018,8 @@ def get_default_config(
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
} }
if use_nn_moe:
config["num_ldmatrixes"] = 1
return config return config
...@@ -1010,6 +1031,7 @@ def try_get_optimal_moe_config( ...@@ -1010,6 +1031,7 @@ def try_get_optimal_moe_config(
M: int, M: int,
is_marlin: bool = False, is_marlin: bool = False,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
): ):
from vllm.model_executor.layers.fused_moe import get_config from vllm.model_executor.layers.fused_moe import get_config
override_config = get_config() override_config = get_config()
...@@ -1017,12 +1039,15 @@ def try_get_optimal_moe_config( ...@@ -1017,12 +1039,15 @@ def try_get_optimal_moe_config(
config = override_config config = override_config
else: else:
# First try to load optimal config from the file # First try to load optimal config from the file
E, _, N = w2_shape if not use_nn_moe:
E, _, N = w2_shape
else:
E, N, _ = w2_shape
if dtype == "int4_w4a16": if dtype == "int4_w4a16":
N = N * 2 N = N * 2
block_n = block_shape[0] if block_shape else 0 block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0 block_k = block_shape[1] if block_shape else 0
configs = get_moe_configs(E, N, dtype, block_n, block_k) configs = get_moe_configs(E, N, dtype, block_n, block_k, use_nn_moe=use_nn_moe)
if configs: if configs:
# If an optimal configuration map has been found, look up the # If an optimal configuration map has been found, look up the
...@@ -1030,8 +1055,8 @@ def try_get_optimal_moe_config( ...@@ -1030,8 +1055,8 @@ def try_get_optimal_moe_config(
config = configs[min(configs.keys(), key=lambda x: abs(x - M))] config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else: else:
# Else use the default config # Else use the default config
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, config = get_default_config(M, E, N, w1_shape[2] if not use_nn_moe else w1_shape[1], top_k, dtype,
is_marlin, block_shape) is_marlin, block_shape, use_nn_moe=use_nn_moe)
return config return config
...@@ -1165,12 +1190,13 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -1165,12 +1190,13 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None: block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, use_fp8_w8a8, use_int8_w8a16, activation, use_fp8_w8a8, use_int8_w8a16,
use_int4_w4a16, global_num_experts, expert_map, use_int4_w4a16, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape) block_shape, use_nn_moe)
def inplace_fused_experts_fake( def inplace_fused_experts_fake(
...@@ -1191,7 +1217,8 @@ def inplace_fused_experts_fake( ...@@ -1191,7 +1217,8 @@ def inplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None: block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,) -> None:
pass pass
...@@ -1221,12 +1248,13 @@ def outplace_fused_experts( ...@@ -1221,12 +1248,13 @@ def outplace_fused_experts(
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor: block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, use_fp8_w8a8, use_int8_w8a16, False, activation, use_fp8_w8a8, use_int8_w8a16,
use_int4_w4a16, global_num_experts, expert_map, use_int4_w4a16, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape) a2_scale, block_shape, use_nn_moe)
def outplace_fused_experts_fake( def outplace_fused_experts_fake(
...@@ -1247,7 +1275,8 @@ def outplace_fused_experts_fake( ...@@ -1247,7 +1275,8 @@ def outplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor: block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -1277,21 +1306,22 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -1277,21 +1306,22 @@ def fused_experts(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor: block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
if inplace: if inplace:
torch.ops.vllm.inplace_fused_experts( torch.ops.vllm.inplace_fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids, activation, hidden_states, w1, w2, topk_weights, topk_ids, activation,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts, use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape) block_shape, use_nn_moe)
return hidden_states return hidden_states
else: else:
return torch.ops.vllm.outplace_fused_experts( return torch.ops.vllm.outplace_fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids, activation, hidden_states, w1, w2, topk_weights, topk_ids, activation,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts, use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape) block_shape, use_nn_moe)
def fused_experts_impl(hidden_states: torch.Tensor, def fused_experts_impl(hidden_states: torch.Tensor,
...@@ -1312,11 +1342,14 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1312,11 +1342,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None): block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False):
# Check constraints. # Check constraints.
if use_int4_w4a16: if use_int4_w4a16:
assert hidden_states.shape[1] // 2 == w1.shape[ assert hidden_states.shape[1] // 2 == w1.shape[
2], "Hidden size mismatch" 2], "Hidden size mismatch"
elif use_nn_moe:
assert hidden_states.shape[1] == w1.shape[1], "Hidden size mismatch"
else: else:
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
...@@ -1329,7 +1362,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1329,7 +1362,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
] ]
num_tokens, _ = hidden_states.shape num_tokens, _ = hidden_states.shape
E, N, _ = w1.shape if use_nn_moe:
E, _, N = w1.shape
else:
E, N, _ = w1.shape
if global_num_experts == -1: if global_num_experts == -1:
global_num_experts = E global_num_experts = E
top_k_num = topk_ids.shape[1] top_k_num = topk_ids.shape[1]
...@@ -1349,6 +1385,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1349,6 +1385,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
top_k_num, top_k_num,
config_dtype, config_dtype,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe,
) )
config = get_config_func(M) config = get_config_func(M)
...@@ -1360,7 +1397,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1360,7 +1397,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
dtype=hidden_states.dtype) dtype=hidden_states.dtype)
intermediate_cache1 = cache13[:M * top_k_num * N].view( intermediate_cache1 = cache13[:M * top_k_num * N].view(
(M, topk_ids.shape[1], N)) (M, topk_ids.shape[1], N))
intermediate_cache3 = cache13[:M * top_k_num * w2.shape[1]].view( intermediate_cache3 = cache13[:M * top_k_num * (w2.shape[1] if not use_nn_moe else w2.shape[2])].view(
(M, topk_ids.shape[1], w2.shape[1])) (M, topk_ids.shape[1], w2.shape[1]))
# This needs separate memory since it's used concurrently with cache1 # This needs separate memory since it's used concurrently with cache1
...@@ -1428,7 +1465,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1428,7 +1465,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape) block_shape=block_shape,
use_nn_moe=use_nn_moe)
if activation == "silu": if activation == "silu":
torch.ops._C.silu_and_mul(intermediate_cache2, torch.ops._C.silu_and_mul(intermediate_cache2,
...@@ -1457,7 +1495,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1457,7 +1495,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape) block_shape=block_shape,
use_nn_moe=use_nn_moe)
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx]) out_hidden_states[begin_chunk_idx:end_chunk_idx])
...@@ -1489,6 +1528,7 @@ def fused_moe( ...@@ -1489,6 +1528,7 @@ def fused_moe(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
...@@ -1568,4 +1608,5 @@ def fused_moe( ...@@ -1568,4 +1608,5 @@ def fused_moe(
w2_zp=w2_zp, w2_zp=w2_zp,
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape) block_shape=block_shape,
use_nn_moe=use_nn_moe)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
...@@ -75,24 +76,38 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -75,24 +76,38 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, use_nn_moe: bool,
**extra_weight_attrs):
# Fused gate_up_proj (column parallel) # Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty( if not use_nn_moe:
num_experts, w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size_per_partition, 2 * intermediate_size_per_partition,
hidden_size, hidden_size,
dtype=params_dtype), dtype=params_dtype),
requires_grad=False) requires_grad=False)
else:
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
2 * intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight) layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs) set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel) # down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty( if not use_nn_moe:
num_experts, w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size, hidden_size,
intermediate_size_per_partition, intermediate_size_per_partition,
dtype=params_dtype), dtype=params_dtype),
requires_grad=False) requires_grad=False)
else:
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
...@@ -126,6 +141,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -126,6 +141,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
return self.forward(x=x, return self.forward(x=x,
layer=layer, layer=layer,
...@@ -140,7 +156,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -140,7 +156,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
activation=activation) activation=activation,
use_nn_moe=use_nn_moe)
def forward_cuda( def forward_cuda(
self, self,
...@@ -158,6 +175,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -158,6 +175,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
...@@ -179,7 +197,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -179,7 +197,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
inplace=True, inplace=True,
activation=activation, activation=activation,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
expert_map=expert_map) expert_map=expert_map,
use_nn_moe=use_nn_moe)
def forward_cpu( def forward_cpu(
self, self,
...@@ -197,6 +216,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -197,6 +216,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu", activation: str = "silu",
use_nn_moe: Optional[bool] = False,
**kwargs, **kwargs,
): ):
assert activation == "silu", f"{activation} is not supported." assert activation == "silu", f"{activation} is not supported."
...@@ -228,6 +248,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -228,6 +248,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
activation: str = "silu", activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
assert not use_grouped_topk assert not use_grouped_topk
...@@ -421,6 +442,13 @@ class FusedMoE(torch.nn.Module): ...@@ -421,6 +442,13 @@ class FusedMoE(torch.nn.Module):
self.quant_method = quant_config.get_quant_method(self, prefix) self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None assert self.quant_method is not None
if quant_config is None:
# Not considering quant for now, temporarily
# self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1
self.use_nn_moe = os.environ.get('MOE_NN') == '1'
else:
self.use_nn_moe = False
moe_quant_params = { moe_quant_params = {
"num_experts": self.local_num_experts, "num_experts": self.local_num_experts,
"hidden_size": hidden_size, "hidden_size": hidden_size,
...@@ -428,6 +456,7 @@ class FusedMoE(torch.nn.Module): ...@@ -428,6 +456,7 @@ class FusedMoE(torch.nn.Module):
self.intermediate_size_per_partition, self.intermediate_size_per_partition,
"params_dtype": params_dtype, "params_dtype": params_dtype,
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
"use_nn_moe": self.use_nn_moe,
} }
# need full intermediate size pre-sharding for WNA16 act order # need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__ if (self.quant_method.__class__.__name__
...@@ -502,7 +531,8 @@ class FusedMoE(torch.nn.Module): ...@@ -502,7 +531,8 @@ class FusedMoE(torch.nn.Module):
# Index the loaded weight for tp sharding. # Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2 shard_size = expert_data.shape[shard_dim] // 2
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, loaded_weight = loaded_weight.narrow(shard_dim if not self.use_nn_moe else ~shard_dim,
shard_size * tp_rank,
shard_size) shard_size)
# Narrow parameter and load. # Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13. # w1, gate_proj: Load into first logical weight of w13.
...@@ -512,7 +542,10 @@ class FusedMoE(torch.nn.Module): ...@@ -512,7 +542,10 @@ class FusedMoE(torch.nn.Module):
else: else:
assert shard_id == "w3" assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight) if not self.use_nn_moe:
expert_data.copy_(loaded_weight)
else:
expert_data.copy_(loaded_weight.T)
def _load_w2(self, def _load_w2(self,
expert_data: torch.Tensor, expert_data: torch.Tensor,
...@@ -526,18 +559,24 @@ class FusedMoE(torch.nn.Module): ...@@ -526,18 +559,24 @@ class FusedMoE(torch.nn.Module):
# Narrow parameter and load. # Narrow parameter and load.
shard_size = expert_data.shape[shard_dim] shard_size = expert_data.shape[shard_dim]
if not load_full: if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim, loaded_weight = loaded_weight.narrow(shard_dim if not self.use_nn_moe else ~shard_dim,
shard_size * tp_rank, shard_size * tp_rank,
shard_size) shard_size)
# w2, down_proj: Load into only logical weight of w2. # w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight) if not self.use_nn_moe:
expert_data.copy_(loaded_weight)
else:
expert_data.copy_(loaded_weight.T)
def _load_single_value(self, param: torch.nn.Parameter, def _load_single_value(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, expert_id: int): loaded_weight: torch.Tensor, expert_id: int):
param_data = param.data param_data = param.data
# Input scales can be loaded directly and should be equal. # Input scales can be loaded directly and should be equal.
param_data[expert_id] = loaded_weight if not self.use_nn_moe:
param_data[expert_id] = loaded_weight
else:
param_data[expert_id] = loaded_weight.T
def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor, def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int): shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):
...@@ -549,7 +588,10 @@ class FusedMoE(torch.nn.Module): ...@@ -549,7 +588,10 @@ class FusedMoE(torch.nn.Module):
tp_rank=tp_rank) tp_rank=tp_rank)
else: else:
assert shard_id in ("w1", "w3") assert shard_id in ("w1", "w3")
expert_data.copy_(loaded_weight) if not self.use_nn_moe:
expert_data.copy_(loaded_weight)
else:
expert_data.copy_(loaded_weight.T)
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int: def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
if self.expert_map is None: if self.expert_map is None:
...@@ -593,7 +635,7 @@ class FusedMoE(torch.nn.Module): ...@@ -593,7 +635,7 @@ class FusedMoE(torch.nn.Module):
# is_transposed: if the dim to shard the weight # is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors # should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size_per_partition is # should be whatever dimension intermediate_size_per_partition is
is_transposed = getattr(param, "is_transposed", False) s_transposed = getattr(param, "is_transposed", False) or self.use_nn_moe
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed: if is_transposed:
shard_dim = int(not shard_dim) shard_dim = int(not shard_dim)
...@@ -787,6 +829,7 @@ class FusedMoE(torch.nn.Module): ...@@ -787,6 +829,7 @@ class FusedMoE(torch.nn.Module):
scoring_func=self.scoring_func, scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias, e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation, activation=self.activation,
use_nn_moe=self.use_nn_moe,
) )
if self.dp_size > 1: if self.dp_size > 1:
......
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
"""Inference-only DeepseekV2/DeepseekV3 model.""" """Inference-only DeepseekV2/DeepseekV3 model."""
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
import os
import re
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -56,6 +58,7 @@ from .interfaces import SupportsPP ...@@ -56,6 +58,7 @@ from .interfaces import SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
from vllm import _custom_ops as ops
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
...@@ -675,6 +678,11 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -675,6 +678,11 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
self.sampler = get_sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
...@@ -807,6 +815,37 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -807,6 +815,37 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
if self.use_llama_nn and self.quant_method is None:
lay_key_words = [
"self_attn.q_a_proj.weight",
"self_attn.kv_a_proj_with_mqa.weight",
"mlp.gate.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj",
"shared_experts.gate_up_proj",
"shared_experts.down_proj",
"self_attn.q_proj.weight",
"self_attn.q_b_proj.weight",
"self_attn.kv_b_proj.weight",
"self_attn.o_proj.weight",
"lm_head.weight"
]
combined_words = "|".join(lay_key_words)
for layername in loaded_params:
weight = params_dict[layername]
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
return loaded_params return loaded_params
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment