Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2b47d4fa
Commit
2b47d4fa
authored
Mar 15, 2025
by
zhuwenwen
Browse files
支持fusemoe对int4的scale合zero合并读取操作
parent
ba1ff372
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
865 additions
and
94 deletions
+865
-94
vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=K100_AI,dtype=int4_w4a16.json
...igs/E=256,N=128,device_name=K100_AI,dtype=int4_w4a16.json
+41
-41
vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=K100_AI,dtype=int4_w4a16_120.json
...E=256,N=128,device_name=K100_AI,dtype=int4_w4a16_120.json
+182
-0
vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=K100_AI,dtype=int4_w4a16.json
...figs/E=256,N=64,device_name=K100_AI,dtype=int4_w4a16.json
+164
-0
vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=K100_AI,dtype=int4_w4a16_120.json
.../E=256,N=64,device_name=K100_AI,dtype=int4_w4a16_120.json
+173
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+264
-46
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+4
-3
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+33
-0
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+4
-4
No files found.
vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=K100_AI,dtype=int4_w4a16.json
View file @
2b47d4fa
{
{
"1"
:
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
"num_ldmatrixes"
:
0
},
},
"2"
:
{
"2"
:
{
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
"4"
:
{
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_stages"
:
2
,
...
@@ -29,7 +29,7 @@
...
@@ -29,7 +29,7 @@
"8"
:
{
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_stages"
:
2
,
...
@@ -37,17 +37,17 @@
...
@@ -37,17 +37,17 @@
},
},
"16"
:
{
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
"num_ldmatrixes"
:
0
},
},
"24"
:
{
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_stages"
:
2
,
...
@@ -56,8 +56,8 @@
...
@@ -56,8 +56,8 @@
"32"
:
{
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
4
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
"num_ldmatrixes"
:
0
...
@@ -65,52 +65,52 @@
...
@@ -65,52 +65,52 @@
"48"
:
{
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
"num_ldmatrixes"
:
0
},
},
"64"
:
{
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
4
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
"num_ldmatrixes"
:
0
},
},
"96"
:
{
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
"num_ldmatrixes"
:
0
},
},
"128"
:
{
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
"num_ldmatrixes"
:
0
},
},
"256"
:
{
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
4
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
"num_ldmatrixes"
:
0
},
},
"512"
:
{
"512"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_stages"
:
2
,
...
@@ -118,11 +118,11 @@
...
@@ -118,11 +118,11 @@
},
},
"1024"
:
{
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
"num_ldmatrixes"
:
0
},
},
"1536"
:
{
"1536"
:
{
...
@@ -130,13 +130,13 @@
...
@@ -130,13 +130,13 @@
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
"num_ldmatrixes"
:
0
},
},
"2048"
:
{
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_warps"
:
4
,
...
@@ -144,21 +144,21 @@
...
@@ -144,21 +144,21 @@
"num_ldmatrixes"
:
0
"num_ldmatrixes"
:
0
},
},
"3072"
:
{
"3072"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
4
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
"num_ldmatrixes"
:
0
},
},
"4096"
:
{
"4096"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
4
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
"num_ldmatrixes"
:
0
}
}
}
}
vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=K100_AI,dtype=int4_w4a16_120.json
0 → 100644
View file @
2b47d4fa
{
"1"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"256"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
},
"512"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"1536"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
},
"4096"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
,
"num_ldmatrixes"
:
0
},
"6144"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
},
"8192"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
}
}
vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=K100_AI,dtype=int4_w4a16.json
0 → 100644
View file @
2b47d4fa
{
"1"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"2"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"512"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
},
"1536"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"2048"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
,
"num_ldmatrixes"
:
0
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
4
,
"num_ldmatrixes"
:
0
}
}
vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=K100_AI,dtype=int4_w4a16_120.json
0 → 100644
View file @
2b47d4fa
{
"1"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"2"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"64"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"96"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"128"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"256"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"512"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"1536"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
},
"4096"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"num_ldmatrixes"
:
0
},
"8192"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
1
,
"num_ldmatrixes"
:
0
}
}
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
2b47d4fa
...
@@ -23,6 +23,163 @@ from vllm.utils import direct_register_custom_op
...
@@ -23,6 +23,163 @@ from vllm.utils import direct_register_custom_op
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
@
triton
.
jit
def
fused_moe_kernel_awq
(
# Pointers to matrices
a_ptr
,
# [4, 7168]
b_ptr
,
# [256, 512, 3584]
c_ptr
,
# (8, 8, 512)
b_scale_ptr
,
# (256, 512, 56)
b_zp_ptr
,
# (256, 256, 56)
topk_weights_ptr
,
sorted_token_ids_ptr
,
# [0, 1, 2, 3, 4]
expert_ids_ptr
,
num_tokens_post_padded_ptr
,
# Matrix dimensions
N
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
EM
,
# pading后的总索引长度
num_valid_tokens
,
# 有效索引的上限
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am
,
stride_ak
,
stride_be
,
stride_bk
,
#1
stride_bn
,
stride_cm
,
stride_cn
,
stride_bse
,
stride_bsk
,
#1
stride_bsn
,
stride_bze
,
stride_bzk
,
stride_bzn
,
block_k_diviable
:
tl
.
constexpr
,
group_size
:
tl
.
constexpr
,
# 128
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
MUL_ROUTED_WEIGHT
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
compute_type
:
tl
.
constexpr
,
has_zp
:
tl
.
constexpr
,
use_int4_w4a16
:
tl
.
constexpr
,
use_int8_w8a16
:
tl
.
constexpr
):
pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
EM
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
((
pid
%
num_pid_in_group
)
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
num_tokens_post_padded
=
tl
.
load
(
num_tokens_post_padded_ptr
)
if
pid_m
*
BLOCK_SIZE_M
>=
num_tokens_post_padded
:
return
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
# [block_m]
token_mask
=
offs_token
<
num_valid_tokens
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
# [block_n]
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
# 0, 1, 2, ...... , 127 # # [block_k]
offs_k2
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
//
2
)
# 0, 1, 2, ...... , 127 # # [block_k]
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
# [block_m, block_k]
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
if
use_int4_w4a16
:
# [0, 1, 2, ...... , 126, 127] --> [0, 0, 1, 1 ...... , 63, 63]
# [128, 129, 130, ...... , 254, 255] --> [64, 64, 65, 65 ...... , 127, 127]
# b_ptrs = b_ptr + off_experts * stride_be + \
# (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
offs_bn
[:,
None
]
*
stride_bn
+
(
offs_k2
[
None
,
:])
*
stride_bk
# tl.device_print("stride_bn",stride_bsn)>1
# tl.device_print("stride_bk",stride_bk)=1
b_shifter
=
(
offs_k
[:,
None
]
%
2
)
*
4
# 0, 4
elif
use_int8_w8a16
:
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
\
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
if
not
has_zp
and
use_int4_w4a16
:
b_zp_num
=
8
if
not
has_zp
and
use_int8_w8a16
:
b_zp_num
=
128
elif
has_zp
and
use_int4_w4a16
:
b_zp_shifter
=
(
offs_bn
[
None
,
:]
%
2
)
*
4
# 0, 4
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
if
not
block_k_diviable
:
k_mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
k_other
=
0.0
else
:
k_mask
=
None
k_other
=
None
a
=
tl
.
load
(
a_ptrs
,
mask
=
token_mask
[:,
None
]
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
)
if
use_int4_w4a16
:
b
=
tl
.
interleave
(
b
,
b
)
b
=
b
.
trans
()
b
=
(
b
>>
b_shifter
)
&
0xF
b_scale_ptrs
=
b_scale_ptr
+
off_experts
*
stride_bse
+
\
offs_bn
[
None
,
:]
*
stride_bsk
+
\
((
offs_k
[:,
None
]
+
BLOCK_SIZE_K
*
k
)
//
group_size
)
*
stride_bsn
qzeros_scles
=
tl
.
load
(
b_scale_ptrs
,
mask
=
k_mask
,
other
=
k_other
)
scales_int16
=
tl
.
cast
(
qzeros_scles
,
tl
.
uint16
)
b_scale
=
tl
.
cast
(
scales_int16
,
tl
.
float16
,
bitcast
=
True
)
# tl.device_print("b_scale dequant",b_scale)
mid
=
qzeros_scles
>>
16
# b_zp = tl.cast(mid,tl.float16,bitcast=False)
b_zp
=
tl
.
cast
(
mid
,
tl
.
float16
)
# b_zp = tl.cast(zeros_int16,tl.float16,bitcast=False)
# tl.device_print("bzp",b_zp)
# We accumulate along the K dimension.
b
=
((
b
-
b_zp
)
*
b_scale
).
to
(
tl
.
float16
)
accumulator
=
tl
.
dot
(
a
,
b
,
acc
=
accumulator
)
# Advance the ptrs to the next K block.
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
if
use_int4_w4a16
:
b_ptrs
+=
(
BLOCK_SIZE_K
//
2
)
*
stride_bk
else
:
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
MUL_ROUTED_WEIGHT
:
moe_weight
=
tl
.
load
(
topk_weights_ptr
+
offs_token
,
mask
=
token_mask
,
other
=
0
)
accumulator
=
accumulator
*
moe_weight
[:,
None
]
accumulator
=
accumulator
.
to
(
compute_type
)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
stride_cm
*
offs_token
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
token_mask
[:,
None
]
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
@
triton
.
jit
@
triton
.
jit
def
fused_moe_kernel_gptq_awq
(
def
fused_moe_kernel_gptq_awq
(
# Pointers to matrices
# Pointers to matrices
...
@@ -562,7 +719,7 @@ def moe_align_block_size_triton(
...
@@ -562,7 +719,7 @@ def moe_align_block_size_triton(
def
moe_align_block_size
(
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
num_experts
:
int
,
num_token
:
Optional
[
int
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Aligns the token distribution across experts to be compatible with block
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
size for matrix multiplication.
...
@@ -600,11 +757,18 @@ def moe_align_block_size(
...
@@ -600,11 +757,18 @@ def moe_align_block_size(
- The padding ensures that the total number of tokens is now divisible
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
by block_size for proper block matrix operations.
"""
"""
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
if
num_token
:
sorted_ids
=
torch
.
empty
((
max_num_tokens_padded
,
),
if
num_token
<
block_size
:
dtype
=
torch
.
int32
,
max_num_tokens_padded
=
min
(
topk_ids
.
numel
()
*
block_size
,
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
))
device
=
topk_ids
.
device
)
else
:
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
full
((
max_num_tokens_padded
,),
fill_value
=
topk_ids
.
numel
(),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
else
:
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
empty
((
max_num_tokens_padded
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
expert_ids
=
torch
.
empty
((
max_num_m_blocks
,
),
expert_ids
=
torch
.
empty
((
max_num_m_blocks
,
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
...
@@ -754,7 +918,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
...
@@ -754,7 +918,9 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert
B_scale
is
None
assert
B_scale
is
None
EM
=
sorted_token_ids
.
shape
[
0
]
EM
=
sorted_token_ids
.
shape
[
0
]
if
A
.
shape
[
0
]
<
config
[
"BLOCK_SIZE_M"
]:
if
use_int4_w4a16
:
EM
=
sorted_token_ids
.
shape
[
0
]
elif
A
.
shape
[
0
]
<
config
[
"BLOCK_SIZE_M"
]:
# optimize for small batch_size.
# optimize for small batch_size.
# We assume that top_ids of each token is unique, so
# We assume that top_ids of each token is unique, so
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
# so num_valid_experts <= batch_size <= BLOCK_SIZE_M,
...
@@ -769,43 +935,82 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
...
@@ -769,43 +935,82 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
assert
B_scale
is
not
None
and
B_scale
.
ndim
==
3
assert
B_scale
is
not
None
and
B_scale
.
ndim
==
3
assert
B_zp
is
None
or
B_zp
.
ndim
==
3
assert
B_zp
is
None
or
B_zp
.
ndim
==
3
fused_moe_kernel_gptq_awq
[
grid
](
if
os
.
environ
.
get
(
'AWQ_MOE_SZ'
)
==
'1'
:
A
,
fused_moe_kernel_awq
[
grid
](
B
,
A
,
C
,
B
,
B_scale
,
C
,
B_zp
,
B_scale
,
topk_weights
,
B_zp
,
sorted_token_ids
,
topk_weights
,
expert_ids
,
sorted_token_ids
,
num_tokens_post_padded
,
expert_ids
,
B
.
shape
[
1
],
num_tokens_post_padded
,
A
.
shape
[
1
],
B
.
shape
[
1
],
EM
,
A
.
shape
[
1
],
topk_ids
.
numel
(),
EM
,
A
.
stride
(
0
),
topk_ids
.
numel
(),
A
.
stride
(
1
),
A
.
stride
(
0
),
B
.
stride
(
0
),
A
.
stride
(
1
),
B
.
stride
(
2
),
B
.
stride
(
0
),
B
.
stride
(
1
),
B
.
stride
(
2
),
C
.
stride
(
1
),
B
.
stride
(
1
),
C
.
stride
(
2
),
C
.
stride
(
1
),
B_scale
.
stride
(
0
),
C
.
stride
(
2
),
B_scale
.
stride
(
2
),
B_scale
.
stride
(
0
),
B_scale
.
stride
(
1
),
B_scale
.
stride
(
2
),
B_zp
.
stride
(
0
)
if
B_zp
is
not
None
else
0
,
B_scale
.
stride
(
1
),
B_zp
.
stride
(
2
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
0
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
1
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
2
)
if
B_zp
is
not
None
else
0
,
block_k_diviable
=
A
.
shape
[
1
]
%
config
[
"BLOCK_SIZE_K"
]
==
0
,
B_zp
.
stride
(
1
)
if
B_zp
is
not
None
else
0
,
group_size
=
block_shape
[
1
],
block_k_diviable
=
A
.
shape
[
1
]
%
config
[
"BLOCK_SIZE_K"
]
==
0
,
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
group_size
=
block_shape
[
1
],
top_k
=
top_k
,
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
compute_type
=
compute_type
,
top_k
=
top_k
,
has_zp
=
B_zp
is
not
None
,
compute_type
=
compute_type
,
use_int4_w4a16
=
use_int4_w4a16
,
has_zp
=
B_zp
is
not
None
,
use_int8_w8a16
=
use_int8_w8a16
,
use_int4_w4a16
=
use_int4_w4a16
,
**
config
,
use_int8_w8a16
=
use_int8_w8a16
,
)
**
config
,
)
else
:
fused_moe_kernel_gptq_awq
[
grid
](
A
,
B
,
C
,
B_scale
,
B_zp
,
topk_weights
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
B
.
shape
[
1
],
A
.
shape
[
1
],
EM
,
topk_ids
.
numel
(),
A
.
stride
(
0
),
A
.
stride
(
1
),
B
.
stride
(
0
),
B
.
stride
(
2
),
B
.
stride
(
1
),
C
.
stride
(
1
),
C
.
stride
(
2
),
B_scale
.
stride
(
0
),
B_scale
.
stride
(
2
),
B_scale
.
stride
(
1
),
B_zp
.
stride
(
0
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
2
)
if
B_zp
is
not
None
else
0
,
B_zp
.
stride
(
1
)
if
B_zp
is
not
None
else
0
,
block_k_diviable
=
A
.
shape
[
1
]
%
config
[
"BLOCK_SIZE_K"
]
==
0
,
group_size
=
block_shape
[
1
],
MUL_ROUTED_WEIGHT
=
mul_routed_weight
,
top_k
=
top_k
,
compute_type
=
compute_type
,
has_zp
=
B_zp
is
not
None
,
use_int4_w4a16
=
use_int4_w4a16
,
use_int8_w8a16
=
use_int8_w8a16
,
**
config
,
)
else
:
else
:
fused_moe_kernel
[
grid
](
fused_moe_kernel
[
grid
](
...
@@ -892,6 +1097,15 @@ def get_moe_configs(
...
@@ -892,6 +1097,15 @@ def get_moe_configs(
config_file_path
=
os
.
path
.
join
(
config_file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
if
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
==
120
:
config_file_path_120
=
config_file_path
.
replace
(
".json"
,
"_120.json"
)
if
os
.
path
.
exists
(
config_file_path_120
):
with
open
(
config_file_path_120
)
as
f
:
logger
.
info
(
"Using configuration from %s for MoE layer."
,
config_file_path_120
)
# If a configuration has been found, return it
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
if
os
.
path
.
exists
(
config_file_path
):
if
os
.
path
.
exists
(
config_file_path
):
with
open
(
config_file_path
)
as
f
:
with
open
(
config_file_path
)
as
f
:
logger
.
info
(
"Using configuration from %s for MoE layer."
,
logger
.
info
(
"Using configuration from %s for MoE layer."
,
...
@@ -1379,8 +1593,12 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1379,8 +1593,12 @@ def fused_experts_impl(hidden_states: torch.Tensor,
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
if
moe_ep_size
==
1
:
if
moe_ep_size
==
1
:
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
if
use_int4_w4a16
:
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
))
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
,
curr_hidden_states
.
shape
[
0
]))
else
:
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
))
else
:
else
:
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_ep_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
,
moe_ep_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
,
...
...
vllm/model_executor/model_loader/utils.py
View file @
2b47d4fa
...
@@ -105,10 +105,11 @@ def get_model_architecture(
...
@@ -105,10 +105,11 @@ def get_model_architecture(
os
.
environ
[
'GEMM_PAD'
]
=
'0'
os
.
environ
[
'GEMM_PAD'
]
=
'0'
if
os
.
getenv
(
'FA_PAD'
)
!=
'1'
:
if
os
.
getenv
(
'FA_PAD'
)
!=
'1'
:
os
.
environ
[
'FA_PAD'
]
=
'0'
os
.
environ
[
'FA_PAD'
]
=
'0'
# awq相关配置
try
:
try
:
if
os
.
getenv
(
'AWQ_
PAD'
)
==
'0'
or
((
torch
.
cuda
.
isCurrentDeviceEco
(
torch
.
cuda
.
current_device
()))
and
os
.
getenv
(
'AWQ_PAD
'
)
==
None
)
:
if
os
.
getenv
(
'AWQ_
MOE_SZ
'
)
==
None
:
os
.
environ
[
'AWQ_
PAD
'
]
=
'
0
'
os
.
environ
[
'AWQ_
MOE_SZ
'
]
=
'
1
'
else
:
if
os
.
getenv
(
'AWQ_PAD'
)
==
None
and
(
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
==
120
)
:
os
.
environ
[
'AWQ_PAD'
]
=
'1'
os
.
environ
[
'AWQ_PAD'
]
=
'1'
except
Exception
as
e
:
except
Exception
as
e
:
if
os
.
getenv
(
'AWQ_PAD'
)
!=
'0'
:
if
os
.
getenv
(
'AWQ_PAD'
)
!=
'0'
:
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
2b47d4fa
...
@@ -669,6 +669,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
...
@@ -669,6 +669,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
os
.
environ
[
'LLAMA_NN'
]
=
'0'
os
.
environ
[
'LLAMA_NN'
]
=
'0'
os
.
environ
[
'LM_NN'
]
=
'0'
os
.
environ
[
'LM_NN'
]
=
'0'
self
.
use_w4a16_moe_sz
=
os
.
environ
.
get
(
'AWQ_MOE_SZ'
)
==
'1'
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
parallel_config
=
vllm_config
.
parallel_config
...
@@ -734,6 +735,26 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
...
@@ -734,6 +735,26 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
dtype
=
dtype
,
dtype
=
dtype
,
device
=
device
),
device
=
device
),
})
})
def
restore_qzeros_tensor
(
self
,
qzeros
,
qscales
):
low_bits
=
qzeros
&
0x0F
high_bits
=
qzeros
>>
4
zeors_tensor
=
torch
.
stack
([
low_bits
,
high_bits
],
dim
=
2
).
view
(
qzeros
.
shape
[
0
],
-
1
,
qzeros
.
shape
[
-
1
])
zeors_int16
=
zeors_tensor
.
to
(
torch
.
int16
)
assert
zeors_int16
.
shape
==
qscales
.
shape
uint16_tensor1
=
zeors_int16
.
view
(
torch
.
uint16
)
uint16_tensor2
=
qscales
.
view
(
torch
.
uint16
)
uint32_tensor1
=
uint16_tensor1
.
to
(
torch
.
int32
)
<<
16
uint32_tensor2
=
uint16_tensor2
.
to
(
torch
.
int32
)
result_tensor
=
uint32_tensor1
+
uint32_tensor2
result_tensor
=
result_tensor
.
view
(
torch
.
uint32
)
result_tensor
=
result_tensor
.
transpose
(
1
,
2
).
contiguous
()
return
result_tensor
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
torch
.
Tensor
]])
->
Set
[
str
]:
...
@@ -877,6 +898,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
...
@@ -877,6 +898,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
"mlp.shared_experts.down_proj.qweight"
"mlp.shared_experts.down_proj.qweight"
]
]
combined_words
=
"|"
.
join
(
lay_key_words
)
combined_words
=
"|"
.
join
(
lay_key_words
)
# moe_gather_sz
moe_key_words
=
[
"mlp.experts.w13_qweight"
,
"mlp.experts.w2_qweight"
]
moe_combined_words
=
"|"
.
join
(
moe_key_words
)
for
layername
in
loaded_params
:
for
layername
in
loaded_params
:
weight
=
params_dict
[
layername
]
weight
=
params_dict
[
layername
]
...
@@ -910,6 +934,15 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
...
@@ -910,6 +934,15 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
zeros_and_scalse
.
data
=
torch
.
cat
((
zeros_and_scalse
.
data
,
zeros_and_scalse_pad
),
dim
=
1
).
contiguous
()
zeros_and_scalse
.
data
=
torch
.
cat
((
zeros_and_scalse
.
data
,
zeros_and_scalse_pad
),
dim
=
1
).
contiguous
()
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
qweight
.
data
=
torch
.
cat
((
qweight
.
data
,
qweight_pad
),
dim
=
1
).
contiguous
()
qweight
.
data
=
torch
.
cat
((
qweight
.
data
,
qweight_pad
),
dim
=
1
).
contiguous
()
if
self
.
use_w4a16_moe_sz
:
matches_moe
=
re
.
findall
(
moe_combined_words
,
layername
)
# sz.shape == s.shape.T
if
matches_moe
:
qzeros
=
params_dict
[
layername
.
replace
(
"qweight"
,
"qzeros"
)]
scales
=
params_dict
[
layername
.
replace
(
"qweight"
,
"scales"
)]
sz_tensor
=
self
.
restore_qzeros_tensor
(
qzeros
,
scales
)
scales
.
data
=
sz_tensor
return
loaded_params
return
loaded_params
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
2b47d4fa
...
@@ -86,10 +86,10 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
...
@@ -86,10 +86,10 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker":
draft_worker_config
=
copy
.
deepcopy
(
vllm_config
)
draft_worker_config
=
copy
.
deepcopy
(
vllm_config
)
draft_worker_config
.
model_config
=
speculative_config
.
draft_model_config
draft_worker_config
.
model_config
=
speculative_config
.
draft_model_config
draft_worker_config
.
quant_config
=
VllmConfig
.
_get_quantization_config
(
#
draft_worker_config.quant_config = VllmConfig._get_quantization_config(
draft_worker_config
.
model_config
,
#
draft_worker_config.model_config,
vllm_config
.
load_config
,
#
vllm_config.load_config,
)
#
)
speculative_config
.
draft_parallel_config
.
worker_cls
=
\
speculative_config
.
draft_parallel_config
.
worker_cls
=
\
draft_worker_config
.
parallel_config
.
sd_worker_cls
draft_worker_config
.
parallel_config
.
sd_worker_cls
draft_worker_config
.
parallel_config
=
speculative_config
.
draft_parallel_config
# noqa
draft_worker_config
.
parallel_config
=
speculative_config
.
draft_parallel_config
# noqa
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment