Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
705f6a35
Commit
705f6a35
authored
Jul 16, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.2' into v0.5.2-dtk24.04.1
parents
af837396
4cf256ae
Changes
439
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1442 additions
and
583 deletions
+1442
-583
vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json
...e/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json
+136
-46
vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json
...e/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json
+122
-50
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json
...e/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json
+124
-52
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+136
-69
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+260
-0
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+70
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+63
-121
vllm/model_executor/layers/logits_processor.py
vllm/model_executor/layers/logits_processor.py
+18
-8
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+2
-1
vllm/model_executor/layers/quantization/base_config.py
vllm/model_executor/layers/quantization/base_config.py
+11
-1
vllm/model_executor/layers/quantization/bitsandbytes.py
vllm/model_executor/layers/quantization/bitsandbytes.py
+1
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+140
-33
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
...ayers/quantization/compressed_tensors/schemes/__init__.py
+19
-7
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
...n/compressed_tensors/schemes/compressed_tensors_scheme.py
+13
-2
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
...pressed_tensors/schemes/compressed_tensors_unquantized.py
+8
-4
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
...compressed_tensors/schemes/compressed_tensors_w4a16_24.py
+144
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py
...d_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py
+0
-85
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+87
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+88
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py
...d_tensors/schemes/compressed_tensors_w8a8_statictensor.py
+0
-103
No files found.
Too many changes to show.
To preserve performance only
439 of 439+
files are displayed.
Plain diff
Email patch
vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json
View file @
705f6a35
{
{
"1"
:
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
64
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"2"
:
{
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"4"
:
{
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"8"
:
{
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
6
4
,
"BLOCK_SIZE_N"
:
1
6
,
"BLOCK_SIZE_K"
:
256
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"16"
:
{
"16"
:
{
"BLOCK_SIZE_M"
:
1
6
,
"BLOCK_SIZE_M"
:
6
4
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"24"
:
{
"24"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
},
"32"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
2
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"48"
:
{
"48"
:
{
"BLOCK_SIZE_M"
:
1
28
,
"BLOCK_SIZE_M"
:
1
6
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"64"
:
{
"64"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"96"
:
{
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
8
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"128"
:
{
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
32
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"256"
:
{
"256"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"512"
:
{
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"1024"
:
{
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"1536"
:
{
"1536"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"2048"
:
{
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"3072"
:
{
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"4096"
:
{
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
}
}
}
}
vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json
View file @
705f6a35
{
{
"1"
:
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
"num_warps"
:
2
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"2"
:
{
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
1
28
,
"BLOCK_SIZE_N"
:
1
6
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
"num_warps"
:
2
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"4"
:
{
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
1
"num_warps"
:
2
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"8"
:
{
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
8
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
1
"num_warps"
:
2
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"16"
:
{
"16"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
16
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
1
"num_warps"
:
2
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"24"
:
{
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
1
"num_warps"
:
4
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"32"
:
{
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
16
,
"GROUP_SIZE_M"
:
4
,
"num_stages"
:
0
"num_warps"
:
2
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"48"
:
{
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
16
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
1
"num_warps"
:
2
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"64"
:
{
"64"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
32
,
"GROUP_SIZE_M"
:
4
,
"num_stages"
:
0
"num_warps"
:
4
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"96"
:
{
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"GROUP_SIZE_M"
:
4
,
"num_stages"
:
0
"num_warps"
:
4
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"128"
:
{
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"GROUP_SIZE_M"
:
4
,
"num_stages"
:
0
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"256"
:
{
"256"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
,
"GROUP_SIZE_M"
:
4
,
"num_stages"
:
0
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"512"
:
{
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
32
,
"kpack"
:
2
},
},
"1024"
:
{
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"1536"
:
{
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"2048"
:
{
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"3072"
:
{
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"4096"
:
{
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
}
}
}
}
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json
View file @
705f6a35
{
{
"1"
:
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
6
4
,
"BLOCK_SIZE_N"
:
1
6
,
"BLOCK_SIZE_K"
:
256
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
"num_warps"
:
2
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"2"
:
{
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
2
56
,
"BLOCK_SIZE_K"
:
3
2
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
1
"num_warps"
:
4
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"4"
:
{
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
1
"num_warps"
:
4
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"8"
:
{
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
8
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
1
"num_warps"
:
2
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"16"
:
{
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
6
4
,
"BLOCK_SIZE_N"
:
1
6
,
"BLOCK_SIZE_K"
:
256
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
8
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
1
"num_warps"
:
4
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"24"
:
{
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
1
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"32"
:
{
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"GROUP_SIZE_M"
:
4
,
"num_stages"
:
0
"num_warps"
:
2
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"48"
:
{
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"GROUP_SIZE_M"
:
4
,
"num_stages"
:
0
"num_warps"
:
2
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"64"
:
{
"64"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"GROUP_SIZE_M"
:
4
,
"num_stages"
:
1
"num_warps"
:
4
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"96"
:
{
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"GROUP_SIZE_M"
:
4
,
"num_stages"
:
0
"num_warps"
:
4
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"128"
:
{
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
,
"GROUP_SIZE_M"
:
4
,
"num_stages"
:
0
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"256"
:
{
"256"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
,
"GROUP_SIZE_M"
:
4
,
"num_stages"
:
0
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
32
,
"kpack"
:
2
},
},
"512"
:
{
"512"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
8
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"1024"
:
{
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"1536"
:
{
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"2048"
:
{
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
},
"3072"
:
{
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
},
"4096"
:
{
"4096"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
"num_warps"
:
8
,
"num_stages"
:
0
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
}
}
}
}
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
705f6a35
...
@@ -8,6 +8,7 @@ import torch
...
@@ -8,6 +8,7 @@ import torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -331,6 +332,31 @@ def get_default_config(
...
@@ -331,6 +332,31 @@ def get_default_config(
return
config
return
config
def
try_get_optimal_moe_config
(
w1_shape
:
Tuple
[
int
,
...],
w2_shape
:
Tuple
[
int
,
...],
top_k
:
int
,
dtype
:
Optional
[
str
],
M
:
int
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
):
if
override_config
:
config
=
override_config
else
:
# First try to load optimal config from the file
E
,
_
,
N
=
w2_shape
configs
=
get_moe_configs
(
E
,
N
,
dtype
)
if
configs
:
# If an optimal configuration map has been found, look up the
# optimal config
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
# Else use the default config
config
=
get_default_config
(
M
,
E
,
N
,
w1_shape
[
2
],
top_k
,
dtype
)
return
config
def
fused_topk
(
def
fused_topk
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
...
@@ -368,14 +394,16 @@ def fused_topk(
...
@@ -368,14 +394,16 @@ def fused_topk(
# This is used by the Deepseek-V2 model
# This is used by the Deepseek-V2 model
def
grouped_topk
(
def
grouped_topk
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
topk
:
int
,
renormalize
:
bool
,
renormalize
:
bool
,
num_expert_group
:
int
=
0
,
num_expert_group
:
int
=
0
,
topk_group
:
int
=
0
):
topk_group
:
int
=
0
,
):
assert
hidden_states
.
shape
[
0
]
==
gating_output
.
shape
[
0
],
(
"Number of tokens mismatch"
)
scores
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
scores
=
torch
.
softmax
(
gating_output
,
dim
=-
1
)
num_token
=
scores
.
shape
[
0
]
num_token
=
scores
.
shape
[
0
]
group_scores
=
scores
.
view
(
num_token
,
num_expert_group
,
group_scores
=
scores
.
view
(
num_token
,
num_expert_group
,
...
@@ -420,25 +448,23 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -420,25 +448,23 @@ def fused_experts(hidden_states: torch.Tensor,
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
]
M
,
_
=
hidden_states
.
shape
num_tokens
,
_
=
hidden_states
.
shape
E
,
N
,
_
=
w1
.
shape
E
,
N
,
_
=
w1
.
shape
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
get_config_func
=
functools
.
partial
(
try_get_optimal_moe_config
,
w1
.
shape
,
w2
.
shape
,
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
None
,
override_config
=
override_config
,
)
if
override_config
:
config
=
get_config_func
(
M
)
config
=
override_config
else
:
# First try to load optimal config from the file
configs
=
get_moe_configs
(
E
,
w2
.
shape
[
2
],
"float8"
if
use_fp8
else
None
)
if
configs
:
# If an optimal configuration map has been found, look up the
# optimal config
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
# Else use the default config
config
=
get_default_config
(
M
,
E
,
N
,
w1
.
shape
[
2
],
topk_ids
.
shape
[
1
],
"float8"
if
use_fp8
else
None
)
intermediate_cache1
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
N
),
intermediate_cache1
=
torch
.
empty
((
M
,
topk_ids
.
shape
[
1
],
N
),
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
...
@@ -450,51 +476,78 @@ def fused_experts(hidden_states: torch.Tensor,
...
@@ -450,51 +476,78 @@ def fused_experts(hidden_states: torch.Tensor,
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
dtype
=
hidden_states
.
dtype
)
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
)
compute_type
=
(
tl
.
bfloat16
compute_type
=
(
tl
.
bfloat16
if
hidden_states
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
)
if
hidden_states
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
)
invoke_fused_moe_kernel
(
hidden_states
,
w1
,
intermediate_cache1
,
a1_scale
,
w1_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
topk_ids
.
shape
[
1
],
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
intermediate_cache3
,
a2_scale
,
w2_scale
,
topk_weights
,
topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
True
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
if
inplace
:
if
inplace
:
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
out_hidden_states
=
hidden_states
dim
=
1
,
else
:
out
=
hidden_states
)
out_hidden_states
=
torch
.
empty_like
(
hidden_states
)
return
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
)
for
chunk
in
range
((
num_tokens
//
CHUNK_SIZE
)
+
1
):
begin_chunk_idx
,
end_chunk_idx
=
(
chunk
*
CHUNK_SIZE
,
min
((
chunk
+
1
)
*
CHUNK_SIZE
,
num_tokens
))
curr_hidden_states
=
hidden_states
[
begin_chunk_idx
:
end_chunk_idx
]
tokens_in_chunk
,
_
=
curr_hidden_states
.
shape
if
tokens_in_chunk
==
0
:
break
if
tokens_in_chunk
<
CHUNK_SIZE
and
chunk
>
0
:
# Adjust the intermediate cache size and config for the last
# chunk. Note that in most cases we only have one chunk
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1
=
intermediate_cache1
[:
tokens_in_chunk
]
intermediate_cache2
=
intermediate_cache2
[:
tokens_in_chunk
]
intermediate_cache3
=
intermediate_cache3
[:
tokens_in_chunk
]
config
=
get_config_func
(
tokens_in_chunk
)
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
(
moe_align_block_size
(
curr_topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
))
invoke_fused_moe_kernel
(
curr_hidden_states
,
w1
,
intermediate_cache1
,
a1_scale
,
w1_scale
,
curr_topk_weights
,
curr_topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
False
,
topk_ids
.
shape
[
1
],
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
intermediate_cache3
,
a2_scale
,
w2_scale
,
curr_topk_weights
,
curr_topk_ids
,
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
,
True
,
1
,
config
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
,
out
=
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
])
return
out_hidden_states
def
fused_moe
(
def
fused_moe
(
...
@@ -506,6 +559,9 @@ def fused_moe(
...
@@ -506,6 +559,9 @@ def fused_moe(
renormalize
:
bool
,
renormalize
:
bool
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
override_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
use_fp8
:
bool
=
False
,
use_fp8
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -528,6 +584,10 @@ def fused_moe(
...
@@ -528,6 +584,10 @@ def fused_moe(
Defaults to False.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
for the kernel configuration.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
...
@@ -541,8 +601,15 @@ def fused_moe(
...
@@ -541,8 +601,15 @@ def fused_moe(
# Check constraints.
# Check constraints.
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
assert
gating_output
.
shape
[
1
]
==
w1
.
shape
[
0
],
"Number of experts mismatch"
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
if
use_grouped_topk
:
renormalize
)
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
topk_weights
,
topk_ids
=
grouped_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
,
num_expert_group
,
topk_group
)
else
:
topk_weights
,
topk_ids
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
return
fused_experts
(
hidden_states
,
return
fused_experts
(
hidden_states
,
w1
,
w1
,
w2
,
w2
,
...
@@ -554,4 +621,4 @@ def fused_moe(
...
@@ -554,4 +621,4 @@ def fused_moe(
w1_scale
=
w1_scale
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
)
a2_scale
=
a2_scale
)
\ No newline at end of file
vllm/model_executor/layers/fused_moe/layer.py
0 → 100644
View file @
705f6a35
from
abc
import
abstractmethod
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_moe
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
logger
=
init_logger
(
__name__
)
class
FusedMoEMethodBase
(
QuantizeMethodBase
):
@
abstractmethod
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
raise
NotImplementedError
@
abstractmethod
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
raise
NotImplementedError
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
):
"""MoE method without quantization."""
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
# Fused gate_up_proj (column parallel)
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
# down_proj (row parallel)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
fused_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
router_logits
,
top_k
,
renormalize
=
renormalize
,
inplace
=
True
,
use_grouped_topk
=
use_grouped_topk
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
)
class
FusedMoE
(
torch
.
nn
.
Module
):
"""FusedMoE layer for MoE models.
This layer contains both MergedColumnParallel weights (gate_up_proj /
w13) and RowParallelLinear weights (down_proj/ w2).
Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We
copy that naming convention here and handle any remapping in the
load_weights function in each model implementation.
Args:
num_experts: Number of experts in the model
top_k: Number of experts selected for each token
hidden_size: Input hidden state size of the transformer
intermediate_size: Intermediate size of the experts
params_dtype: Data type for the parameters.
reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
"""
def
__init__
(
self
,
num_experts
:
int
,
top_k
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
False
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
tp_size
=
(
tp_size
if
tp_size
is
not
None
else
get_tensor_model_parallel_world_size
())
self
.
top_k
=
top_k
self
.
num_experts
=
num_experts
self
.
intermediate_size_per_partition
=
intermediate_size
//
self
.
tp_size
self
.
reduce_results
=
reduce_results
self
.
renormalize
=
renormalize
self
.
use_grouped_topk
=
use_grouped_topk
if
self
.
use_grouped_topk
:
assert
num_expert_group
is
not
None
and
topk_group
is
not
None
self
.
num_expert_group
=
num_expert_group
self
.
topk_group
=
topk_group
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
(
UnquantizedFusedMoEMethod
())
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
)
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
layer
=
self
,
num_experts
=
num_experts
,
hidden_size
=
hidden_size
,
intermediate_size
=
self
.
intermediate_size_per_partition
,
params_dtype
=
params_dtype
,
weight_loader
=
self
.
weight_loader
)
def
weight_loader
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
int
,
expert_id
:
int
):
param_data
=
param
.
data
# Input scales can be loaded directly and should be equal.
if
"input_scale"
in
weight_name
:
if
param_data
[
expert_id
]
!=
1
and
(
param_data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
:
raise
ValueError
(
"input_scales of w1 and w3 of a layer "
f
"must be equal. But got
{
param_data
[
expert_id
]
}
"
f
"vs.
{
loaded_weight
}
"
)
param_data
[
expert_id
]
=
loaded_weight
# Weight scales
elif
"weight_scale"
in
weight_name
:
# If we are in merged column case (gate_up_proj)
# shard_id 0 == gate_proj / w1
# shard_id 2 == up_proj / w3
if
shard_id
==
0
or
shard_id
==
2
:
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx
=
0
if
shard_id
==
0
else
1
param_data
[
expert_id
][
idx
]
=
loaded_weight
# If we are in the row parallel case (down_proj)
# shard_id 1 == down_proj / w2
else
:
param_data
[
expert_id
]
=
loaded_weight
# Weights
else
:
tp_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
intermediate_size_per_partition
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
# w1, gate_proj case: Load into first shard of w13.
if
shard_id
==
0
:
param_data
[
expert_id
,
0
:
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
# w3, up_proj case: Load into second shard of w13.
elif
shard_id
==
2
:
param_data
[
expert_id
,
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
# w2, down_proj case: Load into only shard of w2.
elif
shard_id
==
1
:
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
else
:
raise
ValueError
(
f
"Shard id must be in [0,1,2] but got
{
shard_id
}
"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
assert
self
.
quant_method
is
not
None
# Matrix multiply.
final_hidden_states
=
self
.
quant_method
.
apply
(
self
,
x
=
hidden_states
,
router_logits
=
router_logits
,
top_k
=
self
.
top_k
,
renormalize
=
self
.
renormalize
,
use_grouped_topk
=
self
.
use_grouped_topk
,
num_expert_group
=
self
.
num_expert_group
,
topk_group
=
self
.
topk_group
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
@
classmethod
def
make_expert_params_mapping
(
cls
,
ckpt_gate_proj_name
:
str
,
ckpt_down_proj_name
:
str
,
ckpt_up_proj_name
:
str
,
num_experts
:
int
)
->
List
[
Tuple
[
str
,
str
,
int
,
int
]]:
gate_up
=
[
ckpt_gate_proj_name
,
ckpt_up_proj_name
]
gate_down_up
=
[
ckpt_gate_proj_name
,
ckpt_down_proj_name
,
ckpt_up_proj_name
]
return
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
(
"experts.w13_scale"
if
weight_name
in
gate_up
else
"experts.w2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight_scale"
,
expert_id
,
shard_id
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
weight_name
in
enumerate
(
gate_down_up
)
]
+
[
# These are the weights for the experts
# (param_name, weight_name, expert_id, shard_id)
(
"experts.w13_weight"
if
weight_name
in
gate_up
else
"experts.w2_weight"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
,
shard_id
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
weight_name
in
enumerate
(
gate_down_up
)
]
+
[
# These are the weight scales for the experts
# (param_name, weight_name, expert_id, shard_id)
(
"experts.a13_scale"
if
weight_name
in
gate_up
else
"experts.a2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.input_scale"
,
expert_id
,
shard_id
)
for
expert_id
in
range
(
num_experts
)
for
shard_id
,
weight_name
in
enumerate
(
gate_down_up
)
]
vllm/model_executor/layers/layernorm.py
View file @
705f6a35
...
@@ -67,7 +67,77 @@ class RMSNorm(CustomOp):
...
@@ -67,7 +67,77 @@ class RMSNorm(CustomOp):
)
)
return
out
return
out
def
forward_xpu
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
from
vllm._ipex_ops
import
ipex_ops
as
ops
if
residual
is
not
None
:
ops
.
fused_add_rms_norm
(
x
,
residual
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
)
return
x
,
residual
out
=
torch
.
empty_like
(
x
)
ops
.
rms_norm
(
out
,
x
,
self
.
weight
.
data
,
self
.
variance_epsilon
,
)
return
out
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
s
=
f
"hidden_size=
{
self
.
weight
.
data
.
size
(
0
)
}
"
s
=
f
"hidden_size=
{
self
.
weight
.
data
.
size
(
0
)
}
"
s
+=
f
", eps=
{
self
.
variance_epsilon
}
"
s
+=
f
", eps=
{
self
.
variance_epsilon
}
"
return
s
return
s
class
GemmaRMSNorm
(
CustomOp
):
"""RMS normalization for Gemma.
Two differences from the above RMSNorm:
1. x * (1 + w) instead of x * w.
2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
"""
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
=
1e-6
,
)
->
None
:
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
zeros
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward_native
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype
=
x
.
dtype
if
residual
is
not
None
:
x
=
x
+
residual
residual
=
x
x
=
x
.
float
()
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
x
=
x
*
(
1.0
+
self
.
weight
.
float
())
x
=
x
.
to
(
orig_dtype
)
return
x
if
residual
is
None
else
(
x
,
residual
)
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
# TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
return
self
.
forward_native
(
x
,
residual
)
vllm/model_executor/layers/linear.py
View file @
705f6a35
...
@@ -42,6 +42,29 @@ def adjust_bitsandbytes_shard(param: Parameter,
...
@@ -42,6 +42,29 @@ def adjust_bitsandbytes_shard(param: Parameter,
return
quantized_size
,
quantized_offset
return
quantized_size
,
quantized_offset
def
adjust_scalar_to_fused_array
(
param
,
loaded_weight
,
shard_id
):
"""For fused modules (QKV and MLP) we have an array of length
N that holds 1 scale for each "logical" matrix. So the param
is an array of length N. The loaded_weight corresponds to
one of the shards on disk. Here, we slice the param based on
the shard_id for loading.
"""
qkv_idxs
=
{
"q"
:
0
,
"k"
:
1
,
"v"
:
2
}
if
isinstance
(
shard_id
,
str
):
shard_id
=
qkv_idxs
[
shard_id
]
elif
not
isinstance
(
shard_id
,
int
):
raise
ValueError
(
f
"Unknown Shard Id
{
shard_id
}
"
)
# AutoFP8 scales do not have a shape
# compressed-tensors scales do have a shape
if
len
(
loaded_weight
.
shape
)
!=
0
:
assert
loaded_weight
.
shape
[
0
]
==
1
loaded_weight
=
loaded_weight
[
0
]
return
param
[
shard_id
],
loaded_weight
def
pad_weight
(
weight
:
torch
.
Tensor
,
num_pad
:
int
,
pad_dim
:
int
=
0
):
def
pad_weight
(
weight
:
torch
.
Tensor
,
num_pad
:
int
,
pad_dim
:
int
=
0
):
if
weight
.
dim
()
==
1
:
if
weight
.
dim
()
==
1
:
padding
=
torch
.
zeros
(
num_pad
,
dtype
=
weight
.
dtype
,
device
=
weight
.
device
)
padding
=
torch
.
zeros
(
num_pad
,
dtype
=
weight
.
dtype
,
device
=
weight
.
device
)
...
@@ -105,17 +128,12 @@ class LinearMethodBase(QuantizeMethodBase):
...
@@ -105,17 +128,12 @@ class LinearMethodBase(QuantizeMethodBase):
class
UnquantizedLinearMethod
(
LinearMethodBase
):
class
UnquantizedLinearMethod
(
LinearMethodBase
):
"""Linear method without quantization.
"""Linear method without quantization."""
Args:
def
__init__
(
self
):
separate_bias_add: If true, add bias separately after matrix
multiplication.
"""
def
__init__
(
self
,
separate_bias_add
:
bool
=
False
):
self
.
separate_bias_add
=
separate_bias_add
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
...
@@ -133,23 +151,18 @@ class UnquantizedLinearMethod(LinearMethodBase):
...
@@ -133,23 +151,18 @@ class UnquantizedLinearMethod(LinearMethodBase):
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
weight
=
layer
.
weight
if
self
.
separate_bias_add
:
if
bias
is
not
None
:
return
F
.
linear
(
x
,
weight
)
+
bias
return
F
.
linear
(
x
,
weight
)
if
self
.
use_llama_nn
:
if
self
.
use_llama_nn
:
weight
=
weight
.
reshape
(
weight
.
shape
[
1
],
-
1
)
layer
.
weight
=
layer
.
weight
.
reshape
(
layer
.
weight
.
shape
[
1
],
-
1
)
if
bias
is
not
None
:
if
bias
is
not
None
:
return
torch
.
matmul
(
x
,
weight
)
+
bias
return
torch
.
matmul
(
x
,
layer
.
weight
)
+
bias
else
:
else
:
if
gemm_bank_conf
(
weight
.
shape
[
1
]
-
32
)
and
os
.
environ
[
'GEMM_PAD'
]
==
'1'
:
if
gemm_bank_conf
(
layer
.
weight
.
shape
[
1
]
-
32
)
and
os
.
environ
[
'GEMM_PAD'
]
==
'1'
:
return
torch
.
matmul
(
x
,
weight
[:,:
-
32
])
return
torch
.
matmul
(
x
,
layer
.
weight
[:,:
-
32
])
else
:
else
:
return
torch
.
matmul
(
x
,
weight
)
return
torch
.
matmul
(
x
,
layer
.
weight
)
else
:
else
:
return
F
.
linear
(
x
,
weight
,
bias
)
return
F
.
linear
(
x
,
layer
.
weight
,
bias
)
class
LinearBase
(
torch
.
nn
.
Module
):
class
LinearBase
(
torch
.
nn
.
Module
):
...
@@ -311,10 +324,6 @@ class ColumnParallelLinear(LinearBase):
...
@@ -311,10 +324,6 @@ class ColumnParallelLinear(LinearBase):
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
# Special case for Fp8 scales.
fp8_scales_shard_indexer
=
getattr
(
param
,
"fp8_scales_shard_indexer"
,
None
)
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
param_data
=
param
.
data
param_data
=
param
.
data
...
@@ -323,11 +332,11 @@ class ColumnParallelLinear(LinearBase):
...
@@ -323,11 +332,11 @@ class ColumnParallelLinear(LinearBase):
start_idx
=
tp_rank
*
shard_size
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
shard_size
)
# Special case for Fp8 scales.
elif
fp8_scales_shard_indexer
is
not
None
:
# Special case for loading scales off disk, which often do not
param_data
,
loaded_weight
=
fp8_scales_shard_indexer
(
param_data
,
# have a shape (such as in the case of AutoFP8).
loaded_weight
,
if
len
(
loaded_weight
.
shape
)
==
0
:
shard_id
=
0
)
loaded_weight
=
loaded_weight
.
reshape
(
1
)
assert
param_data
.
shape
==
loaded_weight
.
shape
assert
param_data
.
shape
==
loaded_weight
.
shape
if
self
.
use_llama_nn
:
if
self
.
use_llama_nn
:
...
@@ -409,38 +418,21 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -409,38 +418,21 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
# Special case for AQLM codebooks.
# Special case for AQLM codebooks.
is_metadata
=
getattr
(
param
,
"is_metadata"
,
False
)
is_metadata
=
getattr
(
param
,
"is_metadata"
,
False
)
# Special case for per-tensor scale to load scalar into fused array.
param_shard_splitter
=
getattr
(
param
,
"shard_splitter"
,
None
)
needs_scalar_to_array
=
getattr
(
param
,
"needs_scalar_to_array"
,
False
)
if
output_dim
is
not
None
and
param_shard_splitter
is
not
None
:
raise
NotImplementedError
(
"We do not currently support output_dim != None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# If a parameter has defined a shard_splitter to be used for
# the weight, it should be applied before the weight is
# loaded/copied to the parameter. The shard_splitter applies
# logic by using the loaded_shard_id to ensure that the loaded
# param is loaded to the correct location
# within the parameter defined by the linear method.
if
loaded_shard_id
is
None
and
param_shard_splitter
is
not
None
:
raise
NotImplementedError
(
"We do not currently support loaded_shard_id == None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# Special case for Fp8 scales.
fp8_scales_shard_indexer
=
getattr
(
param
,
"fp8_scales_shard_indexer"
,
None
)
if
loaded_shard_id
is
None
:
if
loaded_shard_id
is
None
:
# Loaded weight is already
packed
.
# Loaded weight is already
fused on disk (qkv/mlp)
.
if
output_dim
is
None
:
if
output_dim
is
None
:
if
needs_scalar_to_array
:
param_data
,
loaded_weight
=
adjust_scalar_to_fused_array
(
param_data
,
loaded_weight
,
0
)
assert
param_data
.
shape
==
loaded_weight
.
shape
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
param_data
.
copy_
(
loaded_weight
)
return
return
current_shard_offset
=
0
current_shard_offset
=
0
shard_offsets
=
[]
shard_offsets
:
List
[
Tuple
[
int
,
int
,
int
]]
=
[]
for
i
,
output_size
in
enumerate
(
self
.
output_sizes
):
for
i
,
output_size
in
enumerate
(
self
.
output_sizes
):
shard_offsets
.
append
((
i
,
current_shard_offset
,
output_size
))
shard_offsets
.
append
((
i
,
current_shard_offset
,
output_size
))
current_shard_offset
+=
output_size
current_shard_offset
+=
output_size
...
@@ -502,15 +494,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -502,15 +494,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset
=
loaded_shard_id
*
shard_size
shard_offset
=
loaded_shard_id
*
shard_size
param_data
=
param_data
.
narrow
(
0
,
shard_offset
,
shard_size
)
param_data
=
param_data
.
narrow
(
0
,
shard_offset
,
shard_size
)
# If a param_shard_splitter is defined by the LinearMethod, use it.
# Special case for per-tensor scales in fused case.
elif
param_shard_splitter
is
not
None
:
elif
needs_scalar_to_array
:
logical_widths
=
getattr
(
param
,
"logical_widths"
,
None
)
param_data
,
loaded_weight
=
adjust_scalar_to_fused_array
(
param_data
,
loaded_weight
=
param_shard_splitter
(
param_data
,
loaded_weight
,
loaded_shard_id
,
logical_widths
)
# Special case for Fp8 scales.
elif
fp8_scales_shard_indexer
is
not
None
:
param_data
,
loaded_weight
=
fp8_scales_shard_indexer
(
param_data
,
loaded_weight
,
loaded_shard_id
)
param_data
,
loaded_weight
,
loaded_shard_id
)
else
:
else
:
...
@@ -520,13 +506,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -520,13 +506,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in "
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions."
)
"the same for all partitions."
)
if
fp8_scales_shard_indexer
is
None
:
if
len
(
param_data
.
shape
)
==
0
:
param_data
=
param_data
.
reshape
(
1
)
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
if
self
.
use_llama_nn
:
if
self
.
use_llama_nn
:
assert
param_data_
.
shape
==
loaded_weight
.
shape
assert
param_data_
.
shape
==
loaded_weight
.
shape
...
@@ -539,7 +518,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -539,7 +518,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data
.
copy_
(
loaded_weight
)
param_data
.
copy_
(
loaded_weight
)
class
QKVParallelLinear
(
ColumnParallelLinear
):
class
QKVParallelLinear
(
ColumnParallelLinear
):
"""Linear layers for the attention's QKV transformation.
"""Linear layers for the attention's QKV transformation.
...
@@ -616,32 +594,16 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -616,32 +594,16 @@ class QKVParallelLinear(ColumnParallelLinear):
# Special case for AQLM codebooks.
# Special case for AQLM codebooks.
is_metadata
=
getattr
(
param
,
"is_metadata"
,
False
)
is_metadata
=
getattr
(
param
,
"is_metadata"
,
False
)
param_shard_splitter
=
getattr
(
param
,
"shard_splitter"
,
None
)
# Special case for per-tensor scales in fused case.
needs_scalar_to_array
=
getattr
(
param
,
"needs_scalar_to_array"
,
False
)
if
output_dim
is
not
None
and
param_shard_splitter
is
not
None
:
raise
NotImplementedError
(
"We do not currently support output_dim != None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# If a parameter has defined a shard_splitter to be used for
# the weight, it should be applied before the weight is
# loaded/copied to the parameter. The shard_splitter applies
# logic by using the loaded_shard_id to ensure that the loaded
# param is loaded to the correct location
# within the parameter defined by the linear method.
if
loaded_shard_id
is
None
and
param_shard_splitter
is
not
None
:
raise
NotImplementedError
(
"We do not currently support loaded_shard_id == None and "
"shard_splitter != None for a parameter. Please open an issue."
)
# Special case for Fp8 scales.
fp8_scales_shard_indexer
=
getattr
(
param
,
"fp8_scales_shard_indexer"
,
None
)
if
loaded_shard_id
is
None
:
if
loaded_shard_id
is
None
:
# Loaded weight is already
packed
.
# Loaded weight is already
fused on disk (qkv/mlp)
.
if
output_dim
is
None
:
if
output_dim
is
None
:
if
needs_scalar_to_array
:
param_data
,
loaded_weight
=
adjust_scalar_to_fused_array
(
param_data
,
loaded_weight
,
0
)
assert
param_data
.
shape
==
loaded_weight
.
shape
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
param_data
.
copy_
(
loaded_weight
)
return
return
...
@@ -735,15 +697,9 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -735,15 +697,9 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_index
=
[
"q"
,
"k"
,
"v"
].
index
(
loaded_shard_id
)
shard_index
=
[
"q"
,
"k"
,
"v"
].
index
(
loaded_shard_id
)
param_data
=
param_data
.
narrow
(
0
,
shard_index
*
shard_size
,
param_data
=
param_data
.
narrow
(
0
,
shard_index
*
shard_size
,
shard_size
)
shard_size
)
# If a param_shard_splitter is defined by the LinearMethod, use it.
# Special case for per-tensor scales in fused case.
elif
param_shard_splitter
is
not
None
:
elif
needs_scalar_to_array
:
logical_widths
=
getattr
(
param
,
"logical_widths"
,
None
)
param_data
,
loaded_weight
=
adjust_scalar_to_fused_array
(
param_data
,
loaded_weight
=
param_shard_splitter
(
param_data
,
loaded_weight
,
loaded_shard_id
,
logical_widths
)
# Special case for Fp8 scales.
elif
fp8_scales_shard_indexer
is
not
None
:
param_data
,
loaded_weight
=
fp8_scales_shard_indexer
(
param_data
,
loaded_weight
,
loaded_shard_id
)
param_data
,
loaded_weight
,
loaded_shard_id
)
else
:
else
:
ignore_warning
=
getattr
(
param
,
"ignore_warning"
,
False
)
ignore_warning
=
getattr
(
param
,
"ignore_warning"
,
False
)
...
@@ -752,11 +708,6 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -752,11 +708,6 @@ class QKVParallelLinear(ColumnParallelLinear):
"Loading a weight without `output_dim` attribute in "
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"QKVParallelLinear, assume the weight is the same "
"for all partitions."
)
"for all partitions."
)
if
len
(
param_data
.
shape
)
==
0
:
param_data
=
param_data
.
reshape
(
1
)
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
if
self
.
use_llama_nn
:
if
self
.
use_llama_nn
:
assert
param_data_
.
shape
==
loaded_weight
.
shape
assert
param_data_
.
shape
==
loaded_weight
.
shape
...
@@ -843,10 +794,6 @@ class RowParallelLinear(LinearBase):
...
@@ -843,10 +794,6 @@ class RowParallelLinear(LinearBase):
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
# Special case for Fp8 scales.
fp8_scales_shard_indexer
=
getattr
(
param
,
"fp8_scales_shard_indexer"
,
None
)
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
param_data
=
param
.
data
param_data
=
param
.
data
...
@@ -856,13 +803,9 @@ class RowParallelLinear(LinearBase):
...
@@ -856,13 +803,9 @@ class RowParallelLinear(LinearBase):
loaded_weight
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
loaded_weight
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
shard_size
)
shard_size
)
# Special case for Fp8 scales.
# Special case for loading scales off disk, which often do not
elif
fp8_scales_shard_indexer
is
not
None
:
# have a shape (such as in the case of AutoFP8).
param_data
,
loaded_weight
=
fp8_scales_shard_indexer
(
param_data
,
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
,
shard_id
=
0
)
if
fp8_scales_shard_indexer
is
None
and
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
loaded_weight
=
loaded_weight
.
reshape
(
1
)
assert
param_data
.
shape
==
loaded_weight
.
shape
assert
param_data
.
shape
==
loaded_weight
.
shape
...
@@ -880,7 +823,6 @@ class RowParallelLinear(LinearBase):
...
@@ -880,7 +823,6 @@ class RowParallelLinear(LinearBase):
param
.
data
=
param
.
data
.
reshape
(
param
.
data
.
shape
[
1
],
-
1
)
param
.
data
=
param
.
data
.
reshape
(
param
.
data
.
shape
[
1
],
-
1
)
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
# Set up backprop all-reduce.
if
self
.
input_is_parallel
:
if
self
.
input_is_parallel
:
input_parallel
=
input_
input_parallel
=
input_
else
:
else
:
...
...
vllm/model_executor/layers/logits_processor.py
View file @
705f6a35
...
@@ -6,6 +6,8 @@ import torch
...
@@ -6,6 +6,8 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.distributed
import
tensor_model_parallel_gather
from
vllm.distributed
import
tensor_model_parallel_gather
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
...
@@ -22,7 +24,8 @@ class LogitsProcessor(nn.Module):
...
@@ -22,7 +24,8 @@ class LogitsProcessor(nn.Module):
vocab_size
:
int
,
vocab_size
:
int
,
org_vocab_size
:
Optional
[
int
]
=
None
,
org_vocab_size
:
Optional
[
int
]
=
None
,
scale
:
float
=
1.0
,
scale
:
float
=
1.0
,
logits_as_input
:
bool
=
False
)
->
None
:
logits_as_input
:
bool
=
False
,
soft_cap
:
Optional
[
float
]
=
None
)
->
None
:
"""
"""
Args:
Args:
scale: A scaling factor to apply to the logits.
scale: A scaling factor to apply to the logits.
...
@@ -34,10 +37,12 @@ class LogitsProcessor(nn.Module):
...
@@ -34,10 +37,12 @@ class LogitsProcessor(nn.Module):
self
.
logits_as_input
=
logits_as_input
self
.
logits_as_input
=
logits_as_input
# original vocabulary size (without LoRA).
# original vocabulary size (without LoRA).
self
.
org_vocab_size
=
org_vocab_size
or
vocab_size
self
.
org_vocab_size
=
org_vocab_size
or
vocab_size
# Soft cap the logits. Used in Gemma 2.
self
.
soft_cap
=
soft_cap
def
forward
(
def
forward
(
self
,
self
,
embedding
:
torch
.
Tensor
,
lm_head
:
VocabParallelEmbedding
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -49,9 +54,13 @@ class LogitsProcessor(nn.Module):
...
@@ -49,9 +54,13 @@ class LogitsProcessor(nn.Module):
sampling_metadata
)
sampling_metadata
)
# Get the logits for the next tokens.
# Get the logits for the next tokens.
logits
=
self
.
_get_logits
(
hidden_states
,
embedding
,
embedding_bias
)
logits
=
self
.
_get_logits
(
hidden_states
,
lm_head
,
embedding_bias
)
if
logits
is
not
None
:
if
logits
is
not
None
:
if
self
.
soft_cap
is
not
None
:
logits
=
logits
/
self
.
soft_cap
logits
=
torch
.
tanh
(
logits
)
logits
=
logits
*
self
.
soft_cap
if
self
.
scale
!=
1.0
:
if
self
.
scale
!=
1.0
:
logits
*=
self
.
scale
logits
*=
self
.
scale
...
@@ -60,12 +69,13 @@ class LogitsProcessor(nn.Module):
...
@@ -60,12 +69,13 @@ class LogitsProcessor(nn.Module):
return
logits
return
logits
def
_get_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
embedding
:
torch
.
Tensor
,
def
_get_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
lm_head
:
VocabParallelEmbedding
,
embedding_bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
embedding_bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
# Get the logits for the next tokens.
# Get the logits for the next tokens.
logits
=
torch
.
matmul
(
hidden_states
,
embedding
.
t
())
logits
=
lm_head
.
linear_method
.
apply
(
lm_head
,
if
embedding_bias
is
not
None
:
hidden_states
,
logits
+=
embedding_bias
bias
=
embedding_bias
)
logits
=
tensor_model_parallel_gather
(
logits
)
logits
=
tensor_model_parallel_gather
(
logits
)
# Remove paddings in vocab (if any).
# Remove paddings in vocab (if any).
if
logits
is
not
None
:
if
logits
is
not
None
:
...
...
vllm/model_executor/layers/quantization/awq.py
View file @
705f6a35
...
@@ -43,7 +43,8 @@ class AWQConfig(QuantizationConfig):
...
@@ -43,7 +43,8 @@ class AWQConfig(QuantizationConfig):
def
get_supported_act_dtypes
(
self
)
->
List
[
torch
.
dtype
]:
def
get_supported_act_dtypes
(
self
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
return
[
torch
.
half
]
def
get_min_capability
(
self
)
->
int
:
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# The AWQ kernel only supports Turing or newer GPUs.
# The AWQ kernel only supports Turing or newer GPUs.
return
75
return
75
...
...
vllm/model_executor/layers/quantization/base_config.py
View file @
705f6a35
...
@@ -44,8 +44,9 @@ class QuantizationConfig(ABC):
...
@@ -44,8 +44,9 @@ class QuantizationConfig(ABC):
"""List of supported activation dtypes."""
"""List of supported activation dtypes."""
raise
NotImplementedError
raise
NotImplementedError
@
classmethod
@
abstractmethod
@
abstractmethod
def
get_min_capability
(
self
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
"""Minimum GPU capability to support the quantization method.
"""Minimum GPU capability to support the quantization method.
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
...
@@ -86,6 +87,15 @@ class QuantizationConfig(ABC):
...
@@ -86,6 +87,15 @@ class QuantizationConfig(ABC):
raise
ValueError
(
f
"Cannot find any of
{
keys
}
in the model's "
raise
ValueError
(
f
"Cannot find any of
{
keys
}
in the model's "
"quantization config."
)
"quantization config."
)
@
staticmethod
def
get_from_keys_or
(
config
:
Dict
[
str
,
Any
],
keys
:
List
[
str
],
default
:
Any
)
->
Any
:
"""Get a optional value from the model's quantization config."""
try
:
return
QuantizationConfig
.
get_from_keys
(
config
,
keys
)
except
ValueError
:
return
default
@
abstractmethod
@
abstractmethod
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
QuantizeMethodBase
]:
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
QuantizeMethodBase
]:
...
...
vllm/model_executor/layers/quantization/bitsandbytes.py
View file @
705f6a35
...
@@ -38,7 +38,7 @@ class BitsAndBytesConfig(QuantizationConfig):
...
@@ -38,7 +38,7 @@ class BitsAndBytesConfig(QuantizationConfig):
return
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
return
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
@
classmethod
def
get_min_capability
(
self
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
return
70
return
70
@
staticmethod
@
staticmethod
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
705f6a35
...
@@ -7,17 +7,23 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
...
@@ -7,17 +7,23 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
# noqa: E501
from
vllm.model_executor.layers.quantization.base_config
import
(
# noqa: E501
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
,
CompressedTensorsW8A8DynamicToken
,
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
CompressedTensorsW8A8StaticTensor
)
CompressedTensorsScheme
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsWNA16
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
QuantizationArgs
,
QuantizationStrategy
,
find_first_name_or_class_match
)
CompressionFormat
,
QuantizationArgs
,
QuantizationStrategy
,
QuantizationType
,
find_first_name_or_class_match
)
from
vllm.platforms
import
current_platform
class
CompressedTensorsConfig
(
QuantizationConfig
):
class
CompressedTensorsConfig
(
QuantizationConfig
):
def
__init__
(
self
,
layer_quant_details
:
Dict
[
str
,
Any
],
ignore
:
List
[
str
]):
def
__init__
(
self
,
layer_quant_details
:
Dict
[
str
,
Any
],
ignore
:
List
[
str
],
quant_format
:
str
):
self
.
ignore
=
ignore
self
.
ignore
=
ignore
self
.
layer_quant_details
=
layer_quant_details
self
.
layer_quant_details
=
layer_quant_details
self
.
quant_format
=
quant_format
def
get_linear_method
(
self
)
->
"CompressedTensorsLinearMethod"
:
def
get_linear_method
(
self
)
->
"CompressedTensorsLinearMethod"
:
return
CompressedTensorsLinearMethod
(
self
)
return
CompressedTensorsLinearMethod
(
self
)
...
@@ -26,11 +32,11 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -26,11 +32,11 @@ class CompressedTensorsConfig(QuantizationConfig):
return
[]
return
[]
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
float16
]
return
[
torch
.
float16
,
torch
.
bfloat16
]
# Need to figure it out
@
classmethod
def
get_min_capability
(
self
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
return
60
return
75
def
get_name
(
self
)
->
str
:
def
get_name
(
self
)
->
str
:
return
"compressed_tensors"
return
"compressed_tensors"
...
@@ -46,29 +52,54 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -46,29 +52,54 @@ class CompressedTensorsConfig(QuantizationConfig):
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"CompressedTensorsConfig"
:
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"CompressedTensorsConfig"
:
layer_quant_details
:
Dict
[
str
,
Any
]
=
dict
()
layer_quant_details
:
Dict
[
str
,
Any
]
=
dict
()
ignore
:
List
[
str
]
=
config
.
get
(
"ignore"
,
None
)
ignore
:
List
[
str
]
=
config
.
get
(
"ignore"
,
None
)
quant_format
:
str
=
config
.
get
(
"format"
,
None
)
# The quant_config has multiple config_groups, each containing
# an input_activations key with details about how the activations are
# quantized, a weights key indicating how the weights are quantized,
# and a list of targets under the `targets` key, dictating which
# layers are impacted by the quantization details. The quantization
# details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use.
for
key
,
quant_config
in
config
[
"config_groups"
].
items
():
for
key
,
quant_config
in
config
[
"config_groups"
].
items
():
targets
=
quant_config
.
get
(
"targets"
)
targets
=
quant_config
.
get
(
"targets"
)
for
target
in
targets
:
for
target
in
targets
:
layer_quant_details
[
target
]
=
{}
layer_quant_details
[
target
]
=
{}
layer_quant_details
[
target
][
layer_quant_details
[
target
][
"weight"
]
=
QuantizationArgs
.
parse_obj
(
"weight
s
"
]
=
QuantizationArgs
.
parse_obj
(
quant_config
.
get
(
"weights"
))
quant_config
.
get
(
"weights"
))
layer_quant_details
[
target
][
try
:
"input"
]
=
QuantizationArgs
.
parse_obj
(
layer_quant_details
[
target
][
quant_config
.
get
(
"input_activations"
))
"input_activations"
]
=
QuantizationArgs
.
parse_obj
(
quant_config
.
get
(
"input_activations"
))
except
Exception
:
layer_quant_details
[
target
][
"input_activations"
]
=
None
return
cls
(
layer_quant_details
=
layer_quant_details
,
ignore
=
ignore
)
return
cls
(
layer_quant_details
=
layer_quant_details
,
ignore
=
ignore
,
quant_format
=
quant_format
)
@
classmethod
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
return
[]
def
_check_gptq_and_marlin_can_run
(
self
):
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
if
capability
<
80
:
raise
RuntimeError
(
"The quantization config is not supported for "
,
"the current GPU. Minimum capability: 80. "
,
f
"Current capability:
{
capability
}
."
)
def
_is_static_tensor_w8a8
(
self
,
weight_quant
:
BaseModel
,
def
_is_static_tensor_w8a8
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
input_quant
:
BaseModel
)
->
bool
:
is_8_bits
=
weight_quant
.
num_bits
==
input_quant
.
num_bits
==
8
is_8_bits
=
weight_quant
.
num_bits
==
input_quant
.
num_bits
==
8
is_tensor
=
(
weight_quant
.
strategy
==
input_quant
.
strategy
==
weight_strategy
=
(
QuantizationStrategy
.
TENSOR
.
value
)
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
or
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
.
value
)
is_tensor
=
(
weight_strategy
and
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
)
is_symmetric
=
weight_quant
.
symmetric
and
input_quant
.
symmetric
is_symmetric
=
weight_quant
.
symmetric
and
input_quant
.
symmetric
is_static
=
not
weight_quant
.
dynamic
and
not
input_quant
.
dynamic
is_static
=
not
weight_quant
.
dynamic
and
not
input_quant
.
dynamic
...
@@ -77,24 +108,98 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -77,24 +108,98 @@ class CompressedTensorsConfig(QuantizationConfig):
def
_is_dynamic_token_w8a8
(
self
,
weight_quant
:
BaseModel
,
def
_is_dynamic_token_w8a8
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
input_quant
:
BaseModel
)
->
bool
:
is_8_bits
=
weight_quant
.
num_bits
==
input_quant
.
num_bits
==
8
is_8_bits
=
weight_quant
.
num_bits
==
input_quant
.
num_bits
==
8
is_token_tensor
=
(
weight_quant
.
strategy
weight_strategy
=
(
==
QuantizationStrategy
.
TENSOR
.
value
)
and
(
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
.
value
input_quant
.
strategy
or
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
.
value
)
==
QuantizationStrategy
.
TOKEN
.
value
)
is_token
=
(
weight_strategy
and
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
.
value
)
is_symmetric
=
weight_quant
.
symmetric
and
input_quant
.
symmetric
is_symmetric
=
weight_quant
.
symmetric
and
input_quant
.
symmetric
is_dynamic
=
not
weight_quant
.
dynamic
and
input_quant
.
dynamic
is_dynamic
=
not
weight_quant
.
dynamic
and
input_quant
.
dynamic
return
is_8_bits
and
is_token_tensor
and
is_symmetric
and
is_dynamic
return
is_8_bits
and
is_token
and
is_symmetric
and
is_dynamic
def
_is_fp8_w8a8
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
# Confirm weights and activations quantized.
if
weight_quant
is
None
or
input_quant
is
None
:
return
False
# Confirm we have floating points.
if
not
(
weight_quant
.
type
==
QuantizationType
.
FLOAT
and
input_quant
.
type
==
QuantizationType
.
FLOAT
):
return
False
# Confirm weight scheme is supported.
is_symmetric_weight
=
weight_quant
.
symmetric
is_static_weight
=
not
weight_quant
.
dynamic
is_per_tensor_weight
=
(
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
)
if
not
(
is_symmetric_weight
and
is_static_weight
and
is_per_tensor_weight
):
return
False
# Dynamic quantization is always supported if weights supported.
if
input_quant
.
dynamic
:
return
True
# Confirm activation scheme is supported.
is_symmetric_activation
=
input_quant
.
symmetric
is_per_tensor_activation
=
(
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
)
if
not
(
is_symmetric_activation
and
is_per_tensor_activation
):
return
False
# All conditions satisfied.
return
True
def
_is_wNa16_group_channel
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
input_quant_none
=
input_quant
is
None
is_symmetric
=
weight_quant
.
symmetric
is_channel_group
=
(
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
.
value
or
weight_quant
.
strategy
==
QuantizationStrategy
.
GROUP
.
value
)
is_static
=
not
weight_quant
.
dynamic
return
(
is_channel_group
and
input_quant_none
and
is_symmetric
and
is_static
)
def
_get_schema
(
self
,
weight_quant
:
BaseModel
,
def
_get_schema
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
"CompressedTensorsScheme"
:
input_quant
:
BaseModel
)
->
"CompressedTensorsScheme"
:
if
self
.
_is_static_tensor_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8StaticTensor
()
if
self
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
if
self
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8DynamicToken
()
self
.
_check_gptq_and_marlin_can_run
()
if
(
self
.
quant_format
==
CompressionFormat
.
marlin_24
.
value
raise
NotImplementedError
(
"Scheme not supported."
)
and
weight_quant
.
num_bits
in
W4A16SPARSE24_SUPPORTED_BITS
):
return
CompressedTensorsW4A16Sparse24
(
strategy
=
weight_quant
.
strategy
,
num_bits
=
weight_quant
.
num_bits
,
group_size
=
weight_quant
.
group_size
)
if
(
self
.
quant_format
==
CompressionFormat
.
pack_quantized
.
value
and
weight_quant
.
num_bits
in
WNA16_SUPPORTED_BITS
):
return
CompressedTensorsWNA16
(
num_bits
=
weight_quant
.
num_bits
,
strategy
=
weight_quant
.
strategy
,
group_size
=
weight_quant
.
group_size
)
if
(
self
.
quant_format
==
CompressionFormat
.
int_quantized
.
value
or
self
.
quant_format
==
CompressionFormat
.
float_quantized
.
value
):
if
self
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Fp8
(
input_dynamic
=
input_quant
.
dynamic
)
if
self
.
_is_static_tensor_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Int8
(
strategy
=
weight_quant
.
strategy
,
is_static_input_scheme
=
True
)
if
self
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Int8
(
strategy
=
weight_quant
.
strategy
,
is_static_input_scheme
=
False
)
raise
NotImplementedError
(
"No compressed-tensors compatible scheme was found."
)
def
get_scheme
(
self
,
layer
:
torch
.
nn
.
Module
)
->
"CompressedTensorsScheme"
:
def
get_scheme
(
self
,
layer
:
torch
.
nn
.
Module
)
->
"CompressedTensorsScheme"
:
...
@@ -113,8 +218,9 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -113,8 +218,9 @@ class CompressedTensorsConfig(QuantizationConfig):
raise
ValueError
(
raise
ValueError
(
f
"Could not find quantization details for
{
layer
}
."
)
f
"Could not find quantization details for
{
layer
}
."
)
return
self
.
_get_schema
(
weight_quant
=
layer_quant_details
[
"weight"
],
return
self
.
_get_schema
(
input_quant
=
layer_quant_details
[
"input"
])
weight_quant
=
layer_quant_details
[
"weights"
],
input_quant
=
layer_quant_details
[
"input_activations"
])
class
CompressedTensorsLinearMethod
(
LinearMethodBase
):
class
CompressedTensorsLinearMethod
(
LinearMethodBase
):
...
@@ -122,6 +228,9 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
...
@@ -122,6 +228,9 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def
__init__
(
self
,
quantization_config
:
CompressedTensorsConfig
):
def
__init__
(
self
,
quantization_config
:
CompressedTensorsConfig
):
self
.
quantization_config
=
quantization_config
self
.
quantization_config
=
quantization_config
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
scheme
.
process_weights_after_loading
(
layer
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
...
@@ -138,6 +247,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
...
@@ -138,6 +247,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
scheme
=
self
.
quantization_config
.
get_scheme
(
layer
=
layer
)
scheme
=
self
.
quantization_config
.
get_scheme
(
layer
=
layer
)
scheme
.
create_weights
(
scheme
.
create_weights
(
layer
=
layer
,
layer
=
layer
,
input_size
=
input_size
,
input_size_per_partition
=
input_size_per_partition
,
input_size_per_partition
=
input_size_per_partition
,
output_partition_sizes
=
output_partition_sizes
,
output_partition_sizes
=
output_partition_sizes
,
output_size
=
output_size
,
output_size
=
output_size
,
...
@@ -157,10 +267,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
...
@@ -157,10 +267,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
"""
"""
if
bias
is
not
None
:
raise
ValueError
(
"bias is not supported for this linear method"
)
scheme
=
layer
.
scheme
scheme
=
layer
.
scheme
if
scheme
is
None
:
if
scheme
is
None
:
raise
ValueError
(
"A scheme must be defined for each layer"
)
raise
ValueError
(
"A scheme must be defined for each layer"
)
return
scheme
.
apply_weights
(
layer
,
x
)
return
scheme
.
apply_weights
(
layer
,
x
,
bias
=
bias
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
View file @
705f6a35
from
.compressed_tensors_scheme
import
CompressedTensorsScheme
# noqa: F401
from
.compressed_tensors_scheme
import
CompressedTensorsScheme
from
.compressed_tensors_unquantized
import
(
# noqa: F401
from
.compressed_tensors_unquantized
import
CompressedTensorsUnquantized
CompressedTensorsUnquantized
)
from
.compressed_tensors_w4a16_24
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
from
.compressed_tensors_w8a8_dynamictoken
import
(
# noqa: F401, E501
CompressedTensorsW4A16Sparse24
)
CompressedTensorsW8A8DynamicToken
)
from
.compressed_tensors_w8a8_fp8
import
CompressedTensorsW8A8Fp8
from
.compressed_tensors_w8a8_statictensor
import
(
# noqa: F401, E501
from
.compressed_tensors_w8a8_int8
import
CompressedTensorsW8A8Int8
CompressedTensorsW8A8StaticTensor
)
from
.compressed_tensors_wNa16
import
(
WNA16_SUPPORTED_BITS
,
CompressedTensorsWNA16
)
__all__
=
[
"CompressedTensorsScheme"
,
"CompressedTensorsUnquantized"
,
"CompressedTensorsWNA16"
,
"CompressedTensorsW4A16Sparse24"
,
"CompressedTensorsW8A8Int8"
,
"CompressedTensorsW8A8Fp8"
,
"WNA16_SUPPORTED_BITS"
,
"W4A16SPARSE24_SUPPORTED_BITS"
,
]
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
View file @
705f6a35
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
import
torch
import
torch
...
@@ -20,14 +21,24 @@ class CompressedTensorsScheme(ABC):
...
@@ -20,14 +21,24 @@ class CompressedTensorsScheme(ABC):
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
):
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]):
"""
"""
Run the forward pass for the particular scheme. This is where
Run the forward pass for the particular scheme. This is where
scheme-specific dequant/quant steps/kernels should be applied.
scheme-specific dequant/quant steps/kernels should be applied.
:param layer: toch.nn.Module with the registered weights and
:param layer: to
r
ch.nn.Module with the registered weights and
other parameters relevant to the particular scheme.
other parameters relevant to the particular scheme.
:param x: input to the layer
:param x: input to the layer
:param bias: bias parameter
"""
"""
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
):
"""
Called after weight loading is complete for any cleanup that
needs to occur.
"""
raise
NotImplementedError
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
View file @
705f6a35
from
typing
import
Callable
,
List
from
typing
import
Callable
,
List
,
Optional
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -18,6 +18,9 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
...
@@ -18,6 +18,9 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
in a linear transformation.
in a linear transformation.
"""
"""
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
pass
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
input_size_per_partition
:
int
,
...
@@ -34,6 +37,7 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
...
@@ -34,6 +37,7 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
layer
.
register_parameter
(
"weight"
,
weight
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
"weight_loader"
:
weight_loader
})
set_weight_attrs
(
weight
,
{
"weight_loader"
:
weight_loader
})
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
):
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
weight
=
layer
.
weight
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
F
.
linear
(
x
,
weight
)
return
F
.
linear
(
x
,
layer
.
weight
,
bias
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
0 → 100644
View file @
705f6a35
from
typing
import
Callable
,
List
,
Optional
import
torch
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQ_MARLIN_24_MAX_PARALLEL
,
GPTQ_MARLIN_24_MIN_THREAD_N
)
from
vllm.model_executor.utils
import
set_weight_attrs
__all__
=
[
"CompressedTensorsW4A16Sparse24"
]
W4A16SPARSE24_SUPPORTED_BITS
=
[
4
]
class
CompressedTensorsW4A16Sparse24
(
CompressedTensorsScheme
):
def
__init__
(
self
,
strategy
:
str
,
num_bits
:
int
,
group_size
:
Optional
[
int
]
=
None
):
self
.
strategy
=
strategy
self
.
group_size
=
group_size
self
.
num_bits
=
num_bits
self
.
tile_size
=
16
if
self
.
strategy
==
"group"
and
self
.
group_size
is
None
:
raise
ValueError
(
"group_size must be given when using strategy group"
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
pass
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
pack_factor
=
32
//
self
.
num_bits
output_size_per_partition
=
sum
(
output_partition_sizes
)
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
tile_size
//
2
,
output_size_per_partition
*
self
.
tile_size
//
pack_factor
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
pack_factor
,
"marlin_tile_size"
:
self
.
tile_size
,
"weight_loader"
:
weight_loader
},
)
layer
.
register_parameter
(
"weight_packed"
,
qweight
)
input_groups
=
(
1
if
self
.
group_size
is
None
else
input_size_per_partition
//
self
.
group_size
)
scales
=
Parameter
(
torch
.
empty
(
input_groups
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
"output_dim"
:
1
,
"input_dim"
:
None
if
input_groups
==
1
else
0
,
"weight_loader"
:
weight_loader
},
)
layer
.
register_parameter
(
"scale_packed"
,
scales
)
weight_shape
=
Parameter
(
torch
.
empty
(
2
,
dtype
=
torch
.
int64
),
requires_grad
=
False
)
layer
.
register_parameter
(
"weight_shape"
,
weight_shape
)
set_weight_attrs
(
weight_shape
,
{
"weight_loader"
:
weight_loader
})
meta
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
8
//
2
//
2
,
output_size_per_partition
*
2
,
dtype
=
torch
.
int16
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
meta
,
{
"input_dim"
:
0
,
"packed_dim"
:
1
,
"pack_factor"
:
1
,
"output_dim"
:
1
,
"marlin_tile_size"
:
2
,
"weight_loader"
:
weight_loader
},
)
layer
.
register_parameter
(
"meta"
,
meta
)
max_workspace_size
=
(
output_size_per_partition
//
GPTQ_MARLIN_24_MIN_THREAD_N
)
*
GPTQ_MARLIN_24_MAX_PARALLEL
workspace
=
Parameter
(
torch
.
zeros
(
max_workspace_size
,
dtype
=
torch
.
int
),
requires_grad
=
False
)
layer
.
workspace
=
workspace
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
qweight
=
layer
.
weight_packed
meta
=
layer
.
meta
scales
=
layer
.
scale_packed
workspace
=
layer
.
workspace
x_2d
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
size_m
=
x_2d
.
shape
[
0
]
size_k
=
x_2d
.
shape
[
1
]
size_n
=
scales
.
shape
[
1
]
output_2d
=
ops
.
gptq_marlin_24_gemm
(
x_2d
,
qweight
,
meta
,
scales
,
workspace
,
self
.
num_bits
,
size_m
,
size_n
,
size_k
)
output
=
output_2d
.
view
(
x
.
shape
[:
-
1
]
+
(
output_2d
.
shape
[
1
],
))
if
bias
is
not
None
:
output
.
add_
(
bias
)
# In-place add
return
output
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_dynamictoken.py
deleted
100644 → 0
View file @
af837396
from
typing
import
Callable
,
List
,
Tuple
,
Union
import
torch
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
custom_ops
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.utils
import
set_weight_attrs
__all__
=
[
"CompressedTensorsW8A8DynamicToken"
]
class
CompressedTensorsW8A8DynamicToken
(
CompressedTensorsScheme
):
def
_shard_id_as_int
(
self
,
shard_id
:
Union
[
str
,
int
])
->
int
:
if
isinstance
(
shard_id
,
int
):
return
shard_id
assert
isinstance
(
shard_id
,
str
)
qkv_idxs
=
{
"q"
:
0
,
"k"
:
1
,
"v"
:
2
}
assert
shard_id
in
qkv_idxs
return
qkv_idxs
[
shard_id
]
def
scales_shard_splitter
(
self
,
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
shard_id
:
Union
[
str
,
int
],
logical_widths
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
shard_id
=
self
.
_shard_id_as_int
(
shard_id
)
offset
=
sum
(
logical_widths
[:
shard_id
])
size
=
logical_widths
[
shard_id
]
# update loaded weight with copies for broadcast.
loaded_weight
=
loaded_weight
.
repeat
(
size
)
return
param
[
offset
:
offset
+
size
],
loaded_weight
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
# When the scales have a single value, it is required that they be
# on the CPU for performance and CUDA Graphs compatibility. Please
# refer to the comment in
# CompressedTensorsW8A8StaticTensor::create_weights for further
# information.
is_tensor_partitioned
=
len
(
output_partition_sizes
)
!=
1
weight_scale_dim
=
sum
(
output_partition_sizes
)
if
is_tensor_partitioned
else
1
weight_zero_point
=
Parameter
(
torch
.
empty
(
1
,
dtype
=
torch
.
int8
),
requires_grad
=
False
)
weight_scale
=
Parameter
(
torch
.
empty
(
weight_scale_dim
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
torch
.
int8
),
requires_grad
=
False
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
set_weight_attrs
(
weight
,
{
"weight_loader"
:
weight_loader
})
set_weight_attrs
(
weight
,
{
"logical_widths"
:
output_partition_sizes
})
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
set_weight_attrs
(
weight_scale
,
{
"weight_loader"
:
weight_loader
})
set_weight_attrs
(
weight_scale
,
{
"shard_splitter"
:
self
.
scales_shard_splitter
,
"logical_widths"
:
output_partition_sizes
})
layer
.
register_parameter
(
"weight_zero_point"
,
weight_zero_point
)
set_weight_attrs
(
weight_zero_point
,
{
"weight_loader"
:
weight_loader
})
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
):
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
x_q
,
input_scales
=
custom_ops
.
scaled_int8_quant
(
x
)
return
custom_ops
.
cutlass_scaled_mm_dq
(
x_q
,
weight
.
t
(),
input_scales
,
weight_scale
,
x
.
dtype
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
0 → 100644
View file @
705f6a35
from
typing
import
Callable
,
List
,
Optional
import
torch
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
create_per_tensor_scale_param
,
cutlass_fp8_supported
,
requantize_with_max_scale
)
from
vllm.model_executor.utils
import
set_weight_attrs
__all__
=
[
"CompressedTensorsW8A8Fp8"
]
class
CompressedTensorsW8A8Fp8
(
CompressedTensorsScheme
):
def
__init__
(
self
,
input_dynamic
:
bool
):
self
.
input_dynamic
=
input_dynamic
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), we requantize with a single scale.
def
process_weights_after_loading
(
self
,
layer
)
->
None
:
# Dequant -> Quant with max scale.
max_w_scale
,
weight
=
requantize_with_max_scale
(
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
logical_widths
=
layer
.
logical_widths
,
)
# Update layer with new values.
layer
.
weight
=
torch
.
nn
.
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
max_w_scale
,
requires_grad
=
False
)
if
self
.
input_dynamic
:
layer
.
input_scale
=
None
else
:
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
del
params_dtype
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
logical_widths
=
output_partition_sizes
# WEIGHT
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
float8_e4m3fn
),
requires_grad
=
False
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
,
"weight_loader"
:
weight_loader
,
})
# WEIGHT SCALE
weight_scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE
if
not
self
.
input_dynamic
:
input_scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
apply_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
,
cutlass_fp8_supported
=
self
.
cutlass_fp8_supported
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
0 → 100644
View file @
705f6a35
from
typing
import
Callable
,
List
,
Optional
import
torch
from
torch.nn
import
Parameter
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
QuantizationStrategy
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_int8_linear
,
convert_to_channelwise
,
create_per_channel_scale_param
,
create_per_tensor_scale_param
)
from
vllm.model_executor.utils
import
set_weight_attrs
class
CompressedTensorsW8A8Int8
(
CompressedTensorsScheme
):
def
__init__
(
self
,
strategy
:
str
,
is_static_input_scheme
:
bool
):
self
.
strategy
=
strategy
self
.
is_static_input_scheme
=
is_static_input_scheme
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# WEIGHT
# Cutlass kernels need transposed weight.
weight
=
layer
.
weight
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
# WEIGHT SCALE
# Cutlass kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module
=
len
(
self
.
logical_widths
)
>
1
if
is_fused_module
and
self
.
strategy
==
QuantizationStrategy
.
TENSOR
:
ws_channelwise
=
convert_to_channelwise
(
layer
.
weight_scale
,
self
.
logical_widths
)
layer
.
weight_scale
=
Parameter
(
ws_channelwise
,
requires_grad
=
False
)
# INPUT SCALE
if
self
.
is_static_input_scheme
:
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
else
:
layer
.
input_scale
=
None
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
self
.
logical_widths
=
output_partition_sizes
# WEIGHT
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
torch
.
int8
),
requires_grad
=
False
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
,
"weight_loader"
:
weight_loader
,
})
# WEIGHT SCALE
layer_kwargs
=
{
"weight_loader"
:
weight_loader
}
if
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
scale
=
create_per_channel_scale_param
(
output_partition_sizes
,
**
layer_kwargs
)
else
:
assert
self
.
strategy
==
QuantizationStrategy
.
TENSOR
scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
**
layer_kwargs
)
layer
.
register_parameter
(
"weight_scale"
,
scale
)
# INPUT SCALE
if
self
.
is_static_input_scheme
:
scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
**
layer_kwargs
)
layer
.
register_parameter
(
"input_scale"
,
scale
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
apply_int8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
,
bias
=
bias
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py
deleted
100644 → 0
View file @
af837396
from
typing
import
Callable
,
List
,
Tuple
,
Union
import
torch
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
custom_ops
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.utils
import
set_weight_attrs
__all__
=
[
"CompressedTensorsW8A8StaticTensor"
]
class
CompressedTensorsW8A8StaticTensor
(
CompressedTensorsScheme
):
def
_shard_id_as_int
(
self
,
shard_id
:
Union
[
str
,
int
])
->
int
:
if
isinstance
(
shard_id
,
int
):
return
shard_id
assert
isinstance
(
shard_id
,
str
)
qkv_idxs
=
{
"q"
:
0
,
"k"
:
1
,
"v"
:
2
}
assert
shard_id
in
qkv_idxs
return
qkv_idxs
[
shard_id
]
def
scales_shard_splitter
(
self
,
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
shard_id
:
Union
[
str
,
int
],
logical_widths
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
shard_id
=
self
.
_shard_id_as_int
(
shard_id
)
offset
=
sum
(
logical_widths
[:
shard_id
])
size
=
logical_widths
[
shard_id
]
# update loaded weight with copies for broadcast.
loaded_weight
=
loaded_weight
.
repeat
(
size
)
return
param
[
offset
:
offset
+
size
],
loaded_weight
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
# TODO: remove zero_point parameters once the configs given remove them
is_tensor_partitioned
=
len
(
output_partition_sizes
)
!=
1
weight_scale_dim
=
sum
(
output_partition_sizes
)
if
is_tensor_partitioned
else
1
input_scale
=
Parameter
(
torch
.
empty
(
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
input_zero_point
=
Parameter
(
torch
.
empty
(
1
,
dtype
=
torch
.
int8
),
requires_grad
=
False
)
weight_scale
=
Parameter
(
torch
.
empty
(
weight_scale_dim
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
weight_zero_point
=
Parameter
(
torch
.
empty
(
1
,
dtype
=
torch
.
int8
),
requires_grad
=
False
)
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
torch
.
int8
),
requires_grad
=
False
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
"weight_loader"
:
weight_loader
,
"input_dim"
:
1
,
"output_dim"
:
0
,
})
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
set_weight_attrs
(
input_scale
,
{
"weight_loader"
:
weight_loader
,
"ignore_warning"
:
True
,
})
layer
.
register_parameter
(
"input_zero_point"
,
input_zero_point
)
set_weight_attrs
(
input_zero_point
,
{
"weight_loader"
:
weight_loader
,
"ignore_warning"
:
True
,
})
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
set_weight_attrs
(
weight_scale
,
{
"weight_loader"
:
weight_loader
,
"shard_splitter"
:
self
.
scales_shard_splitter
,
"logical_widths"
:
output_partition_sizes
,
"ignore_warning"
:
True
,
})
layer
.
register_parameter
(
"weight_zero_point"
,
weight_zero_point
)
set_weight_attrs
(
weight_zero_point
,
{
"weight_loader"
:
weight_loader
,
"ignore_warning"
:
True
})
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
):
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
act_scale
=
layer
.
input_scale
# Input quantize
x_q
,
_
=
custom_ops
.
scaled_int8_quant
(
x
,
act_scale
)
return
custom_ops
.
cutlass_scaled_mm_dq
(
x_q
,
weight
.
t
(),
act_scale
,
weight_scale
,
x
.
dtype
)
Prev
1
…
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment