Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e661d594
Commit
e661d594
authored
Aug 12, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1
parents
6b16ea2e
4db5176d
Changes
374
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1137 additions
and
272 deletions
+1137
-272
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+15
-2
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+32
-60
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+2
-2
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+43
-10
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+52
-11
vllm/model_executor/models/qwen2_moe.py
vllm/model_executor/models/qwen2_moe.py
+54
-15
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+621
-0
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+26
-24
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+57
-44
vllm/multimodal/__init__.py
vllm/multimodal/__init__.py
+5
-2
vllm/multimodal/base.py
vllm/multimodal/base.py
+59
-24
vllm/multimodal/image.py
vllm/multimodal/image.py
+1
-1
vllm/outputs.py
vllm/outputs.py
+9
-8
vllm/sampling_params.py
vllm/sampling_params.py
+8
-7
vllm/scalar_type.py
vllm/scalar_type.py
+35
-0
vllm/scripts.py
vllm/scripts.py
+7
-6
vllm/sequence.py
vllm/sequence.py
+36
-39
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+22
-15
vllm/spec_decode/draft_model_runner.py
vllm/spec_decode/draft_model_runner.py
+50
-1
vllm/spec_decode/medusa_worker.py
vllm/spec_decode/medusa_worker.py
+3
-1
No files found.
vllm/model_executor/models/opt.py
View file @
e661d594
...
...
@@ -237,14 +237,19 @@ class OPTDecoder(nn.Module):
for
_
in
range
(
config
.
num_hidden_layers
)
])
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
)
pos_embeds
=
self
.
embed_positions
(
positions
)
if
self
.
project_in
is
not
None
:
inputs_embeds
,
_
=
self
.
project_in
(
inputs_embeds
)
...
...
@@ -272,14 +277,22 @@ class OPTModel(nn.Module):
super
().
__init__
()
self
.
decoder
=
OPTDecoder
(
config
,
cache_config
,
quant_config
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
decoder
.
get_input_embeddings
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
self
.
decoder
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
return
self
.
decoder
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
inputs_embeds
=
inputs_embeds
)
class
OPTForCausalLM
(
nn
.
Module
):
...
...
vllm/model_executor/models/paligemma.py
View file @
e661d594
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
import
torch
from
PIL
import
Image
from
torch
import
nn
from
transformers
import
PaliGemmaConfig
,
SiglipVisionConfig
,
SiglipVisionModel
from
transformers
import
PaliGemmaConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
...
...
@@ -19,9 +17,11 @@ from vllm.model_executor.models.gemma import GemmaModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
cached_get_tokenizer
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
,
SequenceData
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
.interfaces
import
SupportsVision
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
,
get_max_siglip_image_tokens
)
from
.utils
import
merge_vision_embeddings
logger
=
init_logger
(
__name__
)
...
...
@@ -33,55 +33,22 @@ _KEYS_TO_MODIFY_MAPPING = {
def
get_max_paligemma_image_tokens
(
ctx
:
InputContext
):
hf_config
=
ctx
.
get_hf_config
(
PaliGemmaConfig
)
text_config
=
hf_config
.
text_config
return
text_config
.
num_image_tokens
def
dummy_seq_data_for_paligemma
(
hf_config
:
PaliGemmaConfig
,
seq_len
:
int
,
*
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
):
if
image_feature_size_override
is
None
:
image_feature_size
=
hf_config
.
text_config
.
num_image_tokens
else
:
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
)
return
SequenceData
(
token_ids
)
def
dummy_image_for_paligemma
(
hf_config
:
SiglipVisionConfig
,
*
,
image_width_override
:
Optional
[
int
]
=
None
,
image_height_override
:
Optional
[
int
]
=
None
,
):
width
=
height
=
hf_config
.
image_size
if
image_width_override
is
not
None
:
width
=
image_width_override
if
image_height_override
is
not
None
:
height
=
image_height_override
vision_config
=
hf_config
.
vision_config
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
return
{
"image"
:
image
}
return
get_max_siglip_image_tokens
(
vision_config
)
def
dummy_data_for_paligemma
(
ctx
:
InputContext
,
seq_len
:
int
):
hf_config
=
ctx
.
get_hf_config
(
PaliGemmaConfig
)
vision_config
=
hf_config
.
vision_config
seq_data
=
dummy_seq_data_for_
paligemma
(
hf
_config
,
seq_data
=
dummy_seq_data_for_
siglip
(
vision
_config
,
seq_len
,
image_token_id
=
hf_config
.
image_token_index
,
)
mm_data
=
dummy_image_for_
paligemma
(
vision_config
)
mm_data
=
dummy_image_for_
siglip
(
vision_config
)
return
seq_data
,
mm_data
...
...
@@ -133,12 +100,10 @@ class PaliGemmaMultiModalProjector(nn.Module):
def
__init__
(
self
,
vision_hidden_size
:
int
,
projection_dim
:
int
):
super
().
__init__
()
self
.
linear
=
ColumnParallelLinear
(
vision_hidden_size
,
projection_dim
,
bias
=
True
)
self
.
linear
=
nn
.
Linear
(
vision_hidden_size
,
projection_dim
,
bias
=
True
)
def
forward
(
self
,
image_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
linear
(
image_features
)
hidden_states
=
self
.
linear
(
image_features
)
return
hidden_states
...
...
@@ -211,30 +176,37 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
data
=
self
.
_validate_pixel_values
(
pixel_values
),
)
def
_image_pixels_to_features
(
self
,
vision_tower
:
SiglipVisionModel
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_image_pixels_to_features
(
self
,
vision_tower
:
SiglipVisionModel
,
pixel_values
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
target_dtype
=
vision_tower
.
get_input_embeddings
().
weight
.
dtype
image_outputs
=
vision_tower
(
pixel_values
.
to
(
dtype
=
target_dtype
),
output_hidden_states
=
True
)
selected_image_features
=
image_outputs
.
last_hidden_state
image_features
=
vision_tower
(
pixel_values
.
to
(
dtype
=
target_dtype
))
return
selected_
image_features
return
image_features
def
_process_image_pixels
(
self
,
inputs
:
PaliGemmaImagePixelInputs
)
->
torch
.
Tensor
:
self
,
inputs
:
PaliGemmaImagePixelInputs
,
)
->
torch
.
Tensor
:
assert
self
.
vision_tower
is
not
None
pixel_values
=
inputs
[
"data"
]
return
self
.
_image_pixels_to_features
(
self
.
vision_tower
,
pixel_values
)
return
self
.
_image_pixels_to_features
(
self
.
vision_tower
,
pixel_values
,
)
def
_process_image_input
(
self
,
image_input
:
PaliGemmaImageInputs
)
->
torch
.
Tensor
:
self
,
image_input
:
PaliGemmaImageInputs
,
)
->
torch
.
Tensor
:
assert
self
.
vision_tower
is
not
None
image_features
=
self
.
_process_image_pixels
(
image_input
)
image_features
=
self
.
_process_image_pixels
(
image_input
,
)
return
self
.
multi_modal_projector
(
image_features
)
...
...
@@ -345,6 +317,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
if
unloaded_params
:
raise
RuntimeError
(
"Some weights are not initialized from checkpoints:
"
f
"
{
unloaded_params
}
"
)
logger
.
warning
(
"Some weights are not initialized from checkpoints:
%s"
,
unloaded_params
)
vllm/model_executor/models/phi3v.py
View file @
e661d594
...
...
@@ -36,7 +36,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.models.clip
import
CLIPVisionModel
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
BatchedTensors
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
cached_get_tokenizer
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
...
...
@@ -261,7 +261,7 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
class
Phi3VImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
data
:
Batched
Tensor
s
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
...
...
vllm/model_executor/models/qwen.py
View file @
e661d594
...
...
@@ -15,7 +15,7 @@ import re
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
...
@@ -35,6 +35,7 @@ from vllm.utils import print_warning_once
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
.utils
import
is_pp_missing_parameter
,
make_layers
class
QWenMLP
(
nn
.
Module
):
...
...
@@ -194,6 +195,7 @@ class QWenModel(nn.Module):
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -203,10 +205,10 @@ class QWenModel(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
h
=
nn
.
ModuleList
([
QWenBlock
(
config
,
cache_config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
start_layer
,
self
.
end_layer
,
self
.
h
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
QWenBlock
(
config
,
cache_config
,
quant_config
),
prefix
=
f
"
{
prefix
}
.h"
)
self
.
ln_f
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
...
...
@@ -215,18 +217,29 @@ class QWenModel(nn.Module):
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
)
->
torch
.
Tensor
:
hidden_states
=
self
.
wte
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
h
)):
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
wte
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
h
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
ln_f
(
hidden_states
,
residual
)
return
hidden_states
...
...
@@ -267,9 +280,23 @@ class QWenLMHeadModel(nn.Module):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
"residual"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
})
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
...
...
@@ -301,6 +328,9 @@ class QWenLMHeadModel(nn.Module):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
...
...
@@ -318,6 +348,9 @@ class QWenLMHeadModel(nn.Module):
"Only text inputs are allowed. Images won't be handled "
"until Qwen-VL models are fully supported."
)
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
...
vllm/model_executor/models/qwen2.py
View file @
e661d594
...
...
@@ -32,7 +32,7 @@ import re
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
...
@@ -51,6 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
.interfaces
import
SupportsLoRA
from
.utils
import
is_pp_missing_parameter
,
make_layers
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
...
...
@@ -234,6 +235,7 @@ class Qwen2Model(nn.Module):
config
:
Qwen2Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
...
...
@@ -244,30 +246,52 @@ class Qwen2Model(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
Qwen2DecoderLayer
(
config
,
cache_config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
Qwen2DecoderLayer
(
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
if
get_pp_group
().
is_first_rank
:
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
,
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
@@ -350,7 +374,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -359,6 +383,20 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata
)
return
logits
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
"residual"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
})
def
sample
(
self
,
logits
:
torch
.
Tensor
,
...
...
@@ -389,6 +427,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
...
...
@@ -401,7 +441,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
...
vllm/model_executor/models/qwen2_moe.py
View file @
e661d594
...
...
@@ -31,7 +31,8 @@ from transformers import PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
...
...
@@ -52,6 +53,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.utils
import
print_warning_once
from
.utils
import
is_pp_missing_parameter
,
make_layers
class
Qwen2MoeMLP
(
nn
.
Module
):
...
...
@@ -315,6 +318,7 @@ class Qwen2MoeModel(nn.Module):
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
...
...
@@ -324,13 +328,15 @@ class Qwen2MoeModel(nn.Module):
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
Qwen2MoeDecoderLayer
(
config
,
layer_idx
,
cache_config
,
quant_config
=
quant_config
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
Qwen2MoeDecoderLayer
(
config
=
config
,
layer_idx
=
int
(
prefix
.
split
(
"."
)[
-
1
]),
cache_config
=
cache_config
,
quant_config
=
quant_config
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
...
...
@@ -339,14 +345,25 @@ class Qwen2MoeModel(nn.Module):
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
if
get_pp_group
().
is_first_rank
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
attn_metadata
,
residual
)
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
...
...
@@ -380,7 +397,7 @@ class Qwen2MoeForCausalLM(nn.Module):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
...
@@ -389,6 +406,20 @@ class Qwen2MoeForCausalLM(nn.Module):
sampling_metadata
)
return
logits
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
"residual"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
})
def
sample
(
self
,
logits
:
Optional
[
torch
.
Tensor
],
...
...
@@ -435,6 +466,9 @@ class Qwen2MoeForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
if
name
not
in
params_dict
:
continue
...
...
@@ -448,6 +482,9 @@ class Qwen2MoeForCausalLM(nn.Module):
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
...
...
@@ -460,6 +497,9 @@ class Qwen2MoeForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
# Remapping the name of FP8 kv-scale.
if
name
.
endswith
(
"kv_scale"
):
remapped_kv_scale_name
=
name
.
replace
(
...
...
@@ -474,7 +514,6 @@ class Qwen2MoeForCausalLM(nn.Module):
continue
else
:
name
=
remapped_kv_scale_name
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
...
vllm/model_executor/models/siglip.py
0 → 100644
View file @
e661d594
"""Implementation of SiglipVisionModel intended to be only used
within a vision language model."""
import
math
from
typing
import
Optional
,
Tuple
import
torch
from
PIL
import
Image
from
torch
import
nn
from
transformers
import
SiglipConfig
,
SiglipVisionConfig
from
transformers.models.siglip.modeling_siglip
import
SiglipAttention
from
vllm_flash_attn
import
flash_attn_func
from
xformers.ops
import
memory_efficient_attention
from
vllm.config
import
ModelConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.inputs
import
LLMInputs
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.multimodal.image
import
(
cached_get_tokenizer
,
repeat_and_pad_image_tokens
)
from
vllm.sequence
import
SequenceData
def
get_siglip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
assert
image_size
%
patch_size
==
0
return
image_size
//
patch_size
def
get_siglip_num_patches
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
grid_length
=
get_siglip_patch_grid_length
(
image_size
=
image_size
,
patch_size
=
patch_size
)
return
grid_length
*
grid_length
def
get_siglip_image_feature_size
(
hf_config
:
SiglipVisionConfig
)
->
int
:
return
get_siglip_num_patches
(
image_size
=
hf_config
.
image_size
,
patch_size
=
hf_config
.
patch_size
)
def
get_max_siglip_image_tokens
(
hf_config
:
SiglipVisionConfig
)
->
int
:
return
get_siglip_image_feature_size
(
hf_config
)
def
dummy_seq_data_for_siglip
(
hf_config
:
SiglipVisionConfig
,
seq_len
:
int
,
*
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
):
if
image_feature_size_override
is
None
:
image_feature_size
=
get_siglip_image_feature_size
(
hf_config
)
else
:
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
)
return
SequenceData
(
token_ids
)
def
dummy_image_for_siglip
(
hf_config
:
SiglipVisionConfig
,
*
,
image_width_override
:
Optional
[
int
]
=
None
,
image_height_override
:
Optional
[
int
]
=
None
,
):
width
=
height
=
hf_config
.
image_size
if
image_width_override
is
not
None
:
width
=
image_width_override
if
image_height_override
is
not
None
:
height
=
image_height_override
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
return
{
"image"
:
image
}
def
input_processor_for_siglip
(
model_config
:
ModelConfig
,
hf_config
:
SiglipVisionConfig
,
llm_inputs
:
LLMInputs
,
*
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
):
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
return
llm_inputs
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
)
if
image_feature_size_override
is
None
:
image_feature_size
=
get_siglip_image_feature_size
(
hf_config
)
else
:
image_feature_size
=
image_feature_size_override
new_prompt
,
new_token_ids
=
repeat_and_pad_image_tokens
(
tokenizer
,
llm_inputs
.
get
(
"prompt"
),
llm_inputs
[
"prompt_token_ids"
],
image_token_id
=
image_token_id
,
repeat_count
=
image_feature_size
,
)
# NOTE: Create a defensive copy of the original inputs
return
LLMInputs
(
prompt_token_ids
=
new_token_ids
,
prompt
=
new_prompt
,
multi_modal_data
=
multi_modal_data
,
)
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
class
SiglipVisionEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
config
:
SiglipVisionConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
self
.
patch_embedding
=
nn
.
Conv2d
(
in_channels
=
config
.
num_channels
,
out_channels
=
self
.
embed_dim
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
,
padding
=
"valid"
,
)
self
.
num_patches
=
(
self
.
image_size
//
self
.
patch_size
)
**
2
self
.
num_positions
=
self
.
num_patches
self
.
position_embedding
=
VocabParallelEmbedding
(
self
.
num_positions
,
self
.
embed_dim
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
self
.
num_positions
,
dtype
=
torch
.
int64
).
expand
(
(
1
,
-
1
)),
persistent
=
False
,
)
def
interpolate_pos_encoding
(
self
,
embeddings
:
torch
.
Tensor
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
"""
This method is an adapted method for SigLIP (due to SigLIP not having
class embedding unlike other ViTs) that allows the model to interpolate
the pre-trained position encodings such that it can be usable on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
position_embeddings
=
self
.
position_embedding
.
weight
.
unsqueeze
(
0
)
num_patches
=
embeddings
.
shape
[
1
]
num_positions
=
position_embeddings
.
shape
[
1
]
if
num_patches
==
num_positions
and
height
==
width
:
return
position_embeddings
dim
=
embeddings
.
shape
[
-
1
]
height
=
height
//
self
.
patch_size
width
=
width
//
self
.
patch_size
# we add a small number to avoid floating point error
# in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
height
,
width
=
height
+
0.1
,
width
+
0.1
patch_pos_embed
=
position_embeddings
.
reshape
(
1
,
int
(
math
.
sqrt
(
num_positions
)),
int
(
math
.
sqrt
(
num_positions
)),
dim
)
patch_pos_embed
=
patch_pos_embed
.
permute
(
0
,
3
,
1
,
2
)
patch_pos_embed
=
nn
.
functional
.
interpolate
(
patch_pos_embed
,
scale_factor
=
(
height
/
math
.
sqrt
(
num_positions
),
width
/
math
.
sqrt
(
num_positions
),
),
mode
=
"bicubic"
,
align_corners
=
False
,
)
if
(
int
(
height
)
!=
patch_pos_embed
.
shape
[
-
2
]
or
int
(
width
)
!=
patch_pos_embed
.
shape
[
-
1
]):
raise
ValueError
(
"Width or height does not match with "
"the interpolated position embeddings"
)
patch_pos_embed
=
patch_pos_embed
.
permute
(
0
,
2
,
3
,
1
).
view
(
1
,
-
1
,
dim
)
return
patch_pos_embed
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
interpolate_pos_encoding
:
bool
=
False
)
->
torch
.
Tensor
:
_
,
_
,
height
,
width
=
pixel_values
.
shape
target_dtype
=
self
.
patch_embedding
.
weight
.
dtype
patch_embeds
=
self
.
patch_embedding
(
pixel_values
.
to
(
dtype
=
target_dtype
))
# shape = [*, width, grid, grid]
embeddings
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
if
interpolate_pos_encoding
:
embeddings
=
embeddings
+
self
.
interpolate_pos_encoding
(
embeddings
,
height
,
width
)
else
:
embeddings
=
embeddings
+
self
.
position_embedding
(
self
.
position_ids
)
return
embeddings
# NOTE: Not used - kept for later when we TP the ViT
# TODO(ChristopherCho): Implement TP version of Attention
class
SiglipTPAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
config
.
num_attention_heads
if
self
.
total_num_heads
%
tp_size
!=
0
:
raise
ValueError
(
f
"Number of attention heads (
{
self
.
total_num_heads
}
) "
"must be divisible by the tensor model parallel size"
f
" (
{
tp_size
}
)."
)
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
head_dim
=
self
.
embed_dim
//
self
.
total_num_heads
if
self
.
head_dim
*
self
.
total_num_heads
!=
self
.
embed_dim
:
raise
ValueError
(
f
"embed_dim must be divisible by num_heads (got "
"`embed_dim`: {self.embed_dim} and `num_heads`:"
f
"
{
self
.
num_heads
}
)."
)
self
.
qkv_size
=
self
.
num_heads
*
self
.
head_dim
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
dropout
=
config
.
attention_dropout
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
self
.
embed_dim
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
total_num_heads
,
quant_config
=
quant_config
,
)
self
.
out_proj
=
RowParallelLinear
(
input_size
=
self
.
embed_dim
,
output_size
=
self
.
embed_dim
,
quant_config
=
quant_config
,
)
self
.
attn_fn
=
self
.
_basic_attention_forward
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Input shape: Batch x Time x Channel"""
batch_size
,
q_len
,
_
=
hidden_states
.
size
()
qkv_states
,
_
=
self
.
qkv_proj
(
hidden_states
)
query_states
,
key_states
,
value_states
=
qkv_states
.
split
(
[
self
.
qkv_size
]
*
3
,
dim
=-
1
)
attn_output
=
self
.
attn_fn
(
q
=
query_states
,
k
=
key_states
,
v
=
value_states
,
batch_size
=
batch_size
,
q_len
=
q_len
,
)
attn_output
,
_
=
self
.
out_proj
(
attn_output
)
return
attn_output
def
_basic_attention_forward
(
self
,
q
,
k
,
v
,
batch_size
,
q_len
):
q
=
q
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
k
=
k
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
v
=
v
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
k_v_seq_len
=
k
.
shape
[
-
2
]
attn_weights
=
torch
.
matmul
(
q
,
k
.
transpose
(
2
,
3
))
*
self
.
scale
if
attn_weights
.
size
()
!=
(
batch_size
,
self
.
num_heads
,
q_len
,
k_v_seq_len
,
):
raise
ValueError
(
"Attention weights should be of size "
f
"
{
(
batch_size
,
self
.
num_heads
,
q_len
,
k_v_seq_len
)
}
, but is"
f
"
{
attn_weights
.
size
()
}
"
)
# upcast attention to fp32
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
,
dim
=-
1
,
dtype
=
torch
.
float32
).
to
(
q
.
dtype
)
attn_weights
=
nn
.
functional
.
dropout
(
attn_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
attn_output
=
torch
.
matmul
(
attn_weights
,
v
)
if
attn_output
.
size
()
!=
(
batch_size
,
self
.
num_heads
,
q_len
,
self
.
head_dim
,
):
raise
ValueError
(
"`attn_output` should be of size "
f
"
{
(
batch_size
,
self
.
num_heads
,
q_len
,
self
.
head_dim
)
}
, but is"
f
"
{
attn_output
.
size
()
}
"
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
batch_size
,
q_len
,
self
.
embed_dim
)
return
attn_output
# NOTE: Not used - kept for later when we TP the ViT
# TODO(ChristopherCho): flash_attn_func is not working properly.
# It constantly throws a CUDA error.
class
SiglipFlashAttention2
(
SiglipTPAttention
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
attn_fn
=
self
.
_flash_attention_forward
# Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449
# and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133
def
_flash_attention_forward
(
self
,
q
,
k
,
v
,
batch_size
,
q_len
,
*
args
,
**
kwargs
):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the
query, key, and value. (B, S, H, D)
"""
q
=
q
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
k
=
k
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
v
=
v
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
attn_output
=
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
self
.
dropout
,
causal
=
False
,
)
attn_output
=
attn_output
.
reshape
(
batch_size
,
q_len
,
self
.
embed_dim
).
contiguous
()
return
attn_output
# NOTE: Not used - kept for later when we TP the ViT
class
SiglipSdpaAttention
(
SiglipTPAttention
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
is_causal
=
False
self
.
attn_fn
=
self
.
_sdpa_attention_forward
def
_sdpa_attention_forward
(
self
,
q
,
k
,
v
,
batch_size
,
q_len
):
q
=
q
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
k
=
k
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
v
=
v
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
attn_output
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q
,
k
,
v
,
dropout_p
=
self
.
dropout
,
is_causal
=
False
,
scale
=
self
.
scale
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
view
(
batch_size
,
q_len
,
self
.
embed_dim
)
return
attn_output
# NOTE: Not used - kept for later when we TP the ViT
class
SiglipxFormersAttention
(
SiglipTPAttention
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
attn_fn
=
self
.
_xformers_attention_forward
def
_xformers_attention_forward
(
self
,
q
,
k
,
v
,
batch_size
,
q_len
):
q
=
q
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
k
=
k
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
v
=
v
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
attn_output
=
memory_efficient_attention
(
q
,
k
,
v
,
p
=
0.0
,
scale
=
self
.
scale
)
attn_output
=
attn_output
.
reshape
(
batch_size
,
q_len
,
self
.
embed_dim
).
contiguous
()
return
attn_output
# NOTE: Not used - kept for later when we TP the ViT
SIGLIP_ATTENTION_CLASSES
=
{
"eager"
:
SiglipTPAttention
,
"flash_attention_2"
:
SiglipFlashAttention2
,
"sdpa"
:
SiglipSdpaAttention
,
"xformers"
:
SiglipxFormersAttention
,
}
class
SiglipMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
activation_fn
=
get_act_fn
(
config
.
hidden_act
)
# For quantization, we require the hidden size to be a multiple of 64
quantizable
=
(
config
.
hidden_size
%
64
==
0
and
config
.
intermediate_size
%
64
==
0
)
self
.
fc1
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
quant_config
=
quant_config
if
quantizable
else
None
,
)
self
.
fc2
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
quant_config
=
quant_config
if
quantizable
else
None
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
,
_
=
self
.
fc2
(
hidden_states
)
return
hidden_states
class
SiglipEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
SiglipConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
# TODO(ChristopherCho): use TP'ed Attention block
self
.
self_attn
=
SiglipAttention
(
config
)
self
.
layer_norm1
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
SiglipMLP
(
config
,
quant_config
=
quant_config
,
)
self
.
layer_norm2
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
]:
residual
=
hidden_states
hidden_states
=
self
.
layer_norm1
(
hidden_states
)
hidden_states
,
_
=
self
.
self_attn
(
hidden_states
=
hidden_states
)
hidden_states
=
residual
+
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
layer_norm2
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
,
None
class
SiglipEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
:
SiglipConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
layers
=
nn
.
ModuleList
([
SiglipEncoderLayer
(
config
,
quant_config
=
quant_config
,
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
def
forward
(
self
,
inputs_embeds
:
torch
.
Tensor
,
)
->
Tuple
:
hidden_states
=
inputs_embeds
for
encoder_layer
in
self
.
layers
:
hidden_states
,
_
=
encoder_layer
(
hidden_states
)
return
hidden_states
class
SiglipMultiheadAttentionPoolingHead
(
nn
.
Module
):
"""Multihead Attention Pooling."""
def
__init__
(
self
,
config
:
SiglipVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
probe
=
nn
.
Parameter
(
torch
.
randn
(
1
,
1
,
config
.
hidden_size
))
# TODO(ChristopherCho): Implement vLLM version of MultiheadAttention
self
.
attention
=
torch
.
nn
.
MultiheadAttention
(
config
.
hidden_size
,
config
.
num_attention_heads
,
batch_first
=
True
)
self
.
layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
SiglipMLP
(
config
=
config
,
quant_config
=
quant_config
)
def
forward
(
self
,
hidden_state
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
=
hidden_state
.
shape
[
0
]
probe
=
self
.
probe
.
repeat
(
batch_size
,
1
,
1
)
hidden_state
=
self
.
attention
(
probe
,
hidden_state
,
hidden_state
)[
0
]
residual
=
hidden_state
hidden_state
=
self
.
layernorm
(
hidden_state
)
hidden_state
=
residual
+
self
.
mlp
(
hidden_state
)
return
hidden_state
[:,
0
]
class
SiglipVisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
SiglipVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
embed_dim
=
config
.
hidden_size
self
.
embeddings
=
SiglipVisionEmbeddings
(
config
)
self
.
encoder
=
SiglipEncoder
(
config
,
quant_config
=
quant_config
,
)
self
.
post_layernorm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
use_head
=
(
True
if
not
hasattr
(
config
,
"vision_use_head"
)
else
config
.
vision_use_head
)
if
self
.
use_head
:
self
.
head
=
SiglipMultiheadAttentionPoolingHead
(
config
=
config
,
quant_config
=
quant_config
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
interpolate_pos_encoding
:
bool
=
True
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embeddings
(
pixel_values
,
interpolate_pos_encoding
=
interpolate_pos_encoding
,
)
encoder_outputs
=
self
.
encoder
(
inputs_embeds
=
hidden_states
)
last_hidden_state
=
self
.
post_layernorm
(
encoder_outputs
)
# TODO: add this back when pooled_output is used in inference
# if self.use_head:
# pooled_output = self.head(last_hidden_state)
return
last_hidden_state
class
SiglipVisionModel
(
nn
.
Module
):
config_class
=
SiglipVisionConfig
main_input_name
=
"pixel_values"
def
__init__
(
self
,
config
:
SiglipVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
vision_model
=
SiglipVisionTransformer
(
config
,
quant_config
,
)
def
get_input_embeddings
(
self
)
->
nn
.
Module
:
return
self
.
vision_model
.
embeddings
.
patch_embedding
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
interpolate_pos_encoding
:
bool
=
False
,
)
->
torch
.
Tensor
:
return
self
.
vision_model
(
pixel_values
=
pixel_values
,
interpolate_pos_encoding
=
interpolate_pos_encoding
,
)
vllm/model_executor/models/utils.py
View file @
e661d594
...
...
@@ -87,6 +87,7 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
offloaded_parameters
=
False
for
p
in
module
.
parameters
():
if
_CPU_OFFLOAD_BYTES
>=
_CPU_OFFLOAD_MAX_BYTES
:
# we use per-parameter offloading
...
...
@@ -94,35 +95,36 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
break
# `torch.empty_like` does not support `pin_memory` argument
cpu_data
=
torch
.
empty
(
size
=
p
.
data
.
size
(),
dtype
=
p
.
data
.
dtype
,
layout
=
p
.
data
.
layout
,
device
=
'cpu'
,
pin_memory
=
pin_memory
)
cpu_data
=
torch
.
empty_strided
(
size
=
p
.
data
.
size
(),
stride
=
p
.
data
.
stride
(),
dtype
=
p
.
data
.
dtype
,
layout
=
p
.
data
.
layout
,
device
=
'cpu'
,
pin_memory
=
pin_memory
)
cpu_data
.
copy_
(
p
.
data
)
p
.
data
=
cpu_data
_CPU_OFFLOAD_BYTES
+=
p
.
data
.
numel
()
*
p
.
data
.
element_size
()
offloaded_parameters
=
True
if
offloaded_parameters
:
original_forward
=
module
.
forward
def
forward
(
*
args
,
**
kwargs
):
module
.
forward
=
original_forward
device_state
=
{
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k
:
v
.
to
(
device
,
non_blocking
=
True
)
for
k
,
v
in
module
.
state_dict
().
items
()
}
output
=
functional_call
(
module
,
device_state
,
args
=
args
,
kwargs
=
kwargs
)
module
.
forward
=
forward
return
output
state_dict
:
Dict
[
str
,
torch
.
Tensor
]
=
module
.
state_dict
()
original_forward
=
module
.
forward
def
forward
(
*
args
,
**
kwargs
):
module
.
forward
=
original_forward
device_state
=
{
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k
:
v
.
to
(
device
,
non_blocking
=
True
)
for
k
,
v
in
state_dict
.
items
()
}
output
=
functional_call
(
module
,
device_state
,
args
=
args
,
kwargs
=
kwargs
)
module
.
forward
=
forward
return
output
module
.
forward
=
forward
return
module
...
...
vllm/model_executor/sampling_metadata.py
View file @
e661d594
import
random
from
array
import
array
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
from
vllm.model_executor.layers.ops.sample
import
get_num_triton_sampler_splits
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SequenceData
,
SequenceGroupMetadata
from
vllm.triton_utils.sample
import
get_num_triton_sampler_splits
from
vllm.utils
import
(
async_tensor_h2d
,
is_pin_memory_available
,
make_tensor_with_pad
,
maybe_expand_dim
)
_SAMPLING_EPS
=
1e-5
_SEED_0_REPLACEMENT
=
3403598558
# Some triton sampler related code is guarded before it is ready.
_USE_TRITON_SAMPLER
=
False
@
dataclass
...
...
@@ -117,6 +120,7 @@ class SamplingMetadata:
query_lens
:
Optional
[
List
[
int
]],
device
:
str
,
pin_memory
:
bool
,
generators
:
Optional
[
Dict
[
str
,
torch
.
Generator
]]
=
None
,
)
->
"SamplingMetadata"
:
(
seq_groups
,
...
...
@@ -124,7 +128,7 @@ class SamplingMetadata:
categorized_sample_indices
,
num_prompts
,
)
=
_prepare_seq_groups
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
device
)
device
,
generators
)
selected_token_indices
=
async_tensor_h2d
(
selected_token_indices
,
dtype
=
torch
.
long
,
target_device
=
device
,
...
...
@@ -159,6 +163,7 @@ def _prepare_seq_groups(
seq_lens
:
List
[
int
],
query_lens
:
Optional
[
List
[
int
]],
device
:
str
,
generators
:
Optional
[
Dict
[
str
,
torch
.
Generator
]]
=
None
,
)
->
Tuple
[
List
[
SequenceGroupToSample
],
List
[
int
],
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]],
int
]:
"""Prepare sequence groups and indices for sampling.
...
...
@@ -169,8 +174,10 @@ def _prepare_seq_groups(
Index of prompt len should match with seq_group_metadata_list.
query_lens: A list of query lengths. Prompt lens include the length
of entire prompt tokens, and it could be shorter.
device: A device to use for random number generator,
device: A device to use for random number generator
s
,
`SequenceGroupToSample.generator`.
generators: A store of per-request random number generators used
for seeded requests.
Returns:
seq_groups: A list of sequence group to sample.
...
...
@@ -216,8 +223,10 @@ def _prepare_seq_groups(
if
seq_group_metadata
.
is_prompt
:
if
sampling_params
.
seed
is
not
None
:
seq_group_metadata
.
state
.
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
sampling_params
.
seed
)
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
sampling_params
.
seed
)
if
generators
is
not
None
:
generators
[
seq_group_metadata
.
request_id
]
=
generator
num_prompts
+=
1
num_prefill_sample
=
len
(
seq_ids
)
...
...
@@ -234,6 +243,9 @@ def _prepare_seq_groups(
prompt_logprob_len
=
0
sample_len
=
len
(
seq_ids
)
if
do_sample
else
0
if
sampling_params
.
seed
is
not
None
and
generators
is
not
None
:
generator
=
generators
.
get
(
seq_group_metadata
.
request_id
)
# Update indices to select from the model output.
"""
This blocks computes selected_token_indices which is used in the
...
...
@@ -278,9 +290,6 @@ def _prepare_seq_groups(
logit_idx
+=
sample_len
sample_idx
+=
sample_len
if
sampling_params
.
seed
is
not
None
:
generator
=
seq_group_metadata
.
state
.
generator
seq_groups
.
append
(
SequenceGroupToSample
(
seq_ids
=
seq_ids
,
...
...
@@ -329,8 +338,8 @@ class SamplingTensors:
user-defined seed for each sequence.
extra_entropy: extra entropy to use when generating seeds.
"""
prompt_tokens
:
List
[
List
[
int
]
]
=
[]
output_tokens
:
List
[
List
[
int
]
]
=
[]
prompt_tokens
:
List
[
array
]
=
[]
output_tokens
:
List
[
array
]
=
[]
top_ks
:
List
[
int
]
=
[]
temperatures
:
List
[
float
]
=
[]
top_ps
:
List
[
float
]
=
[]
...
...
@@ -340,14 +349,16 @@ class SamplingTensors:
repetition_penalties
:
List
[
float
]
=
[]
sampling_seeds
:
List
[
int
]
=
[]
sample_indices
:
List
[
int
]
=
[]
prompt_best_of
:
List
[
int
]
=
[]
do_penalties
=
False
do_top_p_top_k
=
False
do_min_p
=
False
# We need one base seed per Triton slice.
seeds_to_generate
=
(
extra_seeds_to_generate
+
get_num_triton_sampler_splits
(
vocab_size
))
if
_USE_TRITON_SAMPLER
:
prompt_best_of
:
List
[
int
]
=
[]
# We need one base seed per Triton slice.
seeds_to_generate
=
(
extra_seeds_to_generate
+
get_num_triton_sampler_splits
(
vocab_size
))
assert
sampling_metadata
.
seq_groups
is
not
None
for
seq_group
in
sampling_metadata
.
seq_groups
:
...
...
@@ -359,9 +370,6 @@ class SamplingTensors:
r
=
sampling_params
.
repetition_penalty
top_p
=
sampling_params
.
top_p
min_p
=
sampling_params
.
min_p
seed
=
sampling_params
.
seed
is_greedy
=
sampling_params
.
sampling_type
==
SamplingType
.
GREEDY
# k should not be greater than the vocab size.
top_k
=
min
(
sampling_params
.
top_k
,
vocab_size
)
...
...
@@ -382,8 +390,7 @@ class SamplingTensors:
do_penalties
=
True
is_prompt
=
seq_group
.
is_prompt
if
(
seq_group
.
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
):
if
(
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
):
# For tokens in the prompt that we only need to get
# their logprobs
query_len
=
seq_group
.
query_len
...
...
@@ -408,23 +415,27 @@ class SamplingTensors:
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
repetition_penalties
+=
[
r
]
*
len
(
seq_ids
)
if
is_prompt
:
prompt_best_of
.
append
(
sampling_params
.
best_of
)
query_len
=
seq_group
.
query_len
assert
query_len
is
not
None
for
seq_id
in
seq_ids
:
seq_data
=
seq_group
.
seq_data
[
seq_id
]
extra_entropy
=
extra_entropy
or
()
seq_seeds
=
cls
.
_get_sequence_seeds
(
seed
,
seq_data
.
get_len
(),
*
extra_entropy
,
seq_id
,
seeds_to_generate
=
seeds_to_generate
,
is_greedy
=
is_greedy
)
sampling_seeds
.
append
(
seq_seeds
)
sample_indices
.
extend
(
seq_group
.
sample_indices
)
if
_USE_TRITON_SAMPLER
:
if
is_prompt
:
prompt_best_of
.
append
(
sampling_params
.
best_of
)
query_len
=
seq_group
.
query_len
assert
query_len
is
not
None
seed
=
sampling_params
.
seed
is_greedy
=
sampling_params
.
sampling_type
==
SamplingType
.
GREEDY
for
seq_id
in
seq_ids
:
seq_data
=
seq_group
.
seq_data
[
seq_id
]
extra_entropy
=
extra_entropy
or
()
seq_seeds
=
cls
.
_get_sequence_seeds
(
seed
,
seq_data
.
get_len
(),
*
extra_entropy
,
seq_id
,
seeds_to_generate
=
seeds_to_generate
,
is_greedy
=
is_greedy
)
sampling_seeds
.
append
(
seq_seeds
)
sample_indices
.
extend
(
seq_group
.
sample_indices
)
if
do_penalties
:
for
seq_group
in
sampling_metadata
.
seq_groups
:
...
...
@@ -432,13 +443,15 @@ class SamplingTensors:
if
(
seq_group
.
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
):
prefill_len
=
len
(
seq_group
.
prompt_logprob_indices
)
prompt_tokens
.
extend
([]
for
_
in
range
(
prefill_len
))
output_tokens
.
extend
([]
for
_
in
range
(
prefill_len
))
prompt_tokens
.
extend
(
array
(
'l'
)
for
_
in
range
(
prefill_len
))
output_tokens
.
extend
(
array
(
'l'
)
for
_
in
range
(
prefill_len
))
if
seq_group
.
do_sample
:
for
seq_id
in
seq_ids
:
seq_data
=
seq_group
.
seq_data
[
seq_id
]
prompt_tokens
.
append
(
list
(
seq_data
.
prompt_token_ids
)
)
output_tokens
.
append
(
list
(
seq_data
.
output_token_ids
)
)
prompt_tokens
.
append
(
seq_data
.
prompt_token_ids
_array
)
output_tokens
.
append
(
seq_data
.
output_token_ids
_array
)
sampling_tensors
=
SamplingTensors
.
from_lists
(
temperatures
,
top_ps
,
top_ks
,
min_ps
,
presence_penalties
,
...
...
@@ -454,9 +467,9 @@ class SamplingTensors:
frequency_penalties
:
List
[
float
],
repetition_penalties
:
List
[
float
],
sampling_seeds
:
List
[
int
],
sample_indices
:
List
[
int
],
prompt_tokens
:
List
[
List
[
int
]
],
output_tokens
:
List
[
List
[
int
]],
vocab_siz
e
:
int
,
extra_seeds_to_generate
:
int
,
device
:
torch
.
device
,
prompt_tokens
:
List
[
array
],
output_tokens
:
List
[
array
],
vocab_size
:
int
,
extra_seeds_to_generat
e
:
int
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
)
->
"SamplingTensors"
:
# Note that the performance will be very bad without
# pinned memory.
...
...
@@ -540,7 +553,7 @@ class SamplingTensors:
device
=
"cpu"
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
).
T
.
contiguous
()
).
t
()
.
contiguous
()
# Because the memory is pinned, we can do non-blocking
# transfer to device.
...
...
vllm/multimodal/__init__.py
View file @
e661d594
from
.base
import
(
BatchedTensors
,
MultiModalDataBuiltins
,
MultiModalDataDict
,
MultiModalInputs
,
MultiModalPlugin
)
from
.base
import
(
BatchedTensorInputs
,
BatchedTensors
,
MultiModalDataBuiltins
,
MultiModalDataDict
,
MultiModalInputs
,
MultiModalPlugin
,
NestedTensors
)
from
.registry
import
MultiModalRegistry
MULTIMODAL_REGISTRY
=
MultiModalRegistry
()
...
...
@@ -12,11 +13,13 @@ See also:
"""
__all__
=
[
"BatchedTensorInputs"
,
"BatchedTensors"
,
"MultiModalDataBuiltins"
,
"MultiModalDataDict"
,
"MultiModalInputs"
,
"MultiModalPlugin"
,
"NestedTensors"
,
"MULTIMODAL_REGISTRY"
,
"MultiModalRegistry"
,
]
vllm/multimodal/base.py
View file @
e661d594
import
sys
from
abc
import
ABC
,
abstractmethod
from
collections
import
UserDict
,
defaultdict
from
typing
import
(
Any
,
Callable
,
Dict
,
List
,
Optional
,
Type
,
TypedDict
,
TypeVar
,
Union
)
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Type
,
TypedDict
,
TypeVar
,
Union
,
cast
import
torch
import
torch.types
from
PIL
import
Image
from
torch
import
nn
from
typing_extensions
import
TypeAlias
from
vllm.config
import
ModelConfig
from
vllm.inputs
import
InputContext
from
vllm.logger
import
init_logger
from
vllm.utils
import
JSONTree
,
json_map_leaves
logger
=
init_logger
(
__name__
)
Batch
edTensors
=
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]
]
Nest
edTensors
=
Union
[
GenericSequence
[
torch
.
Tensor
]
,
torch
.
Tensor
]
"""
If each input tensor in the batch has the same size, this is a single batched
tensor; otherwise, this is a list of tensors with one element per batch.
Use a list instead of a tensor if the dimensions of each element do not match.
Currently only supports up to singly nested list of tensors.
"""
BatchedTensors
:
TypeAlias
=
JSONTree
[
torch
.
Tensor
]
"""
A nested JSON structure of tensors which have been batched via
:meth:`MultiModalInputs.batch`.
"""
BatchedTensorInputs
:
TypeAlias
=
Dict
[
str
,
JSONTree
[
torch
.
Tensor
]]
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalInputs.batch`.
"""
if
sys
.
version_info
<
(
3
,
9
):
...
...
@@ -27,7 +42,7 @@ if sys.version_info < (3, 9):
pass
else
:
class
_MultiModalInputsBase
(
UserDict
[
str
,
torch
.
Tensor
]):
class
_MultiModalInputsBase
(
UserDict
[
str
,
Nested
Tensor
s
]):
pass
...
...
@@ -38,33 +53,44 @@ class MultiModalInputs(_MultiModalInputsBase):
"""
@
staticmethod
def
try_concat
(
tensors
:
List
[
torch
.
Tensor
],
*
,
device
:
torch
.
types
.
Device
,
)
->
BatchedTensors
:
unbatched_shape
=
tensors
[
0
].
shape
[
1
:]
def
_try_concat
(
tensors
:
List
[
NestedTensors
])
->
BatchedTensors
:
"""
If each input tensor in the batch has the same shape, return a single
batched tensor; otherwise, return a list of :class:`NestedTensors` with
one element per item in the batch.
"""
# may be list rather than tensors
if
isinstance
(
tensors
[
0
],
list
):
return
[[
t
for
t
in
tensor
[
0
]]
for
tensor
in
cast
(
List
[
List
[
torch
.
Tensor
]],
tensors
)]
for
tensor
in
tensors
:
tensors_
=
cast
(
List
[
torch
.
Tensor
],
tensors
)
unbatched_shape
=
tensors_
[
0
].
shape
[
1
:]
for
tensor
in
tensors_
:
if
tensor
.
shape
[
1
:]
!=
unbatched_shape
:
return
[
tensor
.
squeeze
(
0
).
to
(
device
=
device
)
for
tensor
in
tensors
]
return
[
tensor
.
squeeze
(
0
)
for
tensor
in
tensors_
]
return
torch
.
cat
(
tensors
,
dim
=
0
)
.
to
(
device
=
device
)
return
torch
.
cat
(
tensors
_
,
dim
=
0
)
@
staticmethod
def
batch
(
inputs_list
:
List
[
"MultiModalInputs"
],
device
:
torch
.
types
.
Device
,
)
->
Dict
[
str
,
BatchedTensors
]:
"""Batch multiple inputs together into a dictionary."""
def
batch
(
inputs_list
:
List
[
"MultiModalInputs"
])
->
BatchedTensorInputs
:
"""
Batch multiple inputs together into a dictionary.
The resulting dictionary has the same keys as the inputs.
If the corresponding value from each input is a tensor and they all
share the same shape, the output value is a single batched tensor;
otherwise, the output value is a list containing the original value
from each input.
"""
if
len
(
inputs_list
)
==
0
:
return
{}
keys
=
inputs_list
[
0
].
keys
()
item_lists
:
Dict
[
str
,
List
[
torch
.
Tensor
]]
=
defaultdict
(
list
)
item_lists
:
Dict
[
str
,
List
[
Nested
Tensor
s
]]
=
defaultdict
(
list
)
for
inputs
in
inputs_list
:
if
inputs
.
keys
()
!=
keys
:
...
...
@@ -75,10 +101,19 @@ class MultiModalInputs(_MultiModalInputsBase):
item_lists
[
k
].
append
(
v
)
return
{
k
:
MultiModalInputs
.
try_concat
(
item_list
,
device
=
device
)
k
:
MultiModalInputs
.
_
try_concat
(
item_list
)
for
k
,
item_list
in
item_lists
.
items
()
}
@
staticmethod
def
as_kwargs
(
batched_inputs
:
BatchedTensorInputs
,
*
,
device
:
torch
.
types
.
Device
,
)
->
BatchedTensorInputs
:
return
json_map_leaves
(
lambda
x
:
x
.
to
(
device
,
non_blocking
=
True
),
batched_inputs
)
class
MultiModalDataBuiltins
(
TypedDict
,
total
=
False
):
"""Modality types that are predefined by vLLM."""
...
...
vllm/multimodal/image.py
View file @
e661d594
...
...
@@ -113,7 +113,7 @@ class ImagePlugin(MultiModalPlugin):
def
_default_input_mapper
(
self
,
ctx
:
InputContext
,
data
:
object
)
->
MultiModalInputs
:
model_config
=
ctx
.
model_config
if
isinstance
(
data
,
Image
.
Image
):
if
isinstance
(
data
,
(
Image
.
Image
,
list
)
):
image_processor
=
self
.
_get_hf_image_processor
(
model_config
)
if
image_processor
is
None
:
raise
RuntimeError
(
"No HuggingFace processor is available "
...
...
vllm/outputs.py
View file @
e661d594
...
...
@@ -29,7 +29,7 @@ class CompletionOutput:
index
:
int
text
:
str
token_ids
:
Tuple
[
int
,
...]
cumulative_logprob
:
float
cumulative_logprob
:
Optional
[
float
]
logprobs
:
Optional
[
SampleLogprobs
]
finish_reason
:
Optional
[
str
]
=
None
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
...
...
@@ -124,13 +124,14 @@ class RequestOutput:
include_logprobs
=
seq_group
.
sampling_params
.
logprobs
is
not
None
text_buffer_length
=
seq_group
.
sampling_params
.
output_text_buffer_length
outputs
=
[
CompletionOutput
(
seqs
.
index
(
seq
),
seq
.
get_output_text_to_return
(
text_buffer_length
),
seq
.
get_output_token_ids
(),
seq
.
get_cumulative_logprob
(),
seq
.
output_logprobs
if
include_logprobs
else
None
,
SequenceStatus
.
get_finished_reason
(
seq
.
status
),
seq
.
stop_reason
)
for
seq
in
top_n_seqs
CompletionOutput
(
seqs
.
index
(
seq
),
seq
.
get_output_text_to_return
(
text_buffer_length
),
seq
.
get_output_token_ids
(),
seq
.
get_cumulative_logprob
()
if
include_logprobs
else
None
,
seq
.
output_logprobs
if
include_logprobs
else
None
,
SequenceStatus
.
get_finished_reason
(
seq
.
status
),
seq
.
stop_reason
)
for
seq
in
top_n_seqs
]
# Every sequence in the sequence group should have the same prompt.
...
...
vllm/sampling_params.py
View file @
e661d594
...
...
@@ -92,11 +92,12 @@ class SamplingParams:
min_tokens: Minimum number of tokens to generate per output sequence
before EOS or stop_token_ids can be generated
logprobs: Number of log probabilities to return per output token.
Note that the implementation follows the OpenAI API: The return
result includes the log probabilities on the `logprobs` most likely
tokens, as well the chosen tokens. The API will always return the
log probability of the sampled token, so there may be up to
`logprobs+1` elements in the response.
When set to None, no probability is returned. If set to a non-None
value, the result includes the log probabilities of the specified
number of most likely tokens, as well as the chosen tokens.
Note that the implementation follows the OpenAI API: The API will
always return the log probability of the sampled token, so there
may be up to `logprobs+1` elements in the response.
prompt_logprobs: Number of log probabilities to return per prompt token.
detokenize: Whether to detokenize the output. Defaults to True.
skip_special_tokens: Whether to skip special tokens in the output.
...
...
@@ -168,8 +169,8 @@ class SamplingParams:
self
.
ignore_eos
=
ignore_eos
self
.
max_tokens
=
max_tokens
self
.
min_tokens
=
min_tokens
self
.
logprobs
=
logprobs
self
.
prompt_logprobs
=
prompt_logprobs
self
.
logprobs
=
1
if
logprobs
is
True
else
logprobs
self
.
prompt_logprobs
=
1
if
prompt_logprobs
is
True
else
prompt_logprobs
# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
...
...
vllm/scalar_type.py
0 → 100644
View file @
e661d594
from
._core_ext
import
NanRepr
,
ScalarType
# naming generally follows: https://github.com/jax-ml/ml_dtypes
# for floating point types (leading f) the scheme is:
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
# flags:
# - no-flags: means it follows IEEE 754 conventions
# - f: means finite values only (no infinities)
# - n: means nans are supported (non-standard encoding)
# for integer types the scheme is:
# `[u]int<size_bits>[b<bias>]`
# - if bias is not present it means its zero
class
scalar_types
:
int4
=
ScalarType
.
int_
(
4
,
None
)
uint4
=
ScalarType
.
uint
(
4
,
None
)
int8
=
ScalarType
.
int_
(
8
,
None
)
uint8
=
ScalarType
.
uint
(
8
,
None
)
float8_e4m3fn
=
ScalarType
.
float_
(
4
,
3
,
True
,
NanRepr
.
EXTD_RANGE_MAX_MIN
.
value
)
float8_e5m2
=
ScalarType
.
float_IEEE754
(
5
,
2
)
float16_e8m7
=
ScalarType
.
float_IEEE754
(
8
,
7
)
float16_e5m10
=
ScalarType
.
float_IEEE754
(
5
,
10
)
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
float6_e3m2f
=
ScalarType
.
float_
(
3
,
2
,
True
,
NanRepr
.
NONE
.
value
)
# "gptq" types
uint4b8
=
ScalarType
.
uint
(
4
,
8
)
uint8b128
=
ScalarType
.
uint
(
8
,
128
)
# colloquial names
bfloat16
=
float16_e8m7
float16
=
float16_e5m10
vllm/scripts.py
View file @
e661d594
# The CLI entrypoint to vLLM.
import
argparse
import
asyncio
import
os
import
signal
import
sys
from
typing
import
Optional
from
typing
import
List
,
Optional
from
openai
import
OpenAI
from
openai.types.chat
import
ChatCompletionMessageParam
from
vllm.entrypoints.openai.api_server
import
run_server
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
...
...
@@ -25,7 +27,7 @@ def serve(args: argparse.Namespace) -> None:
# EngineArgs expects the model name to be passed as --model.
args
.
model
=
args
.
model_tag
run_server
(
args
)
asyncio
.
run
(
run_server
(
args
)
)
def
interactive_cli
(
args
:
argparse
.
Namespace
)
->
None
:
...
...
@@ -62,15 +64,14 @@ def complete(model_name: str, client: OpenAI) -> None:
def
chat
(
system_prompt
:
Optional
[
str
],
model_name
:
str
,
client
:
OpenAI
)
->
None
:
conversation
=
[]
conversation
:
List
[
ChatCompletionMessageParam
]
=
[]
if
system_prompt
is
not
None
:
conversation
.
append
({
"role"
:
"system"
,
"content"
:
system_prompt
})
print
(
"Please enter a message for the chat model:"
)
while
True
:
input_message
=
input
(
"> "
)
message
=
{
"role"
:
"user"
,
"content"
:
input_message
}
conversation
.
append
(
message
)
conversation
.
append
({
"role"
:
"user"
,
"content"
:
input_message
})
chat_completion
=
client
.
chat
.
completions
.
create
(
model
=
model_name
,
messages
=
conversation
)
...
...
@@ -78,7 +79,7 @@ def chat(system_prompt: Optional[str], model_name: str,
response_message
=
chat_completion
.
choices
[
0
].
message
output
=
response_message
.
content
conversation
.
append
(
response_message
)
conversation
.
append
(
response_message
)
# type: ignore
print
(
output
)
...
...
vllm/sequence.py
View file @
e661d594
...
...
@@ -3,6 +3,7 @@ import copy
import
enum
import
math
from
abc
import
ABC
,
abstractmethod
from
array
import
array
from
collections
import
defaultdict
from
dataclasses
import
dataclass
,
field
from
typing
import
(
TYPE_CHECKING
,
Dict
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
...
...
@@ -119,10 +120,10 @@ class SequenceData:
prompt_token_ids
:
List
[
int
],
output_token_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
None
:
self
.
_prompt_token_ids
:
List
[
int
]
=
list
(
prompt_token_ids
)
self
.
_prompt_token_ids
=
array
(
'l'
,
prompt_token_ids
)
self
.
_prompt_token_ids_tuple
:
Tuple
[
int
,
...]
=
tuple
(
prompt_token_ids
)
self
.
_output_token_ids
:
List
[
int
]
=
(
list
(
output_token_ids
)
if
output_token_ids
is
not
None
else
[])
self
.
_output_token_ids
=
array
(
'l'
,
output_token_ids
if
output_token_ids
is
not
None
else
[])
self
.
cumulative_logprob
=
0.0
# The number of tokens that are computed (that run against the model).
...
...
@@ -132,8 +133,8 @@ class SequenceData:
self
.
_update_cached_all_tokens
()
def
_update_cached_all_tokens
(
self
):
self
.
_cached_all_token_ids
:
List
[
int
]
=
(
self
.
_prompt_token_ids
+
self
.
_output_token_ids
)
self
.
_cached_all_token_ids
:
List
[
int
]
=
list
(
self
.
_prompt_token_ids
+
self
.
_output_token_ids
)
@
property
def
prompt_token_ids
(
self
)
->
Tuple
[
int
,
...]:
...
...
@@ -141,19 +142,27 @@ class SequenceData:
@
prompt_token_ids
.
setter
def
prompt_token_ids
(
self
,
new_prompt_token_ids
)
->
None
:
self
.
_prompt_token_ids
=
list
(
new_prompt_token_ids
)
self
.
_prompt_token_ids
=
array
(
'l'
,
new_prompt_token_ids
)
self
.
_prompt_token_ids_tuple
=
tuple
(
new_prompt_token_ids
)
self
.
_update_cached_all_tokens
()
@
property
def
prompt_token_ids_array
(
self
)
->
array
:
return
self
.
_prompt_token_ids
@
property
def
output_token_ids
(
self
)
->
Tuple
[
int
,
...]:
return
tuple
(
self
.
_output_token_ids
)
@
output_token_ids
.
setter
def
output_token_ids
(
self
,
new_output_token_ids
)
->
None
:
self
.
_output_token_ids
=
list
(
new_output_token_ids
)
self
.
_output_token_ids
=
array
(
'l'
,
new_output_token_ids
)
self
.
_update_cached_all_tokens
()
@
property
def
output_token_ids_array
(
self
)
->
array
:
return
self
.
_output_token_ids
def
append_token_id
(
self
,
token_id
:
int
,
logprob
:
float
)
->
None
:
self
.
_output_token_ids
.
append
(
token_id
)
self
.
_cached_all_token_ids
.
append
(
token_id
)
...
...
@@ -402,14 +411,6 @@ class Sequence:
f
"num_blocks=
{
self
.
n_blocks
}
, "
)
@
dataclass
class
SequenceGroupState
:
"""Mutable state tied to a specific sequence group"""
# torch.Generator used in seeded sampling
generator
:
Optional
=
None
# type: ignore
class
SequenceGroup
:
"""A group of sequences that are generated from the same prompt.
...
...
@@ -443,6 +444,7 @@ class SequenceGroup:
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
None
:
self
.
request_id
=
request_id
self
.
seqs
=
seqs
self
.
seqs_dict
=
{
seq
.
seq_id
:
seq
for
seq
in
seqs
}
self
.
sampling_params
=
sampling_params
self
.
metrics
=
RequestMetrics
(
arrival_time
=
arrival_time
,
...
...
@@ -452,31 +454,29 @@ class SequenceGroup:
time_in_queue
=
None
)
self
.
lora_request
=
lora_request
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
self
.
state
=
SequenceGroupState
()
self
.
embeddings
=
embeddings
self
.
pooling_params
=
pooling_params
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
encoder_seq
=
encoder_seq
self
.
trace_headers
=
trace_headers
self
.
_first_seq
=
next
(
iter
(
self
.
seqs_dict
.
values
()))
@
property
def
prompt
(
self
)
->
Optional
[
str
]:
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
return
self
.
_first_seq
.
prompt
return
self
.
seqs
[
0
]
.
prompt
@
property
def
prompt_token_ids
(
self
)
->
List
[
int
]:
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
return
self
.
_first_seq
.
prompt_token_ids
return
self
.
seqs
[
0
]
.
prompt_token_ids
@
property
def
multi_modal_data
(
self
)
->
"MultiModalDataDict"
:
# All sequences in the group should have the same multi-modal data.
# We use the multi-modal data of an arbitrary sequence.
return
self
.
_first_seq
.
multi_modal_data
return
self
.
seqs
[
0
]
.
multi_modal_data
@
property
def
lora_int_id
(
self
)
->
int
:
...
...
@@ -512,7 +512,7 @@ class SequenceGroup:
# in TPOT, rather than recalculating TTFT (since from the )
# POV of the user, there is simply a long generation delay.
if
(
self
.
metrics
.
first_token_time
is
None
and
self
.
get_
seqs
()
[
0
].
get_output_len
()
==
1
):
and
self
.
seqs
[
0
].
get_output_len
()
==
1
):
self
.
metrics
.
first_token_time
=
time
def
maybe_set_first_scheduled_time
(
self
,
time
:
float
)
->
None
:
...
...
@@ -548,9 +548,9 @@ class SequenceGroup:
self
,
status
:
Optional
[
SequenceStatus
]
=
None
,
)
->
List
[
Sequence
]:
return
list
(
self
.
seqs_dict
.
values
())
if
status
is
None
else
[
seq
for
seq
in
self
.
seqs_dict
.
values
()
if
seq
.
status
==
statu
s
]
if
status
is
None
:
return
self
.
seq
s
return
[
seq
for
seq
in
self
.
seqs
if
seq
.
status
==
status
]
def
is_encoder_decoder
(
self
)
->
bool
:
return
self
.
encoder_seq
is
not
None
...
...
@@ -559,22 +559,20 @@ class SequenceGroup:
return
self
.
encoder_seq
def
get_unfinished_seqs
(
self
)
->
List
[
Sequence
]:
return
[
seq
for
seq
in
self
.
seqs_dict
.
values
()
if
not
seq
.
is_finished
()
]
return
[
seq
for
seq
in
self
.
seqs
if
not
seq
.
is_finished
()]
def
get_finished_seqs
(
self
)
->
List
[
Sequence
]:
return
[
seq
for
seq
in
self
.
seqs
_dict
.
values
()
if
seq
.
is_finished
()]
return
[
seq
for
seq
in
self
.
seqs
if
seq
.
is_finished
()]
def
update_num_computed_tokens
(
self
,
num_new_computed_tokens
:
int
):
"""Update number of tokens computed so far."""
for
seq
in
self
.
seqs
_dict
.
values
()
:
for
seq
in
self
.
seqs
:
if
not
seq
.
is_finished
():
seq
.
data
.
update_num_computed_tokens
(
num_new_computed_tokens
)
def
get_num_uncomputed_tokens
(
self
)
->
int
:
num_uncomputed_tokens
=
0
for
seq
in
self
.
get_
seqs
()
:
for
seq
in
self
.
seqs
:
if
not
seq
.
is_finished
():
num_uncomputed_tokens
+=
seq
.
data
.
get_num_uncomputed_tokens
()
return
num_uncomputed_tokens
...
...
@@ -583,7 +581,7 @@ class SequenceGroup:
# Optimization. We don't need to call get_seqs if we don't need to
# filter by states.
if
status
is
None
:
return
len
(
self
.
seqs
_dict
)
return
len
(
self
.
seqs
)
return
len
(
self
.
get_seqs
(
status
))
...
...
@@ -602,23 +600,25 @@ class SequenceGroup:
if
seq
.
seq_id
in
self
.
seqs_dict
:
raise
ValueError
(
f
"Sequence
{
seq
.
seq_id
}
already exists."
)
self
.
seqs_dict
[
seq
.
seq_id
]
=
seq
self
.
seqs
.
append
(
seq
)
def
remove
(
self
,
seq_id
:
int
)
->
None
:
if
seq_id
not
in
self
.
seqs_dict
:
seq
=
self
.
seqs_dict
.
pop
(
seq_id
,
None
)
if
seq
is
None
:
raise
ValueError
(
f
"Sequence
{
seq_id
}
not found."
)
del
self
.
seqs
_dict
[
seq_id
]
self
.
seqs
.
remove
(
seq
)
def
is_finished
(
self
)
->
bool
:
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
get_
seqs
()
)
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
seqs
)
def
is_prefill
(
self
)
->
bool
:
# Every sequence should be in the same stage.
return
self
.
get_
seqs
()
[
0
].
is_prefill
()
return
self
.
seqs
[
0
].
is_prefill
()
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceGroup(request_id=
{
self
.
request_id
}
, "
f
"sampling_params=
{
self
.
sampling_params
}
, "
f
"num_seqs=
{
len
(
self
.
seqs
_dict
)
}
)"
)
f
"num_seqs=
{
len
(
self
.
seqs
)
}
)"
)
class
SequenceGroupMetadata
:
...
...
@@ -639,7 +639,6 @@ class SequenceGroupMetadata:
lora_request: LoRA request.
computed_block_nums: The block numbers that are already computed,
used in prefix caching.
state: Internal state tied to this sequence group.
multi_modal_data: Multi modal data.
encoder_seq_data: Optional sequence data for encoder prompt
(SequenceGroup.encoder_seq). Should be None
...
...
@@ -665,7 +664,6 @@ class SequenceGroupMetadata:
token_chunk_size
:
Optional
[
int
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
computed_block_nums
:
Optional
[
List
[
int
]]
=
None
,
state
:
Optional
[
SequenceGroupState
]
=
None
,
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
,
encoder_seq_data
:
Optional
[
SequenceData
]
=
None
,
cross_block_table
:
Optional
[
List
[
int
]]
=
None
,
...
...
@@ -681,7 +679,6 @@ class SequenceGroupMetadata:
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
computed_block_nums
=
computed_block_nums
self
.
multi_modal_data
=
multi_modal_data
self
.
state
=
SequenceGroupState
()
if
state
is
None
else
state
self
.
encoder_seq_data
=
encoder_seq_data
self
.
cross_block_table
=
cross_block_table
self
.
_token_chunk_size
=
token_chunk_size
...
...
vllm/spec_decode/batch_expansion.py
View file @
e661d594
...
...
@@ -3,9 +3,9 @@ from typing import Iterator, List, Tuple
import
torch
from
vllm
import
SamplingParams
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
,
SequenceGroupState
,
get_all_seq_ids
)
SequenceGroupMetadata
,
get_all_seq_ids
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
(
nvtx_range
,
sampler_output_to_torch
,
...
...
@@ -16,6 +16,8 @@ SeqId = int
TargetSeqId
=
int
TokenId
=
int
DEFAULT_SIMPLE_SAMPLING_PARAMS
=
SamplingParams
()
class
BatchExpansionTop1Scorer
(
SpeculativeScorer
):
"""Implements a speculative scorer that uses batch expansion to get
...
...
@@ -247,24 +249,39 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
token_ids_to_score
=
self
.
_get_token_ids_to_score
(
proposal_token_ids
[
batch_index
])
# Use simpler sampling parameters apart from for final token
# (in particular don't do seeded sampling) since those sampled tokens
# aren't used.
# We don't replace the sampling_params in the greedy case because
# this also controls whether the probs get modified in the sampler
# (see use of _modify_greedy_probs_inplace there).
sampling_params
=
input_seq_group_metadata
.
sampling_params
non_bonus_sampling_params
=
DEFAULT_SIMPLE_SAMPLING_PARAMS
\
if
sampling_params
.
temperature
else
sampling_params
target_seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
for
token_ids
in
token_ids_to_score
:
last_index
=
len
(
token_ids_to_score
)
-
1
for
i
,
token_ids
in
enumerate
(
token_ids_to_score
):
target_sampling_params
=
sampling_params
if
i
==
last_index
\
else
non_bonus_sampling_params
target_seq_group_metadata_list
.
append
(
self
.
_create_single_target_seq_group_metadata
(
input_seq_group_metadata
,
input_seq_id
,
next
(
target_seq_ids_iter
),
token_ids
,
sampling_params
=
target_sampling_params
,
))
return
target_seq_group_metadata_list
@
staticmethod
def
_create_single_target_seq_group_metadata
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
,
seq_id
:
SeqId
,
target_seq_id
:
TargetSeqId
,
token_ids
:
List
[
TokenId
],
sampling_params
:
SamplingParams
,
)
->
SequenceGroupMetadata
:
"""Create a single target SequenceGroupMetadata.
...
...
@@ -293,26 +310,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
for
data
in
new_seq_data_dict
.
values
():
data
.
update_num_computed_tokens
(
data
.
get_len
()
-
1
)
if
(
seq_group_metadata
.
state
is
not
None
and
seq_group_metadata
.
state
.
generator
is
not
None
):
generator
=
torch
.
Generator
(
device
=
seq_group_metadata
.
state
.
generator
.
device
)
generator
.
set_state
(
seq_group_metadata
.
state
.
generator
.
get_state
())
state
=
SequenceGroupState
(
generator
=
generator
)
else
:
state
=
None
return
SequenceGroupMetadata
(
request_id
=
seq_group_metadata
.
request_id
,
is_prompt
=
seq_group_metadata
.
is_prompt
,
seq_data
=
new_seq_data_dict
,
sampling_params
=
seq_group_metadata
.
sampling_params
,
sampling_params
=
sampling_params
,
block_tables
=
{
target_seq_id
:
seq_group_metadata
.
block_tables
[
seq_id
],
},
lora_request
=
None
,
token_chunk_size
=
1
,
state
=
state
,
)
def
_split_scoring_output
(
...
...
vllm/spec_decode/draft_model_runner.py
View file @
e661d594
...
...
@@ -11,10 +11,22 @@ except ModuleNotFoundError:
from
vllm.attention.backends.rocm_flash_attn
import
(
ROCmFlashAttentionMetadata
as
FlashAttentionMetadata
)
try
:
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from
flashinfer.prefill
import
BatchPrefillWithPagedKVCacheWrapper
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
256
*
1024
*
1024
except
ImportError
:
BatchDecodeWithPagedKVCacheWrapper
=
None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
=
None
BatchPrefillWithPagedKVCacheWrapper
=
None
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
0
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalInputs
from
vllm.sequence
import
(
ExecuteModelRequest
,
IntermediateTensors
,
SamplerOutput
)
from
vllm.worker.model_runner
import
(
ModelInputForGPUWithSamplingMetadata
,
...
...
@@ -78,6 +90,11 @@ class TP1DraftModelRunner(ModelRunner):
return_hidden_states
=
return_hidden_states
,
)
self
.
flashinfer_decode_workspace_buffer
=
None
self
.
flashinfer_decode_wrapper
=
None
self
.
flashinfer_prefill_workspace_buffer
=
None
self
.
flashinfer_prefill_wrapper
=
None
def
_update_flash_attn_metadata
(
self
,
attn_metadata
,
num_seqs
,
num_queries
):
assert
isinstance
(
attn_metadata
,
FlashAttentionMetadata
)
...
...
@@ -285,6 +302,37 @@ class TP1DraftModelRunner(ModelRunner):
model_input
.
prompt_adapter_requests
,
model_input
.
prompt_adapter_mapping
)
if
self
.
attn_backend
.
get_name
()
==
"flashinfer"
:
assert
model_input
.
attn_metadata
is
not
None
assert
model_input
.
input_tokens
is
not
None
if
self
.
flashinfer_decode_workspace_buffer
is
None
:
self
.
flashinfer_decode_workspace_buffer
=
torch
.
empty
(
FLASHINFER_WORKSPACE_BUFFER_SIZE
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
flashinfer_decode_wrapper
=
\
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_decode_workspace_buffer
,
"NHD"
)
self
.
flashinfer_prefill_workspace_buffer
=
torch
.
empty
(
FLASHINFER_WORKSPACE_BUFFER_SIZE
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
flashinfer_prefill_wrapper
=
\
BatchPrefillWithPagedKVCacheWrapper
(
self
.
flashinfer_prefill_workspace_buffer
,
"NHD"
)
model_input
.
attn_metadata
.
prefill_wrapper
=
\
self
.
flashinfer_prefill_wrapper
if
model_input
.
attn_metadata
.
use_cuda_graph
:
batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
model_input
.
attn_metadata
.
decode_wrapper
=
\
self
.
graph_runners
[
model_input
.
virtual_engine
][
batch_size
].
flashinfer_decode_wrapper
else
:
model_input
.
attn_metadata
.
decode_wrapper
=
\
self
.
flashinfer_decode_wrapper
model_input
.
attn_metadata
.
begin_forward
()
# Detect exec mode
assert
model_input
.
attn_metadata
is
not
None
use_cuda_graph
=
False
...
...
@@ -323,7 +371,8 @@ class TP1DraftModelRunner(ModelRunner):
kv_caches
=
kv_caches
,
attn_metadata
=
model_input
.
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
**
multi_modal_kwargs
,
**
MultiModalInputs
.
as_kwargs
(
multi_modal_kwargs
,
device
=
self
.
device
),
)
# Compute the logits.
...
...
vllm/spec_decode/medusa_worker.py
View file @
e661d594
...
...
@@ -57,9 +57,11 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
seq_lens
,
query_lens
=
self
.
_prepare_input_tensors
(
seq_group_metadata_list
)
generators
=
self
.
model_runner
.
get_generators
(
execute_model_req
.
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
self
.
device
,
self
.
model_runner
.
pin_memory
)
self
.
model_runner
.
pin_memory
,
generators
)
model_outputs
=
self
.
model_runner
.
model
.
generate_proposals
(
previous_hidden_states
=
execute_model_req
.
previous_hidden_states
.
...
...
Prev
1
…
13
14
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment