Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ca796e19
Commit
ca796e19
authored
Mar 21, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.1' into v0.8.1-ori
parents
e983c804
61c7a1b8
Changes
130
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1649 additions
and
577 deletions
+1649
-577
vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json
...Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json
+164
-0
vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json
...tinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json
+164
-0
vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json
...Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json
+164
-0
vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json
...tinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json
+164
-0
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+25
-6
vllm/model_executor/models/gemma3_mm.py
vllm/model_executor/models/gemma3_mm.py
+211
-36
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+18
-51
vllm/model_executor/models/mllama.py
vllm/model_executor/models/mllama.py
+4
-7
vllm/model_executor/models/molmo.py
vllm/model_executor/models/molmo.py
+4
-11
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+16
-57
vllm/model_executor/models/vision.py
vllm/model_executor/models/vision.py
+49
-1
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+111
-36
vllm/multimodal/profiling.py
vllm/multimodal/profiling.py
+2
-2
vllm/transformers_utils/tokenizers/mistral.py
vllm/transformers_utils/tokenizers/mistral.py
+0
-4
vllm/triton_utils/__init__.py
vllm/triton_utils/__init__.py
+1
-8
vllm/triton_utils/custom_cache_manager.py
vllm/triton_utils/custom_cache_manager.py
+0
-55
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+15
-10
vllm/v1/outputs.py
vllm/v1/outputs.py
+1
-1
vllm/v1/sample/ops/utils.py
vllm/v1/sample/ops/utils.py
+30
-0
vllm/v1/sample/rejection_sampler.py
vllm/v1/sample/rejection_sampler.py
+506
-292
No files found.
vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json
0 → 100644
View file @
ca796e19
{
"1"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
16
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"2"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
16
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"4"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
16
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"8"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
32
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"16"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"24"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
16
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"32"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
32
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"48"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"64"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"96"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"128"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"256"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"512"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"1024"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"GROUP_SIZE_M"
:
32
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"1536"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"2048"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"GROUP_SIZE_M"
:
32
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"3072"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"GROUP_SIZE_M"
:
32
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"4096"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"GROUP_SIZE_M"
:
32
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
}
}
\ No newline at end of file
vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json
0 → 100644
View file @
ca796e19
{
"1"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
16
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"2"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
16
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"4"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
16
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"8"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
32
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"16"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"24"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
16
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"32"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
32
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"48"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"64"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"96"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"128"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"256"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"512"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"1024"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"GROUP_SIZE_M"
:
32
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"1536"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"2048"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"GROUP_SIZE_M"
:
32
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"3072"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"GROUP_SIZE_M"
:
32
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"4096"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"GROUP_SIZE_M"
:
32
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
}
}
\ No newline at end of file
vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json
0 → 100644
View file @
ca796e19
{
"1"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"2"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"4"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
16
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"8"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
16
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"16"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
16
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"24"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"32"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"48"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"64"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
16
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"96"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"128"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"256"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"512"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"1024"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"GROUP_SIZE_M"
:
32
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"1536"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"GROUP_SIZE_M"
:
32
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"2048"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"GROUP_SIZE_M"
:
32
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"3072"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"GROUP_SIZE_M"
:
32
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"4096"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"GROUP_SIZE_M"
:
32
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
}
}
\ No newline at end of file
vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json
0 → 100644
View file @
ca796e19
{
"1"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"2"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"4"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
16
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"8"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
16
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"16"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
16
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"24"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"32"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"48"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
8
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"64"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
16
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"96"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"128"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"256"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"512"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
32
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"1024"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"GROUP_SIZE_M"
:
32
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"1536"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"GROUP_SIZE_M"
:
32
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"2048"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"GROUP_SIZE_M"
:
32
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"3072"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"GROUP_SIZE_M"
:
32
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
},
"4096"
:
{
"BLOCK_SIZE_K"
:
128
,
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"GROUP_SIZE_M"
:
32
,
"kpack"
:
1
,
"matrix_instr_nonkdim"
:
16
,
"num_warps"
:
4
}
}
\ No newline at end of file
vllm/model_executor/model_loader/loader.py
View file @
ca796e19
...
@@ -762,7 +762,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -762,7 +762,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
model_name_or_path
:
str
,
model_name_or_path
:
str
,
allowed_patterns
:
List
[
str
],
allowed_patterns
:
List
[
str
],
revision
:
Optional
[
str
]
=
None
,
revision
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
List
[
str
],
str
]:
)
->
Tuple
[
str
,
List
[
str
],
str
]:
"""Retrieve weight files. Download the files if necessary.
"""Retrieve weight files. Download the files if necessary.
Return the weight files and the file pattern."""
Return the weight files and the file pattern."""
...
@@ -773,7 +773,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -773,7 +773,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
weight_files
=
glob
.
glob
(
weight_files
=
glob
.
glob
(
os
.
path
.
join
(
model_name_or_path
,
pattern
))
os
.
path
.
join
(
model_name_or_path
,
pattern
))
if
weight_files
:
if
weight_files
:
return
weight_files
,
pattern
return
model_name_or_path
,
weight_files
,
pattern
else
:
else
:
hf_api
=
HfApi
()
hf_api
=
HfApi
()
repo_files
=
hf_api
.
list_repo_files
(
repo_id
=
model_name_or_path
)
repo_files
=
hf_api
.
list_repo_files
(
repo_id
=
model_name_or_path
)
...
@@ -787,7 +787,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -787,7 +787,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
revision
,
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
)
return
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
)),
pattern
return
hf_folder
,
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
)),
pattern
raise
RuntimeError
(
raise
RuntimeError
(
f
"No model weights found in: `
{
model_name_or_path
}
`"
)
f
"No model weights found in: `
{
model_name_or_path
}
`"
)
...
@@ -798,10 +799,28 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -798,10 +799,28 @@ class BitsAndBytesModelLoader(BaseModelLoader):
allowed_patterns
=
[
"*.safetensors"
,
"*.bin"
,
"*.pt"
]
allowed_patterns
=
[
"*.safetensors"
,
"*.bin"
,
"*.pt"
]
hf_weights_files
,
matched_pattern
=
self
.
_get_weight_files
(
hf_folder
,
hf_weights_files
,
matched_pattern
=
self
.
_get_weight_files
(
model_name_or_path
,
allowed_patterns
,
revision
)
model_name_or_path
,
allowed_patterns
,
revision
)
if
matched_pattern
!=
"*.safetensors"
:
use_safetensors
=
matched_pattern
==
"*.safetensors"
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
index_file
=
SAFE_WEIGHTS_INDEX_NAME
if
use_safetensors
:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if
not
is_local
:
download_safetensors_index_file_from_hf
(
model_name_or_path
,
index_file
,
self
.
load_config
.
download_dir
,
revision
,
)
hf_weights_files
=
filter_duplicate_safetensors_files
(
hf_weights_files
,
hf_folder
,
index_file
)
else
:
hf_weights_files
=
filter_files_not_needed_for_inference
(
hf_weights_files
=
filter_files_not_needed_for_inference
(
hf_weights_files
)
hf_weights_files
)
...
@@ -809,7 +828,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...
@@ -809,7 +828,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
raise
RuntimeError
(
raise
RuntimeError
(
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
return
hf_weights_files
,
matched_pattern
==
"*.
safetensors
"
return
hf_weights_files
,
use_
safetensors
def
_hf_weight_iter
(
self
,
hf_weights_files
,
use_safetensors
:
bool
):
def
_hf_weight_iter
(
self
,
hf_weights_files
,
use_safetensors
:
bool
):
if
use_safetensors
:
if
use_safetensors
:
...
...
vllm/model_executor/models/gemma3_mm.py
View file @
ca796e19
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
math
import
math
from
typing
import
(
Any
,
Iterable
,
Literal
,
Mapping
,
Optional
,
Sequence
,
Set
,
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
Tuple
,
TypedDict
,
Union
)
from
typing
import
Any
,
Literal
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
BatchFeature
,
Gemma3Config
,
Gemma3Processor
from
transformers
import
BatchFeature
,
Gemma3Config
,
Gemma3Processor
from
transformers.models.gemma3.processing_gemma3
import
Gemma3ProcessorKwargs
from
transformers.models.gemma3.processing_gemma3
import
Gemma3ProcessorKwargs
import
vllm.envs
as
envs
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.layernorm
import
GemmaRMSNorm
from
vllm.model_executor.layers.layernorm
import
GemmaRMSNorm
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
MultiModalKwargs
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
MultiModalDataItems
)
# yapf: disable
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
BaseProcessingInfo
,
BoundPromptUpdate
,
PromptUpdate
,
encode_tokens
)
PlaceholderFeaturesInfo
,
PromptReplacement
,
PromptTargetMatch
,
PromptUpdate
,
PromptUpdateDetails
,
encode_tokens
,
find_mm_placeholders
,
replace_token_matches
)
# yapf: enable
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
flatten_2d_lists
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
SupportsMultiModal
,
SupportsPP
,
SupportsV0Only
)
SupportsMultiModal
,
SupportsPP
)
from
.siglip
import
SiglipVisionModel
from
.siglip
import
SiglipVisionModel
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
scatter_patch_features
,
select_patch_features
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -37,13 +46,25 @@ class Gemma3ImagePixelInputs(TypedDict):
...
@@ -37,13 +46,25 @@ class Gemma3ImagePixelInputs(TypedDict):
type
:
Literal
[
"pixel_values"
]
type
:
Literal
[
"pixel_values"
]
pixel_values
:
torch
.
Tensor
pixel_values
:
torch
.
Tensor
"""
"""
Shape: `(num_
crop
s_total, num_channels, height, width)`
Shape: `(num_
patche
s_total, num_channels, height, width)`
`num_
crop
s_total` is the total number of
crop
s
`num_
patche
s_total` is the total number of
patche
s
over each image over each prompt in the batch.
over each image over each prompt in the batch.
"""
"""
num_crops
:
torch
.
Tensor
"""Shape: `(batch_size * num_images,)`"""
num_patches
:
torch
.
Tensor
"""Shape: `(batch_size * num_images)`"""
embed_is_patch
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_images, num_embeds)`
"""
num_embeds
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""Shape: `(batch_size, num_images)`"""
Gemma3ImageInputs
=
Gemma3ImagePixelInputs
Gemma3ImageInputs
=
Gemma3ImagePixelInputs
...
@@ -51,6 +72,9 @@ Gemma3ImageInputs = Gemma3ImagePixelInputs
...
@@ -51,6 +72,9 @@ Gemma3ImageInputs = Gemma3ImagePixelInputs
class
Gemma3ProcessingInfo
(
BaseProcessingInfo
):
class
Gemma3ProcessingInfo
(
BaseProcessingInfo
):
def
get_hf_config
(
self
):
return
self
.
ctx
.
get_hf_config
(
Gemma3Config
)
def
get_hf_processor
(
self
,
**
kwargs
:
object
):
def
get_hf_processor
(
self
,
**
kwargs
:
object
):
return
self
.
ctx
.
get_hf_processor
(
Gemma3Processor
,
**
kwargs
)
return
self
.
ctx
.
get_hf_processor
(
Gemma3Processor
,
**
kwargs
)
...
@@ -114,6 +138,11 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
...
@@ -114,6 +138,11 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
if
not
do_pan_and_scan
:
if
not
do_pan_and_scan
:
return
0
return
0
if
envs
.
VLLM_USE_V1
:
logger
.
warning_once
(
"`do_pan_and_scan=True` has suboptimal results on V1 "
"because of the simplified attention pattern being used."
)
# Based on Gemma3ImageProcessor.pan_and_scan
# Based on Gemma3ImageProcessor.pan_and_scan
if
image_width
>=
image_height
:
if
image_width
>=
image_height
:
if
image_width
/
image_height
<
pan_and_scan_min_ratio_to_activate
:
if
image_width
/
image_height
<
pan_and_scan_min_ratio_to_activate
:
...
@@ -154,7 +183,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
...
@@ -154,7 +183,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
image_width
:
int
,
image_width
:
int
,
image_height
:
int
,
image_height
:
int
,
processor
:
Optional
[
Gemma3Processor
],
processor
:
Optional
[
Gemma3Processor
],
)
->
str
:
)
->
PromptUpdateDetails
:
if
processor
is
None
:
if
processor
is
None
:
processor
=
self
.
get_hf_processor
()
processor
=
self
.
get_hf_processor
()
...
@@ -175,7 +204,11 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
...
@@ -175,7 +204,11 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
f
"Here is the original image
{
image_token
}
and here are some "
f
"Here is the original image
{
image_token
}
and here are some "
f
"crops to help you see better
{
crops_image_tokens
}
"
)
f
"crops to help you see better
{
crops_image_tokens
}
"
)
return
image_text
.
replace
(
image_token
,
processor
.
full_image_sequence
)
repl_full
=
image_text
.
replace
(
image_token
,
processor
.
full_image_sequence
)
repl_features
=
repl_full
.
strip
(
"
\n
"
)
return
PromptUpdateDetails
(
full
=
repl_full
,
features
=
repl_features
)
def
get_num_image_tokens
(
def
get_num_image_tokens
(
self
,
self
,
...
@@ -193,7 +226,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
...
@@ -193,7 +226,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
image_repl_tokens
=
encode_tokens
(
image_repl_tokens
=
encode_tokens
(
tokenizer
,
tokenizer
,
image_repl
,
image_repl
.
features
,
add_special_tokens
=
False
,
add_special_tokens
=
False
,
)
)
return
len
(
image_repl_tokens
)
return
len
(
image_repl_tokens
)
...
@@ -240,12 +273,8 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
...
@@ -240,12 +273,8 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
num_images
=
num_images
)
num_images
=
num_images
)
}
}
# NOTE: We need to separate the image tokens here because
# encode("\n\n\n\n") != encode("\n\n") * 2, which interferes
# with the detection of prompt updates when the image tokens are
# right next to each other
return
ProcessorInputs
(
return
ProcessorInputs
(
prompt_text
=
" "
.
join
([
image_token
]
*
num_images
)
,
prompt_text
=
image_token
*
num_images
,
mm_data
=
mm_data
,
mm_data
=
mm_data
,
)
)
...
@@ -278,13 +307,39 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
...
@@ -278,13 +307,39 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
]
]
hf_processor
=
self
.
info
.
get_hf_processor
(
**
mm_kwargs
)
hf_processor
=
self
.
info
.
get_hf_processor
(
**
mm_kwargs
)
image_repl_features
=
[
self
.
info
.
get_image_repl
(
image_width
=
size
.
width
,
image_height
=
size
.
height
,
processor
=
hf_processor
).
features
for
size
in
image_sizes
]
tokenizer
=
self
.
info
.
get_tokenizer
()
image_repls_feature_tokens
=
[
tokenizer
.
encode
(
image_repl
,
add_special_tokens
=
False
)
for
image_repl
in
image_repl_features
]
num_embeds
=
[
len
(
image_repl_feature_tokens
)
for
image_repl_feature_tokens
in
image_repls_feature_tokens
]
processed_outputs
[
"num_embeds"
]
=
torch
.
tensor
(
num_embeds
)
vocab
=
tokenizer
.
get_vocab
()
image_token_id
=
vocab
[
tokenizer
.
image_token
]
embed_is_patch
=
[
torch
.
tensor
(
image_repl_tokens
)
==
image_token_id
for
image_repl_tokens
in
image_repls_feature_tokens
]
processed_outputs
[
"embed_is_patch"
]
=
embed_is_patch
num_crops
=
[
num_crops
=
[
self
.
info
.
get_num_crops
(
image_width
=
size
.
width
,
self
.
info
.
get_num_crops
(
image_width
=
size
.
width
,
image_height
=
size
.
height
,
image_height
=
size
.
height
,
processor
=
hf_processor
)
processor
=
hf_processor
)
for
size
in
image_sizes
for
size
in
image_sizes
]
]
processed_outputs
[
"num_crops"
]
=
torch
.
tensor
(
num_crops
)
processed_outputs
[
"num_crops"
]
=
torch
.
tensor
(
num_crops
)
return
processed_outputs
return
processed_outputs
...
@@ -300,6 +355,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
...
@@ -300,6 +355,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
pixel_values
=
MultiModalFieldConfig
.
flat_from_sizes
(
"image"
,
num_crops
+
1
),
"image"
,
num_crops
+
1
),
num_crops
=
MultiModalFieldConfig
.
batched
(
"image"
),
num_crops
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
num_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
)
def
_get_prompt_updates
(
def
_get_prompt_updates
(
...
@@ -329,6 +386,91 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
...
@@ -329,6 +386,91 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
)
)
]
]
def
_apply_token_matches
(
self
,
prompt
:
list
[
int
],
mm_matches
:
Mapping
[
str
,
Sequence
[
PromptTargetMatch
]],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
list
[
int
]:
token_ids
=
super
().
_apply_token_matches
(
prompt
,
mm_matches
,
mm_item_counts
,
)
# "\n\n\n" and "\n\n\n\n" are single tokens
# Since our replacement can insert "\n\n" next to "\n"
# tokens, we have to combine them to be consistent with
# the output of the tokenizer
tokenizer
=
self
.
info
.
get_tokenizer
()
vocab
=
tokenizer
.
get_vocab
()
newline_1
=
vocab
[
"
\n
"
]
newline_2
=
vocab
[
"
\n\n
"
]
newline_3
=
vocab
[
"
\n\n\n
"
]
newline_4
=
vocab
[
"
\n\n\n\n
"
]
token_ids
=
replace_token_matches
(
token_ids
,
[
newline_1
,
newline_2
],
[
newline_3
],
)
token_ids
=
replace_token_matches
(
token_ids
,
[
newline_2
,
newline_1
],
[
newline_3
],
)
token_ids
=
replace_token_matches
(
token_ids
,
[
newline_2
,
newline_2
],
[
newline_4
],
)
return
token_ids
def
_find_mm_placeholders
(
self
,
mm_prompt_updates
:
Mapping
[
str
,
Sequence
[
BoundPromptUpdate
]],
new_token_ids
:
list
[
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
Mapping
[
str
,
list
[
PlaceholderFeaturesInfo
]]:
# We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
tokenizer
=
self
.
info
.
get_tokenizer
()
vocab
=
tokenizer
.
get_vocab
()
newline_1
=
vocab
[
"
\n
"
]
newline_2
=
vocab
[
"
\n\n
"
]
newline_3
=
vocab
[
"
\n\n\n
"
]
newline_4
=
vocab
[
"
\n\n\n\n
"
]
def
get_repl_toks
(
tok
:
int
)
->
list
[
int
]:
if
tok
==
newline_3
:
return
[
newline_1
,
newline_2
]
if
tok
==
newline_4
:
return
[
newline_2
,
newline_2
]
return
[
tok
]
repl_token_ids
=
list
[
int
]()
repl_orig_idxs
=
list
[
int
]()
for
orig_idx
,
orig_tok
in
enumerate
(
new_token_ids
):
repl_toks
=
get_repl_toks
(
orig_tok
)
repl_token_ids
.
extend
(
repl_toks
)
repl_orig_idxs
.
extend
(
orig_idx
for
_
in
range
(
len
(
repl_toks
)))
repls
=
find_mm_placeholders
(
mm_prompt_updates
,
repl_token_ids
,
mm_item_counts
)
return
{
modality
:
[
PlaceholderFeaturesInfo
(
modality
=
p
.
modality
,
item_idx
=
p
.
item_idx
,
start_idx
=
repl_orig_idxs
[
p
.
start_idx
],
tokens
=
p
.
tokens
,
)
for
p
in
placeholders
]
for
modality
,
placeholders
in
repls
.
items
()
}
class
Gemma3MultiModalProjector
(
nn
.
Module
):
class
Gemma3MultiModalProjector
(
nn
.
Module
):
...
@@ -374,7 +516,7 @@ class Gemma3MultiModalProjector(nn.Module):
...
@@ -374,7 +516,7 @@ class Gemma3MultiModalProjector(nn.Module):
info
=
Gemma3ProcessingInfo
,
info
=
Gemma3ProcessingInfo
,
dummy_inputs
=
Gemma3DummyInputsBuilder
)
dummy_inputs
=
Gemma3DummyInputsBuilder
)
class
Gemma3ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
,
class
Gemma3ForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
,
SupportsLoRA
,
SupportsV0Only
):
SupportsLoRA
):
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"qkv_proj"
:
[
"q_proj"
,
"q_proj"
,
...
@@ -415,6 +557,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -415,6 +557,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self
.
make_empty_intermediate_tensors
=
(
self
.
make_empty_intermediate_tensors
=
(
self
.
language_model
.
make_empty_intermediate_tensors
)
self
.
language_model
.
make_empty_intermediate_tensors
)
@
property
def
dtype
(
self
):
return
next
(
self
.
parameters
()).
dtype
@
property
@
property
def
sampler
(
self
):
def
sampler
(
self
):
return
self
.
language_model
.
sampler
return
self
.
language_model
.
sampler
...
@@ -438,6 +584,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -438,6 +584,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self
,
**
kwargs
:
object
)
->
Optional
[
Gemma3ImageInputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
Gemma3ImageInputs
]:
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
pixel_values
=
kwargs
.
pop
(
"pixel_values"
,
None
)
num_crops
=
kwargs
.
pop
(
"num_crops"
,
None
)
num_crops
=
kwargs
.
pop
(
"num_crops"
,
None
)
embed_is_patch
=
kwargs
.
pop
(
"embed_is_patch"
,
None
)
num_embeds
=
kwargs
.
pop
(
"num_embeds"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
image_embeds
=
kwargs
.
pop
(
"image_embeds"
,
None
)
assert
image_embeds
is
None
,
"Gemma3 does not support image_embeds."
assert
image_embeds
is
None
,
"Gemma3 does not support image_embeds."
if
pixel_values
is
None
:
if
pixel_values
is
None
:
...
@@ -448,16 +596,26 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -448,16 +596,26 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
f
"Got type:
{
type
(
pixel_values
)
}
"
)
f
"Got type:
{
type
(
pixel_values
)
}
"
)
if
not
isinstance
(
num_crops
,
(
torch
.
Tensor
,
list
)):
if
not
isinstance
(
num_crops
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of num_crops
values
. "
raise
ValueError
(
"Incorrect type of num_crops. "
f
"Got type:
{
type
(
num_crops
)
}
"
)
f
"Got type:
{
type
(
num_crops
)
}
"
)
if
not
isinstance
(
embed_is_patch
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
if
not
isinstance
(
num_embeds
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of num_embeds. "
f
"Got type:
{
type
(
num_embeds
)
}
"
)
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
pixel_values
=
flatten_bn
(
pixel_values
,
concat
=
True
)
num_crops
=
flatten_bn
(
num_crops
,
concat
=
True
)
num_crops
=
flatten_bn
(
num_crops
,
concat
=
True
)
return
Gemma3ImagePixelInputs
(
return
Gemma3ImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
pixel_values
=
self
.
_validate_pixel_values
(
pixel_values
),
pixel_values
=
self
.
_validate_pixel_values
(
pixel_values
),
num_crops
=
num_crops
,
num_patches
=
num_crops
+
1
,
embed_is_patch
=
embed_is_patch
,
num_embeds
=
num_embeds
,
)
)
def
_image_pixels_to_features
(
def
_image_pixels_to_features
(
...
@@ -472,36 +630,51 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -472,36 +630,51 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def
_process_image_input
(
def
_process_image_input
(
self
,
self
,
image_input
:
Gemma3ImageInputs
,
image_input
:
Gemma3ImageInputs
,
)
->
torch
.
Tensor
:
)
->
tuple
[
torch
.
Tensor
,
...]
:
assert
self
.
vision_tower
is
not
None
assert
self
.
vision_tower
is
not
None
pixel_values
=
image_input
[
"pixel_values"
]
pixel_values
=
image_input
[
"pixel_values"
]
vision_outputs
=
self
.
_image_pixels_to_features
(
num_patches
=
image_input
[
"num_patches"
]
image_features
=
self
.
_image_pixels_to_features
(
self
.
vision_tower
,
self
.
vision_tower
,
pixel_values
,
pixel_values
,
)
)
return
self
.
multi_modal_projector
(
vision_outputs
)
image_embeds
=
self
.
multi_modal_projector
(
image_features
)
return
image_embeds
.
split
(
num_patches
.
tolist
())
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
None
:
if
image_input
is
None
:
return
None
return
None
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
return
vision_embeddings
image_features
=
self
.
_process_image_input
(
image_input
)
if
kwargs
.
get
(
"v0_path"
,
False
):
return
image_features
return
flatten_2d_lists
(
scatter_patch_features
(
*
args
)
for
args
in
zip
(
image_features
,
image_input
[
"num_embeds"
],
image_input
[
"embed_is_patch"
],
))
def
get_input_embeddings
(
def
get_input_embeddings
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
multimodal_embeddings
:
Optional
[
MultiModalEmbeddings
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
multimodal_embeddings
is
None
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
else
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
multimodal_embeddings
,
input_ids
,
self
.
config
.
image_token_index
)
inputs_embeds
,
select_patch_features
(
multimodal_embeddings
),
self
.
config
.
image_token_index
,
)
return
inputs_embeds
return
inputs_embeds
def
forward
(
self
,
def
forward
(
self
,
...
@@ -516,6 +689,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -516,6 +689,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
# condition is for v0 compatibility.
elif
inputs_embeds
is
None
:
elif
inputs_embeds
is
None
:
kwargs
.
update
({
"v0_path"
:
True
})
vision_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
vision_embeddings
=
self
.
get_multimodal_embeddings
(
**
kwargs
)
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
,
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
,
...
@@ -524,8 +698,9 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -524,8 +698,9 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
kwargs
=
self
.
prepare_attn_masks
(
kwargs
=
self
.
prepare_attn_masks
(
input_ids
,
input_ids
,
positions
,
positions
,
mask_dtype
=
vision_embeddings
.
dtype
,
mask_dtype
=
self
.
dtype
,
**
kwargs
)
**
kwargs
,
)
input_ids
=
None
input_ids
=
None
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
...
...
vllm/model_executor/models/llava.py
View file @
ca796e19
...
@@ -18,7 +18,7 @@ from transformers.models.pixtral import PixtralProcessor
...
@@ -18,7 +18,7 @@ from transformers.models.pixtral import PixtralProcessor
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
InputProcessingContext
from
vllm.inputs
import
InputProcessingContext
from
vllm.jsontree
import
JSONTree
,
json_map_leaves
from
vllm.jsontree
import
json_map_leaves
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
...
@@ -27,8 +27,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
...
@@ -27,8 +27,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
from
vllm.multimodal.inputs
import
(
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalInputs
,
MultiModalKwargs
,
MultiModalInputs
,
MultiModalKwargs
)
NestedTensors
)
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
ImageSize
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
...
@@ -44,7 +43,8 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
...
@@ -44,7 +43,8 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from
.siglip
import
SiglipVisionModel
from
.siglip
import
SiglipVisionModel
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
get_vision_encoder_info
from
.vision
import
(
get_vision_encoder_info
,
scatter_patch_features
,
select_patch_features
)
class
LlavaImagePixelInputs
(
TypedDict
):
class
LlavaImagePixelInputs
(
TypedDict
):
...
@@ -76,7 +76,7 @@ class PixtralHFImagePixelInputs(TypedDict):
...
@@ -76,7 +76,7 @@ class PixtralHFImagePixelInputs(TypedDict):
Shape: `(batch_size, num_images, num_embeds)`
Shape: `(batch_size, num_images, num_embeds)`
"""
"""
num_
patche
s
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
num_
embed
s
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""Shape: `(batch_size, num_images)`"""
"""Shape: `(batch_size, num_images)`"""
...
@@ -352,15 +352,15 @@ class PixtralHFMultiModalProcessor(
...
@@ -352,15 +352,15 @@ class PixtralHFMultiModalProcessor(
image_height
=
pixel_value
.
shape
[
-
2
],
image_height
=
pixel_value
.
shape
[
-
2
],
)
for
pixel_value
in
processed_outputs
[
"pixel_values"
]
)
for
pixel_value
in
processed_outputs
[
"pixel_values"
]
]
]
num_
patche
s
=
torch
.
tensor
([(
ncols
+
1
)
*
nrows
num_
embed
s
=
torch
.
tensor
([(
ncols
+
1
)
*
nrows
for
ncols
,
nrows
in
tile_sizes
])
for
ncols
,
nrows
in
tile_sizes
])
# Each image may result to masks of different sizes, so we need to
# Each image may result to masks of different sizes, so we need to
# later use `num_
patche
s` to get per-image masks.
# later use `num_
embed
s` to get per-image masks.
embed_is_patch
=
[
embed_is_patch
=
[
torch
.
tensor
(([
True
]
*
ncols
+
[
False
])
*
nrows
)
torch
.
tensor
(([
True
]
*
ncols
+
[
False
])
*
nrows
)
for
ncols
,
nrows
in
tile_sizes
for
ncols
,
nrows
in
tile_sizes
]
]
processed_outputs
[
"num_
patche
s"
]
=
num_
patche
s
processed_outputs
[
"num_
embed
s"
]
=
num_
embed
s
processed_outputs
[
"embed_is_patch"
]
=
embed_is_patch
processed_outputs
[
"embed_is_patch"
]
=
embed_is_patch
return
processed_outputs
return
processed_outputs
...
@@ -372,7 +372,7 @@ class PixtralHFMultiModalProcessor(
...
@@ -372,7 +372,7 @@ class PixtralHFMultiModalProcessor(
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
return
dict
(
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
pixel_values
=
MultiModalFieldConfig
.
batched
(
"image"
),
num_
patche
s
=
MultiModalFieldConfig
.
batched
(
"image"
),
num_
embed
s
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_embeds
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
)
...
@@ -621,16 +621,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -621,16 +621,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
raise
ValueError
(
"Incorrect type of embed_is_patch. "
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
num_
patche
s
=
kwargs
.
pop
(
"num_
patche
s"
)
num_
embed
s
=
kwargs
.
pop
(
"num_
embed
s"
)
if
not
isinstance
(
num_
patche
s
,
(
torch
.
Tensor
,
list
)):
if
not
isinstance
(
num_
embed
s
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of num_
patche
s. "
raise
ValueError
(
"Incorrect type of num_
embed
s. "
f
"Got type:
{
type
(
num_
patche
s
)
}
"
)
f
"Got type:
{
type
(
num_
embed
s
)
}
"
)
return
PixtralHFImagePixelInputs
(
return
PixtralHFImagePixelInputs
(
type
=
"pixel_values_pixtral"
,
type
=
"pixel_values_pixtral"
,
pixel_values
=
flatten_bn
(
pixel_values
),
pixel_values
=
flatten_bn
(
pixel_values
),
embed_is_patch
=
embed_is_patch
,
embed_is_patch
=
embed_is_patch
,
num_
patches
=
num_patche
s
,
num_
embeds
=
num_embed
s
,
)
)
return
LlavaImagePixelInputs
(
return
LlavaImagePixelInputs
(
...
@@ -716,33 +716,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -716,33 +716,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
image_embeds
=
torch
.
split
(
image_embeds
,
feature_sizes
)
image_embeds
=
torch
.
split
(
image_embeds
,
feature_sizes
)
return
image_embeds
return
image_embeds
def
_get_mm_embeds
(
self
,
features
:
torch
.
Tensor
,
# Shape: (num_patch, d)
num_patches
:
torch
.
Tensor
,
# Shape: (num_images,)
embed_is_patch
:
torch
.
Tensor
,
# Shape: (num_images, num_embeds)
)
->
tuple
[
torch
.
Tensor
,
...]:
"""Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
"""
# Insert columns of nan values according to `embed_is_patch`. This work
# ideally should be done in `_process_image_input`, but
# `_process_image_input` is used in both V0 and V1 path. It's safer to
# put the logic here.
# FIXME: Move this logic to `_process_image_input` when v0 is
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
num_patches_per_image
:
list
[
int
]
=
num_patches
.
tolist
()
embeds_flat
=
features
.
new_full
(
(
sum
(
num_patches_per_image
),
*
features
.
shape
[
1
:]),
fill_value
=
torch
.
nan
,
)
embeds_flat
[
embed_is_patch
.
view
(
-
1
)]
=
features
return
embeds_flat
.
split
(
num_patches_per_image
)
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
...
@@ -757,9 +730,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -757,9 +730,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return
vision_embeddings
return
vision_embeddings
return
flatten_2d_lists
(
return
flatten_2d_lists
(
s
elf
.
_get_mm_embed
s
(
*
args
)
for
args
in
zip
(
s
catter_patch_feature
s
(
*
args
)
for
args
in
zip
(
vision_embeddings
,
vision_embeddings
,
image_input
[
"num_
patche
s"
],
image_input
[
"num_
embed
s"
],
image_input
[
"embed_is_patch"
],
image_input
[
"embed_is_patch"
],
))
))
...
@@ -770,16 +743,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -770,16 +743,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
:
# Extract the patch tokens
patch_embeddings
=
json_map_leaves
(
lambda
x
:
x
[
~
x
.
isnan
()].
view
(
-
1
,
*
x
.
shape
[
1
:]),
cast
(
JSONTree
[
torch
.
Tensor
],
multimodal_embeddings
),
)
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
input_ids
,
inputs_embeds
,
inputs_embeds
,
cast
(
NestedTensors
,
patch
_embeddings
),
select_patch_features
(
multimodal
_embeddings
),
self
.
config
.
image_token_index
,
self
.
config
.
image_token_index
,
)
)
return
inputs_embeds
return
inputs_embeds
...
...
vllm/model_executor/models/mllama.py
View file @
ca796e19
...
@@ -1070,8 +1070,8 @@ class MllamaTextModel(nn.Module):
...
@@ -1070,8 +1070,8 @@ class MllamaTextModel(nn.Module):
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
for
decoder_layer
in
self
.
layers
:
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
)
:
if
i
sinstance
(
decoder_layer
,
MllamaC
ross
A
ttention
DecoderL
ayer
)
:
if
i
dx
in
self
.
c
ross
_a
ttention
_l
ayer
s
:
if
not
skip_cross_attention
:
if
not
skip_cross_attention
:
hidden_states
=
decoder_layer
(
hidden_states
=
decoder_layer
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -1081,16 +1081,13 @@ class MllamaTextModel(nn.Module):
...
@@ -1081,16 +1081,13 @@ class MllamaTextModel(nn.Module):
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
=
full_text_row_masked_out_mask
,
full_text_row_masked_out_mask
,
)
)
el
if
isinstance
(
decoder_layer
,
LlamaDecoderLayer
)
:
el
se
:
hidden_states
,
residual
=
decoder_layer
(
hidden_states
,
residual
=
decoder_layer
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
residual
=
None
,
residual
=
None
,
)
)
hidden_states
=
hidden_states
+
residual
hidden_states
=
hidden_states
+
residual
else
:
raise
ValueError
(
f
"Unknown decoder layer type
{
type
(
decoder_layer
)
}
"
)
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
return
hidden_states
...
@@ -1551,4 +1548,4 @@ def convert_dense_cross_attention_mask_to_tensor(
...
@@ -1551,4 +1548,4 @@ def convert_dense_cross_attention_mask_to_tensor(
full_text_mask
=
((
mask
!=
ninf
).
any
(
dim
=-
1
).
type_as
(
mask
)[...,
None
])
full_text_mask
=
((
mask
!=
ninf
).
any
(
dim
=-
1
).
type_as
(
mask
)[...,
None
])
mask
*=
full_text_mask
mask
*=
full_text_mask
# (num_prompt_tokens, num_encoder_tokens)
# (num_prompt_tokens, num_encoder_tokens)
return
mask
return
mask
\ No newline at end of file
vllm/model_executor/models/molmo.py
View file @
ca796e19
...
@@ -4,7 +4,7 @@ import math
...
@@ -4,7 +4,7 @@ import math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
cached_property
,
partial
from
functools
import
cached_property
,
partial
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
,
cast
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -24,7 +24,6 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
...
@@ -24,7 +24,6 @@ from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
split_tensor_along_last_dim
,
split_tensor_along_last_dim
,
tensor_model_parallel_all_gather
)
tensor_model_parallel_all_gather
)
from
vllm.jsontree
import
JSONTree
,
json_map_leaves
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.activation
import
(
MulAndSilu
,
QuickGELU
,
from
vllm.model_executor.layers.activation
import
(
MulAndSilu
,
QuickGELU
,
SiluAndMul
)
SiluAndMul
)
...
@@ -42,8 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -42,8 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalFieldConfig
,
MultiModalKwargs
,
from
vllm.multimodal.inputs
import
MultiModalFieldConfig
,
MultiModalKwargs
NestedTensors
)
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
...
@@ -59,6 +57,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
...
@@ -59,6 +57,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
is_pp_missing_parameter
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
,
merge_multimodal_embeddings
)
maybe_prefix
,
merge_multimodal_embeddings
)
from
.vision
import
select_patch_features
# TODO: hard-coded for now. Consider making it configurable.
# TODO: hard-coded for now. Consider making it configurable.
VIT_LAYERS
=
[
-
2
,
-
9
]
VIT_LAYERS
=
[
-
2
,
-
9
]
...
@@ -1602,16 +1601,10 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
...
@@ -1602,16 +1601,10 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
:
assert
self
.
img_patch_id
is
not
None
assert
self
.
img_patch_id
is
not
None
# Extract the patch tokens scattered in _get_mm_embeds
patch_embeddings
=
json_map_leaves
(
lambda
x
:
x
[
~
x
.
isnan
()].
view
(
-
1
,
*
x
.
shape
[
1
:]),
cast
(
JSONTree
[
torch
.
Tensor
],
multimodal_embeddings
),
)
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
input_ids
,
inputs_embeds
,
inputs_embeds
,
cast
(
NestedTensors
,
patch
_embeddings
),
select_patch_features
(
multimodal
_embeddings
),
self
.
img_patch_id
,
self
.
img_patch_id
,
)
)
return
inputs_embeds
return
inputs_embeds
...
...
vllm/model_executor/models/pixtral.py
View file @
ca796e19
...
@@ -4,7 +4,7 @@ import math
...
@@ -4,7 +4,7 @@ import math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
dataclasses
import
dataclass
,
fields
from
dataclasses
import
dataclass
,
fields
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
List
,
Literal
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
,
cast
from
typing
import
List
,
Literal
,
Optional
,
Set
,
Tuple
,
TypedDict
,
Union
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -22,7 +22,6 @@ from transformers.tokenization_utils_base import TextInput
...
@@ -22,7 +22,6 @@ from transformers.tokenization_utils_base import TextInput
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.jsontree
import
JSONTree
,
json_map_leaves
from
vllm.model_executor.layers.activation
import
get_act_and_mul_fn
from
vllm.model_executor.layers.activation
import
get_act_and_mul_fn
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
@@ -48,7 +47,8 @@ from vllm.utils import flatten_2d_lists
...
@@ -48,7 +47,8 @@ from vllm.utils import flatten_2d_lists
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.utils
import
(
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
from
.utils
import
(
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
,
merge_multimodal_embeddings
)
merge_multimodal_embeddings
)
from
.vision
import
VisionEncoderInfo
,
resolve_visual_encoder_outputs
from
.vision
import
(
VisionEncoderInfo
,
resolve_visual_encoder_outputs
,
scatter_patch_features
,
select_patch_features
)
try
:
try
:
from
xformers
import
ops
as
xops
from
xformers
import
ops
as
xops
...
@@ -77,7 +77,7 @@ class PixtralImagePixelInputs(TypedDict):
...
@@ -77,7 +77,7 @@ class PixtralImagePixelInputs(TypedDict):
Shape: `(batch_size, num_images, num_embeds)`
Shape: `(batch_size, num_images, num_embeds)`
"""
"""
num_
patche
s
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
num_
embed
s
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]
"""Shape: `(batch_size, num_images)`"""
"""Shape: `(batch_size, num_images)`"""
...
@@ -153,7 +153,7 @@ class PixtralProcessorAdapter:
...
@@ -153,7 +153,7 @@ class PixtralProcessorAdapter:
images_processed
=
list
[
torch
.
Tensor
]()
images_processed
=
list
[
torch
.
Tensor
]()
images_tokens
=
list
[
torch
.
Tensor
]()
images_tokens
=
list
[
torch
.
Tensor
]()
images_embed_is_patch
=
list
[
torch
.
Tensor
]()
images_embed_is_patch
=
list
[
torch
.
Tensor
]()
images_num_
patche
s
=
list
[
int
]()
images_num_
embed
s
=
list
[
int
]()
for
image
in
images
:
for
image
in
images
:
image_inputs
=
self
.
image_processor
(
ImageChunk
(
image
=
image
))
image_inputs
=
self
.
image_processor
(
ImageChunk
(
image
=
image
))
...
@@ -163,13 +163,13 @@ class PixtralProcessorAdapter:
...
@@ -163,13 +163,13 @@ class PixtralProcessorAdapter:
images_processed
.
append
(
image_processed
)
images_processed
.
append
(
image_processed
)
images_tokens
.
append
(
image_tokens
)
images_tokens
.
append
(
image_tokens
)
images_embed_is_patch
.
append
(
image_tokens
==
image_token_id
)
images_embed_is_patch
.
append
(
image_tokens
==
image_token_id
)
images_num_
patche
s
.
append
(
len
(
image_tokens
))
images_num_
embed
s
.
append
(
len
(
image_tokens
))
return
{
return
{
"input_ids"
:
torch
.
cat
(
images_tokens
)[
None
].
expand
(
len
(
text
),
-
1
),
"input_ids"
:
torch
.
cat
(
images_tokens
)[
None
].
expand
(
len
(
text
),
-
1
),
"images"
:
images_processed
,
"images"
:
images_processed
,
"embed_is_patch"
:
images_embed_is_patch
,
"embed_is_patch"
:
images_embed_is_patch
,
"num_
patche
s"
:
torch
.
tensor
(
images_num_
patche
s
),
"num_
embed
s"
:
torch
.
tensor
(
images_num_
embed
s
),
}
}
...
@@ -273,7 +273,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
...
@@ -273,7 +273,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
return
dict
(
return
dict
(
images
=
MultiModalFieldConfig
.
batched
(
"image"
),
images
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
embed_is_patch
=
MultiModalFieldConfig
.
batched
(
"image"
),
num_
patche
s
=
MultiModalFieldConfig
.
batched
(
"image"
),
num_
embed
s
=
MultiModalFieldConfig
.
batched
(
"image"
),
)
)
def
_get_prompt_updates
(
def
_get_prompt_updates
(
...
@@ -365,16 +365,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -365,16 +365,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
spatial_merge_size
=
self
.
vision_args
.
spatial_merge_size
,
spatial_merge_size
=
self
.
vision_args
.
spatial_merge_size
,
use_mlp_bias
=
False
,
use_mlp_bias
=
False
,
)
)
if
self
.
vision_args
.
add_pre_mm_projector_layer_norm
:
self
.
pre_mm_projector_norm
=
RMSNorm
(
self
.
vision_args
.
hidden_size
,
eps
=
1e-5
)
if
self
.
vision_args
.
mm_projector_id
==
PATCH_MERGE
:
self
.
patch_merger
=
PatchMerger
(
vision_encoder_dim
=
self
.
vision_args
.
hidden_size
,
spatial_merge_size
=
self
.
vision_args
.
spatial_merge_size
,
use_mlp_bias
=
False
,
)
self
.
vision_language_adapter
=
VisionLanguageAdapter
(
self
.
vision_language_adapter
=
VisionLanguageAdapter
(
self
.
vision_args
,
dim
=
config
.
text_config
.
hidden_size
)
self
.
vision_args
,
dim
=
config
.
text_config
.
hidden_size
)
...
@@ -403,16 +394,16 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -403,16 +394,16 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
raise
ValueError
(
"Incorrect type of embed_is_patch. "
raise
ValueError
(
"Incorrect type of embed_is_patch. "
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
f
"Got type:
{
type
(
embed_is_patch
)
}
"
)
num_
patche
s
=
kwargs
.
pop
(
"num_
patche
s"
)
num_
embed
s
=
kwargs
.
pop
(
"num_
embed
s"
)
if
not
isinstance
(
num_
patche
s
,
(
torch
.
Tensor
,
list
)):
if
not
isinstance
(
num_
embed
s
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of num_
patche
s. "
raise
ValueError
(
"Incorrect type of num_
embed
s. "
f
"Got type:
{
type
(
num_
patche
s
)
}
"
)
f
"Got type:
{
type
(
num_
embed
s
)
}
"
)
return
PixtralImagePixelInputs
(
return
PixtralImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
images
=
flatten_bn
(
images
),
images
=
flatten_bn
(
images
),
embed_is_patch
=
embed_is_patch
,
embed_is_patch
=
embed_is_patch
,
num_
patches
=
num_patche
s
,
num_
embeds
=
num_embed
s
,
)
)
def
_process_image_input
(
def
_process_image_input
(
...
@@ -442,33 +433,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -442,33 +433,6 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
image_embeds
=
torch
.
split
(
image_embeds
,
feature_sizes
)
image_embeds
=
torch
.
split
(
image_embeds
,
feature_sizes
)
return
image_embeds
return
image_embeds
def
_get_mm_embeds
(
self
,
features
:
torch
.
Tensor
,
# Shape: (num_patch, d)
num_patches
:
torch
.
Tensor
,
# Shape: (num_images,)
embed_is_patch
:
torch
.
Tensor
,
# Shape: (num_images, num_embeds)
)
->
tuple
[
torch
.
Tensor
,
...]:
"""Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
"""
# Insert columns of nan values according to `embed_is_patch`. This work
# ideally should be done in `_process_image_input`, but
# `_process_image_input` is used in both V0 and V1 path. It's safer to
# put the logic here.
# FIXME: Move this logic to `_process_image_input` when v0 is
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
num_patches_per_image
:
list
[
int
]
=
num_patches
.
tolist
()
embeds_flat
=
features
.
new_full
(
(
sum
(
num_patches_per_image
),
*
features
.
shape
[
1
:]),
fill_value
=
torch
.
nan
,
)
embeds_flat
[
embed_is_patch
.
view
(
-
1
)]
=
features
return
embeds_flat
.
split
(
num_patches_per_image
)
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
self
,
**
kwargs
:
object
)
->
Optional
[
MultiModalEmbeddings
]:
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
...
@@ -481,9 +445,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -481,9 +445,9 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return
image_features
return
image_features
return
flatten_2d_lists
(
return
flatten_2d_lists
(
s
elf
.
_get_mm_embed
s
(
*
args
)
for
args
in
zip
(
s
catter_patch_feature
s
(
*
args
)
for
args
in
zip
(
image_features
,
image_features
,
image_input
[
"num_
patche
s"
],
image_input
[
"num_
embed
s"
],
image_input
[
"embed_is_patch"
],
image_input
[
"embed_is_patch"
],
))
))
...
@@ -494,15 +458,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -494,15 +458,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
inputs_embeds
=
self
.
language_model
.
get_input_embeddings
(
input_ids
)
if
multimodal_embeddings
is
not
None
:
if
multimodal_embeddings
is
not
None
:
# Extract the patch tokens
patch_embeddings
=
json_map_leaves
(
lambda
x
:
x
[
~
x
.
isnan
()].
view
(
-
1
,
*
x
.
shape
[
1
:]),
cast
(
JSONTree
[
torch
.
Tensor
],
multimodal_embeddings
),
)
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
input_ids
,
inputs_embeds
,
inputs_embeds
,
cast
(
NestedTensors
,
patch
_embeddings
),
select_patch_features
(
multimodal
_embeddings
),
self
.
vision_args
.
image_token_id
,
self
.
vision_args
.
image_token_id
,
)
)
return
inputs_embeds
return
inputs_embeds
...
...
vllm/model_executor/models/vision.py
View file @
ca796e19
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Final
,
Generic
,
Optional
,
Protocol
,
TypeVar
,
Union
from
typing
import
Final
,
Generic
,
Optional
,
Protocol
,
TypeVar
,
Union
,
cast
import
torch
import
torch
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
...
@@ -9,9 +9,12 @@ from transformers import PretrainedConfig
...
@@ -9,9 +9,12 @@ from transformers import PretrainedConfig
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention.selector
import
(
backend_name_to_enum
,
from
vllm.attention.selector
import
(
backend_name_to_enum
,
get_global_forced_attn_backend
)
get_global_forced_attn_backend
)
from
vllm.jsontree
import
JSONTree
,
json_map_leaves
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.platforms
import
_Backend
,
current_platform
from
.interfaces
import
MultiModalEmbeddings
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_C
=
TypeVar
(
"_C"
,
bound
=
PretrainedConfig
)
_C
=
TypeVar
(
"_C"
,
bound
=
PretrainedConfig
)
...
@@ -148,3 +151,48 @@ def resolve_visual_encoder_outputs(
...
@@ -148,3 +151,48 @@ def resolve_visual_encoder_outputs(
if
post_layer_norm
is
not
None
and
uses_last_layer
:
if
post_layer_norm
is
not
None
and
uses_last_layer
:
hs_pool
[
-
1
]
=
post_layer_norm
(
encoder_outputs
)
hs_pool
[
-
1
]
=
post_layer_norm
(
encoder_outputs
)
return
torch
.
cat
(
hs_pool
,
dim
=-
1
)
return
torch
.
cat
(
hs_pool
,
dim
=-
1
)
def
scatter_patch_features
(
features
:
torch
.
Tensor
,
num_embeds
:
torch
.
Tensor
,
embed_is_patch
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
...]:
"""
Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
The rest of the values in the tensor are set to NaN so that they
can be filtered out by :func`select_patch_features`.
Args:
features: The patch features, concatenated across each image.
Shape: `(num_patch, feature_depth)`
num_embeds: The number of image embeddings for each image.
Shape: `(num_images,)`
embed_is_patch: A boolean mask indicating which image embeddings
correspond to patch tokens for each image.
Shape: `(num_images, num_embeds)`
"""
num_embeds_per_image
:
list
[
int
]
=
num_embeds
.
tolist
()
embeds_flat
=
features
.
new_full
(
(
sum
(
num_embeds_per_image
),
features
.
shape
[
-
1
]),
fill_value
=
torch
.
nan
,
)
embeds_flat
[
embed_is_patch
.
view
(
-
1
)]
=
features
.
flatten
(
0
,
-
2
)
return
embeds_flat
.
split
(
num_embeds_per_image
)
def
select_patch_features
(
multimodal_embeddings
:
MultiModalEmbeddings
)
->
MultiModalEmbeddings
:
"""
Given the outputs of :func:`scatter_patch_features`, return only
the values that correspond to patch features.
"""
selected_features
=
json_map_leaves
(
lambda
x
:
x
[
~
x
.
isnan
()].
view
(
-
1
,
*
x
.
shape
[
1
:]),
cast
(
JSONTree
[
torch
.
Tensor
],
multimodal_embeddings
),
)
return
cast
(
MultiModalEmbeddings
,
selected_features
)
vllm/multimodal/processing.py
View file @
ca796e19
...
@@ -26,7 +26,7 @@ from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby
...
@@ -26,7 +26,7 @@ from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby
from
.hasher
import
MultiModalHasher
from
.hasher
import
MultiModalHasher
from
.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
from
.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
MultiModalFieldConfig
,
MultiModalInputs
,
MultiModalKwargs
,
MultiModalFieldConfig
,
MultiModalInputs
,
MultiModalKwargs
,
MultiModalKwargsItem
,
PlaceholderRange
)
MultiModalKwargsItem
,
NestedTensors
,
PlaceholderRange
)
from
.parse
import
(
DictEmbeddingItems
,
EmbeddingItems
,
MultiModalDataItems
,
from
.parse
import
(
DictEmbeddingItems
,
EmbeddingItems
,
MultiModalDataItems
,
MultiModalDataParser
)
MultiModalDataParser
)
...
@@ -511,8 +511,35 @@ def iter_token_matches(
...
@@ -511,8 +511,35 @@ def iter_token_matches(
start_idx
+=
1
start_idx
+=
1
def
replace_token_matches
(
token_ids
:
list
[
int
],
match_ids
:
list
[
int
],
new_ids
:
list
[
int
],
)
->
list
[
int
]:
"""
Replace each occurrence of :code:`match_ids` in :code:`token_ids`
with :code:`new_ids`.
Note that empty matches are ignored.
"""
out_seqs
=
list
[
list
[
int
]]()
prev_end_idx
=
0
for
match
in
iter_token_matches
(
token_ids
,
match_ids
):
start_idx
=
match
.
start_idx
end_idx
=
match
.
end_idx
out_seqs
.
append
(
token_ids
[
prev_end_idx
:
start_idx
])
out_seqs
.
append
(
new_ids
)
prev_end_idx
=
end_idx
out_seqs
.
append
(
token_ids
[
prev_end_idx
:])
return
flatten_2d_lists
(
out_seqs
)
@
dataclass
(
repr
=
False
)
@
dataclass
(
repr
=
False
)
class
_
PromptTargetMatch
(
ABC
):
class
PromptTargetMatch
(
ABC
):
_origin
:
BoundPromptUpdate
_origin
:
BoundPromptUpdate
@
property
@
property
...
@@ -535,7 +562,7 @@ class _PromptTargetMatch(ABC):
...
@@ -535,7 +562,7 @@ class _PromptTargetMatch(ABC):
@
dataclass
(
repr
=
False
)
@
dataclass
(
repr
=
False
)
class
_PromptTargetIndexMatch
(
_
PromptTargetMatch
):
class
_PromptTargetIndexMatch
(
PromptTargetMatch
):
match_idx
:
int
match_idx
:
int
@
property
@
property
...
@@ -548,7 +575,7 @@ class _PromptTargetIndexMatch(_PromptTargetMatch):
...
@@ -548,7 +575,7 @@ class _PromptTargetIndexMatch(_PromptTargetMatch):
@
dataclass
(
repr
=
False
)
@
dataclass
(
repr
=
False
)
class
_PromptTargetTokenMatch
(
_
PromptTargetMatch
):
class
_PromptTargetTokenMatch
(
PromptTargetMatch
):
match
:
_TokenMatch
match
:
_TokenMatch
@
property
@
property
...
@@ -561,7 +588,7 @@ class _PromptTargetTokenMatch(_PromptTargetMatch):
...
@@ -561,7 +588,7 @@ class _PromptTargetTokenMatch(_PromptTargetMatch):
@
dataclass
(
repr
=
False
)
@
dataclass
(
repr
=
False
)
class
_PromptTargetTextMatch
(
_
PromptTargetMatch
):
class
_PromptTargetTextMatch
(
PromptTargetMatch
):
match
:
re
.
Match
[
str
]
match
:
re
.
Match
[
str
]
@
property
@
property
...
@@ -594,7 +621,7 @@ class PlaceholderFeaturesInfo:
...
@@ -594,7 +621,7 @@ class PlaceholderFeaturesInfo:
def
find_token_matches
(
def
find_token_matches
(
prompt
:
list
[
int
],
prompt
:
list
[
int
],
prompt_updates
:
Sequence
[
BoundPromptUpdate
],
prompt_updates
:
Sequence
[
BoundPromptUpdate
],
)
->
Sequence
[
_
PromptTargetMatch
]:
)
->
Sequence
[
PromptTargetMatch
]:
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
def
get_matches
(
update
:
BoundPromptUpdate
):
def
get_matches
(
update
:
BoundPromptUpdate
):
...
@@ -620,7 +647,7 @@ def find_token_matches(
...
@@ -620,7 +647,7 @@ def find_token_matches(
def
find_text_matches
(
def
find_text_matches
(
prompt
:
str
,
prompt
:
str
,
prompt_updates
:
Sequence
[
BoundPromptUpdate
],
prompt_updates
:
Sequence
[
BoundPromptUpdate
],
)
->
Sequence
[
_
PromptTargetMatch
]:
)
->
Sequence
[
PromptTargetMatch
]:
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
"""Return each target of :code:`prompt_updates` found in :code:`prompt`."""
def
get_matches
(
update
:
BoundPromptUpdate
):
def
get_matches
(
update
:
BoundPromptUpdate
):
...
@@ -645,15 +672,15 @@ def find_text_matches(
...
@@ -645,15 +672,15 @@ def find_text_matches(
def
_resolve_matches
(
def
_resolve_matches
(
prompt
:
PromptSeq
,
prompt
:
PromptSeq
,
mm_matches
:
Mapping
[
str
,
Sequence
[
_
PromptTargetMatch
]],
mm_matches
:
Mapping
[
str
,
Sequence
[
PromptTargetMatch
]],
)
->
list
[
_
PromptTargetMatch
]:
)
->
list
[
PromptTargetMatch
]:
"""
"""
Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
Resolve :code:`mm_matches` to ensure that there are no overlapping matches,
and sort them such that earlier matches take priority over later ones.
and sort them such that earlier matches take priority over later ones.
"""
"""
matches
=
[
m
for
matches
in
mm_matches
.
values
()
for
m
in
matches
]
matches
=
[
m
for
matches
in
mm_matches
.
values
()
for
m
in
matches
]
seen_matches
:
list
[
Optional
[
_
PromptTargetMatch
]]
=
[
None
]
*
len
(
prompt
)
seen_matches
:
list
[
Optional
[
PromptTargetMatch
]]
=
[
None
]
*
len
(
prompt
)
for
match
in
matches
:
for
match
in
matches
:
for
idx
in
range
(
match
.
start_idx
,
match
.
end_idx
):
for
idx
in
range
(
match
.
start_idx
,
match
.
end_idx
):
...
@@ -669,7 +696,7 @@ def _resolve_matches(
...
@@ -669,7 +696,7 @@ def _resolve_matches(
def
_apply_matches
(
def
_apply_matches
(
prompt
:
_S
,
prompt
:
_S
,
mm_matches
:
Mapping
[
str
,
Sequence
[
_
PromptTargetMatch
]],
mm_matches
:
Mapping
[
str
,
Sequence
[
PromptTargetMatch
]],
mm_item_counts
:
Mapping
[
str
,
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
list
[
_S
]:
)
->
list
[
_S
]:
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
...
@@ -718,7 +745,7 @@ def _apply_matches(
...
@@ -718,7 +745,7 @@ def _apply_matches(
def
apply_token_matches
(
def
apply_token_matches
(
prompt
:
list
[
int
],
prompt
:
list
[
int
],
mm_matches
:
Mapping
[
str
,
Sequence
[
_
PromptTargetMatch
]],
mm_matches
:
Mapping
[
str
,
Sequence
[
PromptTargetMatch
]],
mm_item_counts
:
Mapping
[
str
,
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
list
[
int
]:
)
->
list
[
int
]:
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
...
@@ -732,7 +759,7 @@ def apply_token_matches(
...
@@ -732,7 +759,7 @@ def apply_token_matches(
def
apply_text_matches
(
def
apply_text_matches
(
prompt
:
str
,
prompt
:
str
,
mm_matches
:
Mapping
[
str
,
Sequence
[
_
PromptTargetMatch
]],
mm_matches
:
Mapping
[
str
,
Sequence
[
PromptTargetMatch
]],
mm_item_counts
:
Mapping
[
str
,
int
],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
str
:
)
->
str
:
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
"""Apply the updates in :code:`mm_matches` to :code:`prompt`."""
...
@@ -826,33 +853,62 @@ class ProcessingCache:
...
@@ -826,33 +853,62 @@ class ProcessingCache:
@
staticmethod
@
staticmethod
def
get_lru_cache
(
def
get_lru_cache
(
capacity_gb
:
in
t
,
capacity_gb
:
floa
t
,
value_type
:
type
[
_V
],
value_type
:
type
[
_V
],
*
,
debug
:
bool
=
False
,
)
->
LRUCache
[
str
,
_V
]:
)
->
LRUCache
[
str
,
_V
]:
def
get_size
(
leaf
:
object
)
->
int
:
def
get_leaf_size
(
leaf
:
object
)
->
int
:
# MultiModalKwargs is not a subclass of dict
if
isinstance
(
leaf
,
MultiModalKwargs
):
return
get_item_size
(
leaf
.
data
)
# MultiModalKwargsItem is not a subclass of dict
if
isinstance
(
leaf
,
MultiModalKwargsItem
):
leaf_data
=
{
k
:
v
.
data
for
k
,
v
in
leaf
.
items
()}
return
get_item_size
(
leaf_data
)
# sys.getsizeof doesn't work for tensors
if
isinstance
(
leaf
,
torch
.
Tensor
):
if
isinstance
(
leaf
,
torch
.
Tensor
):
return
leaf
.
nbytes
# sys.getsizeof doesn't work for tensors
return
leaf
.
nbytes
return
sys
.
getsizeof
(
leaf
)
return
sys
.
getsizeof
(
leaf
)
return
LRUCache
[
str
,
_V
](
def
get_item_size
(
GiB_bytes
*
capacity_gb
,
value
:
Union
[
MultiModalKwargs
,
MultiModalKwargsItem
,
getsizeof
=
lambda
x
:
json_reduce_leaves
(
Mapping
[
str
,
NestedTensors
]]
)
->
int
:
size
=
json_reduce_leaves
(
lambda
a
,
b
:
a
+
b
,
lambda
a
,
b
:
a
+
b
,
json_map_leaves
(
get_size
,
x
),
json_map_leaves
(
get_leaf_size
,
value
),
),
)
)
if
debug
:
logger
.
debug
(
"Calculated size of %s to be %.2f GiB"
,
type
(
value
),
size
/
GiB_bytes
)
return
size
def
__init__
(
self
,
capacity_gb
:
int
)
->
None
:
return
LRUCache
(
GiB_bytes
*
capacity_gb
,
getsizeof
=
get_item_size
)
def
__init__
(
self
,
capacity_gb
:
float
,
*
,
debug_cache_hit_ratio_steps
:
Optional
[
int
]
=
None
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
# DEBUG: Set to None to disable
self
.
debug_cache_hit_ratio_steps
=
debug_cache_hit_ratio_steps
self
.
debug_cache_hit_ratio_steps
:
Optional
[
int
]
=
None
self
.
debug_cache_hits
=
0
self
.
debug_cache_hits
=
0
self
.
debug_cache_total
=
0
self
.
debug_cache_total
=
0
self
.
_cache
=
self
.
get_lru_cache
(
capacity_gb
,
MultiModalKwargsItem
)
self
.
_cache
=
self
.
get_lru_cache
(
capacity_gb
,
MultiModalKwargsItem
,
debug
=
bool
(
debug_cache_hit_ratio_steps
),
)
def
_maybe_log_cache_stats
(
self
)
->
None
:
def
_maybe_log_cache_stats
(
self
)
->
None
:
steps
=
self
.
debug_cache_hit_ratio_steps
steps
=
self
.
debug_cache_hit_ratio_steps
...
@@ -863,6 +919,9 @@ class ProcessingCache:
...
@@ -863,6 +919,9 @@ class ProcessingCache:
if
total
>
0
and
total
%
steps
==
0
:
if
total
>
0
and
total
%
steps
==
0
:
logger
.
debug
(
"ProcessingCache: hit_ratio = %.2f"
,
logger
.
debug
(
"ProcessingCache: hit_ratio = %.2f"
,
self
.
debug_cache_hits
/
total
)
self
.
debug_cache_hits
/
total
)
logger
.
debug
(
"ProcessingCache: size = %.2f / %.2f GiB"
,
self
.
_cache
.
currsize
/
GiB_bytes
,
self
.
_cache
.
maxsize
/
GiB_bytes
)
def
get
(
def
get
(
self
,
self
,
...
@@ -1055,14 +1114,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1055,14 +1114,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Given the original multi-modal items for this modality
Given the original multi-modal items for this modality
and HF-processed data, output the updates to perform.
and HF-processed data, output the updates to perform.
Notes:
The information returned by this method is used to update token inputs
- You should not assume
th
at
HF processor
always performs prompt
which bypass
th
e
HF processor
. It is also used to update the output of
updates: in :meth:`_apply_hf_processor_missing`, this method
HF processor if the HF process does not apply prompt updates to text
is called on text-only and multimodal-only inputs separately,
inputs.
instead of passing them in the same call.
- The update information returned by this method is also used to
Moreover, this information is critical to determine the token positions
determine the placeholder token positions for each
multi
-
modal
in order to construct :class:`~vllm-
multimodal
.input.PlaceholderRange`
item.
for each multi-modal
item.
"""
"""
raise
NotImplementedError
raise
NotImplementedError
...
@@ -1357,6 +1416,22 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1357,6 +1416,22 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
it
=
(
update
.
bind
(
tokenizer
)
for
update
in
prompt_updates
)
it
=
(
update
.
bind
(
tokenizer
)
for
update
in
prompt_updates
)
return
dict
(
full_groupby_modality
(
it
))
return
dict
(
full_groupby_modality
(
it
))
def
_apply_token_matches
(
self
,
prompt
:
list
[
int
],
mm_matches
:
Mapping
[
str
,
Sequence
[
PromptTargetMatch
]],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
list
[
int
]:
return
apply_token_matches
(
prompt
,
mm_matches
,
mm_item_counts
)
def
_apply_text_matches
(
self
,
prompt
:
str
,
mm_matches
:
Mapping
[
str
,
Sequence
[
PromptTargetMatch
]],
mm_item_counts
:
Mapping
[
str
,
int
],
)
->
str
:
return
apply_text_matches
(
prompt
,
mm_matches
,
mm_item_counts
)
def
_apply_prompt_updates
(
def
_apply_prompt_updates
(
self
,
self
,
token_ids
:
list
[
int
],
token_ids
:
list
[
int
],
...
@@ -1388,7 +1463,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1388,7 +1463,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_match_counts
.
get
(
modality
,
0
)
>=
item_count
mm_match_counts
.
get
(
modality
,
0
)
>=
item_count
for
modality
,
item_count
in
mm_item_counts
.
items
()
for
modality
,
item_count
in
mm_item_counts
.
items
()
):
# yapf: disable
):
# yapf: disable
token_ids
=
apply_token_matches
(
token_ids
=
self
.
_
apply_token_matches
(
token_ids
,
token_ids
,
mm_token_matches
,
mm_token_matches
,
mm_item_counts
,
mm_item_counts
,
...
@@ -1406,7 +1481,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1406,7 +1481,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
modality
:
find_text_matches
(
text
,
updates
)
modality
:
find_text_matches
(
text
,
updates
)
for
modality
,
updates
in
mm_prompt_updates
.
items
()
for
modality
,
updates
in
mm_prompt_updates
.
items
()
}
}
text
=
apply_text_matches
(
text
=
self
.
_
apply_text_matches
(
text
,
text
,
mm_text_matches
,
mm_text_matches
,
mm_item_counts
,
mm_item_counts
,
...
...
vllm/multimodal/profiling.py
View file @
ca796e19
...
@@ -73,7 +73,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
...
@@ -73,7 +73,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
height
:
int
,
height
:
int
,
num_images
:
int
,
num_images
:
int
,
)
->
list
[
Image
.
Image
]:
)
->
list
[
Image
.
Image
]:
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
255
)
return
[
image
]
*
num_images
return
[
image
]
*
num_images
def
_get_dummy_videos
(
def
_get_dummy_videos
(
...
@@ -84,7 +84,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
...
@@ -84,7 +84,7 @@ class BaseDummyInputsBuilder(ABC, Generic[_I]):
num_frames
:
int
,
num_frames
:
int
,
num_videos
:
int
,
num_videos
:
int
,
)
->
list
[
npt
.
NDArray
]:
)
->
list
[
npt
.
NDArray
]:
video
=
np
.
zeros
((
num_frames
,
width
,
height
,
3
))
video
=
np
.
full
((
num_frames
,
width
,
height
,
3
)
,
255
)
return
[
video
]
*
num_videos
return
[
video
]
*
num_videos
...
...
vllm/transformers_utils/tokenizers/mistral.py
View file @
ca796e19
...
@@ -143,10 +143,6 @@ def make_mistral_chat_completion_request(
...
@@ -143,10 +143,6 @@ def make_mistral_chat_completion_request(
if
last_message
[
"role"
]
==
"assistant"
:
if
last_message
[
"role"
]
==
"assistant"
:
last_message
[
"prefix"
]
=
True
last_message
[
"prefix"
]
=
True
last_message
=
cast
(
Dict
[
str
,
Any
],
messages
[
-
1
])
if
last_message
[
"role"
]
==
"assistant"
:
last_message
[
"prefix"
]
=
True
# mistral-common requires AssistantMessage content to be string [1].
# mistral-common requires AssistantMessage content to be string [1].
#
#
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
...
...
vllm/triton_utils/__init__.py
View file @
ca796e19
...
@@ -2,11 +2,4 @@
...
@@ -2,11 +2,4 @@
from
vllm.triton_utils.importing
import
HAS_TRITON
from
vllm.triton_utils.importing
import
HAS_TRITON
__all__
=
[
"HAS_TRITON"
]
__all__
=
[
"HAS_TRITON"
]
\ No newline at end of file
if
HAS_TRITON
:
from
vllm.triton_utils.custom_cache_manager
import
(
maybe_set_triton_cache_manager
)
__all__
+=
[
"maybe_set_triton_cache_manager"
]
vllm/triton_utils/custom_cache_manager.py
deleted
100644 → 0
View file @
e983c804
# SPDX-License-Identifier: Apache-2.0
import
os
from
triton.runtime.cache
import
(
FileCacheManager
,
default_cache_dir
,
default_dump_dir
,
default_override_dir
)
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
maybe_set_triton_cache_manager
()
->
None
:
"""Set environment variable to tell Triton to use a
custom cache manager"""
cache_manger
=
os
.
environ
.
get
(
"TRITON_CACHE_MANAGER"
,
None
)
if
cache_manger
is
None
:
manager
=
"vllm.triton_utils.custom_cache_manager:CustomCacheManager"
logger
.
info
(
"Setting Triton cache manager to: %s"
,
manager
)
os
.
environ
[
"TRITON_CACHE_MANAGER"
]
=
manager
class
CustomCacheManager
(
FileCacheManager
):
"""Re-implements Triton's cache manager, ensuring that a
unique cache directory is created for each process. This is
needed to avoid collisions when running with tp>1 and
using multi-processing as the distributed backend.
Note this issue was fixed by triton-lang/triton/pull/4295,
but the fix is not yet included in triton==v3.0.0. However,
it should be included in the subsequent version.
"""
def
__init__
(
self
,
key
,
override
=
False
,
dump
=
False
):
self
.
key
=
key
self
.
lock_path
=
None
if
dump
:
self
.
cache_dir
=
default_dump_dir
()
self
.
cache_dir
=
os
.
path
.
join
(
self
.
cache_dir
,
self
.
key
)
self
.
lock_path
=
os
.
path
.
join
(
self
.
cache_dir
,
"lock"
)
os
.
makedirs
(
self
.
cache_dir
,
exist_ok
=
True
)
elif
override
:
self
.
cache_dir
=
default_override_dir
()
self
.
cache_dir
=
os
.
path
.
join
(
self
.
cache_dir
,
self
.
key
)
else
:
# create cache directory if it doesn't exist
self
.
cache_dir
=
os
.
getenv
(
"TRITON_CACHE_DIR"
,
""
).
strip
()
or
default_cache_dir
()
if
self
.
cache_dir
:
self
.
cache_dir
=
f
"
{
self
.
cache_dir
}
_
{
os
.
getpid
()
}
"
self
.
cache_dir
=
os
.
path
.
join
(
self
.
cache_dir
,
self
.
key
)
self
.
lock_path
=
os
.
path
.
join
(
self
.
cache_dir
,
"lock"
)
os
.
makedirs
(
self
.
cache_dir
,
exist_ok
=
True
)
else
:
raise
RuntimeError
(
"Could not create or locate cache dir"
)
vllm/v1/engine/processor.py
View file @
ca796e19
...
@@ -119,16 +119,21 @@ class Processor:
...
@@ -119,16 +119,21 @@ class Processor:
def
_validate_structured_output
(
self
,
params
:
SamplingParams
)
->
None
:
def
_validate_structured_output
(
self
,
params
:
SamplingParams
)
->
None
:
if
not
params
.
guided_decoding
or
not
self
.
decoding_config
:
if
not
params
.
guided_decoding
or
not
self
.
decoding_config
:
return
return
if
self
.
decoding_config
.
guided_decoding_backend
!=
"xgrammar"
:
raise
ValueError
(
supported_backends
=
[
"xgrammar"
]
"Only xgrammar structured output is supported in V1."
)
engine_level_backend
=
self
.
decoding_config
.
guided_decoding_backend
if
(
params
.
guided_decoding
.
backend
if
engine_level_backend
not
in
supported_backends
:
and
params
.
guided_decoding
.
backend
!=
'xgrammar'
):
raise
ValueError
(
f
"Only
{
supported_backends
}
structured output is "
raise
ValueError
(
"supported in V1."
)
"Only xgrammar structured output is supported in V1."
)
if
params
.
guided_decoding
.
backend
:
if
self
.
vllm_config
.
speculative_config
:
if
params
.
guided_decoding
.
backend
!=
engine_level_backend
:
raise
ValueError
(
"Structured output is not supported with "
raise
ValueError
(
"Request-level structured output backend "
"speculative decoding."
)
"must match engine-level backend. "
f
"
{
params
.
guided_decoding
.
backend
}
"
f
" !=
{
engine_level_backend
}
"
)
else
:
params
.
guided_decoding
.
backend
=
engine_level_backend
if
vllm
.
platforms
.
current_platform
.
is_tpu
():
if
vllm
.
platforms
.
current_platform
.
is_tpu
():
raise
ValueError
(
"Structured output is not supported on TPU."
)
raise
ValueError
(
"Structured output is not supported on TPU."
)
...
...
vllm/v1/outputs.py
View file @
ca796e19
...
@@ -46,7 +46,7 @@ class SamplerOutput:
...
@@ -46,7 +46,7 @@ class SamplerOutput:
# [num_reqs, max_num_generated_tokens]
# [num_reqs, max_num_generated_tokens]
# Different requests can have different number of generated tokens.
# Different requests can have different number of generated tokens.
# All requests are padded to max_num_generated_tokens.
# All requests are padded to max_num_generated_tokens.
#
INVALID
_TOKEN_ID (-1 by default) is used for padding.
#
PLACEHOLDER
_TOKEN_ID (-1 by default) is used for padding.
sampled_token_ids
:
torch
.
Tensor
sampled_token_ids
:
torch
.
Tensor
logprobs_tensors
:
Optional
[
LogprobsTensors
]
logprobs_tensors
:
Optional
[
LogprobsTensors
]
...
...
vllm/v1/sample/ops/utils.py
0 → 100644
View file @
ca796e19
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Union
import
torch
def
compiled_softmax
(
logits
:
torch
.
Tensor
,
temperature
:
Union
[
float
,
torch
.
Tensor
]
=
1.0
,
)
->
torch
.
Tensor
:
"""Faster softmax kernel generated by torch.compile.
Args:
logits: [n, vocab_size]
temperature: [n] or float
"""
# NOTE(woosuk): Avoid recompilation by marking the first dim as dynamic.
torch
.
_dynamo
.
mark_dynamic
(
logits
,
index
=
0
)
if
isinstance
(
temperature
,
torch
.
Tensor
):
torch
.
_dynamo
.
mark_dynamic
(
temperature
,
index
=
0
)
return
_softmax
(
logits
,
temperature
)
@
torch
.
compile
def
_softmax
(
logits
:
torch
.
Tensor
,
temperature
:
Union
[
float
,
torch
.
Tensor
],
)
->
torch
.
Tensor
:
logits
=
logits
/
temperature
return
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
vllm/v1/sample/rejection_sampler.py
View file @
ca796e19
...
@@ -3,25 +3,32 @@ from typing import Optional
...
@@ -3,25 +3,32 @@ from typing import Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.nn.utils.rnn
import
pad_sequence
import
triton
import
triton.language
as
tl
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.spec_decode.utils
import
random_sample
from
vllm.v1.sample.ops.utils
import
compiled_softmax
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
INVALID_TOKEN_ID
=
-
1
PLACEHOLDER_TOKEN_ID
:
tl
.
constexpr
=
-
1
GREEDY_TEMPERATURE
:
tl
.
constexpr
=
-
1
# Maximum number of speculative draft tokens allowed per request in a single
# step. This value is chosen to be large enough to handle typical use cases.
MAX_SPEC_LEN
=
32
class
RejectionSampler
(
nn
.
Module
):
class
RejectionSampler
(
nn
.
Module
):
"""
"""
The implementation strictly follows the algorithm described in
The implementation strictly follows the algorithm described in
https://arxiv.org/abs/2211.17192.
https://arxiv.org/abs/2211.17192.
However, we want to clarify the terminology used in the implementation:
However, we want to clarify the terminology used in the implementation:
accepted tokens: tokens that are accepted based on the relationship
accepted tokens: tokens that are accepted based on the relationship
between the "raw" draft and target probabilities.
between the "raw" draft and target probabilities.
recovered tokens: tokens that are sampled based on the adjusted probability
recovered tokens: tokens that are sampled based on the adjusted probability
distribution, which is derived from both the draft and target
distribution, which is derived from both the draft and target
probabilities.
probabilities.
bonus tokens:
bonus tokens:
If all proposed tokens are accepted, the bonus token is added to the
If all proposed tokens are accepted, the bonus token is added to the
...
@@ -31,48 +38,42 @@ class RejectionSampler(nn.Module):
...
@@ -31,48 +38,42 @@ class RejectionSampler(nn.Module):
sampling process. For example, we can use top_p, top_k sampling for
sampling process. For example, we can use top_p, top_k sampling for
bonus tokens, while spec decode does not support these sampling
bonus tokens, while spec decode does not support these sampling
strategies.
strategies.
output tokens:
output tokens:
Tokens are finally generated with the rejection sampler.
Tokens are finally generated with the rejection sampler.
output tokens = accepted tokens + recovered tokens + bonus tokens
output tokens = accepted tokens + recovered tokens + bonus tokens
"""
"""
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
def
forward
(
self
,
self
,
draft_token_ids
:
list
[
list
[
int
]],
metadata
:
SpecDecodeMetadata
,
# [num_tokens, vocab_size]
draft_probs
:
Optional
[
torch
.
Tensor
],
draft_probs
:
Optional
[
torch
.
Tensor
],
bonus_token_ids_tensor
:
torch
.
Tensor
,
# [batch_size, 1]
# [num_tokens, vocab_size]
target_probs
:
torch
.
Tensor
,
# [num_total_tokens, vocab_size]
target_logits
:
torch
.
Tensor
,
# [batch_size, 1]
bonus_token_ids
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
'''
'''
Args:
Args:
draft_token_ids (List[List[int]]):
metadata:
A 2D list of token IDs for each request in the batch.
Metadata for spec decoding.
Each request might have different number of draft tokens.
It may also contain empty lists for requests that have
no draft tokens.
draft_probs (Optional[torch.Tensor]):
draft_probs (Optional[torch.Tensor]):
Probability distribution for the draft tokens. Shape is
Probability distribution for the draft tokens. Shape is
[batch_size, max_spec_len, vocab_size]. Can be None if
[num_tokens, vocab_size]. Can be None if probabilities are
probabilities are not provided, which is the case for
not provided, which is the case for ngram spec decode.
ngram spec decode.
target_logits (torch.Tensor):
Target model's logits probability distribution.
Shape is [num_tokens, vocab_size]. Here, probabilities from
different requests are flattened into a single tensor because
this is the shape of the output logits.
bonus_token_ids_tensor (torch.Tensor):
bonus_token_ids_tensor (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1].
A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all
Bonus tokens are added to the end of the sequence if all
proposed tokens are accepted. We generate the bonus tokens
proposed tokens are accepted. We generate the bonus tokens
outside of the rejection sampler with the default sampling
outside of the rejection sampler with the default sampling
strategy. It allows for more flexibility in the sampling
strategy. It allows for more flexibility in the sampling
process such as top_p, top_k sampling.
process such as top_p, top_k sampling.
target_probs (torch.Tensor):
Target model probability distribution.
Shape is [num_total_tokens, vocab_size]. num_total_tokens
is the total number of tokens from all requests. Here,
probabilities from different requests are flattened into
a single tensor because this is the shape of the output
logits.
sampling_metadata (SamplingMetadata):
sampling_metadata (SamplingMetadata):
Additional metadata needed for sampling, such as temperature,
Additional metadata needed for sampling, such as temperature,
top-k/top-p parameters, or other relevant information.
top-k/top-p parameters, or other relevant information.
...
@@ -80,268 +81,481 @@ class RejectionSampler(nn.Module):
...
@@ -80,268 +81,481 @@ class RejectionSampler(nn.Module):
output_token_ids (torch.Tensor):
output_token_ids (torch.Tensor):
A tensor containing the final output token IDs.
A tensor containing the final output token IDs.
'''
'''
assert
metadata
.
max_spec_len
<=
MAX_SPEC_LEN
# NOTE: The following input preparationg can be moved
# [num_tokens, vocab_size]
# to the model runner with a persistent manner for better
target_probs
=
compute_probs
(
# performance.
target_logits
,
# Convert draft token IDs to a tensor, split by sample_lens, then pad.
metadata
.
cu_num_draft_tokens
,
draft_token_ids
=
[
sampling_metadata
,
torch
.
tensor
(
x
,
dtype
=
int
,
device
=
'cpu'
)
for
x
in
draft_token_ids
)
]
draft_token_ids_tensor
=
pad_sequence
(
draft_token_ids
,
output_token_ids
=
rejection_sample
(
batch_first
=
True
,
metadata
.
draft_token_ids
,
padding_value
=
INVALID_TOKEN_ID
)
metadata
.
num_draft_tokens
,
metadata
.
max_spec_len
,
# NOTE: CPU <-> GPU synchronization happens here.
metadata
.
cu_num_draft_tokens
,
draft_token_ids_tensor
=
draft_token_ids_tensor
.
to
(
target_probs
.
device
)
draft_probs
,
target_probs
,
# Create one-hot tensor for draft token ids.
bonus_token_ids
,
# This is used for ngram where we don't have draft_probs.
sampling_metadata
,
if
draft_probs
is
None
and
not
sampling_metadata
.
all_greedy
:
)
vocab_size
=
target_probs
.
size
(
-
1
)
draft_probs
=
_create_greedy_token_probs
(
draft_token_ids_tensor
,
vocab_size
,
target_probs
.
device
)
sample_lens
=
[
len
(
x
)
+
1
for
x
in
draft_token_ids
]
target_probs
=
_convert_2d_probs
(
target_probs
,
sample_lens
)
return
self
.
forward_native
(
draft_token_ids_tensor
,
draft_probs
,
bonus_token_ids_tensor
,
target_probs
,
sampling_metadata
)
# TODO: The following method can be optimized for better performance.
def
forward_native
(
self
,
draft_token_ids_tensor
:
torch
.
Tensor
,
# [batch_size, max_spec_len, vocab_size]
draft_probs
:
Optional
[
torch
.
Tensor
],
bonus_token_ids_tensor
:
torch
.
Tensor
,
# [batch_size, max_spec_len + 1, vocab_size]
target_probs
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
# Add 1 to include the 'bonus' token.
if
sampling_metadata
.
all_greedy
:
# Produce a mask that remains 1 (True) until the first
# mismatch (cumprod turns 0 after a mismatch).
target_token_ids_tensor
=
target_probs
.
argmax
(
dim
=-
1
)
accept_mask
=
(
target_token_ids_tensor
[:,
:
-
1
]
==
draft_token_ids_tensor
).
cumprod
(
dim
=
1
)
# Identify valid positions (non-padding).
valid_mask
=
target_token_ids_tensor
!=
INVALID_TOKEN_ID
# Generate mask with bonus token.
generate_mask
=
torch
.
cat
([
accept_mask
,
torch
.
zeros
(
accept_mask
.
size
(
0
),
1
,
device
=
accept_mask
.
device
)
],
dim
=
1
).
to
(
torch
.
bool
)
&
valid_mask
zeros_mask
=
(
generate_mask
==
0
)
first_zero_idx
=
zeros_mask
.
float
().
argmax
(
dim
=
1
)
# Figure out which rows actually contain at least one zero.
rows_with_zero
=
zeros_mask
.
any
(
dim
=
1
)
# Use indexing to set the first zero in each of those rows to 1.
generate_mask
[
rows_with_zero
,
first_zero_idx
[
rows_with_zero
]]
=
1
output_token_ids
=
target_token_ids_tensor
output_token_ids
[
~
generate_mask
]
=
INVALID_TOKEN_ID
else
:
# Reference: https://arxiv.org/pdf/2211.17192
# 1. Extract the probabilities of the draft tokens.
# [batch_size, max_spec_len]
batch_size
=
draft_token_ids_tensor
.
size
(
0
)
max_spec_len
=
draft_token_ids_tensor
.
size
(
1
)
invalid_idx
=
draft_token_ids_tensor
==
INVALID_TOKEN_ID
draft_token_ids_tensor
[
invalid_idx
]
=
0
assert
draft_probs
is
not
None
draft_token_probs
=
draft_probs
.
gather
(
dim
=-
1
,
index
=
draft_token_ids_tensor
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
target_token_probs
=
target_probs
.
gather
(
dim
=-
1
,
index
=
draft_token_ids_tensor
.
unsqueeze
(
-
1
)).
squeeze
(
-
1
)
# Force the probabilities of invalid tokens to inf
# so that they are not accepted.
draft_token_probs
[
invalid_idx
]
=
float
(
'inf'
)
# 2. Generate uniform samples.
# [batch_size, max_spec_len + 1]
uniform_samples
=
_create_uniform_samples
(
sampling_metadata
.
generators
,
batch_size
,
max_spec_len
,
target_probs
.
device
)
# 3. Accept or reject the samples.
# [batch_size, max_spec_len]
# If the draft token probabilities are 0, set them to the smallest
# positive normal value representable by float32.
safe_draft_probs
=
torch
.
where
(
draft_token_probs
>
0
,
draft_token_probs
,
torch
.
finfo
(
torch
.
float32
).
tiny
)
accepted
=
uniform_samples
<=
target_token_probs
/
safe_draft_probs
accept_mask
=
accepted
.
cumprod
(
dim
=
1
)
# Set the token ids to the draft token ids if accepted, otherwise
# set them to INVALID_TOKEN_ID.
accepted_token_ids
=
(
draft_token_ids_tensor
*
accept_mask
+
INVALID_TOKEN_ID
*
(
1
-
accept_mask
))
# 4. Adjust the distribution for the recovered tokens.
# Clamp the bonus probabilities to the smallest positive normal
# value representable by float32.
bonus_prob
=
torch
.
clamp
(
target_probs
[:,
:
-
1
,
:]
-
draft_probs
,
min
=
torch
.
finfo
(
torch
.
float32
).
tiny
)
normalized_bonus_prob
=
bonus_prob
/
bonus_prob
.
sum
(
dim
=-
1
,
keepdim
=
True
)
# 5. Sample recovered token ids.
recovered_token_ids
=
random_sample
(
normalized_bonus_prob
,
sampling_metadata
.
generators
).
reshape
(
batch_size
,
max_spec_len
)
# 6. Get the final output token ids.
# output_token_ids = accepted_token_ids +
# recovered_token_ids +
# bonus_token_id
recovered_bonus_token_ids
=
torch
.
cat
(
[
recovered_token_ids
,
bonus_token_ids_tensor
],
dim
=
1
)
# Generate mask with bonus tokens.
generate_mask
=
torch
.
cat
([
accept_mask
,
torch
.
zeros
(
batch_size
,
1
,
device
=
accept_mask
.
device
)
],
dim
=
1
).
to
(
torch
.
bool
)
zeros_mask
=
(
generate_mask
==
0
)
first_zero_idx
=
zeros_mask
.
float
().
argmax
(
dim
=
1
)
output_token_ids
=
torch
.
cat
([
accepted_token_ids
,
torch
.
full
((
batch_size
,
1
),
fill_value
=
INVALID_TOKEN_ID
,
device
=
accept_mask
.
device
)
],
dim
=
1
)
output_token_ids
[
torch
.
arange
(
batch_size
),
first_zero_idx
]
=
recovered_bonus_token_ids
[
torch
.
arange
(
batch_size
),
first_zero_idx
]
return
output_token_ids
return
output_token_ids
def
compute_probs
(
self
,
logits
:
torch
.
Tensor
,
@
staticmethod
sampling_metadata
:
SamplingMetadata
,
def
parse_output
(
sample_lens
:
list
[
int
])
->
torch
.
Tensor
:
output_token_ids
:
torch
.
Tensor
,
"""
vocab_size
:
int
,
Compute probability distribution from logits based on sampling metadata.
)
->
list
[
list
[
int
]]:
output_token_ids_np
=
output_token_ids
.
cpu
().
numpy
()
This function applies temperature scaling to the logits and converts
# Create mask for valid tokens.
them to probabilities using softmax. Note that division by
valid_mask
=
((
output_token_ids_np
!=
PLACEHOLDER_TOKEN_ID
)
&
temperature is not performed inplace to preserve the original logits
(
output_token_ids_np
<
vocab_size
))
tensor, which will be used by the original sampler to get bonus tokens.
outputs
=
[
row
[
valid_mask
[
i
]].
tolist
()
Args:
for
i
,
row
in
enumerate
(
output_token_ids_np
)
logits: Input logits tensor to be converted to probabilities
]
sampling_metadata: Metadata containing sampling parameters such
return
outputs
as temperature and whether greedy sampling is used
sample_lens: List of sample lengths used for repeating
temperature values
def
rejection_sample
(
# [num_tokens]
Returns:
draft_token_ids
:
torch
.
Tensor
,
torch.Tensor: Probability distribution (softmax of scaled logits)
# [batch_size]
if non-greedy sampling is used, otherwise returns the
num_draft_tokens
:
list
[
int
],
original logits
max_spec_len
:
int
,
"""
# [batch_size]
cu_num_draft_tokens
:
torch
.
Tensor
,
# [num_tokens, vocab_size]
draft_probs
:
Optional
[
torch
.
Tensor
],
# [num_tokens, vocab_size]
target_probs
:
torch
.
Tensor
,
# [batch_size, 1]
bonus_token_ids
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
assert
draft_token_ids
.
ndim
==
1
assert
draft_probs
is
None
or
draft_probs
.
ndim
==
2
assert
cu_num_draft_tokens
.
ndim
==
1
assert
target_probs
.
ndim
==
2
batch_size
=
len
(
num_draft_tokens
)
num_tokens
=
draft_token_ids
.
shape
[
0
]
vocab_size
=
target_probs
.
shape
[
-
1
]
device
=
target_probs
.
device
assert
draft_token_ids
.
is_contiguous
()
assert
draft_probs
is
None
or
draft_probs
.
is_contiguous
()
assert
target_probs
.
is_contiguous
()
assert
bonus_token_ids
.
is_contiguous
()
assert
target_probs
.
shape
==
(
num_tokens
,
vocab_size
)
# Create output buffer.
output_token_ids
=
torch
.
empty
(
(
batch_size
,
max_spec_len
+
1
),
dtype
=
torch
.
int32
,
# Consistent with SamplerOutput.sampled_token_ids.
device
=
device
,
)
output_token_ids
.
fill_
(
PLACEHOLDER_TOKEN_ID
)
if
sampling_metadata
.
all_greedy
:
is_greedy
=
None
else
:
is_greedy
=
sampling_metadata
.
temperature
==
GREEDY_TEMPERATURE
if
not
sampling_metadata
.
all_random
:
# Rejection sampling for greedy sampling requests.
target_argmax
=
target_probs
.
argmax
(
dim
=-
1
)
rejection_greedy_sample_kernel
[(
batch_size
,
)](
output_token_ids
,
cu_num_draft_tokens
,
draft_token_ids
,
target_argmax
,
bonus_token_ids
,
is_greedy
,
max_spec_len
,
num_warps
=
1
,
)
if
sampling_metadata
.
all_greedy
:
if
sampling_metadata
.
all_greedy
:
return
logits
return
output_token_ids
assert
sampling_metadata
.
temperature
is
not
None
# We should optimize the following code as
# Generate uniform probabilities for rejection sampling.
# it will cause CPU -> GPU synchronization.
# [num_tokens]
temperature
=
torch
.
repeat_interleave
(
uniform_probs
=
generate_uniform_probs
(
sampling_metadata
.
temperature
,
num_tokens
,
torch
.
tensor
(
sample_lens
,
num_draft_tokens
,
device
=
sampling_metadata
.
temperature
.
device
))
sampling_metadata
.
generators
,
temperature
=
temperature
.
unsqueeze
(
dim
=
1
)
device
,
logits
=
logits
/
temperature
)
return
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
# Sample recovered tokens for each position.
# [num_tokens]
def
_create_greedy_token_probs
(
recovered_token_ids
=
sample_recovered_tokens
(
token_ids
:
torch
.
Tensor
,
max_spec_len
,
vocab_size
:
int
,
num_draft_tokens
,
out_device
:
torch
.
device
,
cu_num_draft_tokens
,
draft_token_ids
,
draft_probs
,
target_probs
,
sampling_metadata
,
device
,
)
# Rejection sampling for random sampling requests.
rejection_random_sample_kernel
[(
batch_size
,
)](
output_token_ids
,
cu_num_draft_tokens
,
draft_token_ids
,
draft_probs
,
target_probs
,
bonus_token_ids
,
recovered_token_ids
,
uniform_probs
,
is_greedy
,
max_spec_len
,
vocab_size
,
IS_NGRAM
=
draft_probs
is
None
,
num_warps
=
1
,
)
return
output_token_ids
def
compute_probs
(
logits
:
torch
.
Tensor
,
# [num_tokens, vocab_size]
cu_num_draft_tokens
:
torch
.
Tensor
,
# [batch_size]
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
batch_size
,
num_tokens
=
token_ids
.
shape
"""Compute probability distribution from logits based on sampling metadata.
token_probs
=
torch
.
zeros
(
batch_size
,
This function applies temperature scaling to the logits and converts
num_tokens
,
them to probabilities using softmax. For greedy decoding, it returns
vocab_size
,
the original logits.
dtype
=
torch
.
float
,
device
=
out_device
)
Args:
logits: Input logits tensor to be converted to probabilities.
# Ignore INVALID_TOKEN_ID.
cu_num_draft_tokens: Cumulative number of draft tokens.
valid_mask
=
(
token_ids
!=
INVALID_TOKEN_ID
)
sampling_metadata: Metadata containing sampling parameters such as
valid_indices
=
token_ids
.
clone
()
temperature and whether greedy sampling is used.
valid_indices
[
~
valid_mask
]
=
0
Returns:
token_probs
.
scatter_
(
dim
=
2
,
torch.Tensor: Probability distribution (softmax of scaled logits)
index
=
valid_indices
.
unsqueeze
(
-
1
),
if non-greedy sampling is used, otherwise returns the
src
=
valid_mask
.
unsqueeze
(
-
1
).
float
())
original logits.
return
token_probs
def
_convert_2d_probs
(
probs
:
torch
.
Tensor
,
# [num_total_tokens, vocab_size]
sample_lens
:
list
[
int
])
->
torch
.
Tensor
:
"""
"""
Converts a 2D tensor of probabilities to a 3D tensor with padding.
assert
logits
.
ndim
==
2
[num_total_tokens, vocab_size] ->
assert
cu_num_draft_tokens
.
ndim
==
1
[batch_size, max_spec_len + 1, vocab_size]
if
sampling_metadata
.
all_greedy
:
return
logits
num_tokens
=
logits
.
shape
[
0
]
batch_size
=
cu_num_draft_tokens
.
shape
[
0
]
expanded_temperature
=
torch
.
empty
(
(
num_tokens
,
1
),
dtype
=
torch
.
float32
,
device
=
logits
.
device
,
)
expand_kernel
[(
batch_size
,
)](
expanded_temperature
,
sampling_metadata
.
temperature
,
cu_num_draft_tokens
,
GREEDY_TEMPERATURE
,
# replace_from
1
,
# replace_to
MAX_NUM_TOKENS
=
MAX_SPEC_LEN
,
num_warps
=
1
,
)
output_prob
=
compiled_softmax
(
logits
,
expanded_temperature
)
return
output_prob
def
generate_uniform_probs
(
num_tokens
:
int
,
num_draft_tokens
:
list
[
int
],
generators
:
dict
[
int
,
torch
.
Generator
],
device
:
torch
.
device
,
)
->
torch
.
Tensor
:
"""
"""
cumulative_lens
=
torch
.
cumsum
(
torch
.
tensor
(
sample_lens
,
Generates a batch of uniform random samples, with optional seeding
device
=
probs
.
device
),
if available.
dim
=
0
)
split_indices
=
cumulative_lens
[:
-
1
].
tolist
()
# Exclude last index
This method creates a tensor of shape `(num_tokens, )` filled
with uniform random values in the range [0, 1). If `generators` is provided,
# Split into chunks without loops
the requests with their own seeds will use the provided `torch.Generator`
chunks
=
torch
.
tensor_split
(
probs
,
split_indices
,
dim
=
0
)
for reproducibility. The samples for the other requests will be generated
without a seed.
# Pad all sequences to maximum length
padded_probs
=
pad_sequence
(
chunks
,
batch_first
=
True
,
padding_value
=
0.0
)
Args:
return
padded_probs
num_tokens : int
Total number of tokens.
num_draft_tokens : List[List[int]]
def
_create_uniform_samples
(
seeded_seqs
:
dict
[
int
,
torch
.
Generator
],
Number of draft tokens per request.
batch_size
:
int
,
k
:
int
,
generators : Optional[Dict[int, torch.Generator]]
device
:
torch
.
device
)
->
torch
.
Tensor
:
A dictionary mapping indices in the batch to
`torch.Generator` objects.
device : torch.device
The device on which to allocate the tensor.
Returns:
uniform_rand : torch.Tensor
A tensor of shape `(num_tokens, )` containing uniform
random values in the range [0, 1).
"""
"""
Generates a batch of uniform random samples, with optional seeding
uniform_probs
=
torch
.
rand
(
for specific sequences.
(
num_tokens
,
),
dtype
=
torch
.
float32
,
This method creates a tensor of shape `(batch_size, k)` filled
device
=
device
,
with uniform random values in the range [0, 1). If `seeded_seqs`
)
is provided, the sequences corresponding to specific indices
start_idx
=
0
will be generated using the provided `torch.Generator` for
for
req_idx
,
n
in
enumerate
(
num_draft_tokens
):
reproducibility. The other sequences will be generated without
# Do not generate random numbers for requests with no draft tokens.
a seed.
# This can be important for reproducibility.
if
n
==
0
:
Args:
continue
seeded_seqs : Optional[Dict[int, torch.Generator]]
end_idx
=
start_idx
+
n
A dictionary mapping indices in the batch to
generator
=
generators
.
get
(
req_idx
)
`torch.Generator` objects.
if
generator
is
not
None
:
batch_size : int
uniform_probs
[
start_idx
:
end_idx
].
uniform_
(
generator
=
generator
)
The number of sequences to generate.
start_idx
=
end_idx
k : int
return
uniform_probs
The number of random samples per sequence.
device : torch.device
The device on which to allocate the tensor.
def
sample_recovered_tokens
(
max_spec_len
:
int
,
Returns:
num_draft_tokens
:
list
[
int
],
uniform_rand : torch.Tensor
# [batch_size]
A tensor of shape `(batch_size, k)` containing uniform
cu_num_draft_tokens
:
torch
.
Tensor
,
random values in the range [0, 1).
# [num_tokens]
"""
draft_token_ids
:
torch
.
Tensor
,
# [num_tokens, vocab_size]
uniform_rand
=
torch
.
rand
(
batch_size
,
draft_probs
:
Optional
[
torch
.
Tensor
],
k
,
# [num_tokens, vocab_size]
dtype
=
torch
.
float32
,
target_probs
:
torch
.
Tensor
,
device
=
device
)
sampling_metadata
:
SamplingMetadata
,
# Apply seeded generators only where needed
device
:
torch
.
device
,
if
seeded_seqs
:
)
->
torch
.
Tensor
:
for
idx
,
generator
in
seeded_seqs
.
items
():
# NOTE(woosuk): Create only one distribution for each request.
uniform_rand
[
idx
].
uniform_
(
0
,
1
,
generator
=
generator
)
batch_size
=
len
(
num_draft_tokens
)
return
uniform_rand
vocab_size
=
target_probs
.
shape
[
-
1
]
q
=
torch
.
empty
(
(
batch_size
,
vocab_size
),
dtype
=
torch
.
float32
,
device
=
device
,
)
q
.
exponential_
()
for
i
,
generator
in
sampling_metadata
.
generators
.
items
():
# Do not generate random numbers for requests with no draft tokens.
# This can be important for reproducibility.
if
num_draft_tokens
[
i
]
>
0
:
q
[
i
].
exponential_
(
generator
=
generator
)
recovered_token_ids
=
torch
.
empty_like
(
draft_token_ids
)
sample_recovered_tokens_kernel
[(
batch_size
,
max_spec_len
)](
recovered_token_ids
,
cu_num_draft_tokens
,
draft_token_ids
,
draft_probs
,
target_probs
,
q
,
vocab_size
,
triton
.
next_power_of_2
(
vocab_size
),
IS_NGRAM
=
draft_probs
is
None
,
)
return
recovered_token_ids
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@
triton
.
jit
(
do_not_specialize
=
[
"max_spec_len"
])
def
rejection_greedy_sample_kernel
(
output_token_ids_ptr
,
# [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr
,
# [batch_size]
draft_token_ids_ptr
,
# [num_tokens]
target_argmax_ptr
,
# [num_tokens]
bonus_token_ids_ptr
,
# [batch_size]
is_greedy_ptr
,
# [batch_size] or None
max_spec_len
,
):
req_idx
=
tl
.
program_id
(
0
)
# FIXME(woosuk): Because is_greedy_ptr is not None at profiling run,
# re-compilation may happen during runtime when is_greedy_ptr is None.
if
is_greedy_ptr
is
None
:
is_greedy
=
True
else
:
is_greedy
=
tl
.
load
(
is_greedy_ptr
+
req_idx
)
if
not
is_greedy
:
# Early exit for non-greedy sampling requests.
return
if
req_idx
==
0
:
start_idx
=
0
else
:
start_idx
=
tl
.
load
(
cu_num_draft_tokens_ptr
+
req_idx
-
1
)
end_idx
=
tl
.
load
(
cu_num_draft_tokens_ptr
+
req_idx
)
num_draft_tokens
=
end_idx
-
start_idx
rejected
=
False
for
pos
in
range
(
num_draft_tokens
):
if
not
rejected
:
draft_token_id
=
tl
.
load
(
draft_token_ids_ptr
+
start_idx
+
pos
)
target_argmax_id
=
tl
.
load
(
target_argmax_ptr
+
start_idx
+
pos
)
tl
.
store
(
output_token_ids_ptr
+
req_idx
*
(
max_spec_len
+
1
)
+
pos
,
target_argmax_id
)
if
draft_token_id
!=
target_argmax_id
:
# Reject.
rejected
=
True
if
not
rejected
:
# If all tokens are accepted, append the bonus token.
bonus_token_id
=
tl
.
load
(
bonus_token_ids_ptr
+
req_idx
)
tl
.
store
(
output_token_ids_ptr
+
req_idx
*
(
max_spec_len
+
1
)
+
num_draft_tokens
,
bonus_token_id
)
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@
triton
.
jit
(
do_not_specialize
=
[
"max_spec_len"
])
def
rejection_random_sample_kernel
(
output_token_ids_ptr
,
# [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr
,
# [batch_size]
draft_token_ids_ptr
,
# [num_tokens]
draft_probs_ptr
,
# [num_tokens, vocab_size] or None
target_probs_ptr
,
# [num_tokens, vocab_size]
bonus_token_ids_ptr
,
# [batch_size]
recovered_token_ids_ptr
,
# [num_tokens]
uniform_probs_ptr
,
# [num_tokens]
is_greedy_ptr
,
# [batch_size]
max_spec_len
,
vocab_size
,
IS_NGRAM
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
is_greedy
=
tl
.
load
(
is_greedy_ptr
+
req_idx
)
if
is_greedy
:
# Early exit for greedy sampling requests.
return
if
req_idx
==
0
:
start_idx
=
0
else
:
start_idx
=
tl
.
load
(
cu_num_draft_tokens_ptr
+
req_idx
-
1
)
end_idx
=
tl
.
load
(
cu_num_draft_tokens_ptr
+
req_idx
)
num_draft_tokens
=
end_idx
-
start_idx
rejected
=
False
for
pos
in
range
(
num_draft_tokens
):
if
not
rejected
:
draft_token_id
=
tl
.
load
(
draft_token_ids_ptr
+
start_idx
+
pos
)
if
IS_NGRAM
:
draft_prob
=
1
else
:
draft_prob
=
tl
.
load
(
draft_probs_ptr
+
(
start_idx
+
pos
)
*
vocab_size
+
draft_token_id
)
target_prob
=
tl
.
load
(
target_probs_ptr
+
(
start_idx
+
pos
)
*
vocab_size
+
draft_token_id
)
uniform_prob
=
tl
.
load
(
uniform_probs_ptr
+
start_idx
+
pos
)
# NOTE(woosuk): While the draft probability should never be 0,
# we check it to avoid NaNs. If it happens to be 0, we reject.
if
draft_prob
>
0
and
target_prob
/
draft_prob
>=
uniform_prob
:
# Accept.
token_id
=
draft_token_id
else
:
# Reject. Use recovered token.
rejected
=
True
token_id
=
tl
.
load
(
recovered_token_ids_ptr
+
start_idx
+
pos
)
tl
.
store
(
output_token_ids_ptr
+
req_idx
*
(
max_spec_len
+
1
)
+
pos
,
token_id
)
if
not
rejected
:
# If all tokens are accepted, append the bonus token.
bonus_token_id
=
tl
.
load
(
bonus_token_ids_ptr
+
req_idx
)
tl
.
store
(
output_token_ids_ptr
+
req_idx
*
(
max_spec_len
+
1
)
+
num_draft_tokens
,
bonus_token_id
)
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@
triton
.
jit
(
do_not_specialize
=
[
"replace_from"
,
"replace_to"
])
def
expand_kernel
(
output_ptr
,
# [num_tokens]
input_ptr
,
# [batch_size]
cu_num_tokens_ptr
,
# [batch_size]
replace_from
,
replace_to
,
MAX_NUM_TOKENS
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
if
req_idx
==
0
:
# noqa: SIM108
start_idx
=
0
else
:
start_idx
=
tl
.
load
(
cu_num_tokens_ptr
+
req_idx
-
1
)
end_idx
=
tl
.
load
(
cu_num_tokens_ptr
+
req_idx
)
num_tokens
=
end_idx
-
start_idx
src_val
=
tl
.
load
(
input_ptr
+
req_idx
)
src_val
=
tl
.
where
(
src_val
==
replace_from
,
replace_to
,
src_val
)
offset
=
tl
.
arange
(
0
,
MAX_NUM_TOKENS
)
tl
.
store
(
output_ptr
+
start_idx
+
offset
,
src_val
,
mask
=
offset
<
num_tokens
)
@
triton
.
jit
def
sample_recovered_tokens_kernel
(
output_token_ids_ptr
,
# [num_tokens]
cu_num_draft_tokens_ptr
,
# [batch_size]
draft_token_ids_ptr
,
# [num_tokens]
draft_probs_ptr
,
# [num_tokens, vocab_size] or None
target_probs_ptr
,
# [num_tokens, vocab_size]
q_ptr
,
# [batch_size, vocab_size]
vocab_size
,
PADDED_VOCAB_SIZE
:
tl
.
constexpr
,
IS_NGRAM
:
tl
.
constexpr
,
):
req_idx
=
tl
.
program_id
(
0
)
if
req_idx
==
0
:
start_idx
=
0
else
:
start_idx
=
tl
.
load
(
cu_num_draft_tokens_ptr
+
req_idx
-
1
)
end_idx
=
tl
.
load
(
cu_num_draft_tokens_ptr
+
req_idx
)
num_draft_tokens
=
end_idx
-
start_idx
# Early exit for out-of-range positions.
pos
=
tl
.
program_id
(
1
)
if
pos
>=
num_draft_tokens
:
return
vocab_offset
=
tl
.
arange
(
0
,
PADDED_VOCAB_SIZE
)
if
IS_NGRAM
:
draft_token_id
=
tl
.
load
(
draft_token_ids_ptr
+
start_idx
+
pos
)
orig_prob
=
tl
.
load
(
target_probs_ptr
+
(
start_idx
+
pos
)
*
vocab_size
+
draft_token_id
)
# Temporarily zero out the probability of the draft token.
# This is essentially the same as target_prob - draft_prob, except that
# n-gram does not have draft_prob. We regard it as 1.
tl
.
store
(
target_probs_ptr
+
(
start_idx
+
pos
)
*
vocab_size
+
draft_token_id
,
0
)
prob
=
tl
.
load
(
target_probs_ptr
+
(
start_idx
+
pos
)
*
vocab_size
+
vocab_offset
,
mask
=
vocab_offset
<
vocab_size
,
other
=
0
)
else
:
draft_prob
=
tl
.
load
(
draft_probs_ptr
+
(
start_idx
+
pos
)
*
vocab_size
+
vocab_offset
,
mask
=
vocab_offset
<
vocab_size
,
other
=
0
)
target_prob
=
tl
.
load
(
target_probs_ptr
+
(
start_idx
+
pos
)
*
vocab_size
+
vocab_offset
,
mask
=
vocab_offset
<
vocab_size
,
other
=
0
)
prob
=
tl
.
maximum
(
target_prob
-
draft_prob
,
0
)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.
q
=
tl
.
load
(
q_ptr
+
req_idx
*
vocab_size
+
vocab_offset
,
mask
=
vocab_offset
<
vocab_size
,
other
=
float
(
"-inf"
))
recovered_id
=
tl
.
argmax
(
prob
/
q
,
axis
=-
1
)
tl
.
store
(
output_token_ids_ptr
+
start_idx
+
pos
,
recovered_id
)
if
IS_NGRAM
:
# Restore the original probability.
tl
.
store
(
target_probs_ptr
+
(
start_idx
+
pos
)
*
vocab_size
+
draft_token_id
,
orig_prob
)
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment