Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e661d594
Commit
e661d594
authored
Aug 12, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.4' into v0.5.4-dtk24.04.1
parents
6b16ea2e
4db5176d
Changes
374
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1137 additions
and
272 deletions
+1137
-272
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+15
-2
vllm/model_executor/models/paligemma.py
vllm/model_executor/models/paligemma.py
+32
-60
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+2
-2
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+43
-10
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+52
-11
vllm/model_executor/models/qwen2_moe.py
vllm/model_executor/models/qwen2_moe.py
+54
-15
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+621
-0
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+26
-24
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+57
-44
vllm/multimodal/__init__.py
vllm/multimodal/__init__.py
+5
-2
vllm/multimodal/base.py
vllm/multimodal/base.py
+59
-24
vllm/multimodal/image.py
vllm/multimodal/image.py
+1
-1
vllm/outputs.py
vllm/outputs.py
+9
-8
vllm/sampling_params.py
vllm/sampling_params.py
+8
-7
vllm/scalar_type.py
vllm/scalar_type.py
+35
-0
vllm/scripts.py
vllm/scripts.py
+7
-6
vllm/sequence.py
vllm/sequence.py
+36
-39
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+22
-15
vllm/spec_decode/draft_model_runner.py
vllm/spec_decode/draft_model_runner.py
+50
-1
vllm/spec_decode/medusa_worker.py
vllm/spec_decode/medusa_worker.py
+3
-1
No files found.
vllm/model_executor/models/opt.py
View file @
e661d594
...
@@ -237,14 +237,19 @@ class OPTDecoder(nn.Module):
...
@@ -237,14 +237,19 @@ class OPTDecoder(nn.Module):
for
_
in
range
(
config
.
num_hidden_layers
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
])
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
get_input_embeddings
(
input_ids
)
pos_embeds
=
self
.
embed_positions
(
positions
)
pos_embeds
=
self
.
embed_positions
(
positions
)
if
self
.
project_in
is
not
None
:
if
self
.
project_in
is
not
None
:
inputs_embeds
,
_
=
self
.
project_in
(
inputs_embeds
)
inputs_embeds
,
_
=
self
.
project_in
(
inputs_embeds
)
...
@@ -272,14 +277,22 @@ class OPTModel(nn.Module):
...
@@ -272,14 +277,22 @@ class OPTModel(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
decoder
=
OPTDecoder
(
config
,
cache_config
,
quant_config
)
self
.
decoder
=
OPTDecoder
(
config
,
cache_config
,
quant_config
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
decoder
.
get_input_embeddings
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
self
.
decoder
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
return
self
.
decoder
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
inputs_embeds
=
inputs_embeds
)
class
OPTForCausalLM
(
nn
.
Module
):
class
OPTForCausalLM
(
nn
.
Module
):
...
...
vllm/model_executor/models/paligemma.py
View file @
e661d594
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
from
typing
import
Iterable
,
List
,
Literal
,
Optional
,
Tuple
,
TypedDict
import
torch
import
torch
from
PIL
import
Image
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PaliGemmaConfig
,
SiglipVisionConfig
,
SiglipVisionModel
from
transformers
import
PaliGemmaConfig
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
ColumnParallelLinear
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
...
@@ -19,9 +17,11 @@ from vllm.model_executor.models.gemma import GemmaModel
...
@@ -19,9 +17,11 @@ from vllm.model_executor.models.gemma import GemmaModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
cached_get_tokenizer
from
vllm.multimodal.image
import
cached_get_tokenizer
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
,
SequenceData
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
.interfaces
import
SupportsVision
from
.interfaces
import
SupportsVision
from
.siglip
import
(
SiglipVisionModel
,
dummy_image_for_siglip
,
dummy_seq_data_for_siglip
,
get_max_siglip_image_tokens
)
from
.utils
import
merge_vision_embeddings
from
.utils
import
merge_vision_embeddings
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -33,55 +33,22 @@ _KEYS_TO_MODIFY_MAPPING = {
...
@@ -33,55 +33,22 @@ _KEYS_TO_MODIFY_MAPPING = {
def
get_max_paligemma_image_tokens
(
ctx
:
InputContext
):
def
get_max_paligemma_image_tokens
(
ctx
:
InputContext
):
hf_config
=
ctx
.
get_hf_config
(
PaliGemmaConfig
)
hf_config
=
ctx
.
get_hf_config
(
PaliGemmaConfig
)
text_config
=
hf_config
.
text_config
vision_config
=
hf_config
.
vision_config
return
text_config
.
num_image_tokens
def
dummy_seq_data_for_paligemma
(
hf_config
:
PaliGemmaConfig
,
seq_len
:
int
,
*
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
):
if
image_feature_size_override
is
None
:
image_feature_size
=
hf_config
.
text_config
.
num_image_tokens
else
:
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
)
return
SequenceData
(
token_ids
)
def
dummy_image_for_paligemma
(
hf_config
:
SiglipVisionConfig
,
*
,
image_width_override
:
Optional
[
int
]
=
None
,
image_height_override
:
Optional
[
int
]
=
None
,
):
width
=
height
=
hf_config
.
image_size
if
image_width_override
is
not
None
:
width
=
image_width_override
if
image_height_override
is
not
None
:
height
=
image_height_override
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
return
get_max_siglip_image_tokens
(
vision_config
)
return
{
"image"
:
image
}
def
dummy_data_for_paligemma
(
ctx
:
InputContext
,
seq_len
:
int
):
def
dummy_data_for_paligemma
(
ctx
:
InputContext
,
seq_len
:
int
):
hf_config
=
ctx
.
get_hf_config
(
PaliGemmaConfig
)
hf_config
=
ctx
.
get_hf_config
(
PaliGemmaConfig
)
vision_config
=
hf_config
.
vision_config
vision_config
=
hf_config
.
vision_config
seq_data
=
dummy_seq_data_for_
paligemma
(
seq_data
=
dummy_seq_data_for_
siglip
(
hf
_config
,
vision
_config
,
seq_len
,
seq_len
,
image_token_id
=
hf_config
.
image_token_index
,
image_token_id
=
hf_config
.
image_token_index
,
)
)
mm_data
=
dummy_image_for_
paligemma
(
vision_config
)
mm_data
=
dummy_image_for_
siglip
(
vision_config
)
return
seq_data
,
mm_data
return
seq_data
,
mm_data
...
@@ -133,12 +100,10 @@ class PaliGemmaMultiModalProjector(nn.Module):
...
@@ -133,12 +100,10 @@ class PaliGemmaMultiModalProjector(nn.Module):
def
__init__
(
self
,
vision_hidden_size
:
int
,
projection_dim
:
int
):
def
__init__
(
self
,
vision_hidden_size
:
int
,
projection_dim
:
int
):
super
().
__init__
()
super
().
__init__
()
self
.
linear
=
ColumnParallelLinear
(
vision_hidden_size
,
self
.
linear
=
nn
.
Linear
(
vision_hidden_size
,
projection_dim
,
bias
=
True
)
projection_dim
,
bias
=
True
)
def
forward
(
self
,
image_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
image_features
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
linear
(
image_features
)
hidden_states
=
self
.
linear
(
image_features
)
return
hidden_states
return
hidden_states
...
@@ -211,30 +176,37 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
...
@@ -211,30 +176,37 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
data
=
self
.
_validate_pixel_values
(
pixel_values
),
data
=
self
.
_validate_pixel_values
(
pixel_values
),
)
)
def
_image_pixels_to_features
(
self
,
vision_tower
:
SiglipVisionModel
,
def
_image_pixels_to_features
(
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
self
,
vision_tower
:
SiglipVisionModel
,
pixel_values
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
target_dtype
=
vision_tower
.
get_input_embeddings
().
weight
.
dtype
target_dtype
=
vision_tower
.
get_input_embeddings
().
weight
.
dtype
image_outputs
=
vision_tower
(
pixel_values
.
to
(
dtype
=
target_dtype
),
image_features
=
vision_tower
(
pixel_values
.
to
(
dtype
=
target_dtype
))
output_hidden_states
=
True
)
selected_image_features
=
image_outputs
.
last_hidden_state
return
selected_
image_features
return
image_features
def
_process_image_pixels
(
def
_process_image_pixels
(
self
,
inputs
:
PaliGemmaImagePixelInputs
)
->
torch
.
Tensor
:
self
,
inputs
:
PaliGemmaImagePixelInputs
,
)
->
torch
.
Tensor
:
assert
self
.
vision_tower
is
not
None
assert
self
.
vision_tower
is
not
None
pixel_values
=
inputs
[
"data"
]
pixel_values
=
inputs
[
"data"
]
return
self
.
_image_pixels_to_features
(
self
.
vision_tower
,
pixel_values
)
return
self
.
_image_pixels_to_features
(
self
.
vision_tower
,
pixel_values
,
)
def
_process_image_input
(
def
_process_image_input
(
self
,
image_input
:
PaliGemmaImageInputs
)
->
torch
.
Tensor
:
self
,
image_input
:
PaliGemmaImageInputs
,
)
->
torch
.
Tensor
:
assert
self
.
vision_tower
is
not
None
assert
self
.
vision_tower
is
not
None
image_features
=
self
.
_process_image_pixels
(
image_input
)
image_features
=
self
.
_process_image_pixels
(
image_input
,
)
return
self
.
multi_modal_projector
(
image_features
)
return
self
.
multi_modal_projector
(
image_features
)
...
@@ -345,6 +317,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
...
@@ -345,6 +317,6 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
unloaded_params
=
params_dict
.
keys
()
-
loaded_params
if
unloaded_params
:
if
unloaded_params
:
raise
RuntimeError
(
logger
.
warning
(
"Some weights are not initialized from checkpoints:
"
"Some weights are not initialized from checkpoints:
%s"
,
f
"
{
unloaded_params
}
"
)
unloaded_params
)
vllm/model_executor/models/phi3v.py
View file @
e661d594
...
@@ -36,7 +36,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -36,7 +36,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.models.clip
import
CLIPVisionModel
from
vllm.model_executor.models.clip
import
CLIPVisionModel
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.models.llama
import
LlamaModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
BatchedTensors
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
cached_get_tokenizer
from
vllm.multimodal.image
import
cached_get_tokenizer
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
...
@@ -261,7 +261,7 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
...
@@ -261,7 +261,7 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
class
Phi3VImagePixelInputs
(
TypedDict
):
class
Phi3VImagePixelInputs
(
TypedDict
):
type
:
Literal
[
"pixel_values"
]
type
:
Literal
[
"pixel_values"
]
data
:
Batched
Tensor
s
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
"""
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
...
...
vllm/model_executor/models/qwen.py
View file @
e661d594
...
@@ -15,7 +15,7 @@ import re
...
@@ -15,7 +15,7 @@ import re
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
@@ -35,6 +35,7 @@ from vllm.utils import print_warning_once
...
@@ -35,6 +35,7 @@ from vllm.utils import print_warning_once
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
.utils
import
is_pp_missing_parameter
,
make_layers
class
QWenMLP
(
nn
.
Module
):
class
QWenMLP
(
nn
.
Module
):
...
@@ -194,6 +195,7 @@ class QWenModel(nn.Module):
...
@@ -194,6 +195,7 @@ class QWenModel(nn.Module):
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -203,10 +205,10 @@ class QWenModel(nn.Module):
...
@@ -203,10 +205,10 @@ class QWenModel(nn.Module):
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
)
)
self
.
h
=
nn
.
ModuleList
([
self
.
start_layer
,
self
.
end_layer
,
self
.
h
=
make_layers
(
QWenBlock
(
config
,
cache_config
,
quant_config
)
config
.
num_hidden_layers
,
for
_
in
range
(
config
.
num_hidden_layers
)
lambda
prefix
:
QWenBlock
(
config
,
cache_config
,
quant_config
),
]
)
prefix
=
f
"
{
prefix
}
.h"
)
self
.
ln_f
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_f
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
def
forward
(
...
@@ -215,18 +217,29 @@ class QWenModel(nn.Module):
...
@@ -215,18 +217,29 @@ class QWenModel(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
wte
(
input_ids
)
if
get_pp_group
().
is_first_rank
:
residual
=
None
hidden_states
=
self
.
wte
(
input_ids
)
for
i
in
range
(
len
(
self
.
h
)):
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
h
[
i
]
layer
=
self
.
h
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
],
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
attn_metadata
,
residual
,
residual
,
)
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
ln_f
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
ln_f
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
...
@@ -267,9 +280,23 @@ class QWenLMHeadModel(nn.Module):
...
@@ -267,9 +280,23 @@ class QWenLMHeadModel(nn.Module):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
return
hidden_states
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
"residual"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
})
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
logits
=
self
.
logits_processor
(
self
.
lm_head
,
hidden_states
,
...
@@ -301,6 +328,9 @@ class QWenLMHeadModel(nn.Module):
...
@@ -301,6 +328,9 @@ class QWenLMHeadModel(nn.Module):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
...
@@ -318,6 +348,9 @@ class QWenLMHeadModel(nn.Module):
...
@@ -318,6 +348,9 @@ class QWenLMHeadModel(nn.Module):
"Only text inputs are allowed. Images won't be handled "
"Only text inputs are allowed. Images won't be handled "
"until Qwen-VL models are fully supported."
)
"until Qwen-VL models are fully supported."
)
continue
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
...
vllm/model_executor/models/qwen2.py
View file @
e661d594
...
@@ -32,7 +32,7 @@ import re
...
@@ -32,7 +32,7 @@ import re
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
MergedColumnParallelLinear
,
...
@@ -51,6 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -51,6 +51,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
.interfaces
import
SupportsLoRA
from
.interfaces
import
SupportsLoRA
from
.utils
import
is_pp_missing_parameter
,
make_layers
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
from
vllm.model_executor.utils
import
pad_weight
,
gemm_bank_conf
...
@@ -234,6 +235,7 @@ class Qwen2Model(nn.Module):
...
@@ -234,6 +235,7 @@ class Qwen2Model(nn.Module):
config
:
Qwen2Config
,
config
:
Qwen2Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -244,30 +246,52 @@ class Qwen2Model(nn.Module):
...
@@ -244,30 +246,52 @@ class Qwen2Model(nn.Module):
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
)
)
self
.
layers
=
nn
.
ModuleList
([
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
Qwen2DecoderLayer
(
config
,
cache_config
,
quant_config
)
config
.
num_hidden_layers
,
for
_
in
range
(
config
.
num_hidden_layers
)
lambda
prefix
:
Qwen2DecoderLayer
(
config
=
config
,
])
cache_config
=
cache_config
,
quant_config
=
quant_config
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
get_pp_group
().
is_first_rank
:
residual
=
None
if
inputs_embeds
is
not
None
:
for
i
in
range
(
len
(
self
.
layers
)):
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
hidden_states
,
hidden_states
,
kv_caches
[
i
],
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
,
attn_metadata
,
residual
,
residual
,
)
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
...
@@ -350,7 +374,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
...
@@ -350,7 +374,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
@@ -359,6 +383,20 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
...
@@ -359,6 +383,20 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
"residual"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
})
def
sample
(
def
sample
(
self
,
self
,
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
...
@@ -389,6 +427,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
...
@@ -389,6 +427,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
...
@@ -401,7 +441,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
...
@@ -401,7 +441,8 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
if
name
is
None
:
continue
continue
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
...
vllm/model_executor/models/qwen2_moe.py
View file @
e661d594
...
@@ -31,7 +31,8 @@ from transformers import PretrainedConfig
...
@@ -31,7 +31,8 @@ from transformers import PretrainedConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_world_size
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
...
@@ -52,6 +53,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -52,6 +53,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.utils
import
print_warning_once
from
vllm.utils
import
print_warning_once
from
.utils
import
is_pp_missing_parameter
,
make_layers
class
Qwen2MoeMLP
(
nn
.
Module
):
class
Qwen2MoeMLP
(
nn
.
Module
):
...
@@ -315,6 +318,7 @@ class Qwen2MoeModel(nn.Module):
...
@@ -315,6 +318,7 @@ class Qwen2MoeModel(nn.Module):
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
self
.
padding_idx
=
config
.
pad_token_id
...
@@ -324,13 +328,15 @@ class Qwen2MoeModel(nn.Module):
...
@@ -324,13 +328,15 @@ class Qwen2MoeModel(nn.Module):
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
)
)
self
.
layers
=
nn
.
ModuleList
([
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
Qwen2MoeDecoderLayer
(
config
,
config
.
num_hidden_layers
,
layer_idx
,
lambda
prefix
:
Qwen2MoeDecoderLayer
(
config
=
config
,
cache_config
,
layer_idx
=
int
(
quant_config
=
quant_config
)
prefix
.
split
(
"."
)[
-
1
]),
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
cache_config
=
cache_config
,
])
quant_config
=
quant_config
),
prefix
=
f
"
{
prefix
}
.layers"
,
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
def
forward
(
...
@@ -339,14 +345,25 @@ class Qwen2MoeModel(nn.Module):
...
@@ -339,14 +345,25 @@ class Qwen2MoeModel(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
if
get_pp_group
().
is_first_rank
:
residual
=
None
hidden_states
=
self
.
embed_tokens
(
input_ids
)
for
i
in
range
(
len
(
self
.
layers
)):
residual
=
None
else
:
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
attn_metadata
,
kv_caches
[
i
-
self
.
start_layer
],
residual
)
attn_metadata
,
residual
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
,
"residual"
:
residual
})
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
return
hidden_states
...
@@ -380,7 +397,7 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -380,7 +397,7 @@ class Qwen2MoeForCausalLM(nn.Module):
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
...
@@ -389,6 +406,20 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -389,6 +406,20 @@ class Qwen2MoeForCausalLM(nn.Module):
sampling_metadata
)
sampling_metadata
)
return
logits
return
logits
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
"residual"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
})
def
sample
(
def
sample
(
self
,
self
,
logits
:
Optional
[
torch
.
Tensor
],
logits
:
Optional
[
torch
.
Tensor
],
...
@@ -435,6 +466,9 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -435,6 +466,9 @@ class Qwen2MoeForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
if
name
not
in
params_dict
:
if
name
not
in
params_dict
:
continue
continue
...
@@ -448,6 +482,9 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -448,6 +482,9 @@ class Qwen2MoeForCausalLM(nn.Module):
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
weight_loader
(
param
,
...
@@ -460,6 +497,9 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -460,6 +497,9 @@ class Qwen2MoeForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
# Skip layers on other devices.
if
is_pp_missing_parameter
(
name
,
self
):
continue
# Remapping the name of FP8 kv-scale.
# Remapping the name of FP8 kv-scale.
if
name
.
endswith
(
"kv_scale"
):
if
name
.
endswith
(
"kv_scale"
):
remapped_kv_scale_name
=
name
.
replace
(
remapped_kv_scale_name
=
name
.
replace
(
...
@@ -474,7 +514,6 @@ class Qwen2MoeForCausalLM(nn.Module):
...
@@ -474,7 +514,6 @@ class Qwen2MoeForCausalLM(nn.Module):
continue
continue
else
:
else
:
name
=
remapped_kv_scale_name
name
=
remapped_kv_scale_name
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
...
vllm/model_executor/models/siglip.py
0 → 100644
View file @
e661d594
"""Implementation of SiglipVisionModel intended to be only used
within a vision language model."""
import
math
from
typing
import
Optional
,
Tuple
import
torch
from
PIL
import
Image
from
torch
import
nn
from
transformers
import
SiglipConfig
,
SiglipVisionConfig
from
transformers.models.siglip.modeling_siglip
import
SiglipAttention
from
vllm_flash_attn
import
flash_attn_func
from
xformers.ops
import
memory_efficient_attention
from
vllm.config
import
ModelConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.inputs
import
LLMInputs
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.multimodal.image
import
(
cached_get_tokenizer
,
repeat_and_pad_image_tokens
)
from
vllm.sequence
import
SequenceData
def
get_siglip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
assert
image_size
%
patch_size
==
0
return
image_size
//
patch_size
def
get_siglip_num_patches
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
grid_length
=
get_siglip_patch_grid_length
(
image_size
=
image_size
,
patch_size
=
patch_size
)
return
grid_length
*
grid_length
def
get_siglip_image_feature_size
(
hf_config
:
SiglipVisionConfig
)
->
int
:
return
get_siglip_num_patches
(
image_size
=
hf_config
.
image_size
,
patch_size
=
hf_config
.
patch_size
)
def
get_max_siglip_image_tokens
(
hf_config
:
SiglipVisionConfig
)
->
int
:
return
get_siglip_image_feature_size
(
hf_config
)
def
dummy_seq_data_for_siglip
(
hf_config
:
SiglipVisionConfig
,
seq_len
:
int
,
*
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
):
if
image_feature_size_override
is
None
:
image_feature_size
=
get_siglip_image_feature_size
(
hf_config
)
else
:
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
)
return
SequenceData
(
token_ids
)
def
dummy_image_for_siglip
(
hf_config
:
SiglipVisionConfig
,
*
,
image_width_override
:
Optional
[
int
]
=
None
,
image_height_override
:
Optional
[
int
]
=
None
,
):
width
=
height
=
hf_config
.
image_size
if
image_width_override
is
not
None
:
width
=
image_width_override
if
image_height_override
is
not
None
:
height
=
image_height_override
image
=
Image
.
new
(
"RGB"
,
(
width
,
height
),
color
=
0
)
return
{
"image"
:
image
}
def
input_processor_for_siglip
(
model_config
:
ModelConfig
,
hf_config
:
SiglipVisionConfig
,
llm_inputs
:
LLMInputs
,
*
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
):
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
return
llm_inputs
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
)
if
image_feature_size_override
is
None
:
image_feature_size
=
get_siglip_image_feature_size
(
hf_config
)
else
:
image_feature_size
=
image_feature_size_override
new_prompt
,
new_token_ids
=
repeat_and_pad_image_tokens
(
tokenizer
,
llm_inputs
.
get
(
"prompt"
),
llm_inputs
[
"prompt_token_ids"
],
image_token_id
=
image_token_id
,
repeat_count
=
image_feature_size
,
)
# NOTE: Create a defensive copy of the original inputs
return
LLMInputs
(
prompt_token_ids
=
new_token_ids
,
prompt
=
new_prompt
,
multi_modal_data
=
multi_modal_data
,
)
# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
class
SiglipVisionEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
config
:
SiglipVisionConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
self
.
patch_embedding
=
nn
.
Conv2d
(
in_channels
=
config
.
num_channels
,
out_channels
=
self
.
embed_dim
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
,
padding
=
"valid"
,
)
self
.
num_patches
=
(
self
.
image_size
//
self
.
patch_size
)
**
2
self
.
num_positions
=
self
.
num_patches
self
.
position_embedding
=
VocabParallelEmbedding
(
self
.
num_positions
,
self
.
embed_dim
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
self
.
num_positions
,
dtype
=
torch
.
int64
).
expand
(
(
1
,
-
1
)),
persistent
=
False
,
)
def
interpolate_pos_encoding
(
self
,
embeddings
:
torch
.
Tensor
,
height
:
int
,
width
:
int
)
->
torch
.
Tensor
:
"""
This method is an adapted method for SigLIP (due to SigLIP not having
class embedding unlike other ViTs) that allows the model to interpolate
the pre-trained position encodings such that it can be usable on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
position_embeddings
=
self
.
position_embedding
.
weight
.
unsqueeze
(
0
)
num_patches
=
embeddings
.
shape
[
1
]
num_positions
=
position_embeddings
.
shape
[
1
]
if
num_patches
==
num_positions
and
height
==
width
:
return
position_embeddings
dim
=
embeddings
.
shape
[
-
1
]
height
=
height
//
self
.
patch_size
width
=
width
//
self
.
patch_size
# we add a small number to avoid floating point error
# in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
height
,
width
=
height
+
0.1
,
width
+
0.1
patch_pos_embed
=
position_embeddings
.
reshape
(
1
,
int
(
math
.
sqrt
(
num_positions
)),
int
(
math
.
sqrt
(
num_positions
)),
dim
)
patch_pos_embed
=
patch_pos_embed
.
permute
(
0
,
3
,
1
,
2
)
patch_pos_embed
=
nn
.
functional
.
interpolate
(
patch_pos_embed
,
scale_factor
=
(
height
/
math
.
sqrt
(
num_positions
),
width
/
math
.
sqrt
(
num_positions
),
),
mode
=
"bicubic"
,
align_corners
=
False
,
)
if
(
int
(
height
)
!=
patch_pos_embed
.
shape
[
-
2
]
or
int
(
width
)
!=
patch_pos_embed
.
shape
[
-
1
]):
raise
ValueError
(
"Width or height does not match with "
"the interpolated position embeddings"
)
patch_pos_embed
=
patch_pos_embed
.
permute
(
0
,
2
,
3
,
1
).
view
(
1
,
-
1
,
dim
)
return
patch_pos_embed
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
interpolate_pos_encoding
:
bool
=
False
)
->
torch
.
Tensor
:
_
,
_
,
height
,
width
=
pixel_values
.
shape
target_dtype
=
self
.
patch_embedding
.
weight
.
dtype
patch_embeds
=
self
.
patch_embedding
(
pixel_values
.
to
(
dtype
=
target_dtype
))
# shape = [*, width, grid, grid]
embeddings
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
if
interpolate_pos_encoding
:
embeddings
=
embeddings
+
self
.
interpolate_pos_encoding
(
embeddings
,
height
,
width
)
else
:
embeddings
=
embeddings
+
self
.
position_embedding
(
self
.
position_ids
)
return
embeddings
# NOTE: Not used - kept for later when we TP the ViT
# TODO(ChristopherCho): Implement TP version of Attention
class
SiglipTPAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
config
.
num_attention_heads
if
self
.
total_num_heads
%
tp_size
!=
0
:
raise
ValueError
(
f
"Number of attention heads (
{
self
.
total_num_heads
}
) "
"must be divisible by the tensor model parallel size"
f
" (
{
tp_size
}
)."
)
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
head_dim
=
self
.
embed_dim
//
self
.
total_num_heads
if
self
.
head_dim
*
self
.
total_num_heads
!=
self
.
embed_dim
:
raise
ValueError
(
f
"embed_dim must be divisible by num_heads (got "
"`embed_dim`: {self.embed_dim} and `num_heads`:"
f
"
{
self
.
num_heads
}
)."
)
self
.
qkv_size
=
self
.
num_heads
*
self
.
head_dim
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
dropout
=
config
.
attention_dropout
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
self
.
embed_dim
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
total_num_heads
,
quant_config
=
quant_config
,
)
self
.
out_proj
=
RowParallelLinear
(
input_size
=
self
.
embed_dim
,
output_size
=
self
.
embed_dim
,
quant_config
=
quant_config
,
)
self
.
attn_fn
=
self
.
_basic_attention_forward
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Input shape: Batch x Time x Channel"""
batch_size
,
q_len
,
_
=
hidden_states
.
size
()
qkv_states
,
_
=
self
.
qkv_proj
(
hidden_states
)
query_states
,
key_states
,
value_states
=
qkv_states
.
split
(
[
self
.
qkv_size
]
*
3
,
dim
=-
1
)
attn_output
=
self
.
attn_fn
(
q
=
query_states
,
k
=
key_states
,
v
=
value_states
,
batch_size
=
batch_size
,
q_len
=
q_len
,
)
attn_output
,
_
=
self
.
out_proj
(
attn_output
)
return
attn_output
def
_basic_attention_forward
(
self
,
q
,
k
,
v
,
batch_size
,
q_len
):
q
=
q
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
k
=
k
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
v
=
v
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
k_v_seq_len
=
k
.
shape
[
-
2
]
attn_weights
=
torch
.
matmul
(
q
,
k
.
transpose
(
2
,
3
))
*
self
.
scale
if
attn_weights
.
size
()
!=
(
batch_size
,
self
.
num_heads
,
q_len
,
k_v_seq_len
,
):
raise
ValueError
(
"Attention weights should be of size "
f
"
{
(
batch_size
,
self
.
num_heads
,
q_len
,
k_v_seq_len
)
}
, but is"
f
"
{
attn_weights
.
size
()
}
"
)
# upcast attention to fp32
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
,
dim
=-
1
,
dtype
=
torch
.
float32
).
to
(
q
.
dtype
)
attn_weights
=
nn
.
functional
.
dropout
(
attn_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
attn_output
=
torch
.
matmul
(
attn_weights
,
v
)
if
attn_output
.
size
()
!=
(
batch_size
,
self
.
num_heads
,
q_len
,
self
.
head_dim
,
):
raise
ValueError
(
"`attn_output` should be of size "
f
"
{
(
batch_size
,
self
.
num_heads
,
q_len
,
self
.
head_dim
)
}
, but is"
f
"
{
attn_output
.
size
()
}
"
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
batch_size
,
q_len
,
self
.
embed_dim
)
return
attn_output
# NOTE: Not used - kept for later when we TP the ViT
# TODO(ChristopherCho): flash_attn_func is not working properly.
# It constantly throws a CUDA error.
class
SiglipFlashAttention2
(
SiglipTPAttention
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
attn_fn
=
self
.
_flash_attention_forward
# Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449
# and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133
def
_flash_attention_forward
(
self
,
q
,
k
,
v
,
batch_size
,
q_len
,
*
args
,
**
kwargs
):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the
query, key, and value. (B, S, H, D)
"""
q
=
q
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
k
=
k
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
v
=
v
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
attn_output
=
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
self
.
dropout
,
causal
=
False
,
)
attn_output
=
attn_output
.
reshape
(
batch_size
,
q_len
,
self
.
embed_dim
).
contiguous
()
return
attn_output
# NOTE: Not used - kept for later when we TP the ViT
class
SiglipSdpaAttention
(
SiglipTPAttention
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
is_causal
=
False
self
.
attn_fn
=
self
.
_sdpa_attention_forward
def
_sdpa_attention_forward
(
self
,
q
,
k
,
v
,
batch_size
,
q_len
):
q
=
q
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
k
=
k
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
v
=
v
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
attn_output
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q
,
k
,
v
,
dropout_p
=
self
.
dropout
,
is_causal
=
False
,
scale
=
self
.
scale
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
view
(
batch_size
,
q_len
,
self
.
embed_dim
)
return
attn_output
# NOTE: Not used - kept for later when we TP the ViT
class
SiglipxFormersAttention
(
SiglipTPAttention
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
attn_fn
=
self
.
_xformers_attention_forward
def
_xformers_attention_forward
(
self
,
q
,
k
,
v
,
batch_size
,
q_len
):
q
=
q
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
k
=
k
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
v
=
v
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
attn_output
=
memory_efficient_attention
(
q
,
k
,
v
,
p
=
0.0
,
scale
=
self
.
scale
)
attn_output
=
attn_output
.
reshape
(
batch_size
,
q_len
,
self
.
embed_dim
).
contiguous
()
return
attn_output
# NOTE: Not used - kept for later when we TP the ViT
SIGLIP_ATTENTION_CLASSES
=
{
"eager"
:
SiglipTPAttention
,
"flash_attention_2"
:
SiglipFlashAttention2
,
"sdpa"
:
SiglipSdpaAttention
,
"xformers"
:
SiglipxFormersAttention
,
}
class
SiglipMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
activation_fn
=
get_act_fn
(
config
.
hidden_act
)
# For quantization, we require the hidden size to be a multiple of 64
quantizable
=
(
config
.
hidden_size
%
64
==
0
and
config
.
intermediate_size
%
64
==
0
)
self
.
fc1
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
quant_config
=
quant_config
if
quantizable
else
None
,
)
self
.
fc2
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
quant_config
=
quant_config
if
quantizable
else
None
,
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
,
_
=
self
.
fc2
(
hidden_states
)
return
hidden_states
class
SiglipEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
SiglipConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
# TODO(ChristopherCho): use TP'ed Attention block
self
.
self_attn
=
SiglipAttention
(
config
)
self
.
layer_norm1
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
SiglipMLP
(
config
,
quant_config
=
quant_config
,
)
self
.
layer_norm2
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
]:
residual
=
hidden_states
hidden_states
=
self
.
layer_norm1
(
hidden_states
)
hidden_states
,
_
=
self
.
self_attn
(
hidden_states
=
hidden_states
)
hidden_states
=
residual
+
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
layer_norm2
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
,
None
class
SiglipEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
:
SiglipConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
layers
=
nn
.
ModuleList
([
SiglipEncoderLayer
(
config
,
quant_config
=
quant_config
,
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
def
forward
(
self
,
inputs_embeds
:
torch
.
Tensor
,
)
->
Tuple
:
hidden_states
=
inputs_embeds
for
encoder_layer
in
self
.
layers
:
hidden_states
,
_
=
encoder_layer
(
hidden_states
)
return
hidden_states
class
SiglipMultiheadAttentionPoolingHead
(
nn
.
Module
):
"""Multihead Attention Pooling."""
def
__init__
(
self
,
config
:
SiglipVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
probe
=
nn
.
Parameter
(
torch
.
randn
(
1
,
1
,
config
.
hidden_size
))
# TODO(ChristopherCho): Implement vLLM version of MultiheadAttention
self
.
attention
=
torch
.
nn
.
MultiheadAttention
(
config
.
hidden_size
,
config
.
num_attention_heads
,
batch_first
=
True
)
self
.
layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
SiglipMLP
(
config
=
config
,
quant_config
=
quant_config
)
def
forward
(
self
,
hidden_state
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
=
hidden_state
.
shape
[
0
]
probe
=
self
.
probe
.
repeat
(
batch_size
,
1
,
1
)
hidden_state
=
self
.
attention
(
probe
,
hidden_state
,
hidden_state
)[
0
]
residual
=
hidden_state
hidden_state
=
self
.
layernorm
(
hidden_state
)
hidden_state
=
residual
+
self
.
mlp
(
hidden_state
)
return
hidden_state
[:,
0
]
class
SiglipVisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
SiglipVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
embed_dim
=
config
.
hidden_size
self
.
embeddings
=
SiglipVisionEmbeddings
(
config
)
self
.
encoder
=
SiglipEncoder
(
config
,
quant_config
=
quant_config
,
)
self
.
post_layernorm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
use_head
=
(
True
if
not
hasattr
(
config
,
"vision_use_head"
)
else
config
.
vision_use_head
)
if
self
.
use_head
:
self
.
head
=
SiglipMultiheadAttentionPoolingHead
(
config
=
config
,
quant_config
=
quant_config
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
interpolate_pos_encoding
:
bool
=
True
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embeddings
(
pixel_values
,
interpolate_pos_encoding
=
interpolate_pos_encoding
,
)
encoder_outputs
=
self
.
encoder
(
inputs_embeds
=
hidden_states
)
last_hidden_state
=
self
.
post_layernorm
(
encoder_outputs
)
# TODO: add this back when pooled_output is used in inference
# if self.use_head:
# pooled_output = self.head(last_hidden_state)
return
last_hidden_state
class
SiglipVisionModel
(
nn
.
Module
):
config_class
=
SiglipVisionConfig
main_input_name
=
"pixel_values"
def
__init__
(
self
,
config
:
SiglipVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
()
self
.
vision_model
=
SiglipVisionTransformer
(
config
,
quant_config
,
)
def
get_input_embeddings
(
self
)
->
nn
.
Module
:
return
self
.
vision_model
.
embeddings
.
patch_embedding
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
interpolate_pos_encoding
:
bool
=
False
,
)
->
torch
.
Tensor
:
return
self
.
vision_model
(
pixel_values
=
pixel_values
,
interpolate_pos_encoding
=
interpolate_pos_encoding
,
)
vllm/model_executor/models/utils.py
View file @
e661d594
...
@@ -87,6 +87,7 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
...
@@ -87,6 +87,7 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
# offload parameters to CPU
# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
# use pin_memory if possible, which helps cudagraph capture speed
offloaded_parameters
=
False
for
p
in
module
.
parameters
():
for
p
in
module
.
parameters
():
if
_CPU_OFFLOAD_BYTES
>=
_CPU_OFFLOAD_MAX_BYTES
:
if
_CPU_OFFLOAD_BYTES
>=
_CPU_OFFLOAD_MAX_BYTES
:
# we use per-parameter offloading
# we use per-parameter offloading
...
@@ -94,35 +95,36 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
...
@@ -94,35 +95,36 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
break
break
# `torch.empty_like` does not support `pin_memory` argument
# `torch.empty_like` does not support `pin_memory` argument
cpu_data
=
torch
.
empty
(
size
=
p
.
data
.
size
(),
cpu_data
=
torch
.
empty_strided
(
size
=
p
.
data
.
size
(),
dtype
=
p
.
data
.
dtype
,
stride
=
p
.
data
.
stride
(),
layout
=
p
.
data
.
layout
,
dtype
=
p
.
data
.
dtype
,
device
=
'cpu'
,
layout
=
p
.
data
.
layout
,
pin_memory
=
pin_memory
)
device
=
'cpu'
,
pin_memory
=
pin_memory
)
cpu_data
.
copy_
(
p
.
data
)
cpu_data
.
copy_
(
p
.
data
)
p
.
data
=
cpu_data
p
.
data
=
cpu_data
_CPU_OFFLOAD_BYTES
+=
p
.
data
.
numel
()
*
p
.
data
.
element_size
()
_CPU_OFFLOAD_BYTES
+=
p
.
data
.
numel
()
*
p
.
data
.
element_size
()
offloaded_parameters
=
True
if
offloaded_parameters
:
original_forward
=
module
.
forward
def
forward
(
*
args
,
**
kwargs
):
module
.
forward
=
original_forward
device_state
=
{
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k
:
v
.
to
(
device
,
non_blocking
=
True
)
for
k
,
v
in
module
.
state_dict
().
items
()
}
output
=
functional_call
(
module
,
device_state
,
args
=
args
,
kwargs
=
kwargs
)
module
.
forward
=
forward
return
output
state_dict
:
Dict
[
str
,
torch
.
Tensor
]
=
module
.
state_dict
()
original_forward
=
module
.
forward
def
forward
(
*
args
,
**
kwargs
):
module
.
forward
=
original_forward
device_state
=
{
# here we blindly call `to(device)`
# if the parameter is already on the device, it will be a no-op
k
:
v
.
to
(
device
,
non_blocking
=
True
)
for
k
,
v
in
state_dict
.
items
()
}
output
=
functional_call
(
module
,
device_state
,
args
=
args
,
kwargs
=
kwargs
)
module
.
forward
=
forward
module
.
forward
=
forward
return
output
module
.
forward
=
forward
return
module
return
module
...
...
vllm/model_executor/sampling_metadata.py
View file @
e661d594
import
random
import
random
from
array
import
array
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
vllm.model_executor.layers.ops.sample
import
get_num_triton_sampler_splits
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SequenceData
,
SequenceGroupMetadata
from
vllm.triton_utils.sample
import
get_num_triton_sampler_splits
from
vllm.utils
import
(
async_tensor_h2d
,
is_pin_memory_available
,
from
vllm.utils
import
(
async_tensor_h2d
,
is_pin_memory_available
,
make_tensor_with_pad
,
maybe_expand_dim
)
make_tensor_with_pad
,
maybe_expand_dim
)
_SAMPLING_EPS
=
1e-5
_SAMPLING_EPS
=
1e-5
_SEED_0_REPLACEMENT
=
3403598558
_SEED_0_REPLACEMENT
=
3403598558
# Some triton sampler related code is guarded before it is ready.
_USE_TRITON_SAMPLER
=
False
@
dataclass
@
dataclass
...
@@ -117,6 +120,7 @@ class SamplingMetadata:
...
@@ -117,6 +120,7 @@ class SamplingMetadata:
query_lens
:
Optional
[
List
[
int
]],
query_lens
:
Optional
[
List
[
int
]],
device
:
str
,
device
:
str
,
pin_memory
:
bool
,
pin_memory
:
bool
,
generators
:
Optional
[
Dict
[
str
,
torch
.
Generator
]]
=
None
,
)
->
"SamplingMetadata"
:
)
->
"SamplingMetadata"
:
(
(
seq_groups
,
seq_groups
,
...
@@ -124,7 +128,7 @@ class SamplingMetadata:
...
@@ -124,7 +128,7 @@ class SamplingMetadata:
categorized_sample_indices
,
categorized_sample_indices
,
num_prompts
,
num_prompts
,
)
=
_prepare_seq_groups
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
)
=
_prepare_seq_groups
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
device
)
device
,
generators
)
selected_token_indices
=
async_tensor_h2d
(
selected_token_indices
,
selected_token_indices
=
async_tensor_h2d
(
selected_token_indices
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
target_device
=
device
,
target_device
=
device
,
...
@@ -159,6 +163,7 @@ def _prepare_seq_groups(
...
@@ -159,6 +163,7 @@ def _prepare_seq_groups(
seq_lens
:
List
[
int
],
seq_lens
:
List
[
int
],
query_lens
:
Optional
[
List
[
int
]],
query_lens
:
Optional
[
List
[
int
]],
device
:
str
,
device
:
str
,
generators
:
Optional
[
Dict
[
str
,
torch
.
Generator
]]
=
None
,
)
->
Tuple
[
List
[
SequenceGroupToSample
],
List
[
int
],
Dict
[
)
->
Tuple
[
List
[
SequenceGroupToSample
],
List
[
int
],
Dict
[
SamplingType
,
List
[
Tuple
[
int
,
int
]]],
int
]:
SamplingType
,
List
[
Tuple
[
int
,
int
]]],
int
]:
"""Prepare sequence groups and indices for sampling.
"""Prepare sequence groups and indices for sampling.
...
@@ -169,8 +174,10 @@ def _prepare_seq_groups(
...
@@ -169,8 +174,10 @@ def _prepare_seq_groups(
Index of prompt len should match with seq_group_metadata_list.
Index of prompt len should match with seq_group_metadata_list.
query_lens: A list of query lengths. Prompt lens include the length
query_lens: A list of query lengths. Prompt lens include the length
of entire prompt tokens, and it could be shorter.
of entire prompt tokens, and it could be shorter.
device: A device to use for random number generator,
device: A device to use for random number generator
s
,
`SequenceGroupToSample.generator`.
`SequenceGroupToSample.generator`.
generators: A store of per-request random number generators used
for seeded requests.
Returns:
Returns:
seq_groups: A list of sequence group to sample.
seq_groups: A list of sequence group to sample.
...
@@ -216,8 +223,10 @@ def _prepare_seq_groups(
...
@@ -216,8 +223,10 @@ def _prepare_seq_groups(
if
seq_group_metadata
.
is_prompt
:
if
seq_group_metadata
.
is_prompt
:
if
sampling_params
.
seed
is
not
None
:
if
sampling_params
.
seed
is
not
None
:
seq_group_metadata
.
state
.
generator
=
torch
.
Generator
(
generator
=
torch
.
Generator
(
device
=
device
).
manual_seed
(
device
=
device
).
manual_seed
(
sampling_params
.
seed
)
sampling_params
.
seed
)
if
generators
is
not
None
:
generators
[
seq_group_metadata
.
request_id
]
=
generator
num_prompts
+=
1
num_prompts
+=
1
num_prefill_sample
=
len
(
seq_ids
)
num_prefill_sample
=
len
(
seq_ids
)
...
@@ -234,6 +243,9 @@ def _prepare_seq_groups(
...
@@ -234,6 +243,9 @@ def _prepare_seq_groups(
prompt_logprob_len
=
0
prompt_logprob_len
=
0
sample_len
=
len
(
seq_ids
)
if
do_sample
else
0
sample_len
=
len
(
seq_ids
)
if
do_sample
else
0
if
sampling_params
.
seed
is
not
None
and
generators
is
not
None
:
generator
=
generators
.
get
(
seq_group_metadata
.
request_id
)
# Update indices to select from the model output.
# Update indices to select from the model output.
"""
"""
This blocks computes selected_token_indices which is used in the
This blocks computes selected_token_indices which is used in the
...
@@ -278,9 +290,6 @@ def _prepare_seq_groups(
...
@@ -278,9 +290,6 @@ def _prepare_seq_groups(
logit_idx
+=
sample_len
logit_idx
+=
sample_len
sample_idx
+=
sample_len
sample_idx
+=
sample_len
if
sampling_params
.
seed
is
not
None
:
generator
=
seq_group_metadata
.
state
.
generator
seq_groups
.
append
(
seq_groups
.
append
(
SequenceGroupToSample
(
SequenceGroupToSample
(
seq_ids
=
seq_ids
,
seq_ids
=
seq_ids
,
...
@@ -329,8 +338,8 @@ class SamplingTensors:
...
@@ -329,8 +338,8 @@ class SamplingTensors:
user-defined seed for each sequence.
user-defined seed for each sequence.
extra_entropy: extra entropy to use when generating seeds.
extra_entropy: extra entropy to use when generating seeds.
"""
"""
prompt_tokens
:
List
[
List
[
int
]
]
=
[]
prompt_tokens
:
List
[
array
]
=
[]
output_tokens
:
List
[
List
[
int
]
]
=
[]
output_tokens
:
List
[
array
]
=
[]
top_ks
:
List
[
int
]
=
[]
top_ks
:
List
[
int
]
=
[]
temperatures
:
List
[
float
]
=
[]
temperatures
:
List
[
float
]
=
[]
top_ps
:
List
[
float
]
=
[]
top_ps
:
List
[
float
]
=
[]
...
@@ -340,14 +349,16 @@ class SamplingTensors:
...
@@ -340,14 +349,16 @@ class SamplingTensors:
repetition_penalties
:
List
[
float
]
=
[]
repetition_penalties
:
List
[
float
]
=
[]
sampling_seeds
:
List
[
int
]
=
[]
sampling_seeds
:
List
[
int
]
=
[]
sample_indices
:
List
[
int
]
=
[]
sample_indices
:
List
[
int
]
=
[]
prompt_best_of
:
List
[
int
]
=
[]
do_penalties
=
False
do_penalties
=
False
do_top_p_top_k
=
False
do_top_p_top_k
=
False
do_min_p
=
False
do_min_p
=
False
# We need one base seed per Triton slice.
if
_USE_TRITON_SAMPLER
:
seeds_to_generate
=
(
extra_seeds_to_generate
+
prompt_best_of
:
List
[
int
]
=
[]
get_num_triton_sampler_splits
(
vocab_size
))
# We need one base seed per Triton slice.
seeds_to_generate
=
(
extra_seeds_to_generate
+
get_num_triton_sampler_splits
(
vocab_size
))
assert
sampling_metadata
.
seq_groups
is
not
None
assert
sampling_metadata
.
seq_groups
is
not
None
for
seq_group
in
sampling_metadata
.
seq_groups
:
for
seq_group
in
sampling_metadata
.
seq_groups
:
...
@@ -359,9 +370,6 @@ class SamplingTensors:
...
@@ -359,9 +370,6 @@ class SamplingTensors:
r
=
sampling_params
.
repetition_penalty
r
=
sampling_params
.
repetition_penalty
top_p
=
sampling_params
.
top_p
top_p
=
sampling_params
.
top_p
min_p
=
sampling_params
.
min_p
min_p
=
sampling_params
.
min_p
seed
=
sampling_params
.
seed
is_greedy
=
sampling_params
.
sampling_type
==
SamplingType
.
GREEDY
# k should not be greater than the vocab size.
# k should not be greater than the vocab size.
top_k
=
min
(
sampling_params
.
top_k
,
vocab_size
)
top_k
=
min
(
sampling_params
.
top_k
,
vocab_size
)
...
@@ -382,8 +390,7 @@ class SamplingTensors:
...
@@ -382,8 +390,7 @@ class SamplingTensors:
do_penalties
=
True
do_penalties
=
True
is_prompt
=
seq_group
.
is_prompt
is_prompt
=
seq_group
.
is_prompt
if
(
seq_group
.
is_prompt
if
(
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
):
and
sampling_params
.
prompt_logprobs
is
not
None
):
# For tokens in the prompt that we only need to get
# For tokens in the prompt that we only need to get
# their logprobs
# their logprobs
query_len
=
seq_group
.
query_len
query_len
=
seq_group
.
query_len
...
@@ -408,23 +415,27 @@ class SamplingTensors:
...
@@ -408,23 +415,27 @@ class SamplingTensors:
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
frequency_penalties
+=
[
f
]
*
len
(
seq_ids
)
repetition_penalties
+=
[
r
]
*
len
(
seq_ids
)
repetition_penalties
+=
[
r
]
*
len
(
seq_ids
)
if
is_prompt
:
if
_USE_TRITON_SAMPLER
:
prompt_best_of
.
append
(
sampling_params
.
best_of
)
if
is_prompt
:
query_len
=
seq_group
.
query_len
prompt_best_of
.
append
(
sampling_params
.
best_of
)
assert
query_len
is
not
None
query_len
=
seq_group
.
query_len
assert
query_len
is
not
None
for
seq_id
in
seq_ids
:
seq_data
=
seq_group
.
seq_data
[
seq_id
]
seed
=
sampling_params
.
seed
extra_entropy
=
extra_entropy
or
()
is_greedy
=
sampling_params
.
sampling_type
==
SamplingType
.
GREEDY
seq_seeds
=
cls
.
_get_sequence_seeds
(
seed
,
for
seq_id
in
seq_ids
:
seq_data
.
get_len
(),
seq_data
=
seq_group
.
seq_data
[
seq_id
]
*
extra_entropy
,
extra_entropy
=
extra_entropy
or
()
seq_id
,
seq_seeds
=
cls
.
_get_sequence_seeds
(
seeds_to_generate
=
seeds_to_generate
,
seed
,
is_greedy
=
is_greedy
)
seq_data
.
get_len
(),
sampling_seeds
.
append
(
seq_seeds
)
*
extra_entropy
,
sample_indices
.
extend
(
seq_group
.
sample_indices
)
seq_id
,
seeds_to_generate
=
seeds_to_generate
,
is_greedy
=
is_greedy
)
sampling_seeds
.
append
(
seq_seeds
)
sample_indices
.
extend
(
seq_group
.
sample_indices
)
if
do_penalties
:
if
do_penalties
:
for
seq_group
in
sampling_metadata
.
seq_groups
:
for
seq_group
in
sampling_metadata
.
seq_groups
:
...
@@ -432,13 +443,15 @@ class SamplingTensors:
...
@@ -432,13 +443,15 @@ class SamplingTensors:
if
(
seq_group
.
is_prompt
if
(
seq_group
.
is_prompt
and
sampling_params
.
prompt_logprobs
is
not
None
):
and
sampling_params
.
prompt_logprobs
is
not
None
):
prefill_len
=
len
(
seq_group
.
prompt_logprob_indices
)
prefill_len
=
len
(
seq_group
.
prompt_logprob_indices
)
prompt_tokens
.
extend
([]
for
_
in
range
(
prefill_len
))
prompt_tokens
.
extend
(
output_tokens
.
extend
([]
for
_
in
range
(
prefill_len
))
array
(
'l'
)
for
_
in
range
(
prefill_len
))
output_tokens
.
extend
(
array
(
'l'
)
for
_
in
range
(
prefill_len
))
if
seq_group
.
do_sample
:
if
seq_group
.
do_sample
:
for
seq_id
in
seq_ids
:
for
seq_id
in
seq_ids
:
seq_data
=
seq_group
.
seq_data
[
seq_id
]
seq_data
=
seq_group
.
seq_data
[
seq_id
]
prompt_tokens
.
append
(
list
(
seq_data
.
prompt_token_ids
)
)
prompt_tokens
.
append
(
seq_data
.
prompt_token_ids
_array
)
output_tokens
.
append
(
list
(
seq_data
.
output_token_ids
)
)
output_tokens
.
append
(
seq_data
.
output_token_ids
_array
)
sampling_tensors
=
SamplingTensors
.
from_lists
(
sampling_tensors
=
SamplingTensors
.
from_lists
(
temperatures
,
top_ps
,
top_ks
,
min_ps
,
presence_penalties
,
temperatures
,
top_ps
,
top_ks
,
min_ps
,
presence_penalties
,
...
@@ -454,9 +467,9 @@ class SamplingTensors:
...
@@ -454,9 +467,9 @@ class SamplingTensors:
frequency_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
repetition_penalties
:
List
[
float
],
repetition_penalties
:
List
[
float
],
sampling_seeds
:
List
[
int
],
sample_indices
:
List
[
int
],
sampling_seeds
:
List
[
int
],
sample_indices
:
List
[
int
],
prompt_tokens
:
List
[
List
[
int
]
],
prompt_tokens
:
List
[
array
],
output_tokens
:
List
[
array
],
output_tokens
:
List
[
List
[
int
]],
vocab_siz
e
:
int
,
vocab_size
:
int
,
extra_seeds_to_generat
e
:
int
,
extra_seeds_to_generate
:
int
,
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
)
->
"SamplingTensors"
:
dtype
:
torch
.
dtype
)
->
"SamplingTensors"
:
# Note that the performance will be very bad without
# Note that the performance will be very bad without
# pinned memory.
# pinned memory.
...
@@ -540,7 +553,7 @@ class SamplingTensors:
...
@@ -540,7 +553,7 @@ class SamplingTensors:
device
=
"cpu"
,
device
=
"cpu"
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
pin_memory
=
pin_memory
,
pin_memory
=
pin_memory
,
).
T
.
contiguous
()
).
t
()
.
contiguous
()
# Because the memory is pinned, we can do non-blocking
# Because the memory is pinned, we can do non-blocking
# transfer to device.
# transfer to device.
...
...
vllm/multimodal/__init__.py
View file @
e661d594
from
.base
import
(
BatchedTensors
,
MultiModalDataBuiltins
,
MultiModalDataDict
,
from
.base
import
(
BatchedTensorInputs
,
BatchedTensors
,
MultiModalDataBuiltins
,
MultiModalInputs
,
MultiModalPlugin
)
MultiModalDataDict
,
MultiModalInputs
,
MultiModalPlugin
,
NestedTensors
)
from
.registry
import
MultiModalRegistry
from
.registry
import
MultiModalRegistry
MULTIMODAL_REGISTRY
=
MultiModalRegistry
()
MULTIMODAL_REGISTRY
=
MultiModalRegistry
()
...
@@ -12,11 +13,13 @@ See also:
...
@@ -12,11 +13,13 @@ See also:
"""
"""
__all__
=
[
__all__
=
[
"BatchedTensorInputs"
,
"BatchedTensors"
,
"BatchedTensors"
,
"MultiModalDataBuiltins"
,
"MultiModalDataBuiltins"
,
"MultiModalDataDict"
,
"MultiModalDataDict"
,
"MultiModalInputs"
,
"MultiModalInputs"
,
"MultiModalPlugin"
,
"MultiModalPlugin"
,
"NestedTensors"
,
"MULTIMODAL_REGISTRY"
,
"MULTIMODAL_REGISTRY"
,
"MultiModalRegistry"
,
"MultiModalRegistry"
,
]
]
vllm/multimodal/base.py
View file @
e661d594
import
sys
import
sys
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections
import
UserDict
,
defaultdict
from
collections
import
UserDict
,
defaultdict
from
typing
import
(
Any
,
Callable
,
Dict
,
List
,
Optional
,
Type
,
TypedDict
,
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
TypeVar
,
Union
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Type
,
TypedDict
,
TypeVar
,
Union
,
cast
import
torch
import
torch
import
torch.types
import
torch.types
from
PIL
import
Image
from
PIL
import
Image
from
torch
import
nn
from
torch
import
nn
from
typing_extensions
import
TypeAlias
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.inputs
import
InputContext
from
vllm.inputs
import
InputContext
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
JSONTree
,
json_map_leaves
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
Batch
edTensors
=
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]
]
Nest
edTensors
=
Union
[
GenericSequence
[
torch
.
Tensor
]
,
torch
.
Tensor
]
"""
"""
If each input tensor in the batch has the same size, this is a single batched
Use a list instead of a tensor if the dimensions of each element do not match.
tensor; otherwise, this is a list of tensors with one element per batch.
Currently only supports up to singly nested list of tensors.
"""
BatchedTensors
:
TypeAlias
=
JSONTree
[
torch
.
Tensor
]
"""
A nested JSON structure of tensors which have been batched via
:meth:`MultiModalInputs.batch`.
"""
BatchedTensorInputs
:
TypeAlias
=
Dict
[
str
,
JSONTree
[
torch
.
Tensor
]]
"""
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalInputs.batch`.
"""
"""
if
sys
.
version_info
<
(
3
,
9
):
if
sys
.
version_info
<
(
3
,
9
):
...
@@ -27,7 +42,7 @@ if sys.version_info < (3, 9):
...
@@ -27,7 +42,7 @@ if sys.version_info < (3, 9):
pass
pass
else
:
else
:
class
_MultiModalInputsBase
(
UserDict
[
str
,
torch
.
Tensor
]):
class
_MultiModalInputsBase
(
UserDict
[
str
,
Nested
Tensor
s
]):
pass
pass
...
@@ -38,33 +53,44 @@ class MultiModalInputs(_MultiModalInputsBase):
...
@@ -38,33 +53,44 @@ class MultiModalInputs(_MultiModalInputsBase):
"""
"""
@
staticmethod
@
staticmethod
def
try_concat
(
def
_try_concat
(
tensors
:
List
[
NestedTensors
])
->
BatchedTensors
:
tensors
:
List
[
torch
.
Tensor
],
"""
*
,
If each input tensor in the batch has the same shape, return a single
device
:
torch
.
types
.
Device
,
batched tensor; otherwise, return a list of :class:`NestedTensors` with
)
->
BatchedTensors
:
one element per item in the batch.
unbatched_shape
=
tensors
[
0
].
shape
[
1
:]
"""
# may be list rather than tensors
if
isinstance
(
tensors
[
0
],
list
):
return
[[
t
for
t
in
tensor
[
0
]]
for
tensor
in
cast
(
List
[
List
[
torch
.
Tensor
]],
tensors
)]
for
tensor
in
tensors
:
tensors_
=
cast
(
List
[
torch
.
Tensor
],
tensors
)
unbatched_shape
=
tensors_
[
0
].
shape
[
1
:]
for
tensor
in
tensors_
:
if
tensor
.
shape
[
1
:]
!=
unbatched_shape
:
if
tensor
.
shape
[
1
:]
!=
unbatched_shape
:
return
[
return
[
tensor
.
squeeze
(
0
)
for
tensor
in
tensors_
]
tensor
.
squeeze
(
0
).
to
(
device
=
device
)
for
tensor
in
tensors
]
return
torch
.
cat
(
tensors
,
dim
=
0
)
.
to
(
device
=
device
)
return
torch
.
cat
(
tensors
_
,
dim
=
0
)
@
staticmethod
@
staticmethod
def
batch
(
def
batch
(
inputs_list
:
List
[
"MultiModalInputs"
])
->
BatchedTensorInputs
:
inputs_list
:
List
[
"MultiModalInputs"
],
"""
device
:
torch
.
types
.
Device
,
Batch multiple inputs together into a dictionary.
)
->
Dict
[
str
,
BatchedTensors
]:
"""Batch multiple inputs together into a dictionary."""
The resulting dictionary has the same keys as the inputs.
If the corresponding value from each input is a tensor and they all
share the same shape, the output value is a single batched tensor;
otherwise, the output value is a list containing the original value
from each input.
"""
if
len
(
inputs_list
)
==
0
:
if
len
(
inputs_list
)
==
0
:
return
{}
return
{}
keys
=
inputs_list
[
0
].
keys
()
keys
=
inputs_list
[
0
].
keys
()
item_lists
:
Dict
[
str
,
List
[
torch
.
Tensor
]]
=
defaultdict
(
list
)
item_lists
:
Dict
[
str
,
List
[
Nested
Tensor
s
]]
=
defaultdict
(
list
)
for
inputs
in
inputs_list
:
for
inputs
in
inputs_list
:
if
inputs
.
keys
()
!=
keys
:
if
inputs
.
keys
()
!=
keys
:
...
@@ -75,10 +101,19 @@ class MultiModalInputs(_MultiModalInputsBase):
...
@@ -75,10 +101,19 @@ class MultiModalInputs(_MultiModalInputsBase):
item_lists
[
k
].
append
(
v
)
item_lists
[
k
].
append
(
v
)
return
{
return
{
k
:
MultiModalInputs
.
try_concat
(
item_list
,
device
=
device
)
k
:
MultiModalInputs
.
_
try_concat
(
item_list
)
for
k
,
item_list
in
item_lists
.
items
()
for
k
,
item_list
in
item_lists
.
items
()
}
}
@
staticmethod
def
as_kwargs
(
batched_inputs
:
BatchedTensorInputs
,
*
,
device
:
torch
.
types
.
Device
,
)
->
BatchedTensorInputs
:
return
json_map_leaves
(
lambda
x
:
x
.
to
(
device
,
non_blocking
=
True
),
batched_inputs
)
class
MultiModalDataBuiltins
(
TypedDict
,
total
=
False
):
class
MultiModalDataBuiltins
(
TypedDict
,
total
=
False
):
"""Modality types that are predefined by vLLM."""
"""Modality types that are predefined by vLLM."""
...
...
vllm/multimodal/image.py
View file @
e661d594
...
@@ -113,7 +113,7 @@ class ImagePlugin(MultiModalPlugin):
...
@@ -113,7 +113,7 @@ class ImagePlugin(MultiModalPlugin):
def
_default_input_mapper
(
self
,
ctx
:
InputContext
,
def
_default_input_mapper
(
self
,
ctx
:
InputContext
,
data
:
object
)
->
MultiModalInputs
:
data
:
object
)
->
MultiModalInputs
:
model_config
=
ctx
.
model_config
model_config
=
ctx
.
model_config
if
isinstance
(
data
,
Image
.
Image
):
if
isinstance
(
data
,
(
Image
.
Image
,
list
)
):
image_processor
=
self
.
_get_hf_image_processor
(
model_config
)
image_processor
=
self
.
_get_hf_image_processor
(
model_config
)
if
image_processor
is
None
:
if
image_processor
is
None
:
raise
RuntimeError
(
"No HuggingFace processor is available "
raise
RuntimeError
(
"No HuggingFace processor is available "
...
...
vllm/outputs.py
View file @
e661d594
...
@@ -29,7 +29,7 @@ class CompletionOutput:
...
@@ -29,7 +29,7 @@ class CompletionOutput:
index
:
int
index
:
int
text
:
str
text
:
str
token_ids
:
Tuple
[
int
,
...]
token_ids
:
Tuple
[
int
,
...]
cumulative_logprob
:
float
cumulative_logprob
:
Optional
[
float
]
logprobs
:
Optional
[
SampleLogprobs
]
logprobs
:
Optional
[
SampleLogprobs
]
finish_reason
:
Optional
[
str
]
=
None
finish_reason
:
Optional
[
str
]
=
None
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
...
@@ -124,13 +124,14 @@ class RequestOutput:
...
@@ -124,13 +124,14 @@ class RequestOutput:
include_logprobs
=
seq_group
.
sampling_params
.
logprobs
is
not
None
include_logprobs
=
seq_group
.
sampling_params
.
logprobs
is
not
None
text_buffer_length
=
seq_group
.
sampling_params
.
output_text_buffer_length
text_buffer_length
=
seq_group
.
sampling_params
.
output_text_buffer_length
outputs
=
[
outputs
=
[
CompletionOutput
(
seqs
.
index
(
seq
),
CompletionOutput
(
seq
.
get_output_text_to_return
(
text_buffer_length
),
seqs
.
index
(
seq
),
seq
.
get_output_token_ids
(),
seq
.
get_output_text_to_return
(
text_buffer_length
),
seq
.
get_cumulative_logprob
(),
seq
.
get_output_token_ids
(),
seq
.
output_logprobs
if
include_logprobs
else
None
,
seq
.
get_cumulative_logprob
()
if
include_logprobs
else
None
,
SequenceStatus
.
get_finished_reason
(
seq
.
status
),
seq
.
output_logprobs
if
include_logprobs
else
None
,
seq
.
stop_reason
)
for
seq
in
top_n_seqs
SequenceStatus
.
get_finished_reason
(
seq
.
status
),
seq
.
stop_reason
)
for
seq
in
top_n_seqs
]
]
# Every sequence in the sequence group should have the same prompt.
# Every sequence in the sequence group should have the same prompt.
...
...
vllm/sampling_params.py
View file @
e661d594
...
@@ -92,11 +92,12 @@ class SamplingParams:
...
@@ -92,11 +92,12 @@ class SamplingParams:
min_tokens: Minimum number of tokens to generate per output sequence
min_tokens: Minimum number of tokens to generate per output sequence
before EOS or stop_token_ids can be generated
before EOS or stop_token_ids can be generated
logprobs: Number of log probabilities to return per output token.
logprobs: Number of log probabilities to return per output token.
Note that the implementation follows the OpenAI API: The return
When set to None, no probability is returned. If set to a non-None
result includes the log probabilities on the `logprobs` most likely
value, the result includes the log probabilities of the specified
tokens, as well the chosen tokens. The API will always return the
number of most likely tokens, as well as the chosen tokens.
log probability of the sampled token, so there may be up to
Note that the implementation follows the OpenAI API: The API will
`logprobs+1` elements in the response.
always return the log probability of the sampled token, so there
may be up to `logprobs+1` elements in the response.
prompt_logprobs: Number of log probabilities to return per prompt token.
prompt_logprobs: Number of log probabilities to return per prompt token.
detokenize: Whether to detokenize the output. Defaults to True.
detokenize: Whether to detokenize the output. Defaults to True.
skip_special_tokens: Whether to skip special tokens in the output.
skip_special_tokens: Whether to skip special tokens in the output.
...
@@ -168,8 +169,8 @@ class SamplingParams:
...
@@ -168,8 +169,8 @@ class SamplingParams:
self
.
ignore_eos
=
ignore_eos
self
.
ignore_eos
=
ignore_eos
self
.
max_tokens
=
max_tokens
self
.
max_tokens
=
max_tokens
self
.
min_tokens
=
min_tokens
self
.
min_tokens
=
min_tokens
self
.
logprobs
=
logprobs
self
.
logprobs
=
1
if
logprobs
is
True
else
logprobs
self
.
prompt_logprobs
=
prompt_logprobs
self
.
prompt_logprobs
=
1
if
prompt_logprobs
is
True
else
prompt_logprobs
# NOTE: This parameter is only exposed at the engine level for now.
# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
# not support returning only a list of token IDs.
...
...
vllm/scalar_type.py
0 → 100644
View file @
e661d594
from
._core_ext
import
NanRepr
,
ScalarType
# naming generally follows: https://github.com/jax-ml/ml_dtypes
# for floating point types (leading f) the scheme is:
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
# flags:
# - no-flags: means it follows IEEE 754 conventions
# - f: means finite values only (no infinities)
# - n: means nans are supported (non-standard encoding)
# for integer types the scheme is:
# `[u]int<size_bits>[b<bias>]`
# - if bias is not present it means its zero
class
scalar_types
:
int4
=
ScalarType
.
int_
(
4
,
None
)
uint4
=
ScalarType
.
uint
(
4
,
None
)
int8
=
ScalarType
.
int_
(
8
,
None
)
uint8
=
ScalarType
.
uint
(
8
,
None
)
float8_e4m3fn
=
ScalarType
.
float_
(
4
,
3
,
True
,
NanRepr
.
EXTD_RANGE_MAX_MIN
.
value
)
float8_e5m2
=
ScalarType
.
float_IEEE754
(
5
,
2
)
float16_e8m7
=
ScalarType
.
float_IEEE754
(
8
,
7
)
float16_e5m10
=
ScalarType
.
float_IEEE754
(
5
,
10
)
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
float6_e3m2f
=
ScalarType
.
float_
(
3
,
2
,
True
,
NanRepr
.
NONE
.
value
)
# "gptq" types
uint4b8
=
ScalarType
.
uint
(
4
,
8
)
uint8b128
=
ScalarType
.
uint
(
8
,
128
)
# colloquial names
bfloat16
=
float16_e8m7
float16
=
float16_e5m10
vllm/scripts.py
View file @
e661d594
# The CLI entrypoint to vLLM.
# The CLI entrypoint to vLLM.
import
argparse
import
argparse
import
asyncio
import
os
import
os
import
signal
import
signal
import
sys
import
sys
from
typing
import
Optional
from
typing
import
List
,
Optional
from
openai
import
OpenAI
from
openai
import
OpenAI
from
openai.types.chat
import
ChatCompletionMessageParam
from
vllm.entrypoints.openai.api_server
import
run_server
from
vllm.entrypoints.openai.api_server
import
run_server
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
from
vllm.entrypoints.openai.cli_args
import
make_arg_parser
...
@@ -25,7 +27,7 @@ def serve(args: argparse.Namespace) -> None:
...
@@ -25,7 +27,7 @@ def serve(args: argparse.Namespace) -> None:
# EngineArgs expects the model name to be passed as --model.
# EngineArgs expects the model name to be passed as --model.
args
.
model
=
args
.
model_tag
args
.
model
=
args
.
model_tag
run_server
(
args
)
asyncio
.
run
(
run_server
(
args
)
)
def
interactive_cli
(
args
:
argparse
.
Namespace
)
->
None
:
def
interactive_cli
(
args
:
argparse
.
Namespace
)
->
None
:
...
@@ -62,15 +64,14 @@ def complete(model_name: str, client: OpenAI) -> None:
...
@@ -62,15 +64,14 @@ def complete(model_name: str, client: OpenAI) -> None:
def
chat
(
system_prompt
:
Optional
[
str
],
model_name
:
str
,
def
chat
(
system_prompt
:
Optional
[
str
],
model_name
:
str
,
client
:
OpenAI
)
->
None
:
client
:
OpenAI
)
->
None
:
conversation
=
[]
conversation
:
List
[
ChatCompletionMessageParam
]
=
[]
if
system_prompt
is
not
None
:
if
system_prompt
is
not
None
:
conversation
.
append
({
"role"
:
"system"
,
"content"
:
system_prompt
})
conversation
.
append
({
"role"
:
"system"
,
"content"
:
system_prompt
})
print
(
"Please enter a message for the chat model:"
)
print
(
"Please enter a message for the chat model:"
)
while
True
:
while
True
:
input_message
=
input
(
"> "
)
input_message
=
input
(
"> "
)
message
=
{
"role"
:
"user"
,
"content"
:
input_message
}
conversation
.
append
({
"role"
:
"user"
,
"content"
:
input_message
})
conversation
.
append
(
message
)
chat_completion
=
client
.
chat
.
completions
.
create
(
model
=
model_name
,
chat_completion
=
client
.
chat
.
completions
.
create
(
model
=
model_name
,
messages
=
conversation
)
messages
=
conversation
)
...
@@ -78,7 +79,7 @@ def chat(system_prompt: Optional[str], model_name: str,
...
@@ -78,7 +79,7 @@ def chat(system_prompt: Optional[str], model_name: str,
response_message
=
chat_completion
.
choices
[
0
].
message
response_message
=
chat_completion
.
choices
[
0
].
message
output
=
response_message
.
content
output
=
response_message
.
content
conversation
.
append
(
response_message
)
conversation
.
append
(
response_message
)
# type: ignore
print
(
output
)
print
(
output
)
...
...
vllm/sequence.py
View file @
e661d594
...
@@ -3,6 +3,7 @@ import copy
...
@@ -3,6 +3,7 @@ import copy
import
enum
import
enum
import
math
import
math
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
array
import
array
from
collections
import
defaultdict
from
collections
import
defaultdict
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
(
TYPE_CHECKING
,
Dict
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
from
typing
import
(
TYPE_CHECKING
,
Dict
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
...
@@ -119,10 +120,10 @@ class SequenceData:
...
@@ -119,10 +120,10 @@ class SequenceData:
prompt_token_ids
:
List
[
int
],
prompt_token_ids
:
List
[
int
],
output_token_ids
:
Optional
[
List
[
int
]]
=
None
,
output_token_ids
:
Optional
[
List
[
int
]]
=
None
,
)
->
None
:
)
->
None
:
self
.
_prompt_token_ids
:
List
[
int
]
=
list
(
prompt_token_ids
)
self
.
_prompt_token_ids
=
array
(
'l'
,
prompt_token_ids
)
self
.
_prompt_token_ids_tuple
:
Tuple
[
int
,
...]
=
tuple
(
prompt_token_ids
)
self
.
_prompt_token_ids_tuple
:
Tuple
[
int
,
...]
=
tuple
(
prompt_token_ids
)
self
.
_output_token_ids
:
List
[
int
]
=
(
self
.
_output_token_ids
=
array
(
list
(
output_token_ids
)
if
output_token_ids
is
not
None
else
[])
'l'
,
output_token_ids
if
output_token_ids
is
not
None
else
[])
self
.
cumulative_logprob
=
0.0
self
.
cumulative_logprob
=
0.0
# The number of tokens that are computed (that run against the model).
# The number of tokens that are computed (that run against the model).
...
@@ -132,8 +133,8 @@ class SequenceData:
...
@@ -132,8 +133,8 @@ class SequenceData:
self
.
_update_cached_all_tokens
()
self
.
_update_cached_all_tokens
()
def
_update_cached_all_tokens
(
self
):
def
_update_cached_all_tokens
(
self
):
self
.
_cached_all_token_ids
:
List
[
int
]
=
(
self
.
_prompt_token_ids
+
self
.
_cached_all_token_ids
:
List
[
int
]
=
list
(
self
.
_prompt_token_ids
+
self
.
_output_token_ids
)
self
.
_output_token_ids
)
@
property
@
property
def
prompt_token_ids
(
self
)
->
Tuple
[
int
,
...]:
def
prompt_token_ids
(
self
)
->
Tuple
[
int
,
...]:
...
@@ -141,19 +142,27 @@ class SequenceData:
...
@@ -141,19 +142,27 @@ class SequenceData:
@
prompt_token_ids
.
setter
@
prompt_token_ids
.
setter
def
prompt_token_ids
(
self
,
new_prompt_token_ids
)
->
None
:
def
prompt_token_ids
(
self
,
new_prompt_token_ids
)
->
None
:
self
.
_prompt_token_ids
=
list
(
new_prompt_token_ids
)
self
.
_prompt_token_ids
=
array
(
'l'
,
new_prompt_token_ids
)
self
.
_prompt_token_ids_tuple
=
tuple
(
new_prompt_token_ids
)
self
.
_prompt_token_ids_tuple
=
tuple
(
new_prompt_token_ids
)
self
.
_update_cached_all_tokens
()
self
.
_update_cached_all_tokens
()
@
property
def
prompt_token_ids_array
(
self
)
->
array
:
return
self
.
_prompt_token_ids
@
property
@
property
def
output_token_ids
(
self
)
->
Tuple
[
int
,
...]:
def
output_token_ids
(
self
)
->
Tuple
[
int
,
...]:
return
tuple
(
self
.
_output_token_ids
)
return
tuple
(
self
.
_output_token_ids
)
@
output_token_ids
.
setter
@
output_token_ids
.
setter
def
output_token_ids
(
self
,
new_output_token_ids
)
->
None
:
def
output_token_ids
(
self
,
new_output_token_ids
)
->
None
:
self
.
_output_token_ids
=
list
(
new_output_token_ids
)
self
.
_output_token_ids
=
array
(
'l'
,
new_output_token_ids
)
self
.
_update_cached_all_tokens
()
self
.
_update_cached_all_tokens
()
@
property
def
output_token_ids_array
(
self
)
->
array
:
return
self
.
_output_token_ids
def
append_token_id
(
self
,
token_id
:
int
,
logprob
:
float
)
->
None
:
def
append_token_id
(
self
,
token_id
:
int
,
logprob
:
float
)
->
None
:
self
.
_output_token_ids
.
append
(
token_id
)
self
.
_output_token_ids
.
append
(
token_id
)
self
.
_cached_all_token_ids
.
append
(
token_id
)
self
.
_cached_all_token_ids
.
append
(
token_id
)
...
@@ -402,14 +411,6 @@ class Sequence:
...
@@ -402,14 +411,6 @@ class Sequence:
f
"num_blocks=
{
self
.
n_blocks
}
, "
)
f
"num_blocks=
{
self
.
n_blocks
}
, "
)
@
dataclass
class
SequenceGroupState
:
"""Mutable state tied to a specific sequence group"""
# torch.Generator used in seeded sampling
generator
:
Optional
=
None
# type: ignore
class
SequenceGroup
:
class
SequenceGroup
:
"""A group of sequences that are generated from the same prompt.
"""A group of sequences that are generated from the same prompt.
...
@@ -443,6 +444,7 @@ class SequenceGroup:
...
@@ -443,6 +444,7 @@ class SequenceGroup:
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
)
->
None
:
)
->
None
:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
seqs
=
seqs
self
.
seqs_dict
=
{
seq
.
seq_id
:
seq
for
seq
in
seqs
}
self
.
seqs_dict
=
{
seq
.
seq_id
:
seq
for
seq
in
seqs
}
self
.
sampling_params
=
sampling_params
self
.
sampling_params
=
sampling_params
self
.
metrics
=
RequestMetrics
(
arrival_time
=
arrival_time
,
self
.
metrics
=
RequestMetrics
(
arrival_time
=
arrival_time
,
...
@@ -452,31 +454,29 @@ class SequenceGroup:
...
@@ -452,31 +454,29 @@ class SequenceGroup:
time_in_queue
=
None
)
time_in_queue
=
None
)
self
.
lora_request
=
lora_request
self
.
lora_request
=
lora_request
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
self
.
prompt_logprobs
:
Optional
[
PromptLogprobs
]
=
None
self
.
state
=
SequenceGroupState
()
self
.
embeddings
=
embeddings
self
.
embeddings
=
embeddings
self
.
pooling_params
=
pooling_params
self
.
pooling_params
=
pooling_params
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
encoder_seq
=
encoder_seq
self
.
encoder_seq
=
encoder_seq
self
.
trace_headers
=
trace_headers
self
.
trace_headers
=
trace_headers
self
.
_first_seq
=
next
(
iter
(
self
.
seqs_dict
.
values
()))
@
property
@
property
def
prompt
(
self
)
->
Optional
[
str
]:
def
prompt
(
self
)
->
Optional
[
str
]:
# All sequences in the group should have the same prompt.
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
# We use the prompt of an arbitrary sequence.
return
self
.
_first_seq
.
prompt
return
self
.
seqs
[
0
]
.
prompt
@
property
@
property
def
prompt_token_ids
(
self
)
->
List
[
int
]:
def
prompt_token_ids
(
self
)
->
List
[
int
]:
# All sequences in the group should have the same prompt.
# All sequences in the group should have the same prompt.
# We use the prompt of an arbitrary sequence.
# We use the prompt of an arbitrary sequence.
return
self
.
_first_seq
.
prompt_token_ids
return
self
.
seqs
[
0
]
.
prompt_token_ids
@
property
@
property
def
multi_modal_data
(
self
)
->
"MultiModalDataDict"
:
def
multi_modal_data
(
self
)
->
"MultiModalDataDict"
:
# All sequences in the group should have the same multi-modal data.
# All sequences in the group should have the same multi-modal data.
# We use the multi-modal data of an arbitrary sequence.
# We use the multi-modal data of an arbitrary sequence.
return
self
.
_first_seq
.
multi_modal_data
return
self
.
seqs
[
0
]
.
multi_modal_data
@
property
@
property
def
lora_int_id
(
self
)
->
int
:
def
lora_int_id
(
self
)
->
int
:
...
@@ -512,7 +512,7 @@ class SequenceGroup:
...
@@ -512,7 +512,7 @@ class SequenceGroup:
# in TPOT, rather than recalculating TTFT (since from the )
# in TPOT, rather than recalculating TTFT (since from the )
# POV of the user, there is simply a long generation delay.
# POV of the user, there is simply a long generation delay.
if
(
self
.
metrics
.
first_token_time
is
None
if
(
self
.
metrics
.
first_token_time
is
None
and
self
.
get_
seqs
()
[
0
].
get_output_len
()
==
1
):
and
self
.
seqs
[
0
].
get_output_len
()
==
1
):
self
.
metrics
.
first_token_time
=
time
self
.
metrics
.
first_token_time
=
time
def
maybe_set_first_scheduled_time
(
self
,
time
:
float
)
->
None
:
def
maybe_set_first_scheduled_time
(
self
,
time
:
float
)
->
None
:
...
@@ -548,9 +548,9 @@ class SequenceGroup:
...
@@ -548,9 +548,9 @@ class SequenceGroup:
self
,
self
,
status
:
Optional
[
SequenceStatus
]
=
None
,
status
:
Optional
[
SequenceStatus
]
=
None
,
)
->
List
[
Sequence
]:
)
->
List
[
Sequence
]:
return
list
(
self
.
seqs_dict
.
values
())
if
status
is
None
else
[
if
status
is
None
:
seq
for
seq
in
self
.
seqs_dict
.
values
()
if
seq
.
status
==
statu
s
return
self
.
seq
s
]
return
[
seq
for
seq
in
self
.
seqs
if
seq
.
status
==
status
]
def
is_encoder_decoder
(
self
)
->
bool
:
def
is_encoder_decoder
(
self
)
->
bool
:
return
self
.
encoder_seq
is
not
None
return
self
.
encoder_seq
is
not
None
...
@@ -559,22 +559,20 @@ class SequenceGroup:
...
@@ -559,22 +559,20 @@ class SequenceGroup:
return
self
.
encoder_seq
return
self
.
encoder_seq
def
get_unfinished_seqs
(
self
)
->
List
[
Sequence
]:
def
get_unfinished_seqs
(
self
)
->
List
[
Sequence
]:
return
[
return
[
seq
for
seq
in
self
.
seqs
if
not
seq
.
is_finished
()]
seq
for
seq
in
self
.
seqs_dict
.
values
()
if
not
seq
.
is_finished
()
]
def
get_finished_seqs
(
self
)
->
List
[
Sequence
]:
def
get_finished_seqs
(
self
)
->
List
[
Sequence
]:
return
[
seq
for
seq
in
self
.
seqs
_dict
.
values
()
if
seq
.
is_finished
()]
return
[
seq
for
seq
in
self
.
seqs
if
seq
.
is_finished
()]
def
update_num_computed_tokens
(
self
,
num_new_computed_tokens
:
int
):
def
update_num_computed_tokens
(
self
,
num_new_computed_tokens
:
int
):
"""Update number of tokens computed so far."""
"""Update number of tokens computed so far."""
for
seq
in
self
.
seqs
_dict
.
values
()
:
for
seq
in
self
.
seqs
:
if
not
seq
.
is_finished
():
if
not
seq
.
is_finished
():
seq
.
data
.
update_num_computed_tokens
(
num_new_computed_tokens
)
seq
.
data
.
update_num_computed_tokens
(
num_new_computed_tokens
)
def
get_num_uncomputed_tokens
(
self
)
->
int
:
def
get_num_uncomputed_tokens
(
self
)
->
int
:
num_uncomputed_tokens
=
0
num_uncomputed_tokens
=
0
for
seq
in
self
.
get_
seqs
()
:
for
seq
in
self
.
seqs
:
if
not
seq
.
is_finished
():
if
not
seq
.
is_finished
():
num_uncomputed_tokens
+=
seq
.
data
.
get_num_uncomputed_tokens
()
num_uncomputed_tokens
+=
seq
.
data
.
get_num_uncomputed_tokens
()
return
num_uncomputed_tokens
return
num_uncomputed_tokens
...
@@ -583,7 +581,7 @@ class SequenceGroup:
...
@@ -583,7 +581,7 @@ class SequenceGroup:
# Optimization. We don't need to call get_seqs if we don't need to
# Optimization. We don't need to call get_seqs if we don't need to
# filter by states.
# filter by states.
if
status
is
None
:
if
status
is
None
:
return
len
(
self
.
seqs
_dict
)
return
len
(
self
.
seqs
)
return
len
(
self
.
get_seqs
(
status
))
return
len
(
self
.
get_seqs
(
status
))
...
@@ -602,23 +600,25 @@ class SequenceGroup:
...
@@ -602,23 +600,25 @@ class SequenceGroup:
if
seq
.
seq_id
in
self
.
seqs_dict
:
if
seq
.
seq_id
in
self
.
seqs_dict
:
raise
ValueError
(
f
"Sequence
{
seq
.
seq_id
}
already exists."
)
raise
ValueError
(
f
"Sequence
{
seq
.
seq_id
}
already exists."
)
self
.
seqs_dict
[
seq
.
seq_id
]
=
seq
self
.
seqs_dict
[
seq
.
seq_id
]
=
seq
self
.
seqs
.
append
(
seq
)
def
remove
(
self
,
seq_id
:
int
)
->
None
:
def
remove
(
self
,
seq_id
:
int
)
->
None
:
if
seq_id
not
in
self
.
seqs_dict
:
seq
=
self
.
seqs_dict
.
pop
(
seq_id
,
None
)
if
seq
is
None
:
raise
ValueError
(
f
"Sequence
{
seq_id
}
not found."
)
raise
ValueError
(
f
"Sequence
{
seq_id
}
not found."
)
del
self
.
seqs
_dict
[
seq_id
]
self
.
seqs
.
remove
(
seq
)
def
is_finished
(
self
)
->
bool
:
def
is_finished
(
self
)
->
bool
:
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
get_
seqs
()
)
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
seqs
)
def
is_prefill
(
self
)
->
bool
:
def
is_prefill
(
self
)
->
bool
:
# Every sequence should be in the same stage.
# Every sequence should be in the same stage.
return
self
.
get_
seqs
()
[
0
].
is_prefill
()
return
self
.
seqs
[
0
].
is_prefill
()
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceGroup(request_id=
{
self
.
request_id
}
, "
return
(
f
"SequenceGroup(request_id=
{
self
.
request_id
}
, "
f
"sampling_params=
{
self
.
sampling_params
}
, "
f
"sampling_params=
{
self
.
sampling_params
}
, "
f
"num_seqs=
{
len
(
self
.
seqs
_dict
)
}
)"
)
f
"num_seqs=
{
len
(
self
.
seqs
)
}
)"
)
class
SequenceGroupMetadata
:
class
SequenceGroupMetadata
:
...
@@ -639,7 +639,6 @@ class SequenceGroupMetadata:
...
@@ -639,7 +639,6 @@ class SequenceGroupMetadata:
lora_request: LoRA request.
lora_request: LoRA request.
computed_block_nums: The block numbers that are already computed,
computed_block_nums: The block numbers that are already computed,
used in prefix caching.
used in prefix caching.
state: Internal state tied to this sequence group.
multi_modal_data: Multi modal data.
multi_modal_data: Multi modal data.
encoder_seq_data: Optional sequence data for encoder prompt
encoder_seq_data: Optional sequence data for encoder prompt
(SequenceGroup.encoder_seq). Should be None
(SequenceGroup.encoder_seq). Should be None
...
@@ -665,7 +664,6 @@ class SequenceGroupMetadata:
...
@@ -665,7 +664,6 @@ class SequenceGroupMetadata:
token_chunk_size
:
Optional
[
int
]
=
None
,
token_chunk_size
:
Optional
[
int
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
computed_block_nums
:
Optional
[
List
[
int
]]
=
None
,
computed_block_nums
:
Optional
[
List
[
int
]]
=
None
,
state
:
Optional
[
SequenceGroupState
]
=
None
,
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
,
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
,
encoder_seq_data
:
Optional
[
SequenceData
]
=
None
,
encoder_seq_data
:
Optional
[
SequenceData
]
=
None
,
cross_block_table
:
Optional
[
List
[
int
]]
=
None
,
cross_block_table
:
Optional
[
List
[
int
]]
=
None
,
...
@@ -681,7 +679,6 @@ class SequenceGroupMetadata:
...
@@ -681,7 +679,6 @@ class SequenceGroupMetadata:
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
computed_block_nums
=
computed_block_nums
self
.
computed_block_nums
=
computed_block_nums
self
.
multi_modal_data
=
multi_modal_data
self
.
multi_modal_data
=
multi_modal_data
self
.
state
=
SequenceGroupState
()
if
state
is
None
else
state
self
.
encoder_seq_data
=
encoder_seq_data
self
.
encoder_seq_data
=
encoder_seq_data
self
.
cross_block_table
=
cross_block_table
self
.
cross_block_table
=
cross_block_table
self
.
_token_chunk_size
=
token_chunk_size
self
.
_token_chunk_size
=
token_chunk_size
...
...
vllm/spec_decode/batch_expansion.py
View file @
e661d594
...
@@ -3,9 +3,9 @@ from typing import Iterator, List, Tuple
...
@@ -3,9 +3,9 @@ from typing import Iterator, List, Tuple
import
torch
import
torch
from
vllm
import
SamplingParams
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceData
,
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
,
SequenceGroupState
,
SequenceGroupMetadata
,
get_all_seq_ids
)
get_all_seq_ids
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
(
nvtx_range
,
sampler_output_to_torch
,
from
vllm.spec_decode.util
import
(
nvtx_range
,
sampler_output_to_torch
,
...
@@ -16,6 +16,8 @@ SeqId = int
...
@@ -16,6 +16,8 @@ SeqId = int
TargetSeqId
=
int
TargetSeqId
=
int
TokenId
=
int
TokenId
=
int
DEFAULT_SIMPLE_SAMPLING_PARAMS
=
SamplingParams
()
class
BatchExpansionTop1Scorer
(
SpeculativeScorer
):
class
BatchExpansionTop1Scorer
(
SpeculativeScorer
):
"""Implements a speculative scorer that uses batch expansion to get
"""Implements a speculative scorer that uses batch expansion to get
...
@@ -247,24 +249,39 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -247,24 +249,39 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
token_ids_to_score
=
self
.
_get_token_ids_to_score
(
token_ids_to_score
=
self
.
_get_token_ids_to_score
(
proposal_token_ids
[
batch_index
])
proposal_token_ids
[
batch_index
])
# Use simpler sampling parameters apart from for final token
# (in particular don't do seeded sampling) since those sampled tokens
# aren't used.
# We don't replace the sampling_params in the greedy case because
# this also controls whether the probs get modified in the sampler
# (see use of _modify_greedy_probs_inplace there).
sampling_params
=
input_seq_group_metadata
.
sampling_params
non_bonus_sampling_params
=
DEFAULT_SIMPLE_SAMPLING_PARAMS
\
if
sampling_params
.
temperature
else
sampling_params
target_seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
target_seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
=
[]
for
token_ids
in
token_ids_to_score
:
last_index
=
len
(
token_ids_to_score
)
-
1
for
i
,
token_ids
in
enumerate
(
token_ids_to_score
):
target_sampling_params
=
sampling_params
if
i
==
last_index
\
else
non_bonus_sampling_params
target_seq_group_metadata_list
.
append
(
target_seq_group_metadata_list
.
append
(
self
.
_create_single_target_seq_group_metadata
(
self
.
_create_single_target_seq_group_metadata
(
input_seq_group_metadata
,
input_seq_group_metadata
,
input_seq_id
,
input_seq_id
,
next
(
target_seq_ids_iter
),
next
(
target_seq_ids_iter
),
token_ids
,
token_ids
,
sampling_params
=
target_sampling_params
,
))
))
return
target_seq_group_metadata_list
return
target_seq_group_metadata_list
@
staticmethod
def
_create_single_target_seq_group_metadata
(
def
_create_single_target_seq_group_metadata
(
self
,
seq_group_metadata
:
SequenceGroupMetadata
,
seq_group_metadata
:
SequenceGroupMetadata
,
seq_id
:
SeqId
,
seq_id
:
SeqId
,
target_seq_id
:
TargetSeqId
,
target_seq_id
:
TargetSeqId
,
token_ids
:
List
[
TokenId
],
token_ids
:
List
[
TokenId
],
sampling_params
:
SamplingParams
,
)
->
SequenceGroupMetadata
:
)
->
SequenceGroupMetadata
:
"""Create a single target SequenceGroupMetadata.
"""Create a single target SequenceGroupMetadata.
...
@@ -293,26 +310,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -293,26 +310,16 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
for
data
in
new_seq_data_dict
.
values
():
for
data
in
new_seq_data_dict
.
values
():
data
.
update_num_computed_tokens
(
data
.
get_len
()
-
1
)
data
.
update_num_computed_tokens
(
data
.
get_len
()
-
1
)
if
(
seq_group_metadata
.
state
is
not
None
and
seq_group_metadata
.
state
.
generator
is
not
None
):
generator
=
torch
.
Generator
(
device
=
seq_group_metadata
.
state
.
generator
.
device
)
generator
.
set_state
(
seq_group_metadata
.
state
.
generator
.
get_state
())
state
=
SequenceGroupState
(
generator
=
generator
)
else
:
state
=
None
return
SequenceGroupMetadata
(
return
SequenceGroupMetadata
(
request_id
=
seq_group_metadata
.
request_id
,
request_id
=
seq_group_metadata
.
request_id
,
is_prompt
=
seq_group_metadata
.
is_prompt
,
is_prompt
=
seq_group_metadata
.
is_prompt
,
seq_data
=
new_seq_data_dict
,
seq_data
=
new_seq_data_dict
,
sampling_params
=
seq_group_metadata
.
sampling_params
,
sampling_params
=
sampling_params
,
block_tables
=
{
block_tables
=
{
target_seq_id
:
seq_group_metadata
.
block_tables
[
seq_id
],
target_seq_id
:
seq_group_metadata
.
block_tables
[
seq_id
],
},
},
lora_request
=
None
,
lora_request
=
None
,
token_chunk_size
=
1
,
token_chunk_size
=
1
,
state
=
state
,
)
)
def
_split_scoring_output
(
def
_split_scoring_output
(
...
...
vllm/spec_decode/draft_model_runner.py
View file @
e661d594
...
@@ -11,10 +11,22 @@ except ModuleNotFoundError:
...
@@ -11,10 +11,22 @@ except ModuleNotFoundError:
from
vllm.attention.backends.rocm_flash_attn
import
(
from
vllm.attention.backends.rocm_flash_attn
import
(
ROCmFlashAttentionMetadata
as
FlashAttentionMetadata
)
ROCmFlashAttentionMetadata
as
FlashAttentionMetadata
)
try
:
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
from
flashinfer.prefill
import
BatchPrefillWithPagedKVCacheWrapper
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
256
*
1024
*
1024
except
ImportError
:
BatchDecodeWithPagedKVCacheWrapper
=
None
CUDAGraphBatchDecodeWithPagedKVCacheWrapper
=
None
BatchPrefillWithPagedKVCacheWrapper
=
None
FLASHINFER_WORKSPACE_BUFFER_SIZE
=
0
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
ParallelConfig
,
ModelConfig
,
MultiModalConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
)
PromptAdapterConfig
,
SchedulerConfig
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalInputs
from
vllm.sequence
import
(
ExecuteModelRequest
,
IntermediateTensors
,
from
vllm.sequence
import
(
ExecuteModelRequest
,
IntermediateTensors
,
SamplerOutput
)
SamplerOutput
)
from
vllm.worker.model_runner
import
(
ModelInputForGPUWithSamplingMetadata
,
from
vllm.worker.model_runner
import
(
ModelInputForGPUWithSamplingMetadata
,
...
@@ -78,6 +90,11 @@ class TP1DraftModelRunner(ModelRunner):
...
@@ -78,6 +90,11 @@ class TP1DraftModelRunner(ModelRunner):
return_hidden_states
=
return_hidden_states
,
return_hidden_states
=
return_hidden_states
,
)
)
self
.
flashinfer_decode_workspace_buffer
=
None
self
.
flashinfer_decode_wrapper
=
None
self
.
flashinfer_prefill_workspace_buffer
=
None
self
.
flashinfer_prefill_wrapper
=
None
def
_update_flash_attn_metadata
(
self
,
attn_metadata
,
num_seqs
,
def
_update_flash_attn_metadata
(
self
,
attn_metadata
,
num_seqs
,
num_queries
):
num_queries
):
assert
isinstance
(
attn_metadata
,
FlashAttentionMetadata
)
assert
isinstance
(
attn_metadata
,
FlashAttentionMetadata
)
...
@@ -285,6 +302,37 @@ class TP1DraftModelRunner(ModelRunner):
...
@@ -285,6 +302,37 @@ class TP1DraftModelRunner(ModelRunner):
model_input
.
prompt_adapter_requests
,
model_input
.
prompt_adapter_requests
,
model_input
.
prompt_adapter_mapping
)
model_input
.
prompt_adapter_mapping
)
if
self
.
attn_backend
.
get_name
()
==
"flashinfer"
:
assert
model_input
.
attn_metadata
is
not
None
assert
model_input
.
input_tokens
is
not
None
if
self
.
flashinfer_decode_workspace_buffer
is
None
:
self
.
flashinfer_decode_workspace_buffer
=
torch
.
empty
(
FLASHINFER_WORKSPACE_BUFFER_SIZE
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
flashinfer_decode_wrapper
=
\
BatchDecodeWithPagedKVCacheWrapper
(
self
.
flashinfer_decode_workspace_buffer
,
"NHD"
)
self
.
flashinfer_prefill_workspace_buffer
=
torch
.
empty
(
FLASHINFER_WORKSPACE_BUFFER_SIZE
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
self
.
flashinfer_prefill_wrapper
=
\
BatchPrefillWithPagedKVCacheWrapper
(
self
.
flashinfer_prefill_workspace_buffer
,
"NHD"
)
model_input
.
attn_metadata
.
prefill_wrapper
=
\
self
.
flashinfer_prefill_wrapper
if
model_input
.
attn_metadata
.
use_cuda_graph
:
batch_size
=
model_input
.
input_tokens
.
shape
[
0
]
model_input
.
attn_metadata
.
decode_wrapper
=
\
self
.
graph_runners
[
model_input
.
virtual_engine
][
batch_size
].
flashinfer_decode_wrapper
else
:
model_input
.
attn_metadata
.
decode_wrapper
=
\
self
.
flashinfer_decode_wrapper
model_input
.
attn_metadata
.
begin_forward
()
# Detect exec mode
# Detect exec mode
assert
model_input
.
attn_metadata
is
not
None
assert
model_input
.
attn_metadata
is
not
None
use_cuda_graph
=
False
use_cuda_graph
=
False
...
@@ -323,7 +371,8 @@ class TP1DraftModelRunner(ModelRunner):
...
@@ -323,7 +371,8 @@ class TP1DraftModelRunner(ModelRunner):
kv_caches
=
kv_caches
,
kv_caches
=
kv_caches
,
attn_metadata
=
model_input
.
attn_metadata
,
attn_metadata
=
model_input
.
attn_metadata
,
intermediate_tensors
=
intermediate_tensors
,
intermediate_tensors
=
intermediate_tensors
,
**
multi_modal_kwargs
,
**
MultiModalInputs
.
as_kwargs
(
multi_modal_kwargs
,
device
=
self
.
device
),
)
)
# Compute the logits.
# Compute the logits.
...
...
vllm/spec_decode/medusa_worker.py
View file @
e661d594
...
@@ -57,9 +57,11 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
...
@@ -57,9 +57,11 @@ class MedusaWorker(NonLLMProposerWorkerBase, Worker):
seq_lens
,
query_lens
=
self
.
_prepare_input_tensors
(
seq_lens
,
query_lens
=
self
.
_prepare_input_tensors
(
seq_group_metadata_list
)
seq_group_metadata_list
)
generators
=
self
.
model_runner
.
get_generators
(
execute_model_req
.
finished_requests_ids
)
sampling_metadata
=
SamplingMetadata
.
prepare
(
sampling_metadata
=
SamplingMetadata
.
prepare
(
seq_group_metadata_list
,
seq_lens
,
query_lens
,
self
.
device
,
seq_group_metadata_list
,
seq_lens
,
query_lens
,
self
.
device
,
self
.
model_runner
.
pin_memory
)
self
.
model_runner
.
pin_memory
,
generators
)
model_outputs
=
self
.
model_runner
.
model
.
generate_proposals
(
model_outputs
=
self
.
model_runner
.
model
.
generate_proposals
(
previous_hidden_states
=
execute_model_req
.
previous_hidden_states
.
previous_hidden_states
=
execute_model_req
.
previous_hidden_states
.
...
...
Prev
1
…
13
14
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment