Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ec5e299c
Commit
ec5e299c
authored
Feb 21, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.7.3' into v0.7.3-dev
parents
47bd229c
ed6e9075
Changes
521
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3952 additions
and
45 deletions
+3952
-45
vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json
...e/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json
+200
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json
...=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json
+164
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X.json
...e/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X.json
+200
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json
...e/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json
+1
-1
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json
...=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json
+164
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json
...e/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json
+200
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json
...=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json
+164
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X.json
...e/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X.json
+200
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+10
-6
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+2
-2
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+31
-23
vllm/model_executor/layers/logits_processor.py
vllm/model_executor/layers/logits_processor.py
+18
-12
vllm/model_executor/layers/mamba/mamba_mixer2.py
vllm/model_executor/layers/mamba/mamba_mixer2.py
+534
-0
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
+1
-1
vllm/model_executor/layers/mamba/ops/ssd_bmm.py
vllm/model_executor/layers/mamba/ops/ssd_bmm.py
+261
-0
vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
+619
-0
vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
+750
-0
vllm/model_executor/layers/mamba/ops/ssd_combined.py
vllm/model_executor/layers/mamba/ops/ssd_combined.py
+223
-0
vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
+207
-0
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+3
-0
No files found.
Too many changes to show.
To preserve performance only
521 of 521+
files are displayed.
Plain diff
Email patch
vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json
0 → 100644
View file @
ec5e299c
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
"64"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
32
,
"kpack"
:
2
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
}
vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json
0 → 100644
View file @
ec5e299c
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"8"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"24"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"48"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"64"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"96"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"1024"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"4096"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
}
}
vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X.json
0 → 100644
View file @
ec5e299c
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"48"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
"64"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"512"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
}
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json
View file @
ec5e299c
...
...
@@ -128,7 +128,7 @@
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
32
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"512"
:
{
...
...
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json
0 → 100644
View file @
ec5e299c
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"16"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"24"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"32"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"48"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"64"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"256"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"1536"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"4096"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
}
}
vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json
0 → 100644
View file @
ec5e299c
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
"48"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
"64"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
32
,
"kpack"
:
2
},
"512"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
}
vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json
0 → 100644
View file @
ec5e299c
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"16"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"48"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"64"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"96"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"256"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"512"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
},
"4096"
:
{
"BLOCK_SIZE_M"
:
256
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
}
}
vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X.json
0 → 100644
View file @
ec5e299c
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
16
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"8"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"16"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"24"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"32"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
2
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
1
},
"48"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"64"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"96"
:
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
4
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"256"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
4
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"512"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
2
,
"waves_per_eu"
:
0
,
"matrix_instr_nonkdim"
:
16
,
"kpack"
:
2
}
}
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
ec5e299c
...
...
@@ -609,7 +609,7 @@ def moe_align_block_size(
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
if
num_experts
>=
224
:
if
envs
.
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON
:
if
envs
.
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON
or
num_experts
!=
256
:
moe_align_block_size_triton
(
topk_ids
,
num_experts
,
...
...
@@ -619,6 +619,7 @@ def moe_align_block_size(
num_tokens_post_pad
,
)
else
:
# Currently requires num_experts=256
ops
.
sgl_moe_align_block_size
(
topk_ids
,
num_experts
,
...
...
@@ -1023,15 +1024,17 @@ def grouped_topk(hidden_states: torch.Tensor,
else
:
raise
ValueError
(
f
"Unsupported scoring function:
{
scoring_func
}
"
)
num_token
=
scores
.
shape
[
0
]
if
e_score_correction_bias
is
not
None
:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores
=
scores
scores
=
scores
+
e_score_correction_bias
.
unsqueeze
(
0
)
num_token
=
scores
.
shape
[
0
]
group_scores
=
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
max
(
dim
=-
1
).
values
# [n, n_group]
group_scores
=
(
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
topk
(
2
,
dim
=-
1
)[
0
].
sum
(
dim
=-
1
))
else
:
group_scores
=
scores
.
view
(
num_token
,
num_expert_group
,
-
1
).
max
(
dim
=-
1
).
values
# [n, n_group]
group_idx
=
torch
.
topk
(
group_scores
,
k
=
topk_group
,
dim
=-
1
,
sorted
=
False
)[
1
]
# [n, top_k_group]
group_mask
=
torch
.
zeros_like
(
group_scores
)
# [n, n_group]
...
...
@@ -1039,7 +1042,8 @@ def grouped_topk(hidden_states: torch.Tensor,
score_mask
=
group_mask
.
unsqueeze
(
-
1
).
expand
(
num_token
,
num_expert_group
,
scores
.
shape
[
-
1
]
//
num_expert_group
).
reshape
(
num_token
,
-
1
)
# [n, e]
tmp_scores
=
scores
.
masked_fill
(
~
score_mask
.
bool
(),
0.0
)
# [n, e]
tmp_scores
=
scores
.
masked_fill
(
~
score_mask
.
bool
(),
float
(
"-inf"
))
# [n, e]
if
e_score_correction_bias
is
not
None
:
topk_ids
=
torch
.
topk
(
tmp_scores
,
k
=
topk
,
dim
=-
1
,
sorted
=
False
)[
1
]
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
ec5e299c
...
...
@@ -360,8 +360,8 @@ class FusedMoE(torch.nn.Module):
"use_nn_moe"
:
self
.
use_nn_moe
,
}
# need full intermediate size pre-sharding for WNA16 act order
if
(
self
.
quant_method
.
__class__
.
__name__
==
"CompressedTensorsWNA16MoEMethod"
):
if
(
self
.
quant_method
.
__class__
.
__name__
in
(
"GPTQMarlinMoEMethod"
,
"CompressedTensorsWNA16MoEMethod"
)
)
:
moe_quant_params
[
"intermediate_size_full"
]
=
intermediate_size
self
.
quant_method
.
create_weights
(
layer
=
self
,
**
moe_quant_params
)
...
...
vllm/model_executor/layers/linear.py
View file @
ec5e299c
...
...
@@ -309,29 +309,30 @@ class ColumnParallelLinear(LinearBase):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
output_sizes
:
Optional
[
list
[
int
]]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
,
prefix
)
self
.
gather_output
=
gather_output
# Divide the weight matrix along the last dimension.
tp_size
=
get_tensor_model_parallel_world_size
()
assert
self
.
quant_method
is
not
Non
e
self
.
output_size_per_partition
=
divide
(
self
.
output_size
,
tp_size
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
input_siz
e
self
.
output_size_per_partition
=
divide
(
output_size
,
self
.
tp_size
)
self
.
output_partition_sizes
=
[
self
.
output_size_per_partition
]
# If QKV or MergedColumn, use output size of each partition.
if
hasattr
(
self
,
"output_sizes"
):
self
.
output_partition_sizes
=
[
divide
(
output_size
,
tp_size
)
divide
(
output_size
,
self
.
tp_size
)
for
output_size
in
self
.
output_sizes
]
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
,
prefix
)
self
.
gather_output
=
gather_output
if
output_sizes
is
None
:
output_sizes
=
[
output_size
]
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
layer
=
self
,
input_size_per_partition
=
self
.
input_size
,
input_size_per_partition
=
self
.
input_size
_per_partition
,
output_partition_sizes
=
self
.
output_partition_sizes
,
input_size
=
self
.
input_size
,
output_size
=
self
.
output_size
,
...
...
@@ -354,6 +355,12 @@ class ColumnParallelLinear(LinearBase):
tp_rank
=
get_tensor_model_parallel_rank
()
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
is_sharded_weight
=
getattr
(
param
,
"is_sharded_weight"
,
False
)
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight
=
is_sharded_weight
or
use_bitsandbytes_4bit
# Special case for GGUF
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
is_gguf_weight_type
=
getattr
(
param
,
"is_gguf_weight_type"
,
False
)
...
...
@@ -362,13 +369,12 @@ class ColumnParallelLinear(LinearBase):
# Materialize GGUF UninitializedParameter
if
is_gguf_weight
and
isinstance
(
param
,
UninitializedParameter
):
param
.
materialize
(
loaded_weight
.
shape
,
dtype
=
loaded_weight
.
dtype
)
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
is_sharded_weight
=
getattr
(
param
,
"is_sharded_weight"
,
False
)
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight
=
is_sharded_weight
or
use_bitsandbytes_4bit
final_shape
=
list
(
loaded_weight
.
shape
)
if
output_dim
is
not
None
:
tp_size
=
get_tensor_model_parallel_world_size
()
assert
final_shape
[
output_dim
]
%
tp_size
==
0
final_shape
[
output_dim
]
=
final_shape
[
output_dim
]
//
tp_size
param
.
materialize
(
final_shape
,
dtype
=
loaded_weight
.
dtype
)
param_data
=
param
.
data
if
output_dim
is
not
None
and
not
is_sharded_weight
:
...
...
@@ -1058,22 +1064,24 @@ class RowParallelLinear(LinearBase):
reduce_results
:
bool
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
# Divide the weight matrix along the first dimension.
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
self
.
output_size_per_partition
=
output_size
self
.
output_partition_sizes
=
[
output_size
]
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
,
prefix
)
self
.
input_is_parallel
=
input_is_parallel
self
.
reduce_results
=
reduce_results
# Divide the weight matrix along the last dimension.
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
layer
=
self
,
input_size_per_partition
=
self
.
input_size_per_partition
,
output_partition_sizes
=
[
self
.
output_size
]
,
output_partition_sizes
=
self
.
output_
partition_
size
s
,
input_size
=
self
.
input_size
,
output_size
=
self
.
output_size
,
params_dtype
=
self
.
params_dtype
,
...
...
vllm/model_executor/layers/logits_processor.py
View file @
ec5e299c
...
...
@@ -51,7 +51,6 @@ class LogitsProcessor(nn.Module):
# Soft cap the logits. Used in Gemma 2.
self
.
soft_cap
=
soft_cap
# Whether to use gather or all-gather to gather the logits.
parallel_config
=
get_current_vllm_config
().
parallel_config
self
.
use_all_gather
=
current_platform
.
is_tpu
()
\
or
envs
.
VLLM_USE_V1
\
...
...
@@ -88,17 +87,8 @@ class LogitsProcessor(nn.Module):
return
logits
def
_get_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
lm_head
:
VocabParallelEmbedding
,
embedding_bias
:
Optional
[
torch
.
Tensor
],
)
->
Optional
[
torch
.
Tensor
]:
# Get the logits for the next tokens.
logits
=
lm_head
.
linear_method
.
apply
(
lm_head
,
hidden_states
,
bias
=
embedding_bias
)
def
_gather_logits
(
self
,
logits
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""gather/all-gather the logits tensor across model parallel group."""
if
self
.
use_all_gather
:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
...
...
@@ -109,6 +99,22 @@ class LogitsProcessor(nn.Module):
else
:
# None may be returned for rank > 0
logits
=
tensor_model_parallel_gather
(
logits
)
return
logits
def
_get_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
lm_head
:
VocabParallelEmbedding
,
embedding_bias
:
Optional
[
torch
.
Tensor
],
)
->
Optional
[
torch
.
Tensor
]:
# Get the logits for the next tokens.
logits
=
lm_head
.
quant_method
.
apply
(
lm_head
,
hidden_states
,
bias
=
embedding_bias
)
# Gather logits for TP
logits
=
self
.
_gather_logits
(
logits
)
# Remove paddings in vocab (if any).
if
logits
is
not
None
:
logits
=
logits
[...,
:
self
.
org_vocab_size
]
...
...
vllm/model_executor/layers/mamba/mamba_mixer2.py
0 → 100644
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.attention.backends.placeholder_attn
import
(
PlaceholderAttentionMetadata
)
from
vllm.attention.backends.xformers
import
XFormersMetadata
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.mamba.ops.causal_conv1d
import
(
causal_conv1d_fn
,
causal_conv1d_update
)
from
vllm.model_executor.layers.mamba.ops.mamba_ssm
import
(
selective_state_update
)
from
vllm.model_executor.layers.mamba.ops.ssd_combined
import
(
mamba_chunk_scan_combined
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
(
LoaderFunction
,
composed_weight_loader
,
sharded_weight_loader
)
from
vllm.model_executor.models.mamba_cache
import
MambaCacheParams
from
vllm.model_executor.utils
import
set_weight_attrs
# Added by the IBM Team, 2024
# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated
@
CustomOp
.
register
(
"mixer2_gated_rms_norm"
)
class
Mixer2RMSNormGated
(
CustomOp
):
def
__init__
(
self
,
full_hidden_size
,
full_n_groups
,
eps
=
1e-6
):
super
().
__init__
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
full_hidden_size
=
full_hidden_size
self
.
group_size
=
full_hidden_size
//
full_n_groups
self
.
per_rank_hidden_size
=
full_hidden_size
//
self
.
tp_size
self
.
n_groups
=
full_hidden_size
//
self
.
group_size
self
.
variance_epsilon
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
self
.
per_rank_hidden_size
))
set_weight_attrs
(
self
.
weight
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
assert
self
.
full_hidden_size
%
self
.
tp_size
==
0
,
\
"Tensor parallel world size must divide hidden size."
def
forward_native
(
self
,
x
:
torch
.
Tensor
,
gate
:
torch
.
Tensor
,
):
# Three tensor-parallel cases:
# 1. n_groups is 1
# In this case we parallelize along the reduction dim.
# Each rank computes a local sum of squares followed by AllReduce
# 2. tp_size divides n_groups
# Each rank only reduces within its local group(s).
# No collective ops necessary.
# 3. The general case can be pretty complicated so we AllGather
# the input and then redundantly compute the RMSNorm.
input_dtype
=
x
.
dtype
x
=
x
*
nn
.
functional
.
silu
(
gate
.
to
(
torch
.
float32
))
if
self
.
n_groups
==
1
:
if
self
.
tp_size
>
1
:
# Compute local sum and then reduce to obtain global sum
local_sums
=
x
.
pow
(
2
).
sum
(
dim
=-
1
,
keepdim
=
True
)
global_sums
=
tensor_model_parallel_all_reduce
(
local_sums
)
# Calculate the variance
count
=
self
.
tp_size
*
x
.
shape
[
-
1
]
variance
=
(
global_sums
/
count
)
else
:
variance
=
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
else
:
redundant_tp
:
bool
=
self
.
n_groups
%
self
.
tp_size
!=
0
if
redundant_tp
:
# To handle the general case, redundantly apply the variance
x
=
tensor_model_parallel_all_gather
(
x
,
-
1
)
*
prefix_dims
,
hidden_dim
=
x
.
shape
group_count
=
hidden_dim
//
self
.
group_size
x_grouped
=
x
.
view
(
*
prefix_dims
,
group_count
,
self
.
group_size
)
variance
=
x_grouped
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
x_grouped
=
x_grouped
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
x
=
x_grouped
.
view
(
*
prefix_dims
,
hidden_dim
)
if
redundant_tp
:
start
=
self
.
per_rank_hidden_size
*
self
.
tp_rank
end
=
start
+
self
.
per_rank_hidden_size
x
=
x
[...,
start
:
end
]
return
self
.
weight
*
x
.
to
(
input_dtype
)
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
,
gate
:
torch
.
Tensor
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
if
self
.
tp_size
>
1
or
self
.
n_groups
!=
1
:
return
self
.
forward_native
(
x
,
gate
)
from
vllm
import
_custom_ops
as
ops
# cast x and gate to float32 before silu
out
=
torch
.
empty_like
(
x
)
y
=
x
*
nn
.
functional
.
silu
(
gate
.
to
(
torch
.
float32
))
ops
.
rms_norm
(
out
,
y
.
to
(
x
.
dtype
),
self
.
weight
.
data
,
self
.
variance_epsilon
,
)
return
out
def
extra_groups_for_head_shards
(
ngroups
:
int
,
tp_size
:
int
):
"""Compute the increase in group numbers to account for
replication in order to accompany the head shards."""
# in the case ngoups % tp_size == 0, this will be zero
if
ngroups
%
tp_size
==
0
:
return
0
return
tp_size
-
ngroups
%
tp_size
def
mamba_v2_sharded_weight_loader
(
shard_spec
:
List
[
Tuple
[
int
,
int
,
float
]],
tp_size
:
int
,
tp_rank
:
int
,
)
->
LoaderFunction
:
"""Create a weight loader for mamba v2. This ensures that the projections
are correctly sharded so that they can be split into x, B, C. It also
ensures the the all the groups corresponding to a head shard is placed
together with it.
"""
def
loader
(
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
# - track boundary of (sharded) param, and loaded_weight, respectively
boundary
,
loaded_boundary
=
0
,
0
# - iterate over the shard specs
for
full_dim
,
extra
,
ratio
in
shard_spec
:
# - full dim is the model dim (before TP).
# - extra > 0, means there is expected overall increase
# of dimensions. This is so because of replication.
# - ratio is used map the tp_rank to the actual shard
# rank. This is useful when there is replication of
# groups to accompany head shards.
# - size of the loaded shard
shard_size
=
full_dim
//
tp_size
# - compute the rank into the loaded shard.
# - if there is replication, different TP shards will
# take from the same rank.
rank
=
tp_rank
//
ratio
# - leftmost boundary index into loaded weight.
loaded_skip
=
rank
*
shard_size
loaded_start_idx
=
loaded_boundary
+
loaded_skip
# - take these many dims from the loaded weight.
take
=
min
(
shard_size
,
full_dim
-
extra
-
loaded_skip
)
# - always shard on dim 0
# - the ignore is for a mundane mypy error as it does not
# seem to handle slices well.
# https://github.com/python/mypy/issues/2410
param
.
data
[
boundary
:(
boundary
+
take
),
# type: ignore[misc]
...]
=
loaded_weight
[
loaded_start_idx
:(
# type: ignore[misc]
loaded_start_idx
+
take
)]
# type: ignore[misc]
# move indexing boundaries
boundary
+=
shard_size
loaded_boundary
+=
(
full_dim
-
extra
)
return
loader
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
@
CustomOp
.
register
(
"mamba_mixer2"
)
class
MambaMixer2
(
CustomOp
):
"""
Compute ∆, A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
for why A isn't selective) ∆, B, C are input-dependent
(this is a key difference between Mamba and the linear time
invariant S4, and is why Mamba is called
**selective** state spaces)
"""
def
__init__
(
self
,
hidden_size
:
int
,
ssm_state_size
:
int
,
conv_kernel_size
:
int
,
intermediate_size
:
int
,
use_conv_bias
:
bool
,
use_bias
:
bool
,
n_groups
:
int
=
1
,
num_heads
:
int
=
128
,
head_dim
:
int
=
64
,
rms_norm_eps
:
float
=
1e-5
,
activation
=
"silu"
,
chunk_size
:
int
=
256
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
# For TP, the sharding plan is as follows:
# - for the conv modules, since
# conv_dim = intermediate_size * 2 * n_groups * ssm_state_size,
# we shard intermediate_size and n_groups
# - since intermediate_size = n_heads * head_dim, sharding on
# intermediate_size is achieved by sharding on n_heads.
# - IF, world_size divides groups, then sharding
# (n_groups / world_size, n_heads / world_size)
# also maintains the invariant n_heads % n_groups == 0
# - HOWEVER IF, world_size DOES NOT divide groups, then we need
# to allocate extra space in the shard, such that groups
# may be replicated to follow the head shard.
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
assert
num_heads
%
self
.
tp_size
==
0
,
\
"Tensor parallel world size must divide num heads."
self
.
ssm_state_size
=
ssm_state_size
self
.
activation
=
activation
self
.
chunk_size
=
chunk_size
self
.
intermediate_size
=
intermediate_size
self
.
head_dim
=
head_dim
self
.
num_heads
=
num_heads
self
.
n_groups
=
n_groups
if
n_groups
%
self
.
tp_size
!=
0
:
# - for TP we shard conv_dim by sharding on n_groups,
# - but if n_groups cannot divide tp_size, we need to
# extend some extra groups
self
.
n_groups
=
n_groups
+
extra_groups_for_head_shards
(
n_groups
,
self
.
tp_size
)
self
.
conv_dim
=
(
intermediate_size
+
2
*
self
.
n_groups
*
ssm_state_size
)
self
.
conv1d
=
ColumnParallelLinear
(
input_size
=
conv_kernel_size
,
output_size
=
self
.
conv_dim
,
bias
=
use_conv_bias
,
quant_config
=
None
,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self
.
conv1d
.
weight
.
data
=
self
.
conv1d
.
weight
.
data
.
unsqueeze
(
1
)
self
.
in_proj
=
ColumnParallelLinear
(
input_size
=
hidden_size
,
output_size
=
intermediate_size
+
self
.
conv_dim
+
self
.
num_heads
,
bias
=
use_bias
,
quant_config
=
quant_config
)
# - because in_proj is a concatenation of 3 weights, we
# need to interleave them before sharding
# - use the custom weight loader mamba_v2_sharded_weight_loader
# for conv1d.bias, covn1d.weight and in_proj.weight
# - need to set these settings, to assign the groups to the head shards
group_shard_settings
=
(
self
.
n_groups
*
self
.
ssm_state_size
,
# expected model size
(
self
.
n_groups
-
n_groups
)
*
self
.
ssm_state_size
,
# extra dims assigned
self
.
num_heads
//
n_groups
,
# ratio for mapping back to original group
)
intermediate_settings
=
(
intermediate_size
,
0
,
1
)
head_setings
=
(
self
.
num_heads
,
0
,
1
)
# - the weight already has a "weight_loader" attribute
# which set_weight_attrs will raise if we do not
# delete before trying to override it
# - ditto for the otther two weights below
delattr
(
self
.
conv1d
.
bias
,
"weight_loader"
)
set_weight_attrs
(
self
.
conv1d
.
bias
,
{
"weight_loader"
:
mamba_v2_sharded_weight_loader
(
[
intermediate_settings
,
group_shard_settings
,
group_shard_settings
,
],
self
.
tp_size
,
tp_rank
,
)
})
delattr
(
self
.
conv1d
.
weight
,
"weight_loader"
)
set_weight_attrs
(
self
.
conv1d
.
weight
,
{
"weight_loader"
:
mamba_v2_sharded_weight_loader
([
intermediate_settings
,
group_shard_settings
,
group_shard_settings
,
],
self
.
tp_size
,
tp_rank
)
})
delattr
(
self
.
in_proj
.
weight
,
"weight_loader"
)
set_weight_attrs
(
self
.
in_proj
.
weight
,
{
"weight_loader"
:
mamba_v2_sharded_weight_loader
(
[
intermediate_settings
,
# for gate
intermediate_settings
,
group_shard_settings
,
group_shard_settings
,
head_setings
,
# for dt
],
self
.
tp_size
,
tp_rank
)
})
# - these are TPed by heads to reduce the size of the
# temporal shape
self
.
A
=
nn
.
Parameter
(
torch
.
empty
(
divide
(
num_heads
,
self
.
tp_size
),
dtype
=
torch
.
float32
,
))
self
.
D
=
nn
.
Parameter
(
torch
.
ones
(
num_heads
//
self
.
tp_size
))
self
.
dt_bias
=
nn
.
Parameter
(
torch
.
ones
(
num_heads
//
self
.
tp_size
))
set_weight_attrs
(
self
.
D
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
a_weight_loader
=
composed_weight_loader
(
sharded_weight_loader
(
0
),
lambda
x
:
-
torch
.
exp
(
x
.
float
()))
set_weight_attrs
(
self
.
A
,
{
"weight_loader"
:
a_weight_loader
})
set_weight_attrs
(
self
.
dt_bias
,
{
"weight_loader"
:
sharded_weight_loader
(
0
)})
self
.
out_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
use_bias
,
input_is_parallel
=
True
,
quant_config
=
quant_config
)
self
.
norm
=
Mixer2RMSNormGated
(
intermediate_size
,
n_groups
,
eps
=
rms_norm_eps
)
def
forward_native
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
conv_state
:
torch
.
Tensor
,
ssm_state
:
torch
.
Tensor
):
pass
def
forward_cuda
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
mamba_cache_params
:
MambaCacheParams
,
sequence_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
):
seq_len
,
_
=
hidden_states
.
shape
groups_time_state_size
=
self
.
n_groups
*
self
.
ssm_state_size
# detect if there are prefills
has_prefill
=
attn_metadata
.
num_prefills
>
0
# - also need flags to indicate if there are initial states
# - currently we really only support the FlashAttention backend
has_initial_states
=
None
if
(
isinstance
(
attn_metadata
,
(
FlashAttentionMetadata
,
XFormersMetadata
,
PlaceholderAttentionMetadata
))
and
attn_metadata
.
context_lens_tensor
is
not
None
):
has_initial_states
=
attn_metadata
.
context_lens_tensor
>
0
# 1. Gated MLP's linear projection
projected_states
,
_
=
self
.
in_proj
(
hidden_states
)
gate
,
hidden_states_B_C
,
dt
=
torch
.
split
(
projected_states
,
[
self
.
intermediate_size
//
self
.
tp_size
,
self
.
conv_dim
//
self
.
tp_size
,
self
.
num_heads
//
self
.
tp_size
,
],
dim
=-
1
,
)
# 2. Convolution sequence transformation
conv_weights
=
self
.
conv1d
.
weight
.
view
(
self
.
conv1d
.
weight
.
size
(
0
),
self
.
conv1d
.
weight
.
size
(
2
))
if
has_prefill
:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "mamba_cache_params.state_indices_tensor"
hidden_states_B_C
=
causal_conv1d_fn
(
hidden_states_B_C
.
transpose
(
0
,
1
),
conv_weights
,
self
.
conv1d
.
bias
,
activation
=
self
.
activation
,
conv_states
=
mamba_cache_params
.
conv_state
,
has_initial_state
=
has_initial_states
,
cache_indices
=
mamba_cache_params
.
state_indices_tensor
,
query_start_loc
=
attn_metadata
.
query_start_loc
).
transpose
(
0
,
1
)[:
seq_len
]
# TODO: Why is this needed?
hidden_states_B_C
=
hidden_states_B_C
.
contiguous
()
else
:
hidden_states_B_C
=
causal_conv1d_update
(
hidden_states_B_C
,
mamba_cache_params
.
conv_state
,
conv_weights
,
self
.
conv1d
.
bias
,
self
.
activation
,
conv_state_indices
=
mamba_cache_params
.
state_indices_tensor
)
# - get hidden_states, B and C after depthwise convolution.
hidden_states
,
B
,
C
=
torch
.
split
(
hidden_states_B_C
,
[
self
.
intermediate_size
//
self
.
tp_size
,
groups_time_state_size
//
self
.
tp_size
,
groups_time_state_size
//
self
.
tp_size
,
],
dim
=-
1
,
)
# 3. State Space Model sequence transformation
if
has_prefill
:
initial_states
=
None
if
has_initial_states
is
not
None
and
any
(
has_initial_states
):
for
idx
in
mamba_cache_params
.
state_indices_tensor
[
~
has_initial_states
]:
mamba_cache_params
.
ssm_state
[
idx
].
zero_
()
initial_states
=
mamba_cache_params
.
ssm_state
[
mamba_cache_params
.
state_indices_tensor
]
scan_output
,
varlen_state
=
mamba_chunk_scan_combined
(
hidden_states
.
view
(
1
,
seq_len
,
self
.
num_heads
//
self
.
tp_size
,
self
.
head_dim
),
dt
.
unsqueeze
(
0
),
self
.
A
,
B
.
view
(
1
,
seq_len
,
self
.
n_groups
//
self
.
tp_size
,
-
1
),
C
.
view
(
1
,
seq_len
,
self
.
n_groups
//
self
.
tp_size
,
-
1
),
chunk_size
=
self
.
chunk_size
,
D
=
self
.
D
,
z
=
None
,
dt_bias
=
self
.
dt_bias
,
seq_idx
=
sequence_idx
,
cu_seqlens
=
attn_metadata
.
query_start_loc
,
initial_states
=
initial_states
,
return_varlen_states
=
True
,
return_final_states
=
False
,
dt_softplus
=
True
,
dt_limit
=
(
0.0
,
float
(
"inf"
)),
)
# update ssm states
# - varlen state is a (batch, nheads, headdim, dstate) tensor
for
i
,
idx
in
enumerate
(
mamba_cache_params
.
state_indices_tensor
):
mamba_cache_params
.
ssm_state
[
idx
].
copy_
(
varlen_state
[
i
])
# - reshape
hidden_states
=
scan_output
.
view
(
seq_len
,
-
1
)
else
:
n_groups
=
self
.
n_groups
//
self
.
tp_size
A
=
self
.
A
[:,
None
,
...][:,
:,
None
].
expand
(
-
1
,
self
.
head_dim
,
self
.
ssm_state_size
).
to
(
dtype
=
torch
.
float32
)
dt
=
dt
[:,
:,
None
].
expand
(
-
1
,
-
1
,
self
.
head_dim
)
dt_bias
=
self
.
dt_bias
[:,
None
,
...].
expand
(
-
1
,
self
.
head_dim
)
D
=
self
.
D
[:,
None
,
...].
expand
(
-
1
,
self
.
head_dim
)
B
=
B
.
view
(
-
1
,
n_groups
,
B
.
shape
[
1
]
//
n_groups
)
C
=
C
.
view
(
-
1
,
n_groups
,
C
.
shape
[
1
]
//
n_groups
)
hidden_states_reshaped
=
hidden_states
.
view
(
-
1
,
self
.
num_heads
//
self
.
tp_size
,
self
.
head_dim
)
# - the hidden is reshaped into number of current batches
# - in this case there is no more prefill, so the batches gen
# 1 token at a time
# - thus hidden will be (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected
# using "mamba_cache_params.state_indices_tensor", just as
# above in the prefill case
hidden_states
=
selective_state_update
(
mamba_cache_params
.
ssm_state
,
hidden_states_reshaped
,
dt
,
A
,
B
,
C
,
D
,
z
=
None
,
dt_bias
=
dt_bias
,
dt_softplus
=
True
,
state_batch_indices
=
mamba_cache_params
.
state_indices_tensor
,
)
hidden_states
=
hidden_states
.
view
(
-
1
,
(
self
.
num_heads
//
self
.
tp_size
)
*
self
.
head_dim
)
# # 4. gated MLP
hidden_states
=
self
.
norm
(
hidden_states
,
gate
)
# # 5. Final linear projection
out
,
_
=
self
.
out_proj
(
hidden_states
)
return
out
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/
main
/mamba_ssm/ops/triton/selective_state_update.py
# Adapted from https://github.com/state-spaces/mamba/blob/
v2.2.4
/mamba_ssm/ops/triton/selective_state_update.py
import
torch
import
triton
...
...
vllm/model_executor/layers/mamba/ops/ssd_bmm.py
0 → 100644
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py
# ruff: noqa: E501,SIM102
import
math
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
autotune
(
configs
=
[
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
128
,
'BLOCK_SIZE_N'
:
256
,
'BLOCK_SIZE_K'
:
64
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
256
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
128
,
'BLOCK_SIZE_N'
:
128
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
128
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
128
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
128
,
'BLOCK_SIZE_N'
:
32
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
32
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
32
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
2
),
],
key
=
[
'chunk_size'
,
'K'
,
'IS_CAUSAL'
],
)
@
triton
.
jit
def
_bmm_chunk_fwd_kernel
(
# Pointers to matrices
a_ptr
,
b_ptr
,
out_ptr
,
seq_idx_ptr
,
# Matrix dimensions
seqlen
,
chunk_size
,
K
,
ngroups
,
stride_a_batch
,
stride_a_seqlen
,
stride_a_head
,
stride_ak
,
stride_b_batch
,
stride_b_seqlen
,
stride_b_head
,
stride_bk
,
stride_out_batch
,
stride_out_chunk
,
stride_out_head
,
stride_outm
,
stride_outn
,
stride_seq_idx_batch
,
stride_seq_idx_seqlen
,
# Meta-parameters
IS_CAUSAL
:
tl
.
constexpr
,
dot_dtype
:
tl
.
constexpr
,
HAS_SEQ_IDX
:
tl
.
constexpr
,
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
):
pid_b
=
tl
.
program_id
(
axis
=
1
)
pid_ch
=
tl
.
program_id
(
axis
=
2
).
to
(
tl
.
int64
)
pid_c
=
pid_ch
//
ngroups
pid_h
=
pid_ch
-
pid_c
*
ngroups
num_pid_n
=
tl
.
cdiv
(
chunk_size
,
BLOCK_SIZE_N
)
pid_m
=
tl
.
program_id
(
axis
=
0
)
//
num_pid_n
pid_n
=
tl
.
program_id
(
axis
=
0
)
%
num_pid_n
if
IS_CAUSAL
:
if
pid_n
*
BLOCK_SIZE_N
>=
(
pid_m
+
1
)
*
BLOCK_SIZE_M
:
return
a_ptr
+=
pid_b
*
stride_a_batch
+
pid_c
*
chunk_size
*
stride_a_seqlen
+
pid_h
*
stride_a_head
b_ptr
+=
pid_b
*
stride_b_batch
+
pid_c
*
chunk_size
*
stride_b_seqlen
+
pid_h
*
stride_b_head
if
HAS_SEQ_IDX
:
seq_idx_ptr
+=
pid_b
*
stride_seq_idx_batch
+
pid_c
*
chunk_size
*
stride_seq_idx_seqlen
offs_m
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_n
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
offs_m
[:,
None
]
*
stride_a_seqlen
+
offs_k
[
None
,
:]
*
stride_ak
)
b_ptrs
=
b_ptr
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_n
[
None
,
:]
*
stride_b_seqlen
)
chunk_size_limit
=
min
(
chunk_size
,
seqlen
-
pid_c
*
chunk_size
)
acc
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
a
=
tl
.
load
(
a_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
chunk_size_limit
)
&
(
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
),
other
=
0.0
).
to
(
dot_dtype
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
(
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
)
&
(
offs_n
[
None
,
:]
<
chunk_size_limit
),
other
=
0.0
).
to
(
dot_dtype
)
acc
+=
tl
.
dot
(
a
,
b
)
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
offs_m
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_n
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
if
HAS_SEQ_IDX
:
chunk_size_limit
=
min
(
chunk_size
,
seqlen
-
pid_c
*
chunk_size
)
seq_idx_m
=
tl
.
load
(
seq_idx_ptr
+
offs_m
*
stride_seq_idx_seqlen
,
mask
=
offs_m
<
chunk_size_limit
,
other
=-
1
)
seq_idx_n
=
tl
.
load
(
seq_idx_ptr
+
offs_n
*
stride_seq_idx_seqlen
,
mask
=
offs_n
<
chunk_size_limit
,
other
=-
2
)
acc
=
tl
.
where
(
seq_idx_m
[:,
None
]
==
seq_idx_n
[
None
,
:],
acc
,
0.0
)
out
=
acc
.
to
(
out_ptr
.
dtype
.
element_ty
)
out_ptr
+=
pid_b
*
stride_out_batch
+
pid_c
*
stride_out_chunk
+
pid_h
*
stride_out_head
out_ptrs
=
out_ptr
+
(
stride_outm
*
offs_m
[:,
None
]
+
offs_n
[
None
,
:]
*
stride_outn
)
tl
.
store
(
out_ptrs
,
out
,
mask
=
(
offs_m
[:,
None
]
<
chunk_size
)
&
(
offs_n
[
None
,
:]
<
chunk_size
))
def
_bmm_chunk_fwd
(
a
,
b
,
chunk_size
,
seq_idx
=
None
,
causal
=
False
,
output_dtype
=
None
):
"""
Argument:
a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
b: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
guaranteed to be correct.
Return:
out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
"""
# Check constraints.
has_groups
=
a
.
dim
()
==
4
if
not
has_groups
:
batch
,
seqlen
,
k
=
a
.
shape
else
:
batch
,
seqlen
,
ngroups
,
k
=
a
.
shape
assert
b
.
shape
==
a
.
shape
if
seq_idx
is
not
None
:
assert
seq_idx
.
shape
==
(
batch
,
seqlen
)
if
a
.
stride
(
-
1
)
!=
1
and
a
.
stride
(
1
)
!=
1
:
a
=
a
.
contiguous
()
if
b
.
stride
(
-
1
)
!=
1
and
b
.
stride
(
1
)
!=
1
:
b
=
b
.
contiguous
()
nchunks
=
math
.
ceil
(
seqlen
/
chunk_size
)
# Allocates output.
out_dtype
=
a
.
dtype
if
output_dtype
is
None
else
output_dtype
out
=
torch
.
empty
(
(
batch
,
nchunks
,
chunk_size
,
chunk_size
)
if
not
has_groups
else
(
batch
,
nchunks
,
ngroups
,
chunk_size
,
chunk_size
),
device
=
a
.
device
,
dtype
=
out_dtype
)
dot_dtype
=
(
tl
.
bfloat16
if
a
.
dtype
==
torch
.
bfloat16
or
b
.
dtype
==
torch
.
bfloat16
else
(
tl
.
float16
if
a
.
dtype
==
torch
.
float16
or
b
.
dtype
==
torch
.
float16
else
tl
.
float32
))
grid
=
lambda
META
:
(
triton
.
cdiv
(
chunk_size
,
META
[
'BLOCK_SIZE_M'
])
*
triton
.
cdiv
(
chunk_size
,
META
[
'BLOCK_SIZE_N'
]),
batch
,
nchunks
if
not
has_groups
else
nchunks
*
ngroups
)
with
torch
.
cuda
.
device
(
a
.
device
.
index
):
_bmm_chunk_fwd_kernel
[
grid
](
a
,
b
,
out
,
seq_idx
,
seqlen
,
chunk_size
,
k
,
ngroups
if
has_groups
else
1
,
a
.
stride
(
0
),
a
.
stride
(
1
),
0
if
not
has_groups
else
a
.
stride
(
2
),
a
.
stride
(
-
1
),
b
.
stride
(
0
),
b
.
stride
(
1
),
0
if
not
has_groups
else
b
.
stride
(
2
),
b
.
stride
(
-
1
),
out
.
stride
(
0
),
out
.
stride
(
1
),
0
if
not
has_groups
else
out
.
stride
(
2
),
out
.
stride
(
-
2
),
out
.
stride
(
-
1
),
*
((
seq_idx
.
stride
(
0
),
seq_idx
.
stride
(
1
))
if
seq_idx
is
not
None
else
(
0
,
0
)),
causal
,
dot_dtype
,
HAS_SEQ_IDX
=
seq_idx
is
not
None
,
)
return
out
vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
0 → 100644
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py
# ruff: noqa: E501,SIM102
import
math
import
torch
import
triton
import
triton.language
as
tl
from
packaging
import
version
TRITON_22
=
version
.
parse
(
triton
.
__version__
)
>=
version
.
parse
(
'2.2.0'
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
128
,
'BLOCK_SIZE_N'
:
256
,
'BLOCK_SIZE_K'
:
64
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
256
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
128
,
'BLOCK_SIZE_N'
:
128
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
128
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
128
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
128
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
64
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
128
,
'BLOCK_SIZE_K'
:
64
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
128
,
'BLOCK_SIZE_N'
:
32
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
32
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
32
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
2
),
],
key
=
[
'chunk_size'
,
'hdim'
,
'dstate'
,
'IS_CAUSAL'
],
)
@
triton
.
jit
def
_chunk_scan_fwd_kernel
(
# Pointers to matrices
cb_ptr
,
x_ptr
,
z_ptr
,
out_ptr
,
out_x_ptr
,
dt_ptr
,
dA_cumsum_ptr
,
seq_idx_ptr
,
C_ptr
,
states_ptr
,
D_ptr
,
initstates_ptr
,
chunk_indices_ptr
,
chunk_offsets_ptr
,
chunk_meta_num
,
# Matrix dimensions
chunk_size
,
hdim
,
dstate
,
batch
,
seqlen
,
nheads_ngroups_ratio
,
# Strides
stride_cb_batch
,
stride_cb_chunk
,
stride_cb_head
,
stride_cb_csize_m
,
stride_cb_csize_k
,
stride_x_batch
,
stride_x_seqlen
,
stride_x_head
,
stride_x_hdim
,
stride_z_batch
,
stride_z_seqlen
,
stride_z_head
,
stride_z_hdim
,
stride_out_batch
,
stride_out_seqlen
,
stride_out_head
,
stride_out_hdim
,
stride_dt_batch
,
stride_dt_chunk
,
stride_dt_head
,
stride_dt_csize
,
stride_dA_cs_batch
,
stride_dA_cs_chunk
,
stride_dA_cs_head
,
stride_dA_cs_csize
,
stride_seq_idx_batch
,
stride_seq_idx_seqlen
,
stride_C_batch
,
stride_C_seqlen
,
stride_C_head
,
stride_C_dstate
,
stride_states_batch
,
stride_states_chunk
,
stride_states_head
,
stride_states_hdim
,
stride_states_dstate
,
stride_init_states_batch
,
stride_init_states_head
,
stride_init_states_hdim
,
stride_init_states_dstate
,
stride_D_head
,
# Meta-parameters
IS_CAUSAL
:
tl
.
constexpr
,
HAS_D
:
tl
.
constexpr
,
D_HAS_HDIM
:
tl
.
constexpr
,
HAS_Z
:
tl
.
constexpr
,
HAS_SEQ_IDX
:
tl
.
constexpr
,
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
BLOCK_SIZE_DSTATE
:
tl
.
constexpr
,
IS_TRITON_22
:
tl
.
constexpr
,
HAS_INITSTATES
:
tl
.
constexpr
,
):
pid_bc
=
tl
.
program_id
(
axis
=
1
).
to
(
tl
.
int64
)
pid_c
=
pid_bc
//
batch
pid_b
=
pid_bc
-
pid_c
*
batch
if
not
HAS_INITSTATES
:
c_idx
=
pid_c
c_off
=
0
else
:
c_idx
=
tl
.
load
(
chunk_indices_ptr
+
pid_c
,
mask
=
pid_c
>
-
1
,
other
=
0
)
c_off
=
tl
.
load
(
chunk_offsets_ptr
+
pid_c
,
mask
=
pid_c
>
-
1
,
other
=
0
)
pid_h
=
tl
.
program_id
(
axis
=
2
)
num_pid_n
=
tl
.
cdiv
(
hdim
,
BLOCK_SIZE_N
)
pid_m
=
tl
.
program_id
(
axis
=
0
)
//
num_pid_n
pid_n
=
tl
.
program_id
(
axis
=
0
)
%
num_pid_n
cb_ptr
+=
pid_b
*
stride_cb_batch
+
c_idx
*
stride_cb_chunk
+
(
pid_h
//
nheads_ngroups_ratio
)
*
stride_cb_head
x_ptr
+=
pid_b
*
stride_x_batch
+
c_idx
*
chunk_size
*
stride_x_seqlen
+
pid_h
*
stride_x_head
dt_ptr
+=
pid_b
*
stride_dt_batch
+
c_idx
*
stride_dt_chunk
+
pid_h
*
stride_dt_head
dA_cumsum_ptr
+=
pid_b
*
stride_dA_cs_batch
+
c_idx
*
stride_dA_cs_chunk
+
pid_h
*
stride_dA_cs_head
C_ptr
+=
pid_b
*
stride_C_batch
+
c_idx
*
chunk_size
*
stride_C_seqlen
+
(
pid_h
//
nheads_ngroups_ratio
)
*
stride_C_head
# M-block offsets and prev states
# - logic in next block may override these if there is an active offset
offs_m
=
pid_m
*
BLOCK_SIZE_M
+
c_off
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
prev_states_ptr
=
states_ptr
+
pid_b
*
stride_states_batch
+
c_idx
*
stride_states_chunk
+
pid_h
*
stride_states_head
prev_states_hdim
=
stride_states_hdim
prev_states_dstate
=
stride_states_dstate
chunk_size_limit
=
min
(
chunk_size
,
seqlen
-
c_idx
*
chunk_size
)
if
HAS_SEQ_IDX
:
seq_idx_ptr
+=
pid_b
*
stride_seq_idx_batch
+
c_idx
*
chunk_size
*
stride_seq_idx_seqlen
# - we only need seq_idx_prev to be aligned to chunk boundary
seq_idx_prev
=
tl
.
load
(
seq_idx_ptr
-
stride_seq_idx_seqlen
,
mask
=
c_idx
>=
1
,
other
=
0
)
if
HAS_INITSTATES
:
# if there are init states, we only need seq_idx_m to point
# what is the current seq_idx
# get current seq idx
if
(
pid_m
*
BLOCK_SIZE_M
+
c_off
)
<
chunk_size_limit
:
seq_idx_m
=
tl
.
load
(
seq_idx_ptr
+
(
pid_m
*
BLOCK_SIZE_M
+
c_off
)
*
stride_seq_idx_seqlen
,
)
# - recall that in ssd_state_passing, for the case c_off == 0
# i.e., the very first sequence, we made states_ptr hold its initial state
# so this edge case is taken care of
if
((
c_off
==
0
)
and
(
seq_idx_prev
!=
seq_idx_m
)
# if a seq is changed exactly on boundary
or
(
c_off
>
0
)
# implies a new example (pseudo chunk)
):
# - replace prev_states_ptr with init_states
prev_states_ptr
=
initstates_ptr
+
seq_idx_m
*
stride_init_states_batch
+
pid_h
*
stride_init_states_head
prev_states_hdim
=
stride_init_states_hdim
# override strides
prev_states_dstate
=
stride_init_states_dstate
offs_n
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
dA_cs_m
=
tl
.
load
(
dA_cumsum_ptr
+
offs_m
*
stride_dA_cs_csize
,
mask
=
offs_m
<
chunk_size
,
other
=
0.0
).
to
(
tl
.
float32
)
# - handle chunk state limit
if
HAS_INITSTATES
:
# have to split this if otherwise compilation will have problems
dA_cs_m_boundary
=
0.0
# get the c_idx for the next (logica) chunk
c_idx_n
=
tl
.
load
(
chunk_indices_ptr
+
(
pid_c
+
1
),
mask
=
pid_c
>
-
1
and
(
pid_c
+
1
)
<
chunk_meta_num
,
other
=-
1
# to trigger different chunk
)
# - there are things to consider
# A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct
# contribution of past states
# B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to
# encroach into the next sequence, where c_off_n is the offset of the next
# (logical) chunk.
# An equivalent check for B is c_idx == c_idx_n, where there is repetition in
# (logical) chunk indices.
if
(
c_idx
==
c_idx_n
)
or
c_off
>
0
:
# get the next offset
c_off_n
=
tl
.
load
(
chunk_offsets_ptr
+
(
pid_c
+
1
),
mask
=
pid_c
>
-
1
and
(
pid_c
+
1
)
<
chunk_meta_num
,
other
=
chunk_size
)
# in this case, adjust down the chunk_size_limit
if
c_idx
==
c_idx_n
:
chunk_size_limit
=
min
(
c_off_n
,
chunk_size_limit
)
# get the cs at the offset boundary
# - c_off == 0 is a passthrough
dA_cs_m_boundary
=
tl
.
load
(
dA_cumsum_ptr
+
(
pid_m
*
BLOCK_SIZE_M
+
c_off
-
1
)
*
stride_dA_cs_csize
,
mask
=
(((
pid_m
*
BLOCK_SIZE_M
+
c_off
-
1
)
>
-
1
)
and
((
pid_m
*
BLOCK_SIZE_M
+
c_off
)
<
chunk_size
)),
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_SEQ_IDX
:
# - handle seq idx when HAS_INITSTATES==False
if
not
HAS_INITSTATES
:
seq_idx_m
=
tl
.
load
(
seq_idx_ptr
+
offs_m
*
stride_seq_idx_seqlen
,
mask
=
offs_m
<
chunk_size_limit
,
other
=-
1
)
acc
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
# Without the if (pid_c > -1), with Triton 2.1.0, I get
# Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed.
# With Triton 2.2.0, this works
if
IS_TRITON_22
or
c_idx
>
-
1
:
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
offs_k_dstate
=
tl
.
arange
(
0
,
BLOCK_SIZE_DSTATE
if
BLOCK_SIZE_DSTATE
<=
128
else
BLOCK_SIZE_K
)
C_ptrs
=
C_ptr
+
(
offs_m
[:,
None
]
*
stride_C_seqlen
+
offs_k_dstate
[
None
,
:]
*
stride_C_dstate
)
prev_states_ptrs
=
prev_states_ptr
+
(
offs_n
[
None
,
:]
*
prev_states_hdim
+
offs_k_dstate
[:,
None
]
*
prev_states_dstate
)
if
HAS_SEQ_IDX
:
if
not
HAS_INITSTATES
:
# - this is for continuous batching where there is no init states
scale_m
=
tl
.
where
(
seq_idx_m
==
seq_idx_prev
,
tl
.
exp
(
dA_cs_m
),
0.0
)
else
:
# - if there is initstates, we will rely on prev_states, no zeroing
# required.
scale_m
=
tl
.
exp
(
dA_cs_m
-
dA_cs_m_boundary
)
else
:
scale_m
=
tl
.
exp
(
dA_cs_m
)
if
BLOCK_SIZE_DSTATE
<=
128
:
C
=
tl
.
load
(
C_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
chunk_size_limit
)
&
(
offs_k_dstate
[
None
,
:]
<
dstate
),
other
=
0.0
)
prev_states
=
tl
.
load
(
prev_states_ptrs
,
mask
=
(
offs_k_dstate
[:,
None
]
<
dstate
)
&
(
offs_n
[
None
,
:]
<
hdim
),
other
=
0.0
)
prev_states
=
prev_states
.
to
(
C_ptr
.
dtype
.
element_ty
)
acc
=
tl
.
dot
(
C
,
prev_states
)
*
scale_m
[:,
None
]
else
:
for
k
in
range
(
0
,
dstate
,
BLOCK_SIZE_K
):
C
=
tl
.
load
(
C_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
chunk_size_limit
)
&
(
offs_k_dstate
[
None
,
:]
<
dstate
-
k
),
other
=
0.0
)
# C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty)
prev_states
=
tl
.
load
(
prev_states_ptrs
,
mask
=
(
offs_k_dstate
[:,
None
]
<
dstate
-
k
)
&
(
offs_n
[
None
,
:]
<
hdim
),
other
=
0.0
)
prev_states
=
prev_states
.
to
(
C_ptr
.
dtype
.
element_ty
)
acc
+=
tl
.
dot
(
C
,
prev_states
)
C_ptrs
+=
BLOCK_SIZE_K
prev_states_ptrs
+=
BLOCK_SIZE_K
acc
*=
scale_m
[:,
None
]
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
+
c_off
cb_ptrs
=
cb_ptr
+
(
offs_m
[:,
None
]
*
stride_cb_csize_m
+
offs_k
[
None
,
:]
*
stride_cb_csize_k
)
x_ptrs
=
x_ptr
+
(
offs_k
[:,
None
]
*
stride_x_seqlen
+
offs_n
[
None
,
:]
*
stride_x_hdim
)
dt_ptrs
=
dt_ptr
+
offs_k
*
stride_dt_csize
dA_cumsum_ptrs
=
dA_cumsum_ptr
+
offs_k
*
stride_dA_cs_csize
K_MAX
=
chunk_size_limit
if
not
IS_CAUSAL
else
min
(
(
pid_m
+
1
)
*
BLOCK_SIZE_M
,
chunk_size_limit
)
for
k
in
range
(
0
,
K_MAX
,
BLOCK_SIZE_K
):
cb
=
tl
.
load
(
cb_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
chunk_size
)
&
(
offs_k
[
None
,
:]
<
chunk_size
-
k
),
other
=
0.0
).
to
(
tl
.
float32
)
dA_cs_k
=
tl
.
load
(
dA_cumsum_ptrs
,
mask
=
offs_k
<
chunk_size
-
k
,
other
=
0.0
).
to
(
tl
.
float32
)
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
# So we don't need masking wrt seq_idx here.
cb
*=
tl
.
exp
(
dA_cs_m
[:,
None
]
-
dA_cs_k
[
None
,
:])
dt_k
=
tl
.
load
(
dt_ptrs
,
mask
=
offs_k
<
chunk_size
-
k
,
other
=
0.0
).
to
(
tl
.
float32
)
cb
*=
dt_k
if
IS_CAUSAL
:
mask
=
offs_m
[:,
None
]
>=
k
+
offs_k
[
None
,
:]
cb
=
tl
.
where
(
mask
,
cb
,
0.0
)
cb
=
cb
.
to
(
x_ptr
.
dtype
.
element_ty
)
x
=
tl
.
load
(
x_ptrs
,
mask
=
(
offs_k
[:,
None
]
<
chunk_size_limit
-
k
)
&
(
offs_n
[
None
,
:]
<
hdim
),
other
=
0.0
)
acc
+=
tl
.
dot
(
cb
,
x
)
cb_ptrs
+=
BLOCK_SIZE_K
*
stride_cb_csize_k
x_ptrs
+=
BLOCK_SIZE_K
*
stride_x_seqlen
dt_ptrs
+=
BLOCK_SIZE_K
*
stride_dt_csize
dA_cumsum_ptrs
+=
BLOCK_SIZE_K
*
stride_dA_cs_csize
offs_out_m
=
pid_m
*
BLOCK_SIZE_M
+
c_off
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_out_n
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
if
HAS_D
:
if
D_HAS_HDIM
:
D
=
tl
.
load
(
D_ptr
+
pid_h
*
stride_D_head
+
offs_n
,
mask
=
offs_n
<
hdim
,
other
=
0.0
).
to
(
tl
.
float32
)
else
:
D
=
tl
.
load
(
D_ptr
+
pid_h
*
stride_D_head
).
to
(
tl
.
float32
)
x_residual
=
tl
.
load
(
x_ptr
+
(
offs_m
[:,
None
]
*
stride_x_seqlen
+
offs_n
[
None
,
:]
*
stride_x_hdim
),
mask
=
(
offs_m
[:,
None
]
<
chunk_size_limit
)
&
(
offs_n
[
None
,
:]
<
hdim
),
other
=
0.0
).
to
(
tl
.
float32
)
acc
+=
x_residual
*
D
if
HAS_Z
:
out_x_ptr
+=
pid_b
*
stride_out_batch
+
c_idx
*
chunk_size
*
stride_out_seqlen
+
pid_h
*
stride_out_head
out_x_ptrs
=
out_x_ptr
+
(
stride_out_seqlen
*
offs_out_m
[:,
None
]
+
offs_out_n
[
None
,
:])
tl
.
store
(
out_x_ptrs
,
acc
,
mask
=
(
offs_out_m
[:,
None
]
<
chunk_size_limit
)
&
(
offs_out_n
[
None
,
:]
<
hdim
))
z_ptr
+=
pid_b
*
stride_z_batch
+
c_idx
*
chunk_size
*
stride_z_seqlen
+
pid_h
*
stride_z_head
z_ptrs
=
z_ptr
+
(
stride_z_seqlen
*
offs_out_m
[:,
None
]
+
stride_z_hdim
*
offs_out_n
[
None
,
:])
z
=
tl
.
load
(
z_ptrs
,
mask
=
(
offs_out_m
[:,
None
]
<
chunk_size_limit
)
&
(
offs_out_n
[
None
,
:]
<
hdim
),
other
=
0.0
).
to
(
tl
.
float32
)
acc
*=
z
*
tl
.
sigmoid
(
z
)
out_ptr
+=
pid_b
*
stride_out_batch
+
c_idx
*
chunk_size
*
stride_out_seqlen
+
pid_h
*
stride_out_head
out_ptrs
=
out_ptr
+
(
stride_out_seqlen
*
offs_out_m
[:,
None
]
+
offs_out_n
[
None
,
:]
*
stride_out_hdim
)
tl
.
store
(
out_ptrs
,
acc
,
mask
=
(
offs_out_m
[:,
None
]
<
chunk_size_limit
)
&
(
offs_out_n
[
None
,
:]
<
hdim
))
def
_seq_idx_to_chunk_indices_offsets
(
seq_idx
,
chunk_size
:
int
):
# convert seq_idx to chunk indices and offsets
# - derive the cu_seqlens
_
,
cu_seqlens
=
torch
.
where
(
seq_idx
.
diff
())
cu_seqlens
+=
1
# outputs will have length expansion of chunks that do not divide
# chunk_size
N
=
math
.
ceil
(
seq_idx
.
shape
[
-
1
]
/
chunk_size
)
+
(
cu_seqlens
%
chunk_size
>
0
).
sum
()
chunk_indices
=
torch
.
arange
(
N
,
dtype
=
torch
.
int
,
device
=
seq_idx
.
device
)
chunk_offsets
=
torch
.
zeros
((
N
,
),
dtype
=
torch
.
int
,
device
=
seq_idx
.
device
)
cu_seqlens
=
cu_seqlens
.
tolist
()
+
[
seq_idx
.
shape
[
-
1
]]
p
=
0
# num of insertions
for
s
,
e
in
zip
(
cu_seqlens
[:
-
1
],
cu_seqlens
[
1
:]):
# if does not divide chunk_size, then there is one chunk insertion
p
+=
(
s
%
chunk_size
>
0
)
# get the dimensions
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s
,
_e
=
s
//
chunk_size
+
p
,
e
//
chunk_size
+
p
+
(
e
%
chunk_size
>
0
)
# adjust inidces and offsets
chunk_indices
[
_s
:
_e
]
-=
p
chunk_offsets
[
_s
]
=
s
%
chunk_size
return
chunk_indices
,
chunk_offsets
def
_chunk_scan_fwd
(
cb
,
x
,
dt
,
dA_cumsum
,
C
,
states
,
D
=
None
,
z
=
None
,
seq_idx
=
None
,
initial_states
=
None
,
):
batch
,
seqlen
,
nheads
,
headdim
=
x
.
shape
_
,
_
,
nchunks
,
chunk_size
=
dt
.
shape
_
,
_
,
ngroups
,
dstate
=
C
.
shape
assert
nheads
%
ngroups
==
0
assert
C
.
shape
==
(
batch
,
seqlen
,
ngroups
,
dstate
)
assert
cb
.
shape
==
(
batch
,
nchunks
,
ngroups
,
chunk_size
,
chunk_size
)
if
z
is
not
None
:
assert
z
.
shape
==
x
.
shape
if
D
is
not
None
:
assert
D
.
shape
==
(
nheads
,
headdim
)
or
D
.
shape
==
(
nheads
,
)
assert
dt
.
shape
==
(
batch
,
nheads
,
nchunks
,
chunk_size
)
assert
dA_cumsum
.
shape
==
(
batch
,
nheads
,
nchunks
,
chunk_size
)
assert
states
.
shape
==
(
batch
,
nchunks
,
nheads
,
headdim
,
dstate
)
chunk_indices
,
chunk_offsets
=
None
,
None
if
seq_idx
is
not
None
:
assert
seq_idx
.
shape
==
(
batch
,
seqlen
)
if
initial_states
is
not
None
:
# with initial states, we need to take care of how
# seq_idx crosses the boundaries
assert
batch
==
1
,
"chunk scan only supports initial states with batch 1"
assert
initial_states
.
shape
==
(
seq_idx
[
0
].
max
()
+
1
,
nheads
,
headdim
,
dstate
)
if
initial_states
.
shape
[
0
]
==
1
:
# no in this case no point to use initial states
initial_states
=
None
else
:
chunk_indices
,
chunk_offsets
=
_seq_idx_to_chunk_indices_offsets
(
seq_idx
,
chunk_size
)
# Allocates output.
out
=
torch
.
empty
(
batch
,
seqlen
,
nheads
,
headdim
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
if
z
is
not
None
:
out_x
=
torch
.
empty
(
batch
,
seqlen
,
nheads
,
headdim
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
assert
out_x
.
stride
()
==
out
.
stride
()
else
:
out_x
=
None
grid
=
lambda
META
:
(
triton
.
cdiv
(
chunk_size
,
META
[
'BLOCK_SIZE_M'
])
*
triton
.
cdiv
(
headdim
,
META
[
'BLOCK_SIZE_N'
]),
batch
*
nchunks
if
chunk_offsets
is
None
else
len
(
chunk_offsets
),
nheads
)
z_strides
=
((
z
.
stride
(
0
),
z
.
stride
(
1
),
z
.
stride
(
2
),
z
.
stride
(
3
))
if
z
is
not
None
else
(
0
,
0
,
0
,
0
))
_chunk_scan_fwd_kernel
[
grid
](
cb
,
x
,
z
,
out
,
out_x
,
dt
,
dA_cumsum
,
seq_idx
,
C
,
states
,
D
,
initial_states
,
chunk_indices
,
chunk_offsets
,
len
(
chunk_indices
)
if
chunk_indices
is
not
None
else
0
,
chunk_size
,
headdim
,
dstate
,
batch
,
seqlen
,
nheads
//
ngroups
,
cb
.
stride
(
0
),
cb
.
stride
(
1
),
cb
.
stride
(
2
),
cb
.
stride
(
3
),
cb
.
stride
(
4
),
x
.
stride
(
0
),
x
.
stride
(
1
),
x
.
stride
(
2
),
x
.
stride
(
3
),
z_strides
[
0
],
z_strides
[
1
],
z_strides
[
2
],
z_strides
[
3
],
out
.
stride
(
0
),
out
.
stride
(
1
),
out
.
stride
(
2
),
out
.
stride
(
3
),
dt
.
stride
(
0
),
dt
.
stride
(
2
),
dt
.
stride
(
1
),
dt
.
stride
(
3
),
dA_cumsum
.
stride
(
0
),
dA_cumsum
.
stride
(
2
),
dA_cumsum
.
stride
(
1
),
dA_cumsum
.
stride
(
3
),
*
((
seq_idx
.
stride
(
0
),
seq_idx
.
stride
(
1
))
if
seq_idx
is
not
None
else
(
0
,
0
)),
C
.
stride
(
0
),
C
.
stride
(
1
),
C
.
stride
(
2
),
C
.
stride
(
3
),
states
.
stride
(
0
),
states
.
stride
(
1
),
states
.
stride
(
2
),
states
.
stride
(
3
),
states
.
stride
(
4
),
*
((
initial_states
.
stride
(
0
),
initial_states
.
stride
(
1
),
initial_states
.
stride
(
2
),
initial_states
.
stride
(
3
))
if
initial_states
is
not
None
else
(
0
,
0
,
0
,
0
)),
D
.
stride
(
0
)
if
D
is
not
None
else
0
,
True
,
D
is
not
None
,
D
.
dim
()
==
2
if
D
is
not
None
else
True
,
BLOCK_SIZE_DSTATE
=
max
(
triton
.
next_power_of_2
(
dstate
),
16
),
HAS_Z
=
z
is
not
None
,
HAS_SEQ_IDX
=
seq_idx
is
not
None
,
IS_TRITON_22
=
TRITON_22
,
HAS_INITSTATES
=
initial_states
is
not
None
,
)
return
out
,
out_x
vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
0 → 100644
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_state.py
# ruff: noqa: E501
import
math
import
torch
import
triton
import
triton.language
as
tl
from
.mamba_ssm
import
softplus
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BLOCK_SIZE_H'
:
1
}),
triton
.
Config
({
'BLOCK_SIZE_H'
:
2
}),
triton
.
Config
({
'BLOCK_SIZE_H'
:
4
}),
triton
.
Config
({
'BLOCK_SIZE_H'
:
8
}),
triton
.
Config
({
'BLOCK_SIZE_H'
:
16
}),
triton
.
Config
({
'BLOCK_SIZE_H'
:
32
}),
triton
.
Config
({
'BLOCK_SIZE_H'
:
64
}),
],
key
=
[
'chunk_size'
,
'nheads'
],
)
@
triton
.
jit
def
_chunk_cumsum_fwd_kernel
(
# Pointers to matrices
dt_ptr
,
A_ptr
,
dt_bias_ptr
,
dt_out_ptr
,
dA_cumsum_ptr
,
# Matrix dimension
batch
,
seqlen
,
nheads
,
chunk_size
,
dt_min
,
dt_max
,
# Strides
stride_dt_batch
,
stride_dt_seqlen
,
stride_dt_head
,
stride_A_head
,
stride_dt_bias_head
,
stride_dt_out_batch
,
stride_dt_out_chunk
,
stride_dt_out_head
,
stride_dt_out_csize
,
stride_dA_cs_batch
,
stride_dA_cs_chunk
,
stride_dA_cs_head
,
stride_dA_cs_csize
,
# Meta-parameters
DT_SOFTPLUS
:
tl
.
constexpr
,
HAS_DT_BIAS
:
tl
.
constexpr
,
BLOCK_SIZE_H
:
tl
.
constexpr
,
BLOCK_SIZE_CHUNK
:
tl
.
constexpr
,
):
pid_b
=
tl
.
program_id
(
axis
=
0
)
# if dt is long, may cause problems, so use 64 bit
# https://github.com/triton-lang/triton/issues/1058
pid_c
=
tl
.
program_id
(
axis
=
1
).
to
(
tl
.
int64
)
pid_h
=
tl
.
program_id
(
axis
=
2
)
dt_ptr
+=
pid_b
*
stride_dt_batch
+
pid_c
*
chunk_size
*
stride_dt_seqlen
dt_out_ptr
+=
pid_b
*
stride_dt_out_batch
+
pid_c
*
stride_dt_out_chunk
dA_cumsum_ptr
+=
pid_b
*
stride_dA_cs_batch
+
pid_c
*
stride_dA_cs_chunk
offs_h
=
pid_h
*
BLOCK_SIZE_H
+
tl
.
arange
(
0
,
BLOCK_SIZE_H
)
offs_c
=
tl
.
arange
(
0
,
BLOCK_SIZE_CHUNK
)
dt_ptrs
=
dt_ptr
+
(
offs_h
[:,
None
]
*
stride_dt_head
+
offs_c
[
None
,
:]
*
stride_dt_seqlen
)
A_ptrs
=
A_ptr
+
offs_h
*
stride_A_head
dt_out_ptrs
=
dt_out_ptr
+
(
offs_h
[:,
None
]
*
stride_dt_out_head
+
offs_c
[
None
,
:]
*
stride_dt_out_csize
)
dA_cs_ptrs
=
dA_cumsum_ptr
+
(
offs_h
[:,
None
]
*
stride_dA_cs_head
+
offs_c
[
None
,
:]
*
stride_dA_cs_csize
)
chunk_size_limit
=
min
(
chunk_size
,
seqlen
-
pid_c
*
chunk_size
)
dt
=
tl
.
load
(
dt_ptrs
,
mask
=
(
offs_h
[:,
None
]
<
nheads
)
&
(
offs_c
[
None
,
:]
<
chunk_size_limit
),
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_DT_BIAS
:
dt_bias
=
tl
.
load
(
dt_bias_ptr
+
offs_h
*
stride_dt_bias_head
,
mask
=
offs_h
<
nheads
,
other
=
0.0
).
to
(
tl
.
float32
)
dt
+=
dt_bias
[:,
None
]
if
DT_SOFTPLUS
:
dt
=
tl
.
where
(
dt
<=
20.0
,
softplus
(
dt
),
dt
)
# As of Triton 2.2.0, tl.clamp is not available yet
# dt = tl.clamp(dt, dt_min, dt_max)
dt
=
tl
.
minimum
(
tl
.
maximum
(
dt
,
dt_min
),
dt_max
)
dt
=
tl
.
where
(
(
offs_h
[:,
None
]
<
nheads
)
&
(
offs_c
[
None
,
:]
<
chunk_size_limit
),
dt
,
0.0
)
tl
.
store
(
dt_out_ptrs
,
dt
,
mask
=
(
offs_h
[:,
None
]
<
nheads
)
&
(
offs_c
[
None
,
:]
<
chunk_size
))
A
=
tl
.
load
(
A_ptrs
,
mask
=
offs_h
<
nheads
,
other
=
0.0
).
to
(
tl
.
float32
)
dA
=
dt
*
A
[:,
None
]
dA_cs
=
tl
.
cumsum
(
dA
,
axis
=
1
)
tl
.
store
(
dA_cs_ptrs
,
dA_cs
,
mask
=
(
offs_h
[:,
None
]
<
nheads
)
&
(
offs_c
[
None
,
:]
<
chunk_size
))
@
triton
.
autotune
(
configs
=
[
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
128
,
'BLOCK_SIZE_N'
:
256
,
'BLOCK_SIZE_K'
:
64
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
256
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
128
,
'BLOCK_SIZE_N'
:
128
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
128
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
128
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
128
,
'BLOCK_SIZE_N'
:
32
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
32
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
32
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
2
),
],
key
=
[
'hdim'
,
'dstate'
,
'chunk_size'
],
)
@
triton
.
jit
def
_chunk_state_fwd_kernel
(
# Pointers to matrices
x_ptr
,
b_ptr
,
states_ptr
,
dt_ptr
,
dA_cumsum_ptr
,
seq_idx_ptr
,
# Matrix dimensions
hdim
,
dstate
,
chunk_size
,
batch
,
seqlen
,
nheads_ngroups_ratio
,
# Strides
stride_x_batch
,
stride_x_seqlen
,
stride_x_head
,
stride_x_hdim
,
stride_b_batch
,
stride_b_seqlen
,
stride_b_head
,
stride_b_dstate
,
stride_states_batch
,
stride_states_chunk
,
stride_states_head
,
stride_states_hdim
,
stride_states_dstate
,
stride_dt_batch
,
stride_dt_chunk
,
stride_dt_head
,
stride_dt_csize
,
stride_dA_cs_batch
,
stride_dA_cs_chunk
,
stride_dA_cs_head
,
stride_dA_cs_csize
,
stride_seq_idx_batch
,
stride_seq_idx_seqlen
,
# Meta-parameters
HAS_SEQ_IDX
:
tl
.
constexpr
,
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
):
pid_bc
=
tl
.
program_id
(
axis
=
1
).
to
(
tl
.
int64
)
pid_c
=
pid_bc
//
batch
pid_b
=
pid_bc
-
pid_c
*
batch
pid_h
=
tl
.
program_id
(
axis
=
2
)
num_pid_n
=
tl
.
cdiv
(
dstate
,
BLOCK_SIZE_N
)
pid_m
=
tl
.
program_id
(
axis
=
0
)
//
num_pid_n
pid_n
=
tl
.
program_id
(
axis
=
0
)
%
num_pid_n
b_ptr
+=
pid_b
*
stride_b_batch
+
pid_c
*
chunk_size
*
stride_b_seqlen
+
(
pid_h
//
nheads_ngroups_ratio
)
*
stride_b_head
x_ptr
+=
pid_b
*
stride_x_batch
+
pid_c
*
chunk_size
*
stride_x_seqlen
+
pid_h
*
stride_x_head
dt_ptr
+=
pid_b
*
stride_dt_batch
+
pid_c
*
stride_dt_chunk
+
pid_h
*
stride_dt_head
dA_cumsum_ptr
+=
pid_b
*
stride_dA_cs_batch
+
pid_c
*
stride_dA_cs_chunk
+
pid_h
*
stride_dA_cs_head
if
HAS_SEQ_IDX
:
seq_idx_ptr
+=
pid_b
*
stride_seq_idx_batch
+
pid_c
*
chunk_size
*
stride_seq_idx_seqlen
offs_m
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_n
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
x_ptrs
=
x_ptr
+
(
offs_m
[:,
None
]
*
stride_x_hdim
+
offs_k
[
None
,
:]
*
stride_x_seqlen
)
b_ptrs
=
b_ptr
+
(
offs_n
[
None
,
:]
*
stride_b_dstate
+
offs_k
[:,
None
]
*
stride_b_seqlen
)
dt_ptrs
=
dt_ptr
+
offs_k
*
stride_dt_csize
dA_cs_last
=
tl
.
load
(
dA_cumsum_ptr
+
(
chunk_size
-
1
)
*
stride_dA_cs_csize
).
to
(
tl
.
float32
)
dA_cumsum_ptrs
=
dA_cumsum_ptr
+
offs_k
*
stride_dA_cs_csize
if
HAS_SEQ_IDX
:
seq_idx_ptrs
=
seq_idx_ptr
+
offs_k
*
stride_seq_idx_seqlen
chunk_size_limit
=
min
(
chunk_size
,
seqlen
-
pid_c
*
chunk_size
)
if
HAS_SEQ_IDX
:
seq_idx_last
=
tl
.
load
(
seq_idx_ptr
+
(
chunk_size_limit
-
1
)
*
stride_seq_idx_seqlen
)
acc
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
chunk_size_limit
,
BLOCK_SIZE_K
):
x
=
tl
.
load
(
x_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
hdim
)
&
(
offs_k
[
None
,
:]
<
chunk_size_limit
-
k
),
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
(
offs_k
[:,
None
]
<
chunk_size_limit
-
k
)
&
(
offs_n
[
None
,
:]
<
dstate
),
other
=
0.0
).
to
(
tl
.
float32
)
dA_cs_k
=
tl
.
load
(
dA_cumsum_ptrs
,
mask
=
offs_k
<
chunk_size_limit
-
k
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_SEQ_IDX
:
seq_idx_k
=
tl
.
load
(
seq_idx_ptrs
,
mask
=
offs_k
<
chunk_size_limit
-
k
,
other
=-
1
)
dt_k
=
tl
.
load
(
dt_ptrs
,
mask
=
offs_k
<
chunk_size_limit
-
k
,
other
=
0.0
).
to
(
tl
.
float32
)
if
not
HAS_SEQ_IDX
:
scale
=
tl
.
exp
(
dA_cs_last
-
dA_cs_k
)
*
dt_k
else
:
scale
=
tl
.
where
(
seq_idx_k
==
seq_idx_last
,
tl
.
exp
(
dA_cs_last
-
dA_cs_k
)
*
dt_k
,
0.0
)
b
*=
scale
[:,
None
]
b
=
b
.
to
(
x_ptr
.
dtype
.
element_ty
)
acc
+=
tl
.
dot
(
x
,
b
)
x_ptrs
+=
BLOCK_SIZE_K
*
stride_x_seqlen
b_ptrs
+=
BLOCK_SIZE_K
*
stride_b_seqlen
dt_ptrs
+=
BLOCK_SIZE_K
*
stride_dt_csize
dA_cumsum_ptrs
+=
BLOCK_SIZE_K
*
stride_dA_cs_csize
if
HAS_SEQ_IDX
:
seq_idx_ptrs
+=
BLOCK_SIZE_K
*
stride_seq_idx_seqlen
states
=
acc
.
to
(
states_ptr
.
dtype
.
element_ty
)
states_ptr
+=
pid_b
*
stride_states_batch
+
pid_c
*
stride_states_chunk
+
pid_h
*
stride_states_head
offs_m
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_n
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
states_ptrs
=
states_ptr
+
(
offs_m
[:,
None
]
*
stride_states_hdim
+
offs_n
[
None
,
:]
*
stride_states_dstate
)
c_mask
=
(
offs_m
[:,
None
]
<
hdim
)
&
(
offs_n
[
None
,
:]
<
dstate
)
tl
.
store
(
states_ptrs
,
states
,
mask
=
c_mask
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
128
,
'BLOCK_SIZE_N'
:
256
,
'BLOCK_SIZE_K'
:
64
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
256
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
128
,
'BLOCK_SIZE_N'
:
128
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
128
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
128
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
128
,
'BLOCK_SIZE_N'
:
32
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
32
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
32
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
(
{
'BLOCK_SIZE_M'
:
64
,
'BLOCK_SIZE_N'
:
64
,
'BLOCK_SIZE_K'
:
32
},
num_stages
=
4
,
num_warps
=
2
),
],
key
=
[
'hdim'
,
'dstate'
,
'chunk_size'
],
)
@
triton
.
jit
def
_chunk_state_varlen_kernel
(
# Pointers to matrices
x_ptr
,
b_ptr
,
dt_ptr
,
dA_cumsum_ptr
,
chunk_states_ptr
,
cu_seqlens_ptr
,
states_ptr
,
initstates_ptr
,
# Matrix dimensions
hdim
,
dstate
,
chunk_size
,
seqlen
,
nheads_ngroups_ratio
,
# Strides
stride_x_seqlen
,
stride_x_head
,
stride_x_hdim
,
stride_b_seqlen
,
stride_b_head
,
stride_b_dstate
,
stride_dt_chunk
,
stride_dt_head
,
stride_dt_csize
,
stride_dA_cs_chunk
,
stride_dA_cs_head
,
stride_dA_cs_csize
,
stride_chunk_states_chunk
,
stride_chunk_states_head
,
stride_chunk_states_hdim
,
stride_chunk_states_dstate
,
stride_states_batch
,
stride_states_head
,
stride_states_hdim
,
stride_states_dstate
,
stride_init_states_batch
,
stride_init_states_head
,
stride_init_states_hdim
,
stride_init_states_dstate
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
HAS_INITSTATES
:
tl
.
constexpr
,
):
pid_b
=
tl
.
program_id
(
axis
=
1
)
pid_h
=
tl
.
program_id
(
axis
=
2
)
num_pid_n
=
tl
.
cdiv
(
dstate
,
BLOCK_SIZE_N
)
pid_m
=
tl
.
program_id
(
axis
=
0
)
//
num_pid_n
pid_n
=
tl
.
program_id
(
axis
=
0
)
%
num_pid_n
end_idx
=
tl
.
load
(
cu_seqlens_ptr
+
pid_b
+
1
)
pid_c
=
(
end_idx
-
1
)
//
chunk_size
b_ptr
+=
pid_c
*
chunk_size
*
stride_b_seqlen
+
(
pid_h
//
nheads_ngroups_ratio
)
*
stride_b_head
x_ptr
+=
pid_c
*
chunk_size
*
stride_x_seqlen
+
pid_h
*
stride_x_head
dt_ptr
+=
pid_c
*
stride_dt_chunk
+
pid_h
*
stride_dt_head
dA_cumsum_ptr
+=
pid_c
*
stride_dA_cs_chunk
+
pid_h
*
stride_dA_cs_head
chunk_states_ptr
+=
pid_c
*
stride_chunk_states_chunk
+
pid_h
*
stride_chunk_states_head
if
HAS_INITSTATES
:
# if there are init states provided, we differentiate between states (which
# are boundary conditions at a chunk boundary) and initstates (which are boundary
# conditions when a new example in a cont batch starts)
initstates_ptr
+=
pid_h
*
stride_init_states_head
offs_m
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_n
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
x_ptrs
=
x_ptr
+
(
offs_m
[:,
None
]
*
stride_x_hdim
+
offs_k
[
None
,
:]
*
stride_x_seqlen
)
b_ptrs
=
b_ptr
+
(
offs_n
[
None
,
:]
*
stride_b_dstate
+
offs_k
[:,
None
]
*
stride_b_seqlen
)
dt_ptrs
=
dt_ptr
+
offs_k
*
stride_dt_csize
dA_cs_last
=
tl
.
load
(
dA_cumsum_ptr
+
(
end_idx
-
pid_c
*
chunk_size
-
1
)
*
stride_dA_cs_csize
).
to
(
tl
.
float32
)
dA_cumsum_ptrs
=
dA_cumsum_ptr
+
offs_k
*
stride_dA_cs_csize
chunk_size_limit
=
end_idx
-
pid_c
*
chunk_size
start_idx
=
tl
.
load
(
cu_seqlens_ptr
+
pid_b
)
start_idx_cur
=
tl
.
maximum
(
start_idx
-
pid_c
*
chunk_size
,
0
)
acc
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
chunk_size_limit
,
BLOCK_SIZE_K
):
x
=
tl
.
load
(
x_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
hdim
)
&
(
offs_k
[
None
,
:]
<
chunk_size_limit
-
k
)
&
(
offs_k
[
None
,
:]
>=
start_idx_cur
-
k
),
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
(
offs_k
[:,
None
]
<
chunk_size_limit
-
k
)
&
(
offs_n
[
None
,
:]
<
dstate
)
&
(
offs_k
[:,
None
]
>=
start_idx_cur
-
k
),
other
=
0.0
).
to
(
tl
.
float32
)
dA_cs_k
=
tl
.
load
(
dA_cumsum_ptrs
,
mask
=
offs_k
<
chunk_size_limit
-
k
,
other
=
0.0
).
to
(
tl
.
float32
)
dt_k
=
tl
.
load
(
dt_ptrs
,
mask
=
offs_k
<
chunk_size_limit
-
k
,
other
=
0.0
).
to
(
tl
.
float32
)
scale
=
tl
.
where
(
(
offs_k
>=
start_idx_cur
-
k
)
&
(
offs_k
<
chunk_size_limit
-
k
),
tl
.
exp
(
dA_cs_last
-
dA_cs_k
)
*
dt_k
,
0.0
)
b
*=
scale
[:,
None
]
b
=
b
.
to
(
x_ptr
.
dtype
.
element_ty
)
acc
+=
tl
.
dot
(
x
,
b
)
x_ptrs
+=
BLOCK_SIZE_K
*
stride_x_seqlen
b_ptrs
+=
BLOCK_SIZE_K
*
stride_b_seqlen
dt_ptrs
+=
BLOCK_SIZE_K
*
stride_dt_csize
dA_cumsum_ptrs
+=
BLOCK_SIZE_K
*
stride_dA_cs_csize
# If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
# If HAS_INITSTATES==True need to consider two possiblties
# - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs
# - if state_idx >= pid * chunk_size, then we need to insert initstates
if
((
start_idx
<
pid_c
*
chunk_size
)
# first chunk
or
(
HAS_INITSTATES
)):
dA_cs_boundary
=
0.0
# default
if
not
HAS_INITSTATES
:
past_states_ptrs
=
chunk_states_ptr
+
(
offs_m
[:,
None
]
*
stride_chunk_states_hdim
+
offs_n
[
None
,
:]
*
stride_chunk_states_dstate
)
else
:
# - this seems repetitve, buts its to help the compiler
if
start_idx
<
pid_c
*
chunk_size
:
past_states_ptrs
=
chunk_states_ptr
+
(
offs_m
[:,
None
]
*
stride_chunk_states_hdim
+
offs_n
[
None
,
:]
*
stride_chunk_states_dstate
)
else
:
past_states_ptrs
=
initstates_ptr
+
(
pid_b
*
stride_init_states_batch
+
offs_m
[:,
None
]
*
stride_init_states_hdim
+
offs_n
[
None
,
:]
*
stride_init_states_dstate
)
# need to adjust the boundary
if
start_idx
>
pid_c
*
chunk_size
:
dA_cs_boundary
=
tl
.
load
(
dA_cumsum_ptr
+
(
start_idx
-
pid_c
*
chunk_size
-
1
)
*
stride_dA_cs_csize
).
to
(
tl
.
float32
)
past_states
=
tl
.
load
(
past_states_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
hdim
)
&
(
offs_n
[
None
,
:]
<
dstate
),
other
=
0.0
).
to
(
tl
.
float32
)
scale
=
tl
.
exp
(
dA_cs_last
-
dA_cs_boundary
)
acc
+=
past_states
*
scale
states
=
acc
.
to
(
states_ptr
.
dtype
.
element_ty
)
states_ptr
+=
pid_b
*
stride_states_batch
+
pid_h
*
stride_states_head
offs_m
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_n
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
states_ptrs
=
states_ptr
+
(
offs_m
[:,
None
]
*
stride_states_hdim
+
offs_n
[
None
,
:]
*
stride_states_dstate
)
c_mask
=
(
offs_m
[:,
None
]
<
hdim
)
&
(
offs_n
[
None
,
:]
<
dstate
)
tl
.
store
(
states_ptrs
,
states
,
mask
=
c_mask
)
def
_chunk_cumsum_fwd
(
dt
,
A
,
chunk_size
,
dt_bias
=
None
,
dt_softplus
=
False
,
dt_limit
=
(
0.0
,
float
(
"inf"
))):
batch
,
seqlen
,
nheads
=
dt
.
shape
assert
A
.
shape
==
(
nheads
,
)
if
dt_bias
is
not
None
:
assert
dt_bias
.
shape
==
(
nheads
,
)
nchunks
=
math
.
ceil
(
seqlen
/
chunk_size
)
dt_out
=
torch
.
empty
(
batch
,
nheads
,
nchunks
,
chunk_size
,
device
=
dt
.
device
,
dtype
=
torch
.
float32
)
dA_cumsum
=
torch
.
empty
(
batch
,
nheads
,
nchunks
,
chunk_size
,
device
=
dt
.
device
,
dtype
=
torch
.
float32
)
grid_chunk_cs
=
lambda
META
:
(
batch
,
nchunks
,
triton
.
cdiv
(
nheads
,
META
[
'BLOCK_SIZE_H'
]))
with
torch
.
cuda
.
device
(
dt
.
device
.
index
):
_chunk_cumsum_fwd_kernel
[
grid_chunk_cs
](
dt
,
A
,
dt_bias
,
dt_out
,
dA_cumsum
,
batch
,
seqlen
,
nheads
,
chunk_size
,
dt_limit
[
0
],
dt_limit
[
1
],
dt
.
stride
(
0
),
dt
.
stride
(
1
),
dt
.
stride
(
2
),
A
.
stride
(
0
),
dt_bias
.
stride
(
0
)
if
dt_bias
is
not
None
else
0
,
dt_out
.
stride
(
0
),
dt_out
.
stride
(
2
),
dt_out
.
stride
(
1
),
dt_out
.
stride
(
3
),
dA_cumsum
.
stride
(
0
),
dA_cumsum
.
stride
(
2
),
dA_cumsum
.
stride
(
1
),
dA_cumsum
.
stride
(
3
),
dt_softplus
,
HAS_DT_BIAS
=
dt_bias
is
not
None
,
BLOCK_SIZE_CHUNK
=
triton
.
next_power_of_2
(
chunk_size
),
)
return
dA_cumsum
,
dt_out
def
_chunk_state_fwd
(
B
,
x
,
dt
,
dA_cumsum
,
seq_idx
=
None
,
states
=
None
,
states_in_fp32
=
True
):
batch
,
seqlen
,
nheads
,
headdim
=
x
.
shape
_
,
_
,
nchunks
,
chunk_size
=
dt
.
shape
_
,
_
,
ngroups
,
dstate
=
B
.
shape
assert
nheads
%
ngroups
==
0
assert
B
.
shape
==
(
batch
,
seqlen
,
ngroups
,
dstate
)
assert
dt
.
shape
==
(
batch
,
nheads
,
nchunks
,
chunk_size
)
assert
dA_cumsum
.
shape
==
dt
.
shape
if
seq_idx
is
not
None
:
assert
seq_idx
.
shape
==
(
batch
,
seqlen
)
if
states
is
not
None
:
assert
states
.
shape
==
(
batch
,
nchunks
,
nheads
,
headdim
,
dstate
)
else
:
states_dtype
=
torch
.
float32
if
states_in_fp32
else
B
.
dtype
states
=
torch
.
empty
((
batch
,
nchunks
,
nheads
,
headdim
,
dstate
),
device
=
x
.
device
,
dtype
=
states_dtype
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
headdim
,
META
[
'BLOCK_SIZE_M'
])
*
triton
.
cdiv
(
dstate
,
META
[
'BLOCK_SIZE_N'
]),
batch
*
nchunks
,
nheads
)
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
_chunk_state_fwd_kernel
[
grid
](
x
,
B
,
states
,
dt
,
dA_cumsum
,
seq_idx
,
headdim
,
dstate
,
chunk_size
,
batch
,
seqlen
,
nheads
//
ngroups
,
x
.
stride
(
0
),
x
.
stride
(
1
),
x
.
stride
(
2
),
x
.
stride
(
3
),
B
.
stride
(
0
),
B
.
stride
(
1
),
B
.
stride
(
2
),
B
.
stride
(
-
1
),
states
.
stride
(
0
),
states
.
stride
(
1
),
states
.
stride
(
2
),
states
.
stride
(
3
),
states
.
stride
(
4
),
dt
.
stride
(
0
),
dt
.
stride
(
2
),
dt
.
stride
(
1
),
dt
.
stride
(
3
),
dA_cumsum
.
stride
(
0
),
dA_cumsum
.
stride
(
2
),
dA_cumsum
.
stride
(
1
),
dA_cumsum
.
stride
(
3
),
*
((
seq_idx
.
stride
(
0
),
seq_idx
.
stride
(
1
))
if
seq_idx
is
not
None
else
(
0
,
0
)),
HAS_SEQ_IDX
=
seq_idx
is
not
None
,
)
return
states
def
chunk_state_varlen
(
B
,
x
,
dt
,
dA_cumsum
,
cu_seqlens
,
chunk_states
,
initial_states
=
None
):
total_seqlen
,
nheads
,
headdim
=
x
.
shape
_
,
nchunks
,
chunk_size
=
dt
.
shape
_
,
ngroups
,
dstate
=
B
.
shape
batch
=
cu_seqlens
.
shape
[
0
]
-
1
cu_seqlens
=
cu_seqlens
.
contiguous
()
assert
nheads
%
ngroups
==
0
assert
B
.
shape
==
(
total_seqlen
,
ngroups
,
dstate
)
assert
dt
.
shape
==
(
nheads
,
nchunks
,
chunk_size
)
assert
dA_cumsum
.
shape
==
dt
.
shape
assert
chunk_states
.
shape
==
(
nchunks
,
nheads
,
headdim
,
dstate
)
if
initial_states
is
not
None
:
assert
initial_states
.
shape
==
(
batch
,
nheads
,
headdim
,
dstate
)
states
=
torch
.
empty
(
batch
,
nheads
,
headdim
,
dstate
,
dtype
=
chunk_states
.
dtype
,
device
=
chunk_states
.
device
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
headdim
,
META
[
'BLOCK_SIZE_M'
])
*
triton
.
cdiv
(
dstate
,
META
[
'BLOCK_SIZE_N'
]),
batch
,
nheads
)
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
_chunk_state_varlen_kernel
[
grid
](
x
,
B
,
dt
,
dA_cumsum
,
chunk_states
,
cu_seqlens
,
states
,
initial_states
,
headdim
,
dstate
,
chunk_size
,
total_seqlen
,
nheads
//
ngroups
,
x
.
stride
(
0
),
x
.
stride
(
1
),
x
.
stride
(
2
),
B
.
stride
(
0
),
B
.
stride
(
1
),
B
.
stride
(
2
),
dt
.
stride
(
1
),
dt
.
stride
(
0
),
dt
.
stride
(
2
),
dA_cumsum
.
stride
(
1
),
dA_cumsum
.
stride
(
0
),
dA_cumsum
.
stride
(
2
),
chunk_states
.
stride
(
0
),
chunk_states
.
stride
(
1
),
chunk_states
.
stride
(
2
),
chunk_states
.
stride
(
3
),
states
.
stride
(
0
),
states
.
stride
(
1
),
states
.
stride
(
2
),
states
.
stride
(
3
),
*
((
initial_states
.
stride
(
0
),
initial_states
.
stride
(
1
),
initial_states
.
stride
(
2
),
initial_states
.
stride
(
3
))
if
initial_states
is
not
None
else
(
0
,
0
,
0
,
0
)),
HAS_INITSTATES
=
initial_states
is
not
None
)
return
states
vllm/model_executor/layers/mamba/ops/ssd_combined.py
0 → 100644
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py
# ruff: noqa: E501
import
torch
import
triton
from
einops
import
rearrange
from
packaging
import
version
from
.ssd_bmm
import
_bmm_chunk_fwd
from
.ssd_chunk_scan
import
_chunk_scan_fwd
from
.ssd_chunk_state
import
(
_chunk_cumsum_fwd
,
_chunk_state_fwd
,
chunk_state_varlen
)
from
.ssd_state_passing
import
_state_passing_fwd
TRITON_22
=
version
.
parse
(
triton
.
__version__
)
>=
version
.
parse
(
'2.2.0'
)
def
_mamba_chunk_scan_combined_fwd
(
x
,
dt
,
A
,
B
,
C
,
chunk_size
,
D
=
None
,
z
=
None
,
dt_bias
=
None
,
initial_states
=
None
,
seq_idx
=
None
,
cu_seqlens
=
None
,
dt_softplus
=
False
,
dt_limit
=
(
0.0
,
float
(
"inf"
))):
batch
,
seqlen
,
nheads
,
headdim
=
x
.
shape
_
,
_
,
ngroups
,
dstate
=
B
.
shape
assert
nheads
%
ngroups
==
0
assert
B
.
shape
==
(
batch
,
seqlen
,
ngroups
,
dstate
)
assert
x
.
shape
==
(
batch
,
seqlen
,
nheads
,
headdim
)
assert
dt
.
shape
==
(
batch
,
seqlen
,
nheads
)
assert
A
.
shape
==
(
nheads
,
)
assert
C
.
shape
==
B
.
shape
if
z
is
not
None
:
assert
z
.
shape
==
x
.
shape
if
D
is
not
None
:
assert
D
.
shape
==
(
nheads
,
headdim
)
or
D
.
shape
==
(
nheads
,
)
if
seq_idx
is
not
None
:
assert
seq_idx
.
shape
==
(
batch
,
seqlen
)
if
B
.
stride
(
-
1
)
!=
1
:
B
=
B
.
contiguous
()
if
C
.
stride
(
-
1
)
!=
1
:
C
=
C
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
and
x
.
stride
(
1
)
!=
1
:
# Either M or K dimension should be contiguous
x
=
x
.
contiguous
()
if
z
is
not
None
and
z
.
stride
(
-
1
)
!=
1
and
z
.
stride
(
1
)
!=
1
:
# Either M or K dimension should be contiguous
z
=
z
.
contiguous
()
if
D
is
not
None
and
D
.
stride
(
-
1
)
!=
1
:
D
=
D
.
contiguous
()
if
initial_states
is
not
None
:
if
cu_seqlens
is
None
:
assert
initial_states
.
shape
==
(
batch
,
nheads
,
headdim
,
dstate
)
else
:
assert
initial_states
.
shape
==
(
len
(
cu_seqlens
)
-
1
,
nheads
,
headdim
,
dstate
)
# This function executes 5 sub-functions for computing mamba
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
# which has a minimal implementation to understand the below operations
# - as explained by the blog, mamba is a special case of causal attention
# - the idea is to chunk the attention matrix and compute each
# submatrix separately using different optimizations.
# - see the blog and paper for a visualization of the submatrices
# which we refer to in the comments below
# 1. Compute chunked cumsum of A * dt
# - here dt may go through a softplus activation
dA_cumsum
,
dt
=
_chunk_cumsum_fwd
(
dt
,
A
,
chunk_size
,
dt_bias
=
dt_bias
,
dt_softplus
=
dt_softplus
,
dt_limit
=
dt_limit
)
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
states
=
_chunk_state_fwd
(
B
,
x
,
dt
,
dA_cumsum
,
seq_idx
=
seq_idx
,
states_in_fp32
=
True
)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
# - for handling chunked prefill, this requires i) initial_states
# ii) seq_idx and iii) has_cu_seqlens to be all specified.
# - When a new seq_idx is detected, we will stop passing the prev_state
# and switch accordingly to the init_state corresponding to the new seq_idx.
# - this will ensure that states will be updated with the rightmost flushed seq_idx
# of the previous chunk. This implies that the first chunk of states is either 0
# or equal to init_states of the first example.
states
,
final_states
=
_state_passing_fwd
(
rearrange
(
states
,
"... p n -> ... (p n)"
),
dA_cumsum
[:,
:,
:,
-
1
],
initial_states
=
rearrange
(
initial_states
,
"... p n -> ... (p n)"
)
if
initial_states
is
not
None
else
None
,
seq_idx
=
seq_idx
,
chunk_size
=
chunk_size
,
out_dtype
=
C
.
dtype
,
is_cont_batched
=
cu_seqlens
is
not
None
)
states
,
final_states
=
(
rearrange
(
t
,
"... (p n) -> ... p n"
,
n
=
dstate
)
for
t
in
[
states
,
final_states
])
# 4. Compute batched matrix multiply for C_j^T B_i terms
CB
=
_bmm_chunk_fwd
(
C
,
B
,
chunk_size
,
seq_idx
=
seq_idx
,
output_dtype
=
torch
.
float32
)
# 5. Scan and compute the diagonal blocks, taking into
# account past causal states.
# - if initial states are provided, then states information will be
# augmented with initial_states.
# - to do this properly, we need to account for example changes in
# the continuous batch, therefore we introduce pseudo chunks, which is
# a chunk that is split up each time an example changes.
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
# a seq_idx change, in which case we take states information from
# init_states.
out
,
out_x
=
_chunk_scan_fwd
(
CB
,
x
,
dt
,
dA_cumsum
,
C
,
states
,
D
=
D
,
z
=
z
,
seq_idx
=
seq_idx
,
initial_states
=
initial_states
,
)
if
cu_seqlens
is
None
:
return
out
,
out_x
,
dt
,
dA_cumsum
,
states
,
final_states
else
:
assert
batch
==
1
,
"passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
varlen_states
=
chunk_state_varlen
(
B
.
squeeze
(
0
),
x
.
squeeze
(
0
),
dt
.
squeeze
(
0
),
dA_cumsum
.
squeeze
(
0
),
cu_seqlens
,
states
.
squeeze
(
0
),
initial_states
=
initial_states
,
)
return
out
,
out_x
,
dt
,
dA_cumsum
,
states
,
final_states
,
varlen_states
def
mamba_chunk_scan_combined
(
x
,
dt
,
A
,
B
,
C
,
chunk_size
,
D
=
None
,
z
=
None
,
dt_bias
=
None
,
initial_states
=
None
,
seq_idx
=
None
,
cu_seqlens
=
None
,
dt_softplus
=
False
,
dt_limit
=
(
0.0
,
float
(
"inf"
)),
return_final_states
=
False
,
return_varlen_states
=
False
):
"""
Argument:
x: (batch, seqlen, nheads, headdim)
dt: (batch, seqlen, nheads)
A: (nheads)
B: (batch, seqlen, ngroups, dstate)
C: (batch, seqlen, ngroups, dstate)
chunk_size: int
D: (nheads, headdim) or (nheads,)
z: (batch, seqlen, nheads, headdim)
dt_bias: (nheads,)
initial_states: (batch, nheads, headdim, dstate)
seq_idx: (batch, seqlen)
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
dt_softplus: Whether to apply softplus to dt
Return:
out: (batch, seqlen, nheads, headdim)
"""
if
not
return_varlen_states
:
cu_seqlens
=
None
else
:
assert
cu_seqlens
is
not
None
,
"cu_seqlens must be provided if return_varlen_states is True"
out
,
out_x
,
dt_out
,
dA_cumsum
,
states
,
final_states
,
*
rest
=
_mamba_chunk_scan_combined_fwd
(
x
,
dt
,
A
,
B
,
C
,
chunk_size
,
D
=
D
,
z
=
z
,
dt_bias
=
dt_bias
,
initial_states
=
initial_states
,
seq_idx
=
seq_idx
,
cu_seqlens
=
cu_seqlens
,
dt_softplus
=
dt_softplus
,
dt_limit
=
dt_limit
)
if
not
return_varlen_states
:
return
out
if
not
return_final_states
else
(
out
,
final_states
)
else
:
varlen_states
=
rest
[
0
]
return
(
out
,
varlen_states
)
if
not
return_final_states
else
(
out
,
final_states
,
varlen_states
)
vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
0 → 100644
View file @
ec5e299c
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py
# ruff: noqa: E501
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BLOCK_SIZE'
:
64
}),
triton
.
Config
({
'BLOCK_SIZE'
:
128
}),
triton
.
Config
({
'BLOCK_SIZE'
:
256
}),
triton
.
Config
({
'BLOCK_SIZE'
:
512
}),
triton
.
Config
({
'BLOCK_SIZE'
:
1024
}),
triton
.
Config
({
'BLOCK_SIZE'
:
2048
}),
],
key
=
[
'dim'
],
)
@
triton
.
jit
def
_state_passing_fwd_kernel
(
# Pointers to matrices
states_ptr
,
out_ptr
,
final_states_ptr
,
dA_cs_ptr
,
initstates_ptr
,
seq_idx_ptr
,
# Matrix dimensions
dim
,
nchunks
,
seqlen
,
chunk_size
,
# Strides
stride_states_batch
,
stride_states_chunk
,
stride_states_head
,
stride_states_dim
,
stride_out_batch
,
stride_out_chunk
,
stride_out_head
,
stride_out_dim
,
stride_final_states_batch
,
stride_final_states_head
,
stride_final_states_dim
,
stride_dA_cs_batch
,
stride_dA_cs_chunk
,
stride_dA_cs_head
,
stride_initstates_batch
,
stride_initstates_head
,
stride_initstates_dim
,
stride_seq_idx_batch
,
stride_seq_idx_seqlen
,
# Meta-parameters
HAS_INITSTATES
:
tl
.
constexpr
,
HAS_SEQ_IDX
:
tl
.
constexpr
,
IS_CONT_BATCHED
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid_b
=
tl
.
program_id
(
axis
=
1
)
pid_h
=
tl
.
program_id
(
axis
=
2
)
pid_m
=
tl
.
program_id
(
axis
=
0
)
states_ptr
+=
pid_b
*
stride_states_batch
+
pid_h
*
stride_states_head
dA_cs_ptr
+=
pid_b
*
stride_dA_cs_batch
+
pid_h
*
stride_dA_cs_head
out_ptr
+=
pid_b
*
stride_out_batch
+
pid_h
*
stride_out_head
final_states_ptr
+=
pid_b
*
stride_final_states_batch
+
pid_h
*
stride_final_states_head
if
HAS_INITSTATES
:
initstates_ptr
+=
pid_h
*
stride_initstates_head
if
not
IS_CONT_BATCHED
:
initstates_ptr
+=
pid_b
*
stride_initstates_batch
if
HAS_SEQ_IDX
:
seq_idx_ptr
+=
pid_b
*
stride_seq_idx_batch
offs_m
=
pid_m
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
states_ptrs
=
states_ptr
+
offs_m
*
stride_states_dim
out_ptrs
=
out_ptr
+
offs_m
*
stride_out_dim
final_states_ptrs
=
final_states_ptr
+
offs_m
*
stride_final_states_dim
# - states will be the past state of the sequence that continues on the current check
if
not
HAS_INITSTATES
:
states
=
tl
.
zeros
((
BLOCK_SIZE
,
),
dtype
=
tl
.
float32
)
else
:
initstates_ptr
+=
offs_m
*
stride_initstates_dim
initstates_ptrs
=
initstates_ptr
# - for cont batches, for the first chunk mean it will be the first batch's
# init state
states
=
tl
.
load
(
initstates_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
tl
.
store
(
out_ptrs
,
states
,
mask
=
offs_m
<
dim
)
out_ptrs
+=
stride_out_chunk
seq_idx
=
0
for
c
in
range
(
nchunks
):
new_states
=
tl
.
load
(
states_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
dA_cs
=
tl
.
load
(
dA_cs_ptr
).
to
(
tl
.
float32
)
scale
=
tl
.
exp
(
dA_cs
)
if
HAS_SEQ_IDX
:
# - the seq to pass forward is the one that is flushed to the right
# boundary.
# - that is given by seq_idx_new below.
seq_idx_new
=
tl
.
load
(
seq_idx_ptr
+
(
min
((
c
+
1
)
*
chunk_size
,
seqlen
)
-
1
)
*
stride_seq_idx_seqlen
)
if
HAS_INITSTATES
:
if
IS_CONT_BATCHED
and
seq_idx
!=
seq_idx_new
:
# this means in the current chunk the rightmost flushed seq
# has changed.
# - so we do not propagate the state from previous chunk
# - but rather we load that sequence's init state
initstates_ptrs
=
initstates_ptr
+
seq_idx_new
*
stride_initstates_batch
# - update state with seq_idx_new's init state
states
=
tl
.
load
(
initstates_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
else
:
scale
=
tl
.
where
(
seq_idx_new
==
seq_idx
,
scale
,
0.0
)
seq_idx
=
seq_idx_new
states
=
scale
*
states
+
new_states
if
c
<
nchunks
-
1
:
tl
.
store
(
out_ptrs
,
states
,
mask
=
offs_m
<
dim
)
else
:
tl
.
store
(
final_states_ptrs
,
states
,
mask
=
offs_m
<
dim
)
states_ptrs
+=
stride_states_chunk
dA_cs_ptr
+=
stride_dA_cs_chunk
out_ptrs
+=
stride_out_chunk
def
_state_passing_fwd
(
states
,
dA_chunk_cumsum
,
initial_states
=
None
,
seq_idx
=
None
,
chunk_size
=
None
,
out_dtype
=
None
,
is_cont_batched
=
False
,
):
batch
,
nchunks
,
nheads
,
dim
=
states
.
shape
assert
dA_chunk_cumsum
.
shape
==
(
batch
,
nheads
,
nchunks
)
if
initial_states
is
not
None
:
if
is_cont_batched
:
# - if cu_seqlens is provided, then the initial states
# are used for continuous batching. In which case we
# require seq_idx to be provided
assert
seq_idx
is
not
None
,
""
assert
initial_states
.
shape
==
(
seq_idx
.
max
().
item
()
+
1
,
nheads
,
dim
)
else
:
# - this is the regular batching case, where initial
# states are used are for each example of the batch.
assert
initial_states
.
shape
==
(
batch
,
nheads
,
dim
)
if
seq_idx
is
not
None
:
assert
chunk_size
is
not
None
seqlen
=
seq_idx
.
shape
[
-
1
]
assert
seq_idx
.
shape
==
(
batch
,
seqlen
)
out_dtype
=
states
.
dtype
if
out_dtype
is
None
else
out_dtype
out
=
torch
.
empty
((
batch
,
nchunks
,
nheads
,
dim
),
device
=
states
.
device
,
dtype
=
out_dtype
)
final_states
=
torch
.
empty
((
batch
,
nheads
,
dim
),
device
=
states
.
device
,
dtype
=
torch
.
float32
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
dim
,
META
[
'BLOCK_SIZE'
]),
batch
,
nheads
)
with
torch
.
cuda
.
device
(
states
.
device
.
index
):
_state_passing_fwd_kernel
[
grid
](
states
,
out
,
final_states
,
dA_chunk_cumsum
,
initial_states
,
seq_idx
,
dim
,
nchunks
,
seqlen
if
seq_idx
is
not
None
else
0
,
chunk_size
if
seq_idx
is
not
None
else
0
,
states
.
stride
(
0
),
states
.
stride
(
1
),
states
.
stride
(
2
),
states
.
stride
(
3
),
out
.
stride
(
0
),
out
.
stride
(
1
),
out
.
stride
(
2
),
out
.
stride
(
3
),
final_states
.
stride
(
0
),
final_states
.
stride
(
1
),
final_states
.
stride
(
2
),
dA_chunk_cumsum
.
stride
(
0
),
dA_chunk_cumsum
.
stride
(
2
),
dA_chunk_cumsum
.
stride
(
1
),
*
((
initial_states
.
stride
(
0
),
initial_states
.
stride
(
1
),
initial_states
.
stride
(
2
))
if
initial_states
is
not
None
else
(
0
,
0
,
0
)),
*
((
seq_idx
.
stride
(
0
),
seq_idx
.
stride
(
1
))
if
seq_idx
is
not
None
else
(
0
,
0
)),
HAS_INITSTATES
=
initial_states
is
not
None
,
HAS_SEQ_IDX
=
seq_idx
is
not
None
,
IS_CONT_BATCHED
=
is_cont_batched
,
)
return
out
,
final_states
vllm/model_executor/layers/quantization/__init__.py
View file @
ec5e299c
...
...
@@ -11,6 +11,7 @@ QUANTIZATION_METHODS: List[str] = [
"deepspeedfp"
,
"tpu_int8"
,
"fp8"
,
"ptpc_fp8"
,
"fbgemm_fp8"
,
"modelopt"
,
# The order of gptq methods is important for config.py iteration over
...
...
@@ -99,6 +100,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
from
.modelopt
import
ModelOptFp8Config
from
.moe_wna16
import
MoeWNA16Config
from
.neuron_quant
import
NeuronQuantConfig
from
.ptpc_fp8
import
PTPCFp8Config
from
.qqq
import
QQQConfig
from
.tpu_int8
import
Int8TpuConfig
...
...
@@ -120,6 +122,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"gptq"
:
GPTQConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
"ptpc_fp8"
:
PTPCFp8Config
,
"qqq"
:
QQQConfig
,
"hqq"
:
HQQMarlinConfig
,
"experts_int8"
:
ExpertsInt8Config
,
...
...
Prev
1
…
16
17
18
19
20
21
22
23
24
…
27
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment