Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ff7ec82c
Unverified
Commit
ff7ec82c
authored
Aug 18, 2024
by
SangBin Cho
Committed by
GitHub
Aug 18, 2024
Browse files
[Core] Optimize SPMD architecture with delta + serialization optimization (#7109)
parent
200a2ffa
Changes
36
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
471 additions
and
267 deletions
+471
-267
vllm/lora/request.py
vllm/lora/request.py
+9
-5
vllm/model_executor/models/blip.py
vllm/model_executor/models/blip.py
+6
-3
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+7
-3
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+7
-3
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+6
-3
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+9
-4
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+4
-2
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+6
-3
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+6
-3
vllm/pooling_params.py
vllm/pooling_params.py
+7
-4
vllm/prompt_adapter/request.py
vllm/prompt_adapter/request.py
+7
-3
vllm/sampling_params.py
vllm/sampling_params.py
+72
-75
vllm/sequence.py
vllm/sequence.py
+250
-146
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+8
-5
vllm/spec_decode/metrics.py
vllm/spec_decode/metrics.py
+5
-3
vllm/worker/worker.py
vllm/worker/worker.py
+62
-2
No files found.
vllm/lora/request.py
View file @
ff7ec82c
import
warnings
import
warnings
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
from
typing
import
Optional
import
msgspec
from
vllm.adapter_commons.request
import
AdapterRequest
from
vllm.adapter_commons.request
import
AdapterRequest
@
dataclass
class
LoRARequest
(
class
LoRARequest
(
AdapterRequest
):
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
"""
"""
Request for a LoRA adapter.
Request for a LoRA adapter.
...
@@ -18,16 +21,17 @@ class LoRARequest(AdapterRequest):
...
@@ -18,16 +21,17 @@ class LoRARequest(AdapterRequest):
lora_int_id must be globally unique for a given adapter.
lora_int_id must be globally unique for a given adapter.
This is currently not enforced in vLLM.
This is currently not enforced in vLLM.
"""
"""
__metaclass__
=
AdapterRequest
lora_name
:
str
lora_name
:
str
lora_int_id
:
int
lora_int_id
:
int
lora_path
:
str
=
""
lora_path
:
str
=
""
lora_local_path
:
Optional
[
str
]
=
field
(
default
=
None
,
repr
=
False
)
lora_local_path
:
Optional
[
str
]
=
msgspec
.
field
(
default
=
None
)
long_lora_max_len
:
Optional
[
int
]
=
None
long_lora_max_len
:
Optional
[
int
]
=
None
__hash__
=
AdapterRequest
.
__hash__
__hash__
=
AdapterRequest
.
__hash__
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
'lora_local_path'
in
self
.
__
dict
__
:
if
'lora_local_path'
in
self
.
__
struct_fields
__
:
warnings
.
warn
(
warnings
.
warn
(
"The 'lora_local_path' attribute is deprecated "
"The 'lora_local_path' attribute is deprecated "
"and will be removed in a future version. "
"and will be removed in a future version. "
...
...
vllm/model_executor/models/blip.py
View file @
ff7ec82c
"""Minimal implementation of BlipVisionModel intended to be only used
"""Minimal implementation of BlipVisionModel intended to be only used
within a vision language model."""
within a vision language model."""
from
array
import
array
from
typing
import
Optional
,
Union
from
typing
import
Optional
,
Union
import
torch
import
torch
...
@@ -16,7 +17,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -16,7 +17,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.multimodal.image
import
(
cached_get_tokenizer
,
from
vllm.multimodal.image
import
(
cached_get_tokenizer
,
repeat_and_pad_image_tokens
)
repeat_and_pad_image_tokens
)
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
def
get_blip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
def
get_blip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
...
@@ -53,8 +54,10 @@ def dummy_seq_data_for_blip(
...
@@ -53,8 +54,10 @@ def dummy_seq_data_for_blip(
else
:
else
:
image_feature_size
=
image_feature_size_override
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
)
[
image_token_id
])
*
image_feature_size
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
image_feature_size
)
return
SequenceData
(
token_ids
)
return
SequenceData
(
token_ids
)
...
...
vllm/model_executor/models/blip2.py
View file @
ff7ec82c
from
array
import
array
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
TypedDict
,
Union
)
...
@@ -17,7 +18,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...
@@ -17,7 +18,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.models.opt
import
OPTModel
from
vllm.model_executor.models.opt
import
OPTModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
,
SequenceData
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SamplerOutput
,
SequenceData
)
from
.blip
import
(
BlipVisionModel
,
dummy_image_for_blip
,
from
.blip
import
(
BlipVisionModel
,
dummy_image_for_blip
,
get_max_blip_image_tokens
)
get_max_blip_image_tokens
)
...
@@ -427,8 +429,10 @@ def dummy_seq_data_for_blip2(
...
@@ -427,8 +429,10 @@ def dummy_seq_data_for_blip2(
else
:
else
:
image_feature_size
=
image_feature_size_override
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
*
num_images
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
*
num_images
)
[
image_token_id
])
*
image_feature_size
*
num_images
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
image_feature_size
*
num_images
)
return
SequenceData
(
token_ids
)
return
SequenceData
(
token_ids
)
...
...
vllm/model_executor/models/chameleon.py
View file @
ff7ec82c
from
array
import
array
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
)
Tuple
,
TypedDict
)
...
@@ -31,7 +32,8 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -31,7 +32,8 @@ from vllm.model_executor.utils import set_weight_attrs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
(
cached_get_tokenizer
,
from
vllm.multimodal.image
import
(
cached_get_tokenizer
,
repeat_and_pad_image_tokens
)
repeat_and_pad_image_tokens
)
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
,
SequenceData
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SamplerOutput
,
SequenceData
)
from
vllm.utils
import
print_warning_once
from
vllm.utils
import
print_warning_once
from
.interfaces
import
SupportsMultiModal
from
.interfaces
import
SupportsMultiModal
...
@@ -70,8 +72,10 @@ def dummy_seq_data_for_chameleon(
...
@@ -70,8 +72,10 @@ def dummy_seq_data_for_chameleon(
else
:
else
:
image_feature_size
=
image_feature_size_override
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
*
num_images
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
*
num_images
)
[
image_token_id
])
*
image_feature_size
*
num_images
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
image_feature_size
*
num_images
)
return
SequenceData
(
token_ids
)
return
SequenceData
(
token_ids
)
...
...
vllm/model_executor/models/clip.py
View file @
ff7ec82c
"""Minimal implementation of CLIPVisionModel intended to be only used
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
within a vision language model."""
from
array
import
array
from
typing
import
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -17,7 +18,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
...
@@ -17,7 +18,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.multimodal.image
import
(
cached_get_tokenizer
,
from
vllm.multimodal.image
import
(
cached_get_tokenizer
,
repeat_and_pad_image_tokens
)
repeat_and_pad_image_tokens
)
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
def
get_clip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
def
get_clip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
...
@@ -53,8 +54,10 @@ def dummy_seq_data_for_clip(
...
@@ -53,8 +54,10 @@ def dummy_seq_data_for_clip(
else
:
else
:
image_feature_size
=
image_feature_size_override
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
*
num_images
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
*
num_images
)
[
image_token_id
])
*
image_feature_size
*
num_images
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
image_feature_size
*
num_images
)
return
SequenceData
(
token_ids
)
return
SequenceData
(
token_ids
)
...
...
vllm/model_executor/models/fuyu.py
View file @
ff7ec82c
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
# limitations under the License.
# limitations under the License.
""" PyTorch Fuyu model."""
""" PyTorch Fuyu model."""
import
math
import
math
from
array
import
array
from
typing
import
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
from
typing
import
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
import
torch
import
torch
...
@@ -37,7 +38,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
...
@@ -37,7 +38,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.image
import
(
cached_get_image_processor
,
from
vllm.multimodal.image
import
(
cached_get_image_processor
,
cached_get_tokenizer
)
cached_get_tokenizer
)
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
,
SequenceData
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SamplerOutput
,
SequenceData
)
from
.interfaces
import
SupportsMultiModal
from
.interfaces
import
SupportsMultiModal
from
.utils
import
merge_multimodal_embeddings
from
.utils
import
merge_multimodal_embeddings
...
@@ -97,9 +99,12 @@ def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
...
@@ -97,9 +99,12 @@ def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
ncol
,
nrow
=
get_max_fuyu_image_feature_size
()
ncol
,
nrow
=
get_max_fuyu_image_feature_size
()
image_feature_size
=
get_max_fuyu_image_tokens
(
ctx
)
image_feature_size
=
get_max_fuyu_image_tokens
(
ctx
)
image_token_ids
=
([
_IMAGE_TOKEN_ID
]
*
ncol
+
[
_NEWLINE_TOKEN_ID
])
*
nrow
image_token_ids
=
(
token_ids
=
image_token_ids
*
num_images
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
_IMAGE_TOKEN_ID
])
*
ncol
+
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
*
num_images
)
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
_NEWLINE_TOKEN_ID
]))
*
nrow
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
image_token_ids
)
*
num_images
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
image_feature_size
*
num_images
)
return
SequenceData
(
token_ids
)
return
SequenceData
(
token_ids
)
...
...
vllm/model_executor/models/minicpmv.py
View file @
ff7ec82c
...
@@ -23,6 +23,7 @@
...
@@ -23,6 +23,7 @@
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import
math
import
math
import
re
import
re
from
array
import
array
from
functools
import
partial
from
functools
import
partial
from
typing
import
(
Any
,
Callable
,
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
from
typing
import
(
Any
,
Callable
,
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
TypedDict
,
Union
)
...
@@ -55,7 +56,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
...
@@ -55,7 +56,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
(
cached_get_image_processor
,
from
vllm.multimodal.image
import
(
cached_get_image_processor
,
cached_get_tokenizer
)
cached_get_tokenizer
)
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
,
SequenceData
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SamplerOutput
,
SequenceData
)
from
.idefics2_vision_model
import
Idefics2VisionTransformer
from
.idefics2_vision_model
import
Idefics2VisionTransformer
...
@@ -408,7 +410,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
...
@@ -408,7 +410,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
def
dummy_seq_data_for_minicpmv
(
seq_len
:
int
,
num_images
:
int
):
def
dummy_seq_data_for_minicpmv
(
seq_len
:
int
,
num_images
:
int
):
token_ids
=
[
0
]
*
seq_len
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
]
)
*
seq_len
return
SequenceData
(
token_ids
)
return
SequenceData
(
token_ids
)
...
...
vllm/model_executor/models/siglip.py
View file @
ff7ec82c
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
within a vision language model."""
within a vision language model."""
import
math
import
math
from
array
import
array
from
typing
import
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -25,7 +26,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -25,7 +26,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.multimodal.image
import
(
cached_get_tokenizer
,
from
vllm.multimodal.image
import
(
cached_get_tokenizer
,
repeat_and_pad_image_tokens
)
repeat_and_pad_image_tokens
)
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
def
get_siglip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
def
get_siglip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
...
@@ -62,8 +63,10 @@ def dummy_seq_data_for_siglip(
...
@@ -62,8 +63,10 @@ def dummy_seq_data_for_siglip(
else
:
else
:
image_feature_size
=
image_feature_size_override
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
*
num_images
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
*
num_images
)
[
image_token_id
])
*
image_feature_size
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
image_feature_size
)
return
SequenceData
(
token_ids
)
return
SequenceData
(
token_ids
)
...
...
vllm/model_executor/sampling_metadata.py
View file @
ff7ec82c
...
@@ -6,7 +6,8 @@ from typing import Dict, List, Optional, Tuple
...
@@ -6,7 +6,8 @@ from typing import Dict, List, Optional, Tuple
import
torch
import
torch
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.triton_utils.sample
import
get_num_triton_sampler_splits
from
vllm.triton_utils.sample
import
get_num_triton_sampler_splits
from
vllm.utils
import
(
PyObjectCache
,
async_tensor_h2d
,
from
vllm.utils
import
(
PyObjectCache
,
async_tensor_h2d
,
is_pin_memory_available
,
make_tensor_with_pad
,
is_pin_memory_available
,
make_tensor_with_pad
,
...
@@ -505,9 +506,11 @@ class SamplingTensors:
...
@@ -505,9 +506,11 @@ class SamplingTensors:
and
sampling_params
.
prompt_logprobs
is
not
None
):
and
sampling_params
.
prompt_logprobs
is
not
None
):
prefill_len
=
len
(
seq_group
.
prompt_logprob_indices
)
prefill_len
=
len
(
seq_group
.
prompt_logprob_indices
)
prompt_tokens
.
extend
(
prompt_tokens
.
extend
(
array
(
'l'
)
for
_
in
range
(
prefill_len
))
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
)
for
_
in
range
(
prefill_len
))
output_tokens
.
extend
(
output_tokens
.
extend
(
array
(
'l'
)
for
_
in
range
(
prefill_len
))
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
)
for
_
in
range
(
prefill_len
))
if
seq_group
.
do_sample
:
if
seq_group
.
do_sample
:
for
seq_id
in
seq_ids
:
for
seq_id
in
seq_ids
:
seq_data
=
seq_group
.
seq_data
[
seq_id
]
seq_data
=
seq_group
.
seq_data
[
seq_id
]
...
...
vllm/pooling_params.py
View file @
ff7ec82c
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
import
msgspec
class
PoolingParams
:
class
PoolingParams
(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
"""Pooling parameters for pooling.
"""Pooling parameters for pooling.
Attributes:
Attributes:
additional_data: Any additional data needed for pooling.
additional_data: Any additional data needed for pooling.
"""
"""
additional_data
:
Optional
[
Any
]
=
None
def
__init__
(
self
,
additional_data
:
Optional
[
Any
]
=
None
):
self
.
additional_data
=
additional_data
def
clone
(
self
)
->
"PoolingParams"
:
def
clone
(
self
)
->
"PoolingParams"
:
"""Returns a deep copy of the PoolingParams instance."""
"""Returns a deep copy of the PoolingParams instance."""
...
...
vllm/prompt_adapter/request.py
View file @
ff7ec82c
from
dataclasses
import
dataclass
import
msgspec
from
vllm.adapter_commons.request
import
AdapterRequest
from
vllm.adapter_commons.request
import
AdapterRequest
@
dataclass
class
PromptAdapterRequest
(
class
PromptAdapterRequest
(
AdapterRequest
):
msgspec
.
Struct
,
array_like
=
True
,
# type: ignore[call-arg]
omit_defaults
=
True
,
# type: ignore[call-arg]
frozen
=
True
):
# type: ignore[call-arg]
"""
"""
Request for a Prompt adapter.
Request for a Prompt adapter.
"""
"""
__metaclass__
=
AdapterRequest
prompt_adapter_name
:
str
prompt_adapter_name
:
str
prompt_adapter_id
:
int
prompt_adapter_id
:
int
...
...
vllm/sampling_params.py
View file @
ff7ec82c
...
@@ -2,10 +2,10 @@
...
@@ -2,10 +2,10 @@
import
copy
import
copy
from
enum
import
IntEnum
from
enum
import
IntEnum
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Union
import
msgspec
import
torch
import
torch
from
pydantic
import
Field
from
typing_extensions
import
Annotated
from
typing_extensions
import
Annotated
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -33,7 +33,11 @@ first argument, and returns a modified tensor of logits
...
@@ -33,7 +33,11 @@ first argument, and returns a modified tensor of logits
to sample from."""
to sample from."""
class
SamplingParams
:
class
SamplingParams
(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
# required for @cached_property.
dict
=
True
):
# type: ignore[call-arg]
"""Sampling parameters for text generation.
"""Sampling parameters for text generation.
Overall, we follow the sampling parameters from the OpenAI text completion
Overall, we follow the sampling parameters from the OpenAI text completion
...
@@ -112,87 +116,73 @@ class SamplingParams:
...
@@ -112,87 +116,73 @@ class SamplingParams:
(i.e., no truncation).
(i.e., no truncation).
"""
"""
def
__init__
(
n
:
int
=
1
self
,
best_of
:
Optional
[
int
]
=
None
n
:
int
=
1
,
presence_penalty
:
float
=
0.0
best_of
:
Optional
[
int
]
=
None
,
frequency_penalty
:
float
=
0.0
presence_penalty
:
float
=
0.0
,
repetition_penalty
:
float
=
1.0
frequency_penalty
:
float
=
0.0
,
temperature
:
float
=
1.0
repetition_penalty
:
float
=
1.0
,
top_p
:
float
=
1.0
temperature
:
float
=
1.0
,
top_k
:
int
=
-
1
top_p
:
float
=
1.0
,
min_p
:
float
=
0.0
top_k
:
int
=
-
1
,
seed
:
Optional
[
int
]
=
None
min_p
:
float
=
0.0
,
use_beam_search
:
bool
=
False
seed
:
Optional
[
int
]
=
None
,
length_penalty
:
float
=
1.0
use_beam_search
:
bool
=
False
,
early_stopping
:
Union
[
bool
,
str
]
=
False
length_penalty
:
float
=
1.0
,
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
early_stopping
:
Union
[
bool
,
str
]
=
False
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
ignore_eos
:
bool
=
False
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
max_tokens
:
Optional
[
int
]
=
16
include_stop_str_in_output
:
bool
=
False
,
min_tokens
:
int
=
0
ignore_eos
:
bool
=
False
,
logprobs
:
Optional
[
int
]
=
None
max_tokens
:
Optional
[
int
]
=
16
,
prompt_logprobs
:
Optional
[
int
]
=
None
min_tokens
:
int
=
0
,
# NOTE: This parameter is only exposed at the engine level for now.
logprobs
:
Optional
[
int
]
=
None
,
# It is not exposed in the OpenAI API server, as the OpenAI API does
prompt_logprobs
:
Optional
[
int
]
=
None
,
# not support returning only a list of token IDs.
detokenize
:
bool
=
True
,
detokenize
:
bool
=
True
skip_special_tokens
:
bool
=
True
,
skip_special_tokens
:
bool
=
True
spaces_between_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
logits_processors
:
Optional
[
List
[
LogitsProcessor
]]
=
None
,
# Optional[List[LogitsProcessor]] type. We use Any here because
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
,
# Optional[List[LogitsProcessor]] type is not supported by msgspec.
)
->
None
:
logits_processors
:
Optional
[
Any
]
=
None
self
.
n
=
n
include_stop_str_in_output
:
bool
=
False
self
.
best_of
=
best_of
if
best_of
is
not
None
else
n
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
msgspec
.
Meta
(
ge
=
1
)]]
=
None
self
.
presence_penalty
=
presence_penalty
self
.
frequency_penalty
=
frequency_penalty
# The below fields are not supposed to be used as an input.
self
.
repetition_penalty
=
repetition_penalty
# They are set in post_init.
if
0
<
temperature
<
_MAX_TEMP
:
output_text_buffer_length
:
int
=
0
_all_stop_token_ids
:
Set
[
int
]
=
msgspec
.
field
(
default_factory
=
set
)
def
__post_init__
(
self
)
->
None
:
self
.
best_of
=
self
.
best_of
or
self
.
n
if
0
<
self
.
temperature
<
_MAX_TEMP
:
logger
.
warning
(
logger
.
warning
(
"temperature %s is less than %s, which may cause numerical "
"temperature %s is less than %s, which may cause numerical "
"errors nan or inf in tensors. We have maxed it out to %s."
,
"errors nan or inf in tensors. We have maxed it out to %s."
,
temperature
,
_MAX_TEMP
,
_MAX_TEMP
)
self
.
temperature
,
_MAX_TEMP
,
_MAX_TEMP
)
temperature
=
max
(
temperature
,
_MAX_TEMP
)
self
.
temperature
=
max
(
self
.
temperature
,
_MAX_TEMP
)
self
.
temperature
=
temperature
if
self
.
seed
==
-
1
:
self
.
top_p
=
top_p
self
.
top_k
=
top_k
self
.
min_p
=
min_p
if
seed
==
-
1
:
self
.
seed
=
None
self
.
seed
=
None
else
:
else
:
self
.
seed
=
seed
self
.
seed
=
self
.
seed
self
.
use_beam_search
=
use_beam_search
if
self
.
stop
is
None
:
self
.
length_penalty
=
length_penalty
self
.
early_stopping
=
early_stopping
if
stop
is
None
:
self
.
stop
=
[]
self
.
stop
=
[]
elif
isinstance
(
stop
,
str
):
elif
isinstance
(
self
.
stop
,
str
):
self
.
stop
=
[
stop
]
self
.
stop
=
[
self
.
stop
]
else
:
else
:
self
.
stop
=
list
(
stop
)
self
.
stop
=
list
(
self
.
stop
)
if
stop_token_ids
is
None
:
if
self
.
stop_token_ids
is
None
:
self
.
stop_token_ids
=
[]
self
.
stop_token_ids
=
[]
else
:
else
:
self
.
stop_token_ids
=
list
(
stop_token_ids
)
self
.
stop_token_ids
=
list
(
self
.
stop_token_ids
)
self
.
ignore_eos
=
ignore_eos
self
.
logprobs
=
1
if
self
.
logprobs
is
True
else
self
.
logprobs
self
.
max_tokens
=
max_tokens
self
.
prompt_logprobs
=
(
1
if
self
.
prompt_logprobs
is
True
else
self
.
min_tokens
=
min_tokens
self
.
prompt_logprobs
)
self
.
logprobs
=
1
if
logprobs
is
True
else
logprobs
self
.
prompt_logprobs
=
1
if
prompt_logprobs
is
True
else
prompt_logprobs
# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
self
.
detokenize
=
detokenize
self
.
skip_special_tokens
=
skip_special_tokens
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
logits_processors
=
logits_processors
self
.
include_stop_str_in_output
=
include_stop_str_in_output
self
.
truncate_prompt_tokens
=
truncate_prompt_tokens
# Number of characters to hold back for stop string evaluation
# Number of characters to hold back for stop string evaluation
# until sequence is finished.
# until sequence is finished.
if
self
.
stop
and
not
include_stop_str_in_output
:
if
self
.
stop
and
not
self
.
include_stop_str_in_output
:
self
.
output_text_buffer_length
=
max
(
len
(
s
)
for
s
in
self
.
stop
)
-
1
self
.
output_text_buffer_length
=
max
(
len
(
s
)
for
s
in
self
.
stop
)
-
1
else
:
self
.
output_text_buffer_length
=
0
self
.
_verify_args
()
self
.
_verify_args
()
if
self
.
use_beam_search
:
if
self
.
use_beam_search
:
...
@@ -206,11 +196,12 @@ class SamplingParams:
...
@@ -206,11 +196,12 @@ class SamplingParams:
self
.
min_p
=
0.0
self
.
min_p
=
0.0
self
.
_verify_greedy_sampling
()
self
.
_verify_greedy_sampling
()
# eos_token_id is added to this by the engine
# eos_token_id is added to this by the engine
self
.
all_stop_token_ids
=
set
(
self
.
stop_token_ids
)
self
.
_
all_stop_token_ids
=
set
(
self
.
stop_token_ids
)
def
_verify_args
(
self
)
->
None
:
def
_verify_args
(
self
)
->
None
:
if
self
.
n
<
1
:
if
self
.
n
<
1
:
raise
ValueError
(
f
"n must be at least 1, got
{
self
.
n
}
."
)
raise
ValueError
(
f
"n must be at least 1, got
{
self
.
n
}
."
)
assert
isinstance
(
self
.
best_of
,
int
)
if
self
.
best_of
<
self
.
n
:
if
self
.
best_of
<
self
.
n
:
raise
ValueError
(
f
"best_of must be greater than or equal to n, "
raise
ValueError
(
f
"best_of must be greater than or equal to n, "
f
"got n=
{
self
.
n
}
and best_of=
{
self
.
best_of
}
."
)
f
"got n=
{
self
.
n
}
and best_of=
{
self
.
best_of
}
."
)
...
@@ -257,6 +248,7 @@ class SamplingParams:
...
@@ -257,6 +248,7 @@ class SamplingParams:
and
self
.
truncate_prompt_tokens
<
1
):
and
self
.
truncate_prompt_tokens
<
1
):
raise
ValueError
(
f
"truncate_prompt_tokens must be >= 1, "
raise
ValueError
(
f
"truncate_prompt_tokens must be >= 1, "
f
"got
{
self
.
truncate_prompt_tokens
}
"
)
f
"got
{
self
.
truncate_prompt_tokens
}
"
)
assert
isinstance
(
self
.
stop
,
list
)
if
any
(
not
stop_str
for
stop_str
in
self
.
stop
):
if
any
(
not
stop_str
for
stop_str
in
self
.
stop
):
raise
ValueError
(
"stop cannot contain an empty string."
)
raise
ValueError
(
"stop cannot contain an empty string."
)
if
self
.
stop
and
not
self
.
detokenize
:
if
self
.
stop
and
not
self
.
detokenize
:
...
@@ -290,6 +282,7 @@ class SamplingParams:
...
@@ -290,6 +282,7 @@ class SamplingParams:
"default value of 1.0 when not using beam search."
)
"default value of 1.0 when not using beam search."
)
def
_verify_greedy_sampling
(
self
)
->
None
:
def
_verify_greedy_sampling
(
self
)
->
None
:
assert
isinstance
(
self
.
best_of
,
int
)
if
self
.
best_of
>
1
:
if
self
.
best_of
>
1
:
raise
ValueError
(
"best_of must be 1 when using greedy sampling."
raise
ValueError
(
"best_of must be 1 when using greedy sampling."
f
"Got
{
self
.
best_of
}
."
)
f
"Got
{
self
.
best_of
}
."
)
...
@@ -303,7 +296,7 @@ class SamplingParams:
...
@@ -303,7 +296,7 @@ class SamplingParams:
if
model_eos_token_id
is
not
None
:
if
model_eos_token_id
is
not
None
:
# Add the eos token id into the sampling_params to support
# Add the eos token id into the sampling_params to support
# min_tokens processing.
# min_tokens processing.
self
.
all_stop_token_ids
.
add
(
model_eos_token_id
)
self
.
_
all_stop_token_ids
.
add
(
model_eos_token_id
)
# Update eos_token_id for generation
# Update eos_token_id for generation
if
(
eos_ids
:
=
generation_config
.
get
(
"eos_token_id"
))
is
not
None
:
if
(
eos_ids
:
=
generation_config
.
get
(
"eos_token_id"
))
is
not
None
:
...
@@ -315,7 +308,7 @@ class SamplingParams:
...
@@ -315,7 +308,7 @@ class SamplingParams:
# purposes.
# purposes.
eos_ids
.
discard
(
model_eos_token_id
)
eos_ids
.
discard
(
model_eos_token_id
)
if
eos_ids
:
if
eos_ids
:
self
.
all_stop_token_ids
.
update
(
eos_ids
)
self
.
_
all_stop_token_ids
.
update
(
eos_ids
)
if
not
self
.
ignore_eos
:
if
not
self
.
ignore_eos
:
eos_ids
.
update
(
self
.
stop_token_ids
)
eos_ids
.
update
(
self
.
stop_token_ids
)
self
.
stop_token_ids
=
list
(
eos_ids
)
self
.
stop_token_ids
=
list
(
eos_ids
)
...
@@ -330,6 +323,10 @@ class SamplingParams:
...
@@ -330,6 +323,10 @@ class SamplingParams:
return
SamplingType
.
RANDOM_SEED
return
SamplingType
.
RANDOM_SEED
return
SamplingType
.
RANDOM
return
SamplingType
.
RANDOM
@
property
def
all_stop_token_ids
(
self
)
->
Set
[
int
]:
return
self
.
_all_stop_token_ids
def
clone
(
self
)
->
"SamplingParams"
:
def
clone
(
self
)
->
"SamplingParams"
:
"""Deep copy excluding LogitsProcessor objects.
"""Deep copy excluding LogitsProcessor objects.
...
...
vllm/sequence.py
View file @
ff7ec82c
...
@@ -4,10 +4,11 @@ import enum
...
@@ -4,10 +4,11 @@ import enum
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
array
import
array
from
array
import
array
from
collections
import
defaultdict
from
collections
import
defaultdict
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Dict
,
List
,
Mapping
,
Optional
,
Set
,
Tuple
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Mapping
,
Optional
,
Set
,
Union
,
cast
)
Tuple
,
Union
,
cast
)
import
msgspec
import
numpy
import
numpy
import
torch
import
torch
...
@@ -16,13 +17,18 @@ from vllm.lora.request import LoRARequest
...
@@ -16,13 +17,18 @@ from vllm.lora.request import LoRARequest
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.inputs
import
LLMInputs
from
vllm.inputs
import
LLMInputs
from
vllm.multimodal
import
MultiModalDataDict
from
vllm.multimodal.base
import
MultiModalDataDict
from
vllm.spec_decode.metrics
import
SpecDecodeWorkerMetrics
VLLM_TOKEN_ID_ARRAY_TYPE
=
"l"
# We use dataclass for now because it is used for
# openai server output, and msgspec is not serializable.
# TODO(sang): Fix it.
@
dataclass
@
dataclass
class
Logprob
:
class
Logprob
:
"""Infos for supporting OpenAI compatible logprobs and token ranks.
"""Infos for supporting OpenAI compatible logprobs and token ranks.
...
@@ -112,7 +118,23 @@ class RequestMetrics:
...
@@ -112,7 +118,23 @@ class RequestMetrics:
model_execute_time
:
Optional
[
float
]
=
None
model_execute_time
:
Optional
[
float
]
=
None
class
SequenceData
:
class
SequenceDataDelta
(
msgspec
.
Struct
,
array_like
=
True
,
# type: ignore[call-arg]
omit_defaults
=
True
):
# type: ignore[call-arg]
"""Delta SequenceData to send to workers per step."""
# A new token to be appended to existing SequenceData.
new_output_token_ids
:
List
[
int
]
# Overwriting existing `cumulative_logprob`
new_cumulative_logprob
:
float
# Overwriting existing `num_computed_tokens`.
new_num_computed_tokens
:
int
# Overwriting existing `stage`.
new_stage
:
SequenceStage
class
SequenceData
(
msgspec
.
Struct
,
omit_defaults
=
True
):
# type: ignore[call-arg]
"""Data associated with a sequence.
"""Data associated with a sequence.
Args:
Args:
...
@@ -125,40 +147,57 @@ class SequenceData:
...
@@ -125,40 +147,57 @@ class SequenceData:
output_token_ids: The token IDs of the output.
output_token_ids: The token IDs of the output.
cumulative_logprob: The cumulative log probability of the output.
cumulative_logprob: The cumulative log probability of the output.
"""
"""
# NOTE: we cannot use Union[List, array] because msgspec cannot support
def
__init__
(
# union of 2 list types.
self
,
_prompt_token_ids
:
array
prompt_token_ids
:
List
[
int
],
_output_token_ids
:
array
=
msgspec
.
field
(
output_token_ids
:
Optional
[
List
[
int
]]
=
None
,
default_factory
=
lambda
:
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[]))
)
->
None
:
self
.
_prompt_token_ids
=
array
(
'l'
,
prompt_token_ids
)
### The below fields should not be passed as an argument ###
self
.
_prompt_token_ids_tuple
:
Tuple
[
int
,
...]
=
tuple
(
prompt_token_ids
)
_cumulative_logprob
:
float
=
0.0
self
.
_output_token_ids
=
array
(
_prompt_token_ids_tuple
:
Tuple
[
int
,
'l'
,
output_token_ids
if
output_token_ids
is
not
None
else
[])
...]
=
msgspec
.
field
(
default_factory
=
tuple
)
# The number of tokens that are computed (that run against the model).
self
.
cumulative_logprob
=
0.0
_num_computed_tokens
:
int
=
0
# The number of tokens that are computed (that run against the model).
_stage
:
SequenceStage
=
SequenceStage
.
PREFILL
self
.
_num_computed_tokens
=
0
_cached_all_token_ids
:
List
[
int
]
=
msgspec
.
field
(
default_factory
=
list
)
self
.
_stage
:
SequenceStage
=
SequenceStage
.
PREFILL
# It is used to get delta input. It is reset when `get_delta_and_reset`
# is called.
_new_appended_tokens
:
List
[
int
]
=
msgspec
.
field
(
default_factory
=
list
)
def
__post_init__
(
self
)
->
None
:
assert
self
.
_prompt_token_ids
.
typecode
==
"l"
assert
self
.
_output_token_ids
.
typecode
==
"l"
self
.
_prompt_token_ids_tuple
:
Tuple
[
int
,
...]
=
tuple
(
self
.
_prompt_token_ids
)
self
.
_update_cached_all_tokens
()
self
.
_update_cached_all_tokens
()
def
_update_cached_all_tokens
(
self
):
def
_update_cached_all_tokens
(
self
):
assert
isinstance
(
self
.
_prompt_token_ids
,
array
)
assert
isinstance
(
self
.
_output_token_ids
,
array
)
self
.
_cached_all_token_ids
:
List
[
int
]
=
list
(
self
.
_prompt_token_ids
+
self
.
_cached_all_token_ids
:
List
[
int
]
=
list
(
self
.
_prompt_token_ids
+
self
.
_output_token_ids
)
self
.
_output_token_ids
)
@
property
def
cumulative_logprob
(
self
)
->
float
:
return
self
.
_cumulative_logprob
@
property
@
property
def
prompt_token_ids
(
self
)
->
Tuple
[
int
,
...]:
def
prompt_token_ids
(
self
)
->
Tuple
[
int
,
...]:
return
self
.
_prompt_token_ids_tuple
return
self
.
_prompt_token_ids_tuple
@
prompt_token_ids
.
setter
@
prompt_token_ids
.
setter
def
prompt_token_ids
(
self
,
new_prompt_token_ids
)
->
None
:
def
prompt_token_ids
(
self
,
new_prompt_token_ids
)
->
None
:
self
.
_prompt_token_ids
=
array
(
'l'
,
new_prompt_token_ids
)
raise
NotImplementedError
self
.
_prompt_token_ids_tuple
=
tuple
(
new_prompt_token_ids
)
self
.
_update_cached_all_tokens
()
@
property
@
property
def
prompt_token_ids_array
(
self
)
->
array
:
def
prompt_token_ids_array
(
self
)
->
array
:
"""Return the prompt token ids in array type.
Note that the array is in "I" type, and it is not compatible
with torch.long (2 bytes vs 4 bytes). So beware of the usage.
"""
return
self
.
_prompt_token_ids
return
self
.
_prompt_token_ids
@
property
@
property
...
@@ -166,18 +205,26 @@ class SequenceData:
...
@@ -166,18 +205,26 @@ class SequenceData:
return
tuple
(
self
.
_output_token_ids
)
return
tuple
(
self
.
_output_token_ids
)
@
output_token_ids
.
setter
@
output_token_ids
.
setter
def
output_token_ids
(
self
,
new_output_token_ids
)
->
None
:
def
output_token_ids
(
self
,
new_output_token_ids
:
List
[
int
])
->
None
:
self
.
_output_token_ids
=
array
(
'l'
,
new_output_token_ids
)
self
.
_output_token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
new_output_token_ids
)
self
.
_update_cached_all_tokens
()
self
.
_update_cached_all_tokens
()
@
property
@
property
def
output_token_ids_array
(
self
)
->
array
:
def
output_token_ids_array
(
self
)
->
array
:
"""Return the prompt token ids in array type.
Note that the array is in "I" type, and it is not compatible
with torch.long (2 bytes vs 4 bytes). So beware of the usage.
"""
assert
isinstance
(
self
.
_output_token_ids
,
array
)
return
self
.
_output_token_ids
return
self
.
_output_token_ids
def
append_token_id
(
self
,
token_id
:
int
,
logprob
:
float
)
->
None
:
def
append_token_id
(
self
,
token_id
:
int
,
logprob
:
float
)
->
None
:
self
.
_output_token_ids
.
append
(
token_id
)
self
.
_output_token_ids
.
append
(
token_id
)
self
.
_new_appended_tokens
.
append
(
token_id
)
self
.
_cached_all_token_ids
.
append
(
token_id
)
self
.
_cached_all_token_ids
.
append
(
token_id
)
self
.
cumulative_logprob
+=
logprob
self
.
_
cumulative_logprob
+=
logprob
def
get_len
(
self
)
->
int
:
def
get_len
(
self
)
->
int
:
return
len
(
self
.
_output_token_ids
)
+
len
(
self
.
_prompt_token_ids
)
return
len
(
self
.
_output_token_ids
)
+
len
(
self
.
_prompt_token_ids
)
...
@@ -222,6 +269,7 @@ class SequenceData:
...
@@ -222,6 +269,7 @@ class SequenceData:
"""
"""
self
.
_num_computed_tokens
=
0
self
.
_num_computed_tokens
=
0
self
.
_stage
=
SequenceStage
.
PREFILL
self
.
_stage
=
SequenceStage
.
PREFILL
self
.
_new_appended_tokens
=
[]
def
get_num_uncomputed_tokens
(
self
)
->
int
:
def
get_num_uncomputed_tokens
(
self
)
->
int
:
"""Return the number of prefill tokens that are not computed."""
"""Return the number of prefill tokens that are not computed."""
...
@@ -241,6 +289,21 @@ class SequenceData:
...
@@ -241,6 +289,21 @@ class SequenceData:
def
get_output_token_ids
(
self
)
->
Tuple
[
int
,
...]:
def
get_output_token_ids
(
self
)
->
Tuple
[
int
,
...]:
return
self
.
output_token_ids
return
self
.
output_token_ids
def
get_delta_and_reset
(
self
)
->
SequenceDataDelta
:
delta
=
SequenceDataDelta
(
self
.
_new_appended_tokens
,
self
.
_cumulative_logprob
,
self
.
get_num_computed_tokens
(),
self
.
stage
)
# Reset delta state.
self
.
_new_appended_tokens
=
[]
return
delta
def
apply_delta
(
self
,
delta
:
SequenceDataDelta
):
self
.
_num_computed_tokens
=
delta
.
new_num_computed_tokens
self
.
_cumulative_logprob
=
delta
.
new_cumulative_logprob
self
.
_stage
=
delta
.
new_stage
self
.
_output_token_ids
.
extend
(
delta
.
new_output_token_ids
)
self
.
_cached_all_token_ids
.
extend
(
delta
.
new_output_token_ids
)
@
property
@
property
def
stage
(
self
)
->
SequenceStage
:
def
stage
(
self
)
->
SequenceStage
:
return
self
.
_stage
return
self
.
_stage
...
@@ -248,8 +311,9 @@ class SequenceData:
...
@@ -248,8 +311,9 @@ class SequenceData:
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceData("
return
(
f
"SequenceData("
f
"prompt_token_ids=
{
self
.
_prompt_token_ids
}
, "
f
"prompt_token_ids=
{
self
.
_prompt_token_ids
}
, "
f
"output_token_ids=
{
self
.
_output_token_ids
}
, "
f
"output_token_ids=
{
self
.
output_token_ids
}
, "
f
"cumulative_logprob=
{
self
.
cumulative_logprob
}
)"
)
f
"cumulative_logprob=
{
self
.
cumulative_logprob
}
, "
f
"get_num_computed_tokens=
{
self
.
get_num_computed_tokens
()
}
"
)
class
Sequence
:
class
Sequence
:
...
@@ -325,7 +389,8 @@ class Sequence:
...
@@ -325,7 +389,8 @@ class Sequence:
f
"invalid input
{
inputs
}
; did you forget the "
f
"invalid input
{
inputs
}
; did you forget the "
"encoder input prompt fields?"
)
"encoder input prompt fields?"
)
self
.
data
=
SequenceData
(
self
.
prompt_token_ids
)
self
.
data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
self
.
prompt_token_ids
))
self
.
output_logprobs
:
SampleLogprobs
=
[]
self
.
output_logprobs
:
SampleLogprobs
=
[]
self
.
output_text
=
""
self
.
output_text
=
""
...
@@ -490,8 +555,8 @@ class Sequence:
...
@@ -490,8 +555,8 @@ class Sequence:
f
"num_blocks=
{
self
.
n_blocks
}
, "
)
f
"num_blocks=
{
self
.
n_blocks
}
, "
)
@
data
class
class
SequenceGroupState
(
msgspec
.
Struct
,
class
SequenceGroupState
:
omit_defaults
=
True
):
# type: ignore[call-arg]
"""Mutable state tied to a specific sequence group"""
"""Mutable state tied to a specific sequence group"""
# for multi-step decoding
# for multi-step decoding
...
@@ -647,14 +712,19 @@ class SequenceGroup:
...
@@ -647,14 +712,19 @@ class SequenceGroup:
if
self
.
sampling_params
and
self
.
sampling_params
.
use_beam_search
:
if
self
.
sampling_params
and
self
.
sampling_params
.
use_beam_search
:
# For beam search, maximally there will always be `best_of` beam
# For beam search, maximally there will always be `best_of` beam
# candidates running in the future.
# candidates running in the future.
return
self
.
sampling_params
.
best_of
best_of
=
self
.
sampling_params
.
best_of
assert
isinstance
(
best_of
,
int
)
return
best_of
else
:
else
:
if
(
self
.
sampling_params
if
self
.
sampling_params
:
and
self
.
sampling_params
.
best_of
>
self
.
num_seqs
()):
best_of
=
self
.
sampling_params
.
best_of
# At prompt stage, the sequence group is not yet filled up
assert
isinstance
(
best_of
,
int
)
# and only have one sequence running. However, in the
if
best_of
>
self
.
num_seqs
():
# generation stage, we will have `best_of` sequences running.
# At prompt stage, the sequence group is not yet filled up
return
self
.
sampling_params
.
best_of
# and only have one sequence running. However, in the
# generation stage, we will have `best_of` sequences
# running.
return
best_of
# At sampling stages, return the number of actual sequences
# At sampling stages, return the number of actual sequences
# that are not finished yet.
# that are not finished yet.
return
self
.
num_unfinished_seqs
()
return
self
.
num_unfinished_seqs
()
...
@@ -757,7 +827,32 @@ class SequenceGroup:
...
@@ -757,7 +827,32 @@ class SequenceGroup:
f
"num_seqs=
{
len
(
self
.
seqs
)
}
)"
)
f
"num_seqs=
{
len
(
self
.
seqs
)
}
)"
)
class
SequenceGroupMetadata
:
class
SequenceGroupMetadataDelta
(
msgspec
.
Struct
,
tag
=
True
,
# type: ignore[call-arg]
array_like
=
True
,
# type: ignore[call-arg]
omit_defaults
=
True
):
# type: ignore[call-arg]
"""Delta of SequenceGroupMetadata.
After sending the first SequenceGroupMetadata, vLLM scheduler
only sends delta to reduce the data payload size.
"""
seq_data_delta
:
Dict
[
int
,
SequenceDataDelta
]
request_id
:
str
block_tables
:
Dict
[
int
,
List
[
int
]]
is_prompt
:
bool
do_sample
:
bool
=
True
token_chunk_size
:
Optional
[
int
]
=
None
computed_block_nums
:
Optional
[
List
[
int
]]
=
None
state
:
Optional
[
SequenceGroupState
]
=
msgspec
.
field
(
default_factory
=
lambda
:
SequenceGroupState
())
class
SequenceGroupMetadata
(
msgspec
.
Struct
,
tag
=
True
,
# type: ignore[call-arg]
array_like
=
True
,
# type: ignore[call-arg]
omit_defaults
=
True
):
# type: ignore[call-arg]
"""Metadata for a sequence group. Used to create `AttentionMetadata`.
"""Metadata for a sequence group. Used to create `AttentionMetadata`.
Args:
Args:
...
@@ -789,52 +884,39 @@ class SequenceGroupMetadata:
...
@@ -789,52 +884,39 @@ class SequenceGroupMetadata:
prompt_adapter_request: Prompt Adapter request.
prompt_adapter_request: Prompt Adapter request.
"""
"""
def
__init__
(
request_id
:
str
self
,
is_prompt
:
bool
request_id
:
str
,
seq_data
:
Dict
[
int
,
SequenceData
]
is_prompt
:
bool
,
sampling_params
:
SamplingParams
seq_data
:
Dict
[
int
,
SequenceData
],
block_tables
:
Dict
[
int
,
List
[
int
]]
sampling_params
:
SamplingParams
,
do_sample
:
bool
=
True
block_tables
:
Dict
[
int
,
List
[
int
]],
pooling_params
:
Optional
[
PoolingParams
]
=
None
do_sample
:
bool
=
True
,
lora_request
:
Optional
[
LoRARequest
]
=
None
pooling_params
:
Optional
[
PoolingParams
]
=
None
,
computed_block_nums
:
Optional
[
List
[
int
]]
=
None
token_chunk_size
:
Optional
[
int
]
=
None
,
state
:
Optional
[
SequenceGroupState
]
=
msgspec
.
field
(
lora_request
:
Optional
[
LoRARequest
]
=
None
,
default_factory
=
lambda
:
SequenceGroupState
())
computed_block_nums
:
Optional
[
List
[
int
]]
=
None
,
# "MultiModalDataDict" types. We have to use Any due to msgspec
state
:
Optional
[
SequenceGroupState
]
=
None
,
# doesn't allow to have union of 2 different dicts.
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
,
multi_modal_data
:
Optional
[
Any
]
=
None
encoder_seq_data
:
Optional
[
SequenceData
]
=
None
,
encoder_seq_data
:
Optional
[
SequenceData
]
=
None
cross_block_table
:
Optional
[
List
[
int
]]
=
None
,
cross_block_table
:
Optional
[
List
[
int
]]
=
None
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
)
->
None
:
token_chunk_size
:
Optional
[
int
]
=
None
self
.
request_id
=
request_id
self
.
is_prompt
=
is_prompt
### Stateful fields that are lazily defined. ###
self
.
seq_data
=
seq_data
# The number of speculative tokens adopted in this request.
self
.
sampling_params
=
sampling_params
# None means specuative decoding is not used.
self
.
block_tables
=
block_tables
# Zero means speculative decoding is disabled for some reasons.
self
.
pooling_params
=
pooling_params
# TODO: We should maintain this states out of the sequence group.
self
.
lora_request
=
lora_request
num_speculative_tokens
:
Optional
[
int
]
=
None
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
computed_block_nums
=
computed_block_nums
def
__post_init__
(
self
):
self
.
multi_modal_data
=
multi_modal_data
if
self
.
seq_data
is
not
None
and
self
.
token_chunk_size
is
None
:
self
.
state
=
SequenceGroupState
()
if
state
is
None
else
state
if
self
.
is_prompt
:
self
.
encoder_seq_data
=
encoder_seq_data
self
.
token_chunk_size
=
next
(
iter
(
self
.
cross_block_table
=
cross_block_table
self
.
seq_data
.
values
())).
get_len
()
self
.
_token_chunk_size
=
token_chunk_size
self
.
do_sample
=
do_sample
# The number of speculative tokens adopted in this request.
# None means specuative decoding is not used.
# Zero means speculative decoding is disabled for some reasons.
# TODO: We should maintain this states out of the sequence group.
self
.
num_speculative_tokens
=
None
if
seq_data
is
not
None
and
self
.
_token_chunk_size
is
None
:
if
is_prompt
:
self
.
_token_chunk_size
=
next
(
iter
(
seq_data
.
values
())).
get_len
()
else
:
else
:
self
.
_
token_chunk_size
=
1
self
.
token_chunk_size
=
1
@
property
@
property
def
lora_int_id
(
self
)
->
int
:
def
lora_int_id
(
self
)
->
int
:
...
@@ -850,18 +932,26 @@ class SequenceGroupMetadata:
...
@@ -850,18 +932,26 @@ class SequenceGroupMetadata:
return
self
.
prompt_adapter_request
.
prompt_adapter_num_virtual_tokens
\
return
self
.
prompt_adapter_request
.
prompt_adapter_num_virtual_tokens
\
if
self
.
prompt_adapter_request
else
0
if
self
.
prompt_adapter_request
else
0
@
property
def
apply_delta
(
self
,
def
token_chunk_size
(
self
)
->
int
:
sequence_group_metadata_delta
:
SequenceGroupMetadataDelta
):
"""Return the number of tokens to be processed (chunk size)."""
for
id
,
delta
in
sequence_group_metadata_delta
.
seq_data_delta
.
items
():
assert
self
.
_token_chunk_size
is
not
None
self
.
seq_data
[
id
].
apply_delta
(
delta
)
return
self
.
_token_chunk_size
assert
self
.
request_id
==
sequence_group_metadata_delta
.
request_id
self
.
block_tables
=
sequence_group_metadata_delta
.
block_tables
self
.
token_chunk_size
=
sequence_group_metadata_delta
.
token_chunk_size
self
.
do_sample
=
sequence_group_metadata_delta
.
do_sample
self
.
is_prompt
=
sequence_group_metadata_delta
.
is_prompt
def
finish_step
(
self
)
->
None
:
def
finish_step
(
self
)
->
None
:
assert
self
.
state
is
not
None
assert
self
.
state
.
current_step
<
self
.
state
.
num_steps
assert
self
.
state
.
current_step
<
self
.
state
.
num_steps
self
.
state
.
current_step
+=
1
self
.
state
.
current_step
+=
1
class
SequenceOutput
:
class
SequenceOutput
(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
"""The model output associated with a sequence.
"""The model output associated with a sequence.
Args:
Args:
...
@@ -871,16 +961,9 @@ class SequenceOutput:
...
@@ -871,16 +961,9 @@ class SequenceOutput:
logprobs: The logprobs of the output token.
logprobs: The logprobs of the output token.
(Token id -> logP(x_i+1 | x_0, ..., x_i))
(Token id -> logP(x_i+1 | x_0, ..., x_i))
"""
"""
parent_seq_id
:
int
def
__init__
(
output_token
:
int
self
,
logprobs
:
Dict
[
int
,
Logprob
]
parent_seq_id
:
int
,
output_token
:
int
,
logprobs
:
Dict
[
int
,
Logprob
],
)
->
None
:
self
.
parent_seq_id
=
parent_seq_id
self
.
output_token
=
output_token
self
.
logprobs
=
logprobs
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceOutput(parent_seq_id=
{
self
.
parent_seq_id
}
, "
return
(
f
"SequenceOutput(parent_seq_id=
{
self
.
parent_seq_id
}
, "
...
@@ -908,17 +991,15 @@ class SequenceGroupOutput(ABC):
...
@@ -908,17 +991,15 @@ class SequenceGroupOutput(ABC):
pass
pass
class
CompletionSequenceGroupOutput
(
SequenceGroupOutput
):
class
CompletionSequenceGroupOutput
(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
__metaclass__
=
SequenceGroupOutput
"""The model output associated with a completion sequence group."""
"""The model output associated with a completion sequence group."""
samples
:
List
[
SequenceOutput
]
def
__init__
(
# Prompt logprob for each prompt query token.
self
,
prompt_logprobs
:
Optional
[
PromptLogprobs
]
samples
:
List
[
SequenceOutput
],
prompt_logprobs
:
Optional
[
PromptLogprobs
],
)
->
None
:
self
.
samples
=
samples
# Prompt logprob for each prompt query token.
self
.
prompt_logprobs
=
prompt_logprobs
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"CompletionSequenceGroupOutput(samples=
{
self
.
samples
}
, "
return
(
f
"CompletionSequenceGroupOutput(samples=
{
self
.
samples
}
, "
...
@@ -931,14 +1012,14 @@ class CompletionSequenceGroupOutput(SequenceGroupOutput):
...
@@ -931,14 +1012,14 @@ class CompletionSequenceGroupOutput(SequenceGroupOutput):
and
self
.
prompt_logprobs
==
other
.
prompt_logprobs
)
and
self
.
prompt_logprobs
==
other
.
prompt_logprobs
)
class
EmbeddingSequenceGroupOutput
(
SequenceGroupOutput
):
class
EmbeddingSequenceGroupOutput
(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
,
# type: ignore[call-arg]
):
"""The model output associated with an embedding sequence group."""
"""The model output associated with an embedding sequence group."""
__metaclass__
=
SequenceGroupOutput
def
__init__
(
embeddings
:
List
[
int
]
self
,
embeddings
:
List
[
float
],
)
->
None
:
self
.
embeddings
=
embeddings
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"EmbeddingSequenceGroupOutput("
return
(
f
"EmbeddingSequenceGroupOutput("
...
@@ -950,8 +1031,10 @@ class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
...
@@ -950,8 +1031,10 @@ class EmbeddingSequenceGroupOutput(SequenceGroupOutput):
return
self
.
embeddings
==
other
.
embeddings
return
self
.
embeddings
==
other
.
embeddings
@
dataclass
class
IntermediateTensors
(
class
IntermediateTensors
:
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
"""For all pipeline stages except the last, we need to return the hidden
"""For all pipeline stages except the last, we need to return the hidden
states and residuals to be sent to the next stage. This data structure
states and residuals to be sent to the next stage. This data structure
contains the hidden states and residuals for a request.
contains the hidden states and residuals for a request.
...
@@ -978,8 +1061,10 @@ class IntermediateTensors:
...
@@ -978,8 +1061,10 @@ class IntermediateTensors:
return
f
"IntermediateTensors(tensors=
{
self
.
tensors
}
)"
return
f
"IntermediateTensors(tensors=
{
self
.
tensors
}
)"
@
dataclass
class
SamplerOutput
(
class
SamplerOutput
:
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
"""For each sequence group, we generate a list of SequenceOutput object,
"""For each sequence group, we generate a list of SequenceOutput object,
each of which contains one possible candidate for the next token.
each of which contains one possible candidate for the next token.
...
@@ -1000,7 +1085,7 @@ class SamplerOutput:
...
@@ -1000,7 +1085,7 @@ class SamplerOutput:
sampled_token_ids_numpy
:
Optional
[
numpy
.
ndarray
]
=
None
sampled_token_ids_numpy
:
Optional
[
numpy
.
ndarray
]
=
None
# Spec decode metrics populated by workers.
# Spec decode metrics populated by workers.
spec_decode_worker_metrics
:
Optional
[
"
SpecDecodeWorkerMetrics
"
]
=
None
spec_decode_worker_metrics
:
Optional
[
SpecDecodeWorkerMetrics
]
=
None
# Optional last hidden states from the model.
# Optional last hidden states from the model.
hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -1039,12 +1124,14 @@ class SamplerOutput:
...
@@ -1039,12 +1124,14 @@ class SamplerOutput:
f
"spec_decode_worker_metrics=
{
self
.
spec_decode_worker_metrics
}
)"
)
f
"spec_decode_worker_metrics=
{
self
.
spec_decode_worker_metrics
}
)"
)
@
dataclass
class
PoolerOutput
(
class
PoolerOutput
:
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
"""The output from a pooling operation in the embedding model."""
"""The output from a pooling operation in the embedding model."""
outputs
:
List
[
EmbeddingSequenceGroupOutput
]
outputs
:
List
[
EmbeddingSequenceGroupOutput
]
spec_decode_worker_metrics
:
Optional
[
"
SpecDecodeWorkerMetrics
"
]
=
None
spec_decode_worker_metrics
:
Optional
[
SpecDecodeWorkerMetrics
]
=
None
def
__getitem__
(
self
,
idx
:
int
):
def
__getitem__
(
self
,
idx
:
int
):
return
self
.
outputs
[
idx
]
return
self
.
outputs
[
idx
]
...
@@ -1083,7 +1170,8 @@ def get_all_seq_ids_and_request_ids(
...
@@ -1083,7 +1170,8 @@ def get_all_seq_ids_and_request_ids(
return
seq_ids
,
request_id_seq_ids_mapping
return
seq_ids
,
request_id_seq_ids_mapping
class
HiddenStates
:
class
HiddenStates
(
msgspec
.
Struct
,
array_like
=
True
,
omit_defaults
=
True
):
# type: ignore[call-arg]
"""Hidden states corresponding to in-progress sequences.
"""Hidden states corresponding to in-progress sequences.
Used in speculative decoding to pass hidden states from
Used in speculative decoding to pass hidden states from
the target model to the proposer model in the subsequent step.
the target model to the proposer model in the subsequent step.
...
@@ -1091,42 +1179,53 @@ class HiddenStates:
...
@@ -1091,42 +1179,53 @@ class HiddenStates:
seq_ids are the sequence ids of each entry of the batch
seq_ids are the sequence ids of each entry of the batch
dimension of the hidden_states tensor"""
dimension of the hidden_states tensor"""
def
__init__
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
hidden_states
:
torch
.
Tensor
):
hidden_states
:
torch
.
Tensor
assert
len
(
seq_group_metadata_list
)
==
len
(
hidden_states
)
_seq_ids
:
List
[
int
]
=
msgspec
.
field
(
default_factory
=
list
)
self
.
seq_ids
:
List
[
int
]
=
get_all_seq_ids
(
seq_group_metadata_list
)
self
.
hidden_states
:
torch
.
Tensor
=
hidden_states
def
__post_init__
(
self
):
self
.
_seq_ids
=
get_all_seq_ids
(
self
.
seq_group_metadata_list
)
assert
len
(
self
.
seq_group_metadata_list
)
==
len
(
self
.
hidden_states
)
@
property
def
seq_ids
(
self
)
->
List
[
int
]:
return
self
.
_seq_ids
def
update
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
def
update
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
hidden_states
:
torch
.
Tensor
)
->
None
:
hidden_states
:
torch
.
Tensor
)
->
None
:
"""Update hidden states from target model invocation."""
"""Update hidden states from target model invocation."""
assert
len
(
seq_group_metadata_list
)
==
len
(
hidden_states
)
assert
len
(
seq_group_metadata_list
)
==
len
(
hidden_states
)
self
.
seq_ids
.
extend
(
get_all_seq_ids
(
seq_group_metadata_list
))
self
.
_
seq_ids
.
extend
(
get_all_seq_ids
(
seq_group_metadata_list
))
self
.
hidden_states
=
torch
.
cat
([
self
.
hidden_states
,
hidden_states
])
self
.
hidden_states
=
torch
.
cat
([
self
.
hidden_states
,
hidden_states
])
def
prune
(
self
,
def
prune
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
])
->
None
:
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
])
->
None
:
"""Prune to provided list of sequence ids."""
"""Prune to provided list of sequence ids."""
seq_ids
=
get_all_seq_ids
(
seq_group_metadata_list
)
seq_ids
=
get_all_seq_ids
(
seq_group_metadata_list
)
if
seq_ids
!=
self
.
seq_ids
:
if
seq_ids
!=
self
.
_
seq_ids
:
# Batch contents changed - prune removed sequences.
# Batch contents changed - prune removed sequences.
index
=
[
self
.
seq_ids
.
index
(
seq_id
)
for
seq_id
in
seq_ids
]
index
=
[
self
.
_
seq_ids
.
index
(
seq_id
)
for
seq_id
in
seq_ids
]
self
.
hidden_states
=
self
.
hidden_states
[
index
]
self
.
hidden_states
=
self
.
hidden_states
[
index
]
self
.
seq_ids
=
seq_ids
self
.
_
seq_ids
=
seq_ids
@
dataclass
class
ExecuteModelRequest
(
class
ExecuteModelRequest
:
msgspec
.
Struct
,
array_like
=
True
,
# type: ignore[call-arg]
omit_defaults
=
True
):
# type: ignore[call-arg]
"""The model execution request, containing CPU metadata only. The LLM
"""The model execution request, containing CPU metadata only. The LLM
engine should create an instance of this class for each request batch."""
engine should create an instance of this class for each request batch."""
# The sequence group metadata list.
# The sequence group metadata list.
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
seq_group_metadata_list
:
List
[
Union
[
SequenceGroupMetadata
,
SequenceGroupMetadataDelta
]]
# Blocks to swap in. List of CPU -> GPU block number.
# Blocks to swap in. List of CPU -> GPU block number.
blocks_to_swap_in
:
List
[
Tuple
[
int
,
int
]]
=
field
(
default_factory
=
list
)
blocks_to_swap_in
:
List
[
Tuple
[
int
,
int
]]
=
msgspec
.
field
(
default_factory
=
list
)
# Blocks to swap out. List of GPU -> CPU block number.
# Blocks to swap out. List of GPU -> CPU block number.
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
field
(
default_factory
=
list
)
blocks_to_swap_out
:
List
[
Tuple
[
int
,
int
]]
=
msgspec
.
field
(
default_factory
=
list
)
# Blocks to copy. Source to dest block.
# Blocks to copy. Source to dest block.
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
=
field
(
default_factory
=
list
)
blocks_to_copy
:
List
[
Tuple
[
int
,
int
]]
=
msgspec
.
field
(
default_factory
=
list
)
# Virtual engine ID for pipeline parallel.
# Virtual engine ID for pipeline parallel.
virtual_engine
:
int
=
0
virtual_engine
:
int
=
0
# The number of slots for lookahead decoding.
# The number of slots for lookahead decoding.
...
@@ -1138,7 +1237,7 @@ class ExecuteModelRequest:
...
@@ -1138,7 +1237,7 @@ class ExecuteModelRequest:
# The number of forward steps to run.
# The number of forward steps to run.
num_steps
:
int
=
1
num_steps
:
int
=
1
# Finished request ids since last step.
# Finished request ids since last step.
finished_requests_ids
:
List
[
str
]
=
field
(
default_factory
=
list
)
finished_requests_ids
:
List
[
str
]
=
msgspec
.
field
(
default_factory
=
list
)
# The last sampled token ids for multi step decoding.
# The last sampled token ids for multi step decoding.
last_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
last_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -1148,6 +1247,7 @@ class ExecuteModelRequest:
...
@@ -1148,6 +1247,7 @@ class ExecuteModelRequest:
# steps
# steps
assert
len
(
self
.
seq_group_metadata_list
)
>
0
assert
len
(
self
.
seq_group_metadata_list
)
>
0
first_seq_group
=
self
.
seq_group_metadata_list
[
0
]
first_seq_group
=
self
.
seq_group_metadata_list
[
0
]
assert
first_seq_group
.
state
is
not
None
return
first_seq_group
.
state
.
current_step
==
0
return
first_seq_group
.
state
.
current_step
==
0
@
property
@
property
...
@@ -1156,6 +1256,7 @@ class ExecuteModelRequest:
...
@@ -1156,6 +1256,7 @@ class ExecuteModelRequest:
# steps
# steps
assert
len
(
self
.
seq_group_metadata_list
)
>
0
assert
len
(
self
.
seq_group_metadata_list
)
>
0
first_seq_group
=
self
.
seq_group_metadata_list
[
0
]
first_seq_group
=
self
.
seq_group_metadata_list
[
0
]
assert
first_seq_group
.
state
is
not
None
num_steps
=
first_seq_group
.
state
.
num_steps
num_steps
=
first_seq_group
.
state
.
num_steps
current_step
=
first_seq_group
.
state
.
current_step
current_step
=
first_seq_group
.
state
.
current_step
return
num_steps
-
current_step
==
1
return
num_steps
-
current_step
==
1
...
@@ -1165,10 +1266,13 @@ class ExecuteModelRequest:
...
@@ -1165,10 +1266,13 @@ class ExecuteModelRequest:
# TODO(will) make this be able to handle batches with variable number of
# TODO(will) make this be able to handle batches with variable number of
# steps
# steps
assert
len
(
self
.
seq_group_metadata_list
)
>
0
assert
len
(
self
.
seq_group_metadata_list
)
>
0
return
self
.
seq_group_metadata_list
[
0
].
state
.
current_step
state
=
self
.
seq_group_metadata_list
[
0
].
state
assert
state
is
not
None
return
state
.
current_step
def
clone
(
def
clone
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
]
self
,
seq_group_metadata_list
:
List
[
Union
[
SequenceGroupMetadata
,
SequenceGroupMetadataDelta
]]
)
->
"ExecuteModelRequest"
:
)
->
"ExecuteModelRequest"
:
"""Clone the request with a new sequence group metadata list."""
"""Clone the request with a new sequence group metadata list."""
return
ExecuteModelRequest
(
return
ExecuteModelRequest
(
...
...
vllm/spec_decode/batch_expansion.py
View file @
ff7ec82c
from
array
import
array
from
itertools
import
chain
,
count
from
itertools
import
chain
,
count
from
typing
import
Iterator
,
List
,
Tuple
from
typing
import
Iterator
,
List
,
Tuple
import
torch
import
torch
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceData
,
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
ExecuteModelRequest
,
SequenceGroupMetadata
,
get_all_seq_ids
)
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
,
get_all_seq_ids
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
(
nvtx_range
,
sampler_output_to_torch
,
from
vllm.spec_decode.util
import
(
nvtx_range
,
sampler_output_to_torch
,
...
@@ -293,14 +295,15 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -293,14 +295,15 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
input sequence.
input sequence.
"""
"""
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_token_ids
=
seq_data
.
get_
prompt_token_ids
()
prompt_token_ids
=
seq_data
.
prompt_token_ids
_array
new_output_token_ids
=
[
*
seq_data
.
get_output_token_ids
(),
*
token_ids
]
new_output_token_ids
=
[
*
seq_data
.
get_output_token_ids
(),
*
token_ids
]
new_seq_data_dict
=
{
new_seq_data_dict
=
{
target_seq_id
:
target_seq_id
:
SequenceData
(
SequenceData
(
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
,
output_token_ids
=
new_output_token_ids
,
_output_token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
new_output_token_ids
),
),
),
}
}
# This is a hack. Technically, spec decoding should compute
# This is a hack. Technically, spec decoding should compute
...
...
vllm/spec_decode/metrics.py
View file @
ff7ec82c
import
time
import
time
from
dataclasses
import
dataclass
from
typing
import
Callable
,
Optional
from
typing
import
Callable
,
Optional
import
msgspec
import
torch
import
torch
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
...
@@ -9,8 +9,10 @@ from vllm.model_executor.layers.spec_decode_base_sampler import (
...
@@ -9,8 +9,10 @@ from vllm.model_executor.layers.spec_decode_base_sampler import (
from
vllm.utils
import
is_pin_memory_available
from
vllm.utils
import
is_pin_memory_available
@
dataclass
class
SpecDecodeWorkerMetrics
(
class
SpecDecodeWorkerMetrics
:
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
"""Dataclass holding metrics emitted from the spec decode worker.
"""Dataclass holding metrics emitted from the spec decode worker.
"""
"""
...
...
vllm/worker/worker.py
View file @
ff7ec82c
"""A GPU worker class."""
"""A GPU worker class."""
import
gc
import
gc
import
os
import
os
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Type
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
import
torch
import
torch
import
torch.distributed
import
torch.distributed
...
@@ -18,7 +18,9 @@ from vllm.model_executor import set_random_seed
...
@@ -18,7 +18,9 @@ from vllm.model_executor import set_random_seed
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
(
ExecuteModelRequest
,
IntermediateTensors
,
SamplerOutput
,
SequenceGroupMetadata
,
SequenceGroupMetadataDelta
)
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.embedding_model_runner
import
EmbeddingModelRunner
from
vllm.worker.embedding_model_runner
import
EmbeddingModelRunner
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
...
@@ -109,6 +111,7 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -109,6 +111,7 @@ class Worker(LocalOrDistributedWorkerBase):
self
.
cache_engine
:
List
[
CacheEngine
]
self
.
cache_engine
:
List
[
CacheEngine
]
# Initialize gpu_cache as embedding models don't initialize kv_caches
# Initialize gpu_cache as embedding models don't initialize kv_caches
self
.
gpu_cache
:
Optional
[
List
[
List
[
torch
.
Tensor
]]]
=
None
self
.
gpu_cache
:
Optional
[
List
[
List
[
torch
.
Tensor
]]]
=
None
self
.
_seq_group_metadata_cache
:
Dict
[
str
,
SequenceGroupMetadata
]
=
{}
def
_is_encoder_decoder_model
(
self
):
def
_is_encoder_decoder_model
(
self
):
return
self
.
model_config
.
is_encoder_decoder_model
return
self
.
model_config
.
is_encoder_decoder_model
...
@@ -303,6 +306,63 @@ class Worker(LocalOrDistributedWorkerBase):
...
@@ -303,6 +306,63 @@ class Worker(LocalOrDistributedWorkerBase):
and
worker_input
.
blocks_to_copy
.
numel
()
>
0
):
and
worker_input
.
blocks_to_copy
.
numel
()
>
0
):
self
.
cache_engine
[
virtual_engine
].
copy
(
worker_input
.
blocks_to_copy
)
self
.
cache_engine
[
virtual_engine
].
copy
(
worker_input
.
blocks_to_copy
)
def
_get_cached_seq_group_metadata
(
self
,
seq_group_metadata_list
:
List
[
Union
[
SequenceGroupMetadata
,
SequenceGroupMetadataDelta
]],
finished_request_ids
:
List
[
str
])
->
List
[
SequenceGroupMetadata
]:
"""Return a list of cached Sequence Group Metadata after updating its
state.
It is used because scheduler only sends delta to workers to reduce
the data payload size. The function also cleans up cache based on
a given `finished_request_ids`.
"""
new_seq_group_metadata_list
=
[]
for
metadata_or_delta
in
seq_group_metadata_list
:
request_id
=
metadata_or_delta
.
request_id
if
request_id
not
in
self
.
_seq_group_metadata_cache
:
# The first prefill.
assert
isinstance
(
metadata_or_delta
,
SequenceGroupMetadata
)
self
.
_seq_group_metadata_cache
[
request_id
]
=
metadata_or_delta
else
:
# The first prefill is already cached.
if
isinstance
(
metadata_or_delta
,
SequenceGroupMetadataDelta
):
self
.
_seq_group_metadata_cache
[
request_id
].
apply_delta
(
metadata_or_delta
)
else
:
# If metadata snapshot is sent again, it is
# preempted. Reset the cache because we need to start
# from scratch.
assert
isinstance
(
metadata_or_delta
,
SequenceGroupMetadata
)
self
.
_seq_group_metadata_cache
[
request_id
]
=
metadata_or_delta
new_seq_group_metadata_list
.
append
(
self
.
_seq_group_metadata_cache
[
request_id
])
# Clean up finished ids
for
finished_id
in
finished_request_ids
:
del
self
.
_seq_group_metadata_cache
[
finished_id
]
return
new_seq_group_metadata_list
def
_execute_model_spmd
(
self
,
execute_model_req
:
ExecuteModelRequest
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Optional
[
List
[
SamplerOutput
]]:
if
execute_model_req
is
not
None
:
new_seq_group_metadata_list
=
self
.
_get_cached_seq_group_metadata
(
execute_model_req
.
seq_group_metadata_list
,
execute_model_req
.
finished_requests_ids
)
execute_model_req
.
seq_group_metadata_list
=
(
new_seq_group_metadata_list
)
output
=
super
().
_execute_model_spmd
(
execute_model_req
,
intermediate_tensors
)
return
output
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
return
self
.
model_runner
.
add_lora
(
lora_request
)
return
self
.
model_runner
.
add_lora
(
lora_request
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment