Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ff7ec82c
Unverified
Commit
ff7ec82c
authored
Aug 18, 2024
by
SangBin Cho
Committed by
GitHub
Aug 18, 2024
Browse files
[Core] Optimize SPMD architecture with delta + serialization optimization (#7109)
parent
200a2ffa
Changes
36
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
471 additions
and
267 deletions
+471
-267
vllm/lora/request.py
vllm/lora/request.py
+9
-5
vllm/model_executor/models/blip.py
vllm/model_executor/models/blip.py
+6
-3
vllm/model_executor/models/blip2.py
vllm/model_executor/models/blip2.py
+7
-3
vllm/model_executor/models/chameleon.py
vllm/model_executor/models/chameleon.py
+7
-3
vllm/model_executor/models/clip.py
vllm/model_executor/models/clip.py
+6
-3
vllm/model_executor/models/fuyu.py
vllm/model_executor/models/fuyu.py
+9
-4
vllm/model_executor/models/minicpmv.py
vllm/model_executor/models/minicpmv.py
+4
-2
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+6
-3
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+6
-3
vllm/pooling_params.py
vllm/pooling_params.py
+7
-4
vllm/prompt_adapter/request.py
vllm/prompt_adapter/request.py
+7
-3
vllm/sampling_params.py
vllm/sampling_params.py
+72
-75
vllm/sequence.py
vllm/sequence.py
+250
-146
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+8
-5
vllm/spec_decode/metrics.py
vllm/spec_decode/metrics.py
+5
-3
vllm/worker/worker.py
vllm/worker/worker.py
+62
-2
No files found.
vllm/lora/request.py
View file @
ff7ec82c
import
warnings
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
import
msgspec
from
vllm.adapter_commons.request
import
AdapterRequest
@
dataclass
class
LoRARequest
(
AdapterRequest
):
class
LoRARequest
(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
"""
Request for a LoRA adapter.
...
...
@@ -18,16 +21,17 @@ class LoRARequest(AdapterRequest):
lora_int_id must be globally unique for a given adapter.
This is currently not enforced in vLLM.
"""
__metaclass__
=
AdapterRequest
lora_name
:
str
lora_int_id
:
int
lora_path
:
str
=
""
lora_local_path
:
Optional
[
str
]
=
field
(
default
=
None
,
repr
=
False
)
lora_local_path
:
Optional
[
str
]
=
msgspec
.
field
(
default
=
None
)
long_lora_max_len
:
Optional
[
int
]
=
None
__hash__
=
AdapterRequest
.
__hash__
def
__post_init__
(
self
):
if
'lora_local_path'
in
self
.
__
dict
__
:
if
'lora_local_path'
in
self
.
__
struct_fields
__
:
warnings
.
warn
(
"The 'lora_local_path' attribute is deprecated "
"and will be removed in a future version. "
...
...
vllm/model_executor/models/blip.py
View file @
ff7ec82c
"""Minimal implementation of BlipVisionModel intended to be only used
within a vision language model."""
from
array
import
array
from
typing
import
Optional
,
Union
import
torch
...
...
@@ -16,7 +17,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.multimodal.image
import
(
cached_get_tokenizer
,
repeat_and_pad_image_tokens
)
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
def
get_blip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
...
...
@@ -53,8 +54,10 @@ def dummy_seq_data_for_blip(
else
:
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
)
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
image_token_id
])
*
image_feature_size
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
image_feature_size
)
return
SequenceData
(
token_ids
)
...
...
vllm/model_executor/models/blip2.py
View file @
ff7ec82c
from
array
import
array
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
...
...
@@ -17,7 +18,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.models.opt
import
OPTModel
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
,
SequenceData
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SamplerOutput
,
SequenceData
)
from
.blip
import
(
BlipVisionModel
,
dummy_image_for_blip
,
get_max_blip_image_tokens
)
...
...
@@ -427,8 +429,10 @@ def dummy_seq_data_for_blip2(
else
:
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
*
num_images
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
*
num_images
)
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
image_token_id
])
*
image_feature_size
*
num_images
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
image_feature_size
*
num_images
)
return
SequenceData
(
token_ids
)
...
...
vllm/model_executor/models/chameleon.py
View file @
ff7ec82c
from
array
import
array
from
functools
import
cached_property
from
typing
import
(
Any
,
Dict
,
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
)
...
...
@@ -31,7 +32,8 @@ from vllm.model_executor.utils import set_weight_attrs
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
(
cached_get_tokenizer
,
repeat_and_pad_image_tokens
)
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
,
SequenceData
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SamplerOutput
,
SequenceData
)
from
vllm.utils
import
print_warning_once
from
.interfaces
import
SupportsMultiModal
...
...
@@ -70,8 +72,10 @@ def dummy_seq_data_for_chameleon(
else
:
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
*
num_images
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
*
num_images
)
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
image_token_id
])
*
image_feature_size
*
num_images
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
image_feature_size
*
num_images
)
return
SequenceData
(
token_ids
)
...
...
vllm/model_executor/models/clip.py
View file @
ff7ec82c
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from
array
import
array
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
...
...
@@ -17,7 +18,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.multimodal.image
import
(
cached_get_tokenizer
,
repeat_and_pad_image_tokens
)
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
def
get_clip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
...
...
@@ -53,8 +54,10 @@ def dummy_seq_data_for_clip(
else
:
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
*
num_images
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
*
num_images
)
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
image_token_id
])
*
image_feature_size
*
num_images
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
image_feature_size
*
num_images
)
return
SequenceData
(
token_ids
)
...
...
vllm/model_executor/models/fuyu.py
View file @
ff7ec82c
...
...
@@ -16,6 +16,7 @@
# limitations under the License.
""" PyTorch Fuyu model."""
import
math
from
array
import
array
from
typing
import
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
import
torch
...
...
@@ -37,7 +38,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.image
import
(
cached_get_image_processor
,
cached_get_tokenizer
)
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
,
SequenceData
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SamplerOutput
,
SequenceData
)
from
.interfaces
import
SupportsMultiModal
from
.utils
import
merge_multimodal_embeddings
...
...
@@ -97,9 +99,12 @@ def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
ncol
,
nrow
=
get_max_fuyu_image_feature_size
()
image_feature_size
=
get_max_fuyu_image_tokens
(
ctx
)
image_token_ids
=
([
_IMAGE_TOKEN_ID
]
*
ncol
+
[
_NEWLINE_TOKEN_ID
])
*
nrow
token_ids
=
image_token_ids
*
num_images
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
*
num_images
)
image_token_ids
=
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
_IMAGE_TOKEN_ID
])
*
ncol
+
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
_NEWLINE_TOKEN_ID
]))
*
nrow
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
image_token_ids
)
*
num_images
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
image_feature_size
*
num_images
)
return
SequenceData
(
token_ids
)
...
...
vllm/model_executor/models/minicpmv.py
View file @
ff7ec82c
...
...
@@ -23,6 +23,7 @@
"""Inference-only MiniCPM-V model compatible with HuggingFace weights."""
import
math
import
re
from
array
import
array
from
functools
import
partial
from
typing
import
(
Any
,
Callable
,
Iterable
,
List
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
)
...
...
@@ -55,7 +56,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.image
import
(
cached_get_image_processor
,
cached_get_tokenizer
)
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
,
SequenceData
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SamplerOutput
,
SequenceData
)
from
.idefics2_vision_model
import
Idefics2VisionTransformer
...
...
@@ -408,7 +410,7 @@ def get_max_minicpmv_image_tokens(ctx: InputContext):
def
dummy_seq_data_for_minicpmv
(
seq_len
:
int
,
num_images
:
int
):
token_ids
=
[
0
]
*
seq_len
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
]
)
*
seq_len
return
SequenceData
(
token_ids
)
...
...
vllm/model_executor/models/siglip.py
View file @
ff7ec82c
...
...
@@ -2,6 +2,7 @@
within a vision language model."""
import
math
from
array
import
array
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
...
...
@@ -25,7 +26,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.multimodal.image
import
(
cached_get_tokenizer
,
repeat_and_pad_image_tokens
)
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
def
get_siglip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
...
...
@@ -62,8 +63,10 @@ def dummy_seq_data_for_siglip(
else
:
image_feature_size
=
image_feature_size_override
token_ids
=
[
image_token_id
]
*
image_feature_size
*
num_images
token_ids
+=
[
0
]
*
(
seq_len
-
image_feature_size
*
num_images
)
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
image_token_id
])
*
image_feature_size
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
image_feature_size
)
return
SequenceData
(
token_ids
)
...
...
vllm/model_executor/sampling_metadata.py
View file @
ff7ec82c
...
...
@@ -6,7 +6,8 @@ from typing import Dict, List, Optional, Tuple
import
torch
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.triton_utils.sample
import
get_num_triton_sampler_splits
from
vllm.utils
import
(
PyObjectCache
,
async_tensor_h2d
,
is_pin_memory_available
,
make_tensor_with_pad
,
...
...
@@ -505,9 +506,11 @@ class SamplingTensors:
and
sampling_params
.
prompt_logprobs
is
not
None
):
prefill_len
=
len
(
seq_group
.
prompt_logprob_indices
)
prompt_tokens
.
extend
(
array
(
'l'
)
for
_
in
range
(
prefill_len
))
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
)
for
_
in
range
(
prefill_len
))
output_tokens
.
extend
(
array
(
'l'
)
for
_
in
range
(
prefill_len
))
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
)
for
_
in
range
(
prefill_len
))
if
seq_group
.
do_sample
:
for
seq_id
in
seq_ids
:
seq_data
=
seq_group
.
seq_data
[
seq_id
]
...
...
vllm/pooling_params.py
View file @
ff7ec82c
from
typing
import
Any
,
Optional
import
msgspec
class
PoolingParams
:
class
PoolingParams
(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
"""Pooling parameters for pooling.
Attributes:
additional_data: Any additional data needed for pooling.
"""
def
__init__
(
self
,
additional_data
:
Optional
[
Any
]
=
None
):
self
.
additional_data
=
additional_data
additional_data
:
Optional
[
Any
]
=
None
def
clone
(
self
)
->
"PoolingParams"
:
"""Returns a deep copy of the PoolingParams instance."""
...
...
vllm/prompt_adapter/request.py
View file @
ff7ec82c
from
dataclasses
import
dataclass
import
msgspec
from
vllm.adapter_commons.request
import
AdapterRequest
@
dataclass
class
PromptAdapterRequest
(
AdapterRequest
):
class
PromptAdapterRequest
(
msgspec
.
Struct
,
array_like
=
True
,
# type: ignore[call-arg]
omit_defaults
=
True
,
# type: ignore[call-arg]
frozen
=
True
):
# type: ignore[call-arg]
"""
Request for a Prompt adapter.
"""
__metaclass__
=
AdapterRequest
prompt_adapter_name
:
str
prompt_adapter_id
:
int
...
...
vllm/sampling_params.py
View file @
ff7ec82c
...
...
@@ -2,10 +2,10 @@
import
copy
from
enum
import
IntEnum
from
functools
import
cached_property
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Union
import
msgspec
import
torch
from
pydantic
import
Field
from
typing_extensions
import
Annotated
from
vllm.logger
import
init_logger
...
...
@@ -33,7 +33,11 @@ first argument, and returns a modified tensor of logits
to sample from."""
class
SamplingParams
:
class
SamplingParams
(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
# required for @cached_property.
dict
=
True
):
# type: ignore[call-arg]
"""Sampling parameters for text generation.
Overall, we follow the sampling parameters from the OpenAI text completion
...
...
@@ -112,87 +116,73 @@ class SamplingParams:
(i.e., no truncation).
"""
def
__init__
(
self
,
n
:
int
=
1
,
best_of
:
Optional
[
int
]
=
None
,
presence_penalty
:
float
=
0.0
,
frequency_penalty
:
float
=
0.0
,
repetition_penalty
:
float
=
1.0
,
temperature
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_k
:
int
=
-
1
,
min_p
:
float
=
0.0
,
seed
:
Optional
[
int
]
=
None
,
use_beam_search
:
bool
=
False
,
length_penalty
:
float
=
1.0
,
early_stopping
:
Union
[
bool
,
str
]
=
False
,
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
,
include_stop_str_in_output
:
bool
=
False
,
ignore_eos
:
bool
=
False
,
max_tokens
:
Optional
[
int
]
=
16
,
min_tokens
:
int
=
0
,
logprobs
:
Optional
[
int
]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
,
detokenize
:
bool
=
True
,
skip_special_tokens
:
bool
=
True
,
spaces_between_special_tokens
:
bool
=
True
,
logits_processors
:
Optional
[
List
[
LogitsProcessor
]]
=
None
,
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
Field
(
ge
=
1
)]]
=
None
,
)
->
None
:
self
.
n
=
n
self
.
best_of
=
best_of
if
best_of
is
not
None
else
n
self
.
presence_penalty
=
presence_penalty
self
.
frequency_penalty
=
frequency_penalty
self
.
repetition_penalty
=
repetition_penalty
if
0
<
temperature
<
_MAX_TEMP
:
n
:
int
=
1
best_of
:
Optional
[
int
]
=
None
presence_penalty
:
float
=
0.0
frequency_penalty
:
float
=
0.0
repetition_penalty
:
float
=
1.0
temperature
:
float
=
1.0
top_p
:
float
=
1.0
top_k
:
int
=
-
1
min_p
:
float
=
0.0
seed
:
Optional
[
int
]
=
None
use_beam_search
:
bool
=
False
length_penalty
:
float
=
1.0
early_stopping
:
Union
[
bool
,
str
]
=
False
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
ignore_eos
:
bool
=
False
max_tokens
:
Optional
[
int
]
=
16
min_tokens
:
int
=
0
logprobs
:
Optional
[
int
]
=
None
prompt_logprobs
:
Optional
[
int
]
=
None
# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
detokenize
:
bool
=
True
skip_special_tokens
:
bool
=
True
spaces_between_special_tokens
:
bool
=
True
# Optional[List[LogitsProcessor]] type. We use Any here because
# Optional[List[LogitsProcessor]] type is not supported by msgspec.
logits_processors
:
Optional
[
Any
]
=
None
include_stop_str_in_output
:
bool
=
False
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
msgspec
.
Meta
(
ge
=
1
)]]
=
None
# The below fields are not supposed to be used as an input.
# They are set in post_init.
output_text_buffer_length
:
int
=
0
_all_stop_token_ids
:
Set
[
int
]
=
msgspec
.
field
(
default_factory
=
set
)
def
__post_init__
(
self
)
->
None
:
self
.
best_of
=
self
.
best_of
or
self
.
n
if
0
<
self
.
temperature
<
_MAX_TEMP
:
logger
.
warning
(
"temperature %s is less than %s, which may cause numerical "
"errors nan or inf in tensors. We have maxed it out to %s."
,
temperature
,
_MAX_TEMP
,
_MAX_TEMP
)
temperature
=
max
(
temperature
,
_MAX_TEMP
)
self
.
temperature
=
temperature
self
.
top_p
=
top_p
self
.
top_k
=
top_k
self
.
min_p
=
min_p
if
seed
==
-
1
:
self
.
temperature
,
_MAX_TEMP
,
_MAX_TEMP
)
self
.
temperature
=
max
(
self
.
temperature
,
_MAX_TEMP
)
if
self
.
seed
==
-
1
:
self
.
seed
=
None
else
:
self
.
seed
=
seed
self
.
use_beam_search
=
use_beam_search
self
.
length_penalty
=
length_penalty
self
.
early_stopping
=
early_stopping
if
stop
is
None
:
self
.
seed
=
self
.
seed
if
self
.
stop
is
None
:
self
.
stop
=
[]
elif
isinstance
(
stop
,
str
):
self
.
stop
=
[
stop
]
elif
isinstance
(
self
.
stop
,
str
):
self
.
stop
=
[
self
.
stop
]
else
:
self
.
stop
=
list
(
stop
)
if
stop_token_ids
is
None
:
self
.
stop
=
list
(
self
.
stop
)
if
self
.
stop_token_ids
is
None
:
self
.
stop_token_ids
=
[]
else
:
self
.
stop_token_ids
=
list
(
stop_token_ids
)
self
.
ignore_eos
=
ignore_eos
self
.
max_tokens
=
max_tokens
self
.
min_tokens
=
min_tokens
self
.
logprobs
=
1
if
logprobs
is
True
else
logprobs
self
.
prompt_logprobs
=
1
if
prompt_logprobs
is
True
else
prompt_logprobs
# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
self
.
detokenize
=
detokenize
self
.
skip_special_tokens
=
skip_special_tokens
self
.
spaces_between_special_tokens
=
spaces_between_special_tokens
self
.
logits_processors
=
logits_processors
self
.
include_stop_str_in_output
=
include_stop_str_in_output
self
.
truncate_prompt_tokens
=
truncate_prompt_tokens
self
.
stop_token_ids
=
list
(
self
.
stop_token_ids
)
self
.
logprobs
=
1
if
self
.
logprobs
is
True
else
self
.
logprobs
self
.
prompt_logprobs
=
(
1
if
self
.
prompt_logprobs
is
True
else
self
.
prompt_logprobs
)
# Number of characters to hold back for stop string evaluation
# until sequence is finished.
if
self
.
stop
and
not
include_stop_str_in_output
:
if
self
.
stop
and
not
self
.
include_stop_str_in_output
:
self
.
output_text_buffer_length
=
max
(
len
(
s
)
for
s
in
self
.
stop
)
-
1
else
:
self
.
output_text_buffer_length
=
0
self
.
_verify_args
()
if
self
.
use_beam_search
:
...
...
@@ -206,11 +196,12 @@ class SamplingParams:
self
.
min_p
=
0.0
self
.
_verify_greedy_sampling
()
# eos_token_id is added to this by the engine
self
.
all_stop_token_ids
=
set
(
self
.
stop_token_ids
)
self
.
_
all_stop_token_ids
=
set
(
self
.
stop_token_ids
)
def
_verify_args
(
self
)
->
None
:
if
self
.
n
<
1
:
raise
ValueError
(
f
"n must be at least 1, got
{
self
.
n
}
."
)
assert
isinstance
(
self
.
best_of
,
int
)
if
self
.
best_of
<
self
.
n
:
raise
ValueError
(
f
"best_of must be greater than or equal to n, "
f
"got n=
{
self
.
n
}
and best_of=
{
self
.
best_of
}
."
)
...
...
@@ -257,6 +248,7 @@ class SamplingParams:
and
self
.
truncate_prompt_tokens
<
1
):
raise
ValueError
(
f
"truncate_prompt_tokens must be >= 1, "
f
"got
{
self
.
truncate_prompt_tokens
}
"
)
assert
isinstance
(
self
.
stop
,
list
)
if
any
(
not
stop_str
for
stop_str
in
self
.
stop
):
raise
ValueError
(
"stop cannot contain an empty string."
)
if
self
.
stop
and
not
self
.
detokenize
:
...
...
@@ -290,6 +282,7 @@ class SamplingParams:
"default value of 1.0 when not using beam search."
)
def
_verify_greedy_sampling
(
self
)
->
None
:
assert
isinstance
(
self
.
best_of
,
int
)
if
self
.
best_of
>
1
:
raise
ValueError
(
"best_of must be 1 when using greedy sampling."
f
"Got
{
self
.
best_of
}
."
)
...
...
@@ -303,7 +296,7 @@ class SamplingParams:
if
model_eos_token_id
is
not
None
:
# Add the eos token id into the sampling_params to support
# min_tokens processing.
self
.
all_stop_token_ids
.
add
(
model_eos_token_id
)
self
.
_
all_stop_token_ids
.
add
(
model_eos_token_id
)
# Update eos_token_id for generation
if
(
eos_ids
:
=
generation_config
.
get
(
"eos_token_id"
))
is
not
None
:
...
...
@@ -315,7 +308,7 @@ class SamplingParams:
# purposes.
eos_ids
.
discard
(
model_eos_token_id
)
if
eos_ids
:
self
.
all_stop_token_ids
.
update
(
eos_ids
)
self
.
_
all_stop_token_ids
.
update
(
eos_ids
)
if
not
self
.
ignore_eos
:
eos_ids
.
update
(
self
.
stop_token_ids
)
self
.
stop_token_ids
=
list
(
eos_ids
)
...
...
@@ -330,6 +323,10 @@ class SamplingParams:
return
SamplingType
.
RANDOM_SEED
return
SamplingType
.
RANDOM
@
property
def
all_stop_token_ids
(
self
)
->
Set
[
int
]:
return
self
.
_all_stop_token_ids
def
clone
(
self
)
->
"SamplingParams"
:
"""Deep copy excluding LogitsProcessor objects.
...
...
vllm/sequence.py
View file @
ff7ec82c
This diff is collapsed.
Click to expand it.
vllm/spec_decode/batch_expansion.py
View file @
ff7ec82c
from
array
import
array
from
itertools
import
chain
,
count
from
typing
import
Iterator
,
List
,
Tuple
import
torch
from
vllm
import
SamplingParams
from
vllm.sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
,
get_all_seq_ids
)
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
ExecuteModelRequest
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
,
get_all_seq_ids
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
(
nvtx_range
,
sampler_output_to_torch
,
...
...
@@ -293,14 +295,15 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
input sequence.
"""
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_token_ids
=
seq_data
.
get_
prompt_token_ids
()
prompt_token_ids
=
seq_data
.
prompt_token_ids
_array
new_output_token_ids
=
[
*
seq_data
.
get_output_token_ids
(),
*
token_ids
]
new_seq_data_dict
=
{
target_seq_id
:
SequenceData
(
prompt_token_ids
=
prompt_token_ids
,
output_token_ids
=
new_output_token_ids
,
prompt_token_ids
,
_output_token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
new_output_token_ids
),
),
}
# This is a hack. Technically, spec decoding should compute
...
...
vllm/spec_decode/metrics.py
View file @
ff7ec82c
import
time
from
dataclasses
import
dataclass
from
typing
import
Callable
,
Optional
import
msgspec
import
torch
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
...
...
@@ -9,8 +9,10 @@ from vllm.model_executor.layers.spec_decode_base_sampler import (
from
vllm.utils
import
is_pin_memory_available
@
dataclass
class
SpecDecodeWorkerMetrics
:
class
SpecDecodeWorkerMetrics
(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
"""Dataclass holding metrics emitted from the spec decode worker.
"""
...
...
vllm/worker/worker.py
View file @
ff7ec82c
"""A GPU worker class."""
import
gc
import
os
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Type
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
import
torch
import
torch.distributed
...
...
@@ -18,7 +18,9 @@ from vllm.model_executor import set_random_seed
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.platforms
import
current_platform
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.sequence
import
(
ExecuteModelRequest
,
IntermediateTensors
,
SamplerOutput
,
SequenceGroupMetadata
,
SequenceGroupMetadataDelta
)
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.embedding_model_runner
import
EmbeddingModelRunner
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
...
...
@@ -109,6 +111,7 @@ class Worker(LocalOrDistributedWorkerBase):
self
.
cache_engine
:
List
[
CacheEngine
]
# Initialize gpu_cache as embedding models don't initialize kv_caches
self
.
gpu_cache
:
Optional
[
List
[
List
[
torch
.
Tensor
]]]
=
None
self
.
_seq_group_metadata_cache
:
Dict
[
str
,
SequenceGroupMetadata
]
=
{}
def
_is_encoder_decoder_model
(
self
):
return
self
.
model_config
.
is_encoder_decoder_model
...
...
@@ -303,6 +306,63 @@ class Worker(LocalOrDistributedWorkerBase):
and
worker_input
.
blocks_to_copy
.
numel
()
>
0
):
self
.
cache_engine
[
virtual_engine
].
copy
(
worker_input
.
blocks_to_copy
)
def
_get_cached_seq_group_metadata
(
self
,
seq_group_metadata_list
:
List
[
Union
[
SequenceGroupMetadata
,
SequenceGroupMetadataDelta
]],
finished_request_ids
:
List
[
str
])
->
List
[
SequenceGroupMetadata
]:
"""Return a list of cached Sequence Group Metadata after updating its
state.
It is used because scheduler only sends delta to workers to reduce
the data payload size. The function also cleans up cache based on
a given `finished_request_ids`.
"""
new_seq_group_metadata_list
=
[]
for
metadata_or_delta
in
seq_group_metadata_list
:
request_id
=
metadata_or_delta
.
request_id
if
request_id
not
in
self
.
_seq_group_metadata_cache
:
# The first prefill.
assert
isinstance
(
metadata_or_delta
,
SequenceGroupMetadata
)
self
.
_seq_group_metadata_cache
[
request_id
]
=
metadata_or_delta
else
:
# The first prefill is already cached.
if
isinstance
(
metadata_or_delta
,
SequenceGroupMetadataDelta
):
self
.
_seq_group_metadata_cache
[
request_id
].
apply_delta
(
metadata_or_delta
)
else
:
# If metadata snapshot is sent again, it is
# preempted. Reset the cache because we need to start
# from scratch.
assert
isinstance
(
metadata_or_delta
,
SequenceGroupMetadata
)
self
.
_seq_group_metadata_cache
[
request_id
]
=
metadata_or_delta
new_seq_group_metadata_list
.
append
(
self
.
_seq_group_metadata_cache
[
request_id
])
# Clean up finished ids
for
finished_id
in
finished_request_ids
:
del
self
.
_seq_group_metadata_cache
[
finished_id
]
return
new_seq_group_metadata_list
def
_execute_model_spmd
(
self
,
execute_model_req
:
ExecuteModelRequest
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Optional
[
List
[
SamplerOutput
]]:
if
execute_model_req
is
not
None
:
new_seq_group_metadata_list
=
self
.
_get_cached_seq_group_metadata
(
execute_model_req
.
seq_group_metadata_list
,
execute_model_req
.
finished_requests_ids
)
execute_model_req
.
seq_group_metadata_list
=
(
new_seq_group_metadata_list
)
output
=
super
().
_execute_model_spmd
(
execute_model_req
,
intermediate_tensors
)
return
output
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
return
self
.
model_runner
.
add_lora
(
lora_request
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment