Commit a810671a authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.14.0rc0' into v0.14.0rc0-ori

parents 86b5aefe 6a09612b
...@@ -77,11 +77,11 @@ if HAS_TRITON: ...@@ -77,11 +77,11 @@ if HAS_TRITON:
BatchedTritonExperts, BatchedTritonExperts,
) )
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
GroupedTopk,
TritonExperts, TritonExperts,
fused_experts, fused_experts,
fused_topk, fused_topk,
get_config_file_name, get_config_file_name,
grouped_topk,
) )
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts, TritonOrDeepGemmExperts,
...@@ -91,7 +91,7 @@ if HAS_TRITON: ...@@ -91,7 +91,7 @@ if HAS_TRITON:
"fused_topk", "fused_topk",
"fused_experts", "fused_experts",
"get_config_file_name", "get_config_file_name",
"grouped_topk", "GroupedTopk",
"cutlass_moe_fp8", "cutlass_moe_fp8",
"cutlass_moe_fp4", "cutlass_moe_fp4",
"cutlass_moe_w4a8_fp8", "cutlass_moe_w4a8_fp8",
......
{
"triton_version": "3.5.0",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
}
}
{
"triton_version": "3.5.1",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
}
}
{
"triton_version": "3.5.1",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 5
}
}
{
"triton_version": "3.5.1",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 5
},
"4096": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
}
}
{
"triton_version": "3.5.0",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
}
}
{
"triton_version": "3.5.0",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
}
}
{
"triton_version": "3.5.0",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
}
}
{
"triton_version": "3.5.0",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
}
}
{
"triton_version": "3.5.0",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
}
}
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import weakref
from collections.abc import Callable from collections.abc import Callable
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight
from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIAndMul from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIAndMul
from vllm.model_executor.layers.quantization.utils.layer_utils import replace_parameter
from vllm.utils.torch_utils import direct_register_custom_op
_CPU_MOE_LAYER_CACHE = {}
_CPU_MOE_ACT = {
"silu": SiluAndMul(),
"swigluoai": SwigluOAIAndMul(),
}
def grouped_topk( def grouped_topk(
...@@ -174,8 +184,105 @@ class SGLFusedMOE: ...@@ -174,8 +184,105 @@ class SGLFusedMOE:
class CPUFusedMOE: class CPUFusedMOE:
def __init__(self, layer: torch.nn.Module) -> None: def __init__(self, layer: torch.nn.Module) -> None:
use_onednn_mm = ops._supports_onednn and ops.is_onednn_acl_supported() use_grouped_gemm, isa = self.check_grouped_gemm(layer)
self.isa = isa
if use_grouped_gemm:
self.forward_method = self.forward_grouped_gemm
self.init_moe_grouped_gemm(layer=layer)
else:
self.forward_method = self.forward_torch
self.init_moe_torch(layer=layer)
def __call__(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation in _CPU_MOE_ACT, f"{activation} is not supported."
assert not apply_router_weight_on_input
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
)
return self.forward_method(
layer,
x,
topk_weights,
topk_ids,
activation,
global_num_experts,
)
def check_grouped_gemm(
self,
layer: torch.nn.Module,
) -> tuple[bool, str]:
if not hasattr(torch.ops._C, "prepack_moe_weight"):
return False, "none"
dtype = layer.w13_weight.dtype
w13_input_size = layer.w13_weight.size(2)
w13_output_size = layer.w13_weight.size(1)
w2_input_size = layer.w2_weight.size(2)
w2_output_size = layer.w2_weight.size(1)
if not (w13_output_size % 32 == 0 and w2_output_size % 32 == 0):
return False, "none"
supports_amx = torch._C._cpu._is_amx_tile_supported()
if (
supports_amx
and dtype == torch.bfloat16
and w13_input_size % 32 == 0
and w2_input_size % 32 == 0
):
return True, "amx"
if supports_amx:
return False, "none"
return True, "vec"
def init_moe_grouped_gemm(
self,
layer: torch.nn.Module,
) -> None:
new_w13 = cpu_prepack_moe_weight(layer.w13_weight, self.isa)
replace_parameter(layer, "w13_weight", new_w13)
new_w2 = cpu_prepack_moe_weight(layer.w2_weight, self.isa)
replace_parameter(layer, "w2_weight", new_w2)
def init_moe_torch(
self,
layer: torch.nn.Module,
) -> None:
use_onednn_mm = ops._supports_onednn and ops.is_onednn_acl_supported()
num_experts = layer.w13_weight.size(0) num_experts = layer.w13_weight.size(0)
has_w13_bias = hasattr(layer, "w13_bias") has_w13_bias = hasattr(layer, "w13_bias")
has_w2_bias = hasattr(layer, "w2_bias") has_w2_bias = hasattr(layer, "w2_bias")
...@@ -208,85 +315,112 @@ class CPUFusedMOE: ...@@ -208,85 +315,112 @@ class CPUFusedMOE:
layer.down_linear.append( layer.down_linear.append(
lambda x, w=layer_w2_weight, b=layer_w2_bias: F.linear(x, w, b) lambda x, w=layer_w2_weight, b=layer_w2_bias: F.linear(x, w, b)
) )
if use_onednn_mm: # remove weight if use_onednn_mm: # remove weight
layer.w13_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) layer.w13_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False) layer.w2_weight = torch.nn.Parameter(torch.empty(0), requires_grad=False)
self.act_to_impl = { _CPU_MOE_LAYER_CACHE[id(layer)] = weakref.ref(layer)
"silu": SiluAndMul(),
"swigluoai": SwigluOAIAndMul(),
}
def __call__( def forward_grouped_gemm(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, input: torch.Tensor,
use_grouped_topk: bool, topk_weights: torch.Tensor,
top_k: int, topk_ids: torch.Tensor,
router_logits: torch.Tensor, activation: str,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1, global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor: ) -> torch.Tensor:
assert activation in self.act_to_impl, f"{activation} is not supported." output = cpu_fused_moe(
assert not apply_router_weight_on_input input,
topk_weights, topk_ids = select_experts( layer.w13_weight,
hidden_states=x, layer.w2_weight,
router_logits=router_logits, getattr(layer, "w13_bias", None),
use_grouped_topk=use_grouped_topk, getattr(layer, "w2_bias", None),
top_k=top_k, topk_weights,
renormalize=renormalize, topk_ids,
topk_group=topk_group, activation,
num_expert_group=num_expert_group, self.isa,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
) )
return output
# Ref code from https://github.com/sgl-project/sglang/blob/716e682721397df103f347d22da8bd46c6016dab/python/sglang/srt/layers/moe/fused_moe_native.py#L53 def forward_torch(
len_experts = global_num_experts self,
layer: torch.nn.Module,
cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts)) input: torch.Tensor,
cnts.scatter_(1, topk_ids.to(torch.int64), 1) topk_weights: torch.Tensor,
tokens_per_expert = cnts.sum(dim=0) topk_ids: torch.Tensor,
idxs = topk_ids.view(-1).argsort() activation: str,
global_num_experts: int = -1,
sorted_tokens = x[idxs // topk_ids.shape[1]] ) -> torch.Tensor:
tokens_per_expert = tokens_per_expert.cpu().numpy() output = torch.empty_like(input)
layer_id = id(layer)
outputs = [] torch.ops.vllm.cpu_fused_moe_torch(
start_idx = 0 layer_id,
output,
for i, num_tokens in enumerate(tokens_per_expert): input,
end_idx = start_idx + num_tokens topk_weights,
if num_tokens == 0: topk_ids,
continue activation,
tokens_for_this_expert = sorted_tokens[start_idx:end_idx] global_num_experts,
gate_up = layer.gate_up_linear[i](tokens_for_this_expert)
gate_up = self.act_to_impl[activation].forward_native(gate_up)
expert_out = layer.down_linear[i](gate_up)
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*topk_ids.shape, -1)
.type(topk_weights.dtype)
.mul_(topk_weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
) )
return final_out
return output
def cpu_fused_moe_torch(
layer_id: int,
output: torch.Tensor,
input: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int = -1,
) -> None:
layer = _CPU_MOE_LAYER_CACHE[layer_id]()
# Ref code from https://github.com/sgl-project/sglang/blob/716e682721397df103f347d22da8bd46c6016dab/python/sglang/srt/layers/moe/fused_moe_native.py#L53
len_experts = global_num_experts
cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
cnts.scatter_(1, topk_ids.to(torch.int64), 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = input[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
gate_up = layer.gate_up_linear[i](tokens_for_this_expert) # type: ignore
gate_up = _CPU_MOE_ACT[activation].forward_native(gate_up)
expert_out = layer.down_linear[i](gate_up) # type: ignore
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
new_x.view(*topk_ids.shape, -1)
.type(topk_weights.dtype)
.mul_(topk_weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
output.copy_(final_out)
direct_register_custom_op(
op_name="cpu_fused_moe_torch",
op_func=cpu_fused_moe_torch,
mutates_args=["output"],
)
...@@ -16,6 +16,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk ...@@ -16,6 +16,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.batch_invariant import ( from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant, vllm_is_batch_invariant,
) )
...@@ -1312,6 +1313,57 @@ def grouped_topk( ...@@ -1312,6 +1313,57 @@ def grouped_topk(
return topk_weights.to(torch.float32), topk_ids.to(torch.int32) return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
@CustomOp.register("grouped_topk")
class GroupedTopk(CustomOp):
"""GroupedTopk used by the Deepseek-V2 and Deepseek-V3 model."""
def __init__(
self,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
) -> None:
super().__init__()
self.native_impl = grouped_topk
self.topk = topk
self.renormalize = renormalize
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor
def forward_native(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.native_impl(
hidden_states,
gating_output,
self.topk,
self.renormalize,
self.num_expert_group,
self.topk_group,
self.scoring_func,
self.routed_scaling_factor,
e_score_correction_bias,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
e_score_correction_bias: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.forward_native(
hidden_states, gating_output, e_score_correction_bias
)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
def eplb_map_to_physical_and_record( def eplb_map_to_physical_and_record(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
...@@ -1673,7 +1725,6 @@ def fused_experts( ...@@ -1673,7 +1725,6 @@ def fused_experts(
and (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2)) and (is_deep_gemm_e8m0_used() or _valid_deep_gemm(hidden_states, w1, w2))
): ):
assert quant_config is not None assert quant_config is not None
assert apply_router_weight_on_input is False
return deep_gemm_moe_fp8( return deep_gemm_moe_fp8(
hidden_states=hidden_states, hidden_states=hidden_states,
w1=w1, w1=w1,
......
...@@ -71,6 +71,18 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -71,6 +71,18 @@ class FusedMoEMethodBase(QuantizeMethodBase):
"implementation based on the prepare_finalize" "implementation based on the prepare_finalize"
) )
def prepare_dp_allgather_tensor(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""Hook to prepare tensors and extra tensors for DP allgather + EP dispatch."""
raise NotImplementedError(
"Method 'prepare_dp_allgather_tensor' is not implemented in "
f"{self.__class__.__name__}."
)
@abstractmethod @abstractmethod
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
self, layer: torch.nn.Module self, layer: torch.nn.Module
......
...@@ -46,6 +46,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( ...@@ -46,6 +46,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
is_flashinfer_supporting_global_sf, is_flashinfer_supporting_global_sf,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
from vllm.utils.math_utils import cdiv, round_up from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import ( from vllm.utils.torch_utils import (
aux_stream, aux_stream,
...@@ -68,7 +69,7 @@ else: ...@@ -68,7 +69,7 @@ else:
return topk_ids return topk_ids
eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
rocm_aiter_grouped_topk, rocm_aiter_grouped_topk,
) )
...@@ -1615,19 +1616,26 @@ class FusedMoE(CustomOp): ...@@ -1615,19 +1616,26 @@ class FusedMoE(CustomOp):
grouped_topk_impl = partial( grouped_topk_impl = partial(
rocm_aiter_grouped_topk, rocm_aiter_grouped_topk,
num_fused_shared_experts=self.num_fused_shared_experts, num_fused_shared_experts=self.num_fused_shared_experts,
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
) )
else: else:
grouped_topk_impl = grouped_topk grouped_topk_impl = GroupedTopk(
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
)
topk_weights, topk_ids = grouped_topk_impl( topk_weights, topk_ids = grouped_topk_impl(
hidden_states=hidden_states, hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
topk=self.top_k,
renormalize=self.renormalize,
num_expert_group=self.num_expert_group,
topk_group=self.topk_group,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias, e_score_correction_bias=self.e_score_correction_bias,
) )
elif self.e_score_correction_bias is not None: elif self.e_score_correction_bias is not None:
...@@ -1740,9 +1748,10 @@ class FusedMoE(CustomOp): ...@@ -1740,9 +1748,10 @@ class FusedMoE(CustomOp):
return states return states
if self.shared_experts is None: if self.shared_experts is None:
if current_platform.is_tpu(): if current_platform.is_tpu() or current_platform.is_cpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we # TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op. # will switch to using the moe_forward custom op.
# Note: CPU doesn't require wrapped forward_impl.
fused_output = self.forward_impl(hidden_states, router_logits) fused_output = self.forward_impl(hidden_states, router_logits)
assert not isinstance(fused_output, tuple) assert not isinstance(fused_output, tuple)
else: else:
...@@ -1758,9 +1767,10 @@ class FusedMoE(CustomOp): ...@@ -1758,9 +1767,10 @@ class FusedMoE(CustomOp):
else: else:
return reduce_output(fused_output)[..., :og_hidden_states] return reduce_output(fused_output)[..., :og_hidden_states]
else: else:
if current_platform.is_tpu(): if current_platform.is_tpu() or current_platform.is_cpu():
# TODO: Once the OOM issue for the TPU backend is resolved, we # TODO: Once the OOM issue for the TPU backend is resolved, we
# will switch to using the moe_forward custom op. # will switch to using the moe_forward custom op.
# Note: CPU doesn't require wrapped forward_impl.
shared_output, fused_output = self.forward_impl( shared_output, fused_output = self.forward_impl(
hidden_states, router_logits hidden_states, router_logits
) )
...@@ -1955,10 +1965,46 @@ class FusedMoE(CustomOp): ...@@ -1955,10 +1965,46 @@ class FusedMoE(CustomOp):
) )
with sp_ctx: with sp_ctx:
extra_tensors = None
if do_naive_dispatch_combine: if do_naive_dispatch_combine:
hidden_states_combined, router_logits = get_ep_group().dispatch( # Avoid circular import
hidden_states, router_logits, self.is_sequence_parallel from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4FusedMoE,
) )
post_quant_allgather = (
has_flashinfer_trtllm_fused_moe()
and self.quant_method is not None
and self.dp_size > 1
and self.use_ep
and isinstance(self.quant_method, ModelOptNvFp4FusedMoE)
)
if post_quant_allgather:
hidden_states_to_dispatch, extra_tensors = (
self.quant_method.prepare_dp_allgather_tensor(
self, hidden_states, router_logits
)
)
else:
hidden_states_to_dispatch = hidden_states
dispatch_res = get_ep_group().dispatch(
hidden_states_to_dispatch,
router_logits,
self.is_sequence_parallel,
extra_tensors=extra_tensors,
)
if extra_tensors is not None:
hidden_states_combined, router_logits, extra_tensors_combined = (
dispatch_res
)
hidden_states_combined = (
hidden_states_combined,
extra_tensors_combined[0],
)
else:
hidden_states_combined, router_logits = dispatch_res
# Run shared experts before matrix multiply. # Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states. # because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream: if has_separate_shared_experts and not use_shared_experts_stream:
......
...@@ -743,7 +743,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -743,7 +743,7 @@ class FusedMoEModularKernel(torch.nn.Module):
1, 1,
( (
M M
if not self.fused_experts.supports_chunking() if not self.fused_experts.enable_chunking()
else min(M, envs.VLLM_FUSED_MOE_CHUNK_SIZE) else min(M, envs.VLLM_FUSED_MOE_CHUNK_SIZE)
), ),
) )
...@@ -786,7 +786,7 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -786,7 +786,7 @@ class FusedMoEModularKernel(torch.nn.Module):
is_forward_context_available() is_forward_context_available()
and get_forward_context().attn_metadata is None and get_forward_context().attn_metadata is None
) )
if is_profile_run and self.fused_experts.supports_chunking() and self.is_dp_ep: if is_profile_run and self.fused_experts.enable_chunking() and self.is_dp_ep:
max_workspace_13, max_workspace_2, max_fused_out_shape = ( max_workspace_13, max_workspace_2, max_fused_out_shape = (
self.fused_experts.workspace_shapes( self.fused_experts.workspace_shapes(
envs.VLLM_FUSED_MOE_CHUNK_SIZE, envs.VLLM_FUSED_MOE_CHUNK_SIZE,
......
...@@ -29,14 +29,14 @@ class SharedFusedMoE(FusedMoE): ...@@ -29,14 +29,14 @@ class SharedFusedMoE(FusedMoE):
self._shared_experts = shared_experts self._shared_experts = shared_experts
# Disable shared expert overlap if: # Disable shared expert overlap if:
# - we are using eplb, because of correctness issues # - we are using eplb with non-default backend, because of correctness issues
# - we are using flashinfer with DP, since there nothing to gain # - we are using flashinfer with DP, since there nothint to gain
# - we are using marlin kernels # - we are using marlin kernels
backend = self.moe_parallel_config.all2all_backend
self.use_overlapped = ( self.use_overlapped = (
use_overlapped use_overlapped
and not ( and not (
# TODO(wentao): find the root cause and remove this condition (self.enable_eplb and backend != "allgather_reducescatter")
self.enable_eplb
or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1) or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1)
) )
and self._shared_experts is not None and self._shared_experts is not None
......
...@@ -121,7 +121,7 @@ class AWQMarlinConfig(QuantizationConfig): ...@@ -121,7 +121,7 @@ class AWQMarlinConfig(QuantizationConfig):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
return 80 return 75
@classmethod @classmethod
def get_config_filenames(cls) -> list[str]: def get_config_filenames(cls) -> list[str]:
......
...@@ -96,6 +96,7 @@ from vllm.utils.deep_gemm import ( ...@@ -96,6 +96,7 @@ from vllm.utils.deep_gemm import (
get_col_major_tma_aligned_tensor, get_col_major_tma_aligned_tensor,
get_mk_alignment_for_contiguous_layout, get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used, is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
) )
from vllm.utils.import_utils import has_deep_gemm from vllm.utils.import_utils import has_deep_gemm
...@@ -469,16 +470,14 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -469,16 +470,14 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
) )
logger.debug_once("Finished shuffling weights for TRT-LLM MOE") logger.debug_once("Finished shuffling weights for TRT-LLM MOE")
layer.gemm1_weights_fp4_shuffled = Parameter( layer.w13_weight = Parameter(
gemm1_weights_fp4_shuffled, requires_grad=False gemm1_weights_fp4_shuffled, requires_grad=False
) )
layer.gemm2_weights_fp4_shuffled = Parameter( layer.w2_weight = Parameter(gemm2_weights_fp4_shuffled, requires_grad=False)
gemm2_weights_fp4_shuffled, requires_grad=False layer.w13_weight_scale = Parameter(
)
layer.gemm1_scales_fp4_shuffled = Parameter(
gemm1_scales_fp4_shuffled, requires_grad=False gemm1_scales_fp4_shuffled, requires_grad=False
) )
layer.gemm2_scales_fp4_shuffled = Parameter( layer.w2_weight_scale = Parameter(
gemm2_scales_fp4_shuffled, requires_grad=False gemm2_scales_fp4_shuffled, requires_grad=False
) )
...@@ -487,12 +486,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -487,12 +486,6 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
(layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32),
requires_grad=False, requires_grad=False,
) )
# Clean up weights that won't be used by TRT-LLM
del layer.w2_weight
del layer.w2_weight_scale
del layer.w13_weight
del layer.w13_weight_scale
else: else:
# swizzle weight scales # swizzle weight scales
layer.w13_weight_scale = torch.nn.Parameter( layer.w13_weight_scale = torch.nn.Parameter(
...@@ -634,17 +627,11 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -634,17 +627,11 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
) )
else: else:
# If no modular kernel is provided, use cutlass_moe_fp4 for TP case
# only (no EP).
from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4
assert layer.expert_map is None, (
"Expert Parallelism / expert_map "
"is currently not supported for "
"CompressedTensorsW4A4Nvfp4MoEMethod."
)
assert self.moe_quant_config is not None assert self.moe_quant_config is not None
# Cutlass moe takes in activations in BF16/Half precision
# and fp4 quantized weights loaded from the checkpoint
return cutlass_moe_fp4( return cutlass_moe_fp4(
a=x, a=x,
w1_fp4=layer.w13_weight, w1_fp4=layer.w13_weight,
...@@ -652,6 +639,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): ...@@ -652,6 +639,7 @@ class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input, apply_router_weight_on_input=layer.apply_router_weight_on_input,
# TODO(bnell): derive these from arguments # TODO(bnell): derive these from arguments
m=x.shape[0], m=x.shape[0],
...@@ -729,6 +717,13 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -729,6 +717,13 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
get_marlin_input_dtype(layer_name) if self.use_marlin else None get_marlin_input_dtype(layer_name) if self.use_marlin else None
) )
self.allow_deep_gemm = (
self.block_quant
and envs.VLLM_MOE_USE_DEEP_GEMM
and is_deep_gemm_supported()
and list(self.weight_block_size) == get_mk_alignment_for_contiguous_layout()
)
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -1244,6 +1239,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1244,6 +1239,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
if self.disable_expert_map if self.disable_expert_map
else layer.expert_map, # ??? else layer.expert_map, # ???
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
allow_deep_gemm=self.allow_deep_gemm,
) )
else: else:
from vllm.model_executor.layers.fused_moe.cutlass_moe import ( from vllm.model_executor.layers.fused_moe.cutlass_moe import (
...@@ -1285,6 +1281,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1285,6 +1281,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
global_num_experts=layer.global_num_experts, global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map, expert_map=layer.expert_map,
quant_config=self.moe_quant_config, quant_config=self.moe_quant_config,
allow_deep_gemm=self.allow_deep_gemm,
) )
@property @property
......
...@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Optional ...@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Optional
import torch import torch
from torch.nn import Module from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.utils._python_dispatch import TorchDispatchMode
import vllm.envs as envs import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk import vllm.model_executor.layers.fused_moe.modular_kernel as mk
...@@ -253,7 +254,7 @@ class Fp8Config(QuantizationConfig): ...@@ -253,7 +254,7 @@ class Fp8Config(QuantizationConfig):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
return 80 return 75
@classmethod @classmethod
def get_config_filenames(cls) -> list[str]: def get_config_filenames(cls) -> list[str]:
...@@ -363,6 +364,26 @@ class Fp8Config(QuantizationConfig): ...@@ -363,6 +364,26 @@ class Fp8Config(QuantizationConfig):
return None return None
class CopyNumelCounter(TorchDispatchMode):
"""
Tracks total number of elements modified with `copy_`. Useful for keeping
track of weight loading where underlying weights can be arbitrarily
transformed (such as with `narrow`) before calling copy.
"""
def __init__(self):
super().__init__()
self.copied_numel = 0
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
out = func(*args, **kwargs)
if func == torch.ops.aten.copy_.default:
self.copied_numel += args[0].numel()
return out
class Fp8LinearMethod(LinearMethodBase): class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8. """Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and Supports loading FP8 checkpoints with static weight scale and
...@@ -469,13 +490,15 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -469,13 +490,15 @@ class Fp8LinearMethod(LinearMethodBase):
else: else:
def patched_weight_loader(param, loaded_weight, *args, **kwargs): def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# load the current weight chunk
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
# track how many elements we have updated # track how many elements we have updated
if not hasattr(layer, "_loaded_numel"): if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0 layer._loaded_numel = 0
layer._loaded_numel += loaded_weight.numel()
# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
layer._loaded_numel += copy_numel_counter.copied_numel
# if we have loaded all of the elements, call # if we have loaded all of the elements, call
# process_weights_after_loading # process_weights_after_loading
...@@ -1348,13 +1371,15 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod): ...@@ -1348,13 +1371,15 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
new_extra_weight_attrs = extra_weight_attrs new_extra_weight_attrs = extra_weight_attrs
def patched_weight_loader(param, loaded_weight, *args, **kwargs): def patched_weight_loader(param, loaded_weight, *args, **kwargs):
# load the current weight chunk
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
# add a counter to track how many elements we have updated # add a counter to track how many elements we have updated
if not hasattr(layer, "_loaded_numel"): if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0 layer._loaded_numel = 0
layer._loaded_numel += loaded_weight.numel()
# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
layer._loaded_numel += copy_numel_counter.copied_numel
# if we have loaded all of the elements, call # if we have loaded all of the elements, call
# process_weights_after_loading # process_weights_after_loading
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment