Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ad58e9b3
"vllm/vscode:/vscode.git/clone" did not exist on "978a4462bbc529ff204647543526e4caa08ed974"
Commit
ad58e9b3
authored
Sep 18, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.1.post2' into v0.6.1.post2-dev
parents
408f663a
9ba0817f
Changes
118
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
300 additions
and
104 deletions
+300
-104
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+2
-2
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+8
-0
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+4
-3
vllm/model_executor/models/phimoe.py
vllm/model_executor/models/phimoe.py
+1
-1
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+49
-34
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+9
-5
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+6
-0
vllm/outputs.py
vllm/outputs.py
+55
-24
vllm/plugins/__init__.py
vllm/plugins/__init__.py
+13
-0
vllm/sampling_params.py
vllm/sampling_params.py
+16
-1
vllm/sequence.py
vllm/sequence.py
+36
-5
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+30
-2
vllm/utils.py
vllm/utils.py
+4
-0
vllm/version.py
vllm/version.py
+2
-1
vllm/worker/cpu_model_runner.py
vllm/worker/cpu_model_runner.py
+5
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+5
-2
vllm/worker/model_runner_base.py
vllm/worker/model_runner_base.py
+34
-0
vllm/worker/multi_step_model_runner.py
vllm/worker/multi_step_model_runner.py
+21
-23
No files found.
vllm/model_executor/models/__init__.py
View file @
ad58e9b3
...
@@ -90,12 +90,12 @@ _MULTIMODAL_MODELS = {
...
@@ -90,12 +90,12 @@ _MULTIMODAL_MODELS = {
"PaliGemmaForConditionalGeneration"
:
(
"paligemma"
,
"PaliGemmaForConditionalGeneration"
:
(
"paligemma"
,
"PaliGemmaForConditionalGeneration"
),
"PaliGemmaForConditionalGeneration"
),
"Phi3VForCausalLM"
:
(
"phi3v"
,
"Phi3VForCausalLM"
),
"Phi3VForCausalLM"
:
(
"phi3v"
,
"Phi3VForCausalLM"
),
"UltravoxModel"
:
(
"ultravox"
,
"UltravoxModel"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"PixtralForConditionalGeneration"
:
(
"pixtral"
,
"PixtralForConditionalGeneration"
:
(
"pixtral"
,
"PixtralForConditionalGeneration"
),
"PixtralForConditionalGeneration"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"Qwen2VLForConditionalGeneration"
:
(
"qwen2_vl"
,
"Qwen2VLForConditionalGeneration"
:
(
"qwen2_vl"
,
"Qwen2VLForConditionalGeneration"
),
"Qwen2VLForConditionalGeneration"
),
"UltravoxModel"
:
(
"ultravox"
,
"UltravoxModel"
),
}
}
_CONDITIONAL_GENERATION_MODELS
=
{
_CONDITIONAL_GENERATION_MODELS
=
{
"BartModel"
:
(
"bart"
,
"BartForConditionalGeneration"
),
"BartModel"
:
(
"bart"
,
"BartForConditionalGeneration"
),
...
...
vllm/model_executor/models/gemma2.py
View file @
ad58e9b3
...
@@ -312,6 +312,14 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
...
@@ -312,6 +312,14 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
# Gemma does not apply LoRA to the embedding layer.
# Gemma does not apply LoRA to the embedding layer.
embedding_modules
=
{}
embedding_modules
=
{}
embedding_padding_modules
=
[]
embedding_padding_modules
=
[]
bitsandbytes_stacked_params_mapping
=
{
# shard_name, weight_name, index
"q_proj"
:
(
"qkv_proj"
,
0
),
"k_proj"
:
(
"qkv_proj"
,
1
),
"v_proj"
:
(
"qkv_proj"
,
2
),
"gate_proj"
:
(
"gate_up_proj"
,
0
),
"up_proj"
:
(
"gate_up_proj"
,
1
),
}
def
__init__
(
def
__init__
(
self
,
self
,
...
...
vllm/model_executor/models/internvl.py
View file @
ad58e9b3
...
@@ -270,6 +270,7 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
...
@@ -270,6 +270,7 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
# Add an N dimension for number of images per prompt (currently 1).
# Add an N dimension for number of images per prompt (currently 1).
data
=
data
.
unsqueeze
(
0
)
data
=
data
.
unsqueeze
(
0
)
elif
is_list_of
(
data
,
Image
.
Image
):
elif
is_list_of
(
data
,
Image
.
Image
):
# we can't stack here because the images may have different num_patches
data
=
[
data
=
[
image_to_pixel_values
(
img
,
image_to_pixel_values
(
img
,
image_size
,
image_size
,
...
@@ -277,7 +278,6 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
...
@@ -277,7 +278,6 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
max_num
,
max_num
,
use_thumbnail
=
use_thumbnail
)
for
img
in
data
use_thumbnail
=
use_thumbnail
)
for
img
in
data
]
]
data
=
torch
.
stack
(
data
)
model_config
=
ctx
.
model_config
model_config
=
ctx
.
model_config
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
tokenizer
=
cached_get_tokenizer
(
model_config
.
tokenizer
,
trust_remote_code
=
True
)
trust_remote_code
=
True
)
...
@@ -449,11 +449,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
...
@@ -449,11 +449,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal):
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel values. "
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
f
"Got type:
{
type
(
pixel_values
)
}
"
)
# We need to flatten (B, N, P) to (B*N*P),
# so we call flatten_bn twice.
return
InternVLImagePixelInputs
(
return
InternVLImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
data
=
self
.
_validate_pixel_values
(
data
=
self
.
_validate_pixel_values
(
flatten_bn
(
pixel_values
,
concat
=
True
)
.
flatten
(
0
,
1
)
),
flatten_bn
(
flatten_bn
(
pixel_values
)
,
concat
=
True
)),
)
)
raise
AssertionError
(
"This line should be unreachable."
)
raise
AssertionError
(
"This line should be unreachable."
)
...
...
vllm/model_executor/models/phimoe.py
View file @
ad58e9b3
...
@@ -600,7 +600,7 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA):
...
@@ -600,7 +600,7 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA):
weight_loader
(
weight_loader
(
param
,
param
,
loaded_weight
,
loaded_weight
,
weight_
name
,
name
,
shard_id
=
shard_id
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
expert_id
=
expert_id
,
)
)
...
...
vllm/model_executor/models/pixtral.py
View file @
ad58e9b3
import
math
from
array
import
array
from
array
import
array
from
dataclasses
import
dataclass
,
fields
from
dataclasses
import
dataclass
,
fields
from
itertools
import
tee
from
itertools
import
tee
...
@@ -15,11 +14,12 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalMask
...
@@ -15,11 +14,12 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalMask
from
vllm.attention
import
AttentionMetadata
from
vllm.attention
import
AttentionMetadata
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.config
import
CacheConfig
,
MultiModalConfig
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
from
vllm.inputs
import
INPUT_REGISTRY
,
InputContext
,
LLMInputs
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.utils
import
merge_multimodal_embeddings
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.base
import
MultiModalInputs
...
@@ -48,23 +48,29 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
...
@@ -48,23 +48,29 @@ def dummy_data_for_pixtral(ctx: InputContext, seq_len: int,
tokenizer
=
cached_get_tokenizer
(
tokenizer
=
cached_get_tokenizer
(
ctx
.
model_config
.
tokenizer
,
ctx
.
model_config
.
tokenizer
,
tokenizer_mode
=
ctx
.
model_config
.
tokenizer_mode
)
tokenizer_mode
=
ctx
.
model_config
.
tokenizer_mode
)
mm_encoder
=
tokenizer
.
instruct
.
mm_encoder
mm_config
=
ctx
.
model_config
.
multimodal_config
mm_encoder
=
tokenizer
.
mistral
.
instruct_tokenizer
.
mm_encoder
max_num_images_per_request
=
mm_config
.
limit_per_prompt
.
get
(
"image"
,
1
)
patch_size
=
mm_encoder
.
mm_config
.
image_patch_size
image_token_id
=
mm_encoder
.
special_ids
.
img
# approximate image size
mm_config
=
ctx
.
model_config
.
multimodal_config
size
=
int
(
math
.
sqrt
(
seq_len
)
*
mm_encoder
.
mm_config
.
im
age_patch_size
)
num_images
=
mm_config
.
l
im
it_per_prompt
.
get
(
"image"
,
1
)
# dummy size
size
=
256
image
=
Image
.
new
(
"RGB"
,
(
size
,
size
),
color
=
0
)
image
=
Image
.
new
(
"RGB"
,
(
size
,
size
),
color
=
0
)
img_chunk
=
ImageChunk
(
image
=
image
)
tokens
=
mm_encoder
(
img_chunk
).
tokens
image_feature_size
=
(
size
**
2
)
//
(
patch_size
**
2
)
token_ids
=
max_num_images_per_request
*
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
tokens
)
num_image_tokens
=
image_feature_size
*
num_images
token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
image_token_id
])
*
num_image_tokens
token_ids
+=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
num_image_tokens
)
seq_data
=
SequenceData
(
token_ids
)
seq_data
=
SequenceData
(
token_ids
)
mm_data
=
{
"image"
:
max_
num_images
_per_request
*
[
image
]}
mm_data
=
{
"image"
:
num_images
*
[
image
]}
return
seq_data
,
mm_data
return
seq_data
,
mm_data
...
@@ -99,32 +105,31 @@ def input_mapper_for_pixtral(ctx: InputContext,
...
@@ -99,32 +105,31 @@ def input_mapper_for_pixtral(ctx: InputContext,
return
MultiModalInputs
({
"images"
:
images
})
return
MultiModalInputs
({
"images"
:
images
})
def
merge_multimodal_embeddings
(
input_ids
:
torch
.
Tensor
,
def
input_processor_for_pixtral
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
inputs_embeds
:
torch
.
Tensor
,
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
image_features
:
Optional
[
List
[
torch
.
Tensor
]],
if
multi_modal_data
is
not
None
and
"image"
in
multi_modal_data
:
image_id
:
int
)
->
torch
.
Tensor
:
tokenizer
=
cached_get_tokenizer
(
text_locations
=
input_ids
!=
image_id
ctx
.
model_config
.
tokenizer
,
image_locations
=
input_ids
==
image_id
tokenizer_mode
=
ctx
.
model_config
.
tokenizer_mode
)
seq_len
=
input_ids
.
shape
[
0
]
N_txt
=
text_locations
.
sum
().
item
()
mm_encoder
=
tokenizer
.
mistral
.
instruct_tokenizer
.
mm_encoder
_
,
D_txt
=
inputs_embeds
.
shape
image_token_id
=
mm_encoder
.
special_ids
.
img
N_img
,
D_img
=
image_features
.
shape
assert
(
D_txt
==
D_img
),
(
f
"Text features dim
{
D_txt
}
should be equal "
if
image_token_id
not
in
llm_inputs
[
'prompt_token_ids'
]:
"to image features dim {D_img}"
)
raise
ValueError
(
assert
(
seq_len
==
N_txt
+
(
f
"You've passed
{
llm_inputs
=
}
without
{
image_token_id
=
}
"
N_img
),
(
f
"seq_len
{
seq_len
}
should be equal to N_txt + N_img "
" Make sure to process your input via mistral_common's"
f
"
{
(
N_txt
,
N_img
,
image_locations
.
sum
().
item
())
}
"
)
" tokenizer or pass a chat completion request. For more"
" For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411."
))
inputs_embeds
[
image_locations
,
:]
=
image_features
return
llm_inputs
return
inputs_embeds
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
input_mapper_for_pixtral
)
@
MULTIMODAL_REGISTRY
.
register_image_input_mapper
(
input_mapper_for_pixtral
)
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_pixtral_image_tokens
)
@
MULTIMODAL_REGISTRY
.
register_max_image_tokens
(
get_max_pixtral_image_tokens
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_pixtral
)
@
INPUT_REGISTRY
.
register_dummy_data
(
dummy_data_for_pixtral
)
@
INPUT_REGISTRY
.
register_input_processor
(
input_processor_for_pixtral
)
class
PixtralForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
class
PixtralForConditionalGeneration
(
nn
.
Module
,
SupportsMultiModal
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -201,11 +206,21 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -201,11 +206,21 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal):
return
None
return
None
if
isinstance
(
images
,
torch
.
Tensor
):
if
isinstance
(
images
,
torch
.
Tensor
):
# always take last images
# if passed as batch take all images
images
=
[
images
[
-
1
][
i
]
for
i
in
range
(
images
.
size
(
1
))]
N
,
B
,
C
,
W
,
H
=
images
.
shape
images
=
images
.
reshape
(
N
*
B
,
C
,
W
,
H
)
images
=
[
images
[
i
]
for
i
in
range
(
images
.
size
(
0
))]
elif
isinstance
(
images
,
list
):
elif
isinstance
(
images
,
list
):
# always take last images
# if passed as list flatten lists of tensors
images
=
[
images
[
-
1
][
i
]
for
i
in
range
(
len
(
images
[
0
]))]
flatten_images
=
[]
for
imgs_per_req
in
images
:
imgs_per_req
=
[
imgs_per_req
[
i
]
for
i
in
range
(
imgs_per_req
.
size
(
0
))
]
if
isinstance
(
imgs_per_req
,
torch
.
Tensor
)
else
imgs_per_req
flatten_images
.
extend
(
imgs_per_req
)
images
=
flatten_images
return
images
return
images
...
...
vllm/model_executor/models/qwen.py
View file @
ad58e9b3
...
@@ -50,6 +50,7 @@ from vllm.multimodal.base import MultiModalInputs
...
@@ -50,6 +50,7 @@ from vllm.multimodal.base import MultiModalInputs
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
IntermediateTensors
,
SequenceData
)
SequenceData
)
from
vllm.utils
import
is_list_of
from
.utils
import
flatten_bn
,
is_pp_missing_parameter
,
make_layers
from
.utils
import
flatten_bn
,
is_pp_missing_parameter
,
make_layers
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
...
@@ -697,9 +698,12 @@ def input_processor_for_qwen(ctx: InputContext,
...
@@ -697,9 +698,12 @@ def input_processor_for_qwen(ctx: InputContext,
raise
ValueError
(
raise
ValueError
(
f
"Expected img embeds to be have 3 dimensions, got
{
num_dims
}
"
)
f
"Expected img embeds to be have 3 dimensions, got
{
num_dims
}
"
)
num_images
=
1
if
num_dims
==
2
else
image_data
.
shape
[
0
]
num_images
=
1
if
num_dims
==
2
else
image_data
.
shape
[
0
]
else
:
elif
isinstance
(
image_data
,
Image
.
Image
):
# TODO - handle multiple image inputs once the API is solidified
num_images
=
1
num_images
=
1
elif
is_list_of
(
image_data
,
Image
.
Image
):
num_images
=
len
(
image_data
)
else
:
raise
TypeError
(
f
"Invalid image type:
{
type
(
image_data
)
}
"
)
if
prompt
is
None
:
if
prompt
is
None
:
prompt
=
tokenizer
.
decode
(
prompt_token_ids
)
prompt
=
tokenizer
.
decode
(
prompt_token_ids
)
...
@@ -780,11 +784,11 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
...
@@ -780,11 +784,11 @@ def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs:
f
"[# images,
{
MAX_QWEN_IMG_TOKENS
}
,
{
img_emb_size
}
], but "
f
"[# images,
{
MAX_QWEN_IMG_TOKENS
}
,
{
img_emb_size
}
], but "
f
"received shape [
{
data
.
shape
}
]"
)
f
"received shape [
{
data
.
shape
}
]"
)
pixel_values
=
data
pixel_values
=
data
else
:
else
:
transform
=
build_normalization_transform
(
image_size
)
transform
=
build_normalization_transform
(
image_size
)
# TODO - handle multiple image inputs once the API is solidified
if
not
isinstance
(
data
,
(
list
,
tuple
)):
transformed_images
=
[
transform
(
data
)]
data
=
[
data
]
transformed_images
=
[
transform
(
datum
)
for
datum
in
data
]
pixel_values
=
torch
.
stack
(
transformed_images
,
dim
=
0
)
pixel_values
=
torch
.
stack
(
transformed_images
,
dim
=
0
)
return
MultiModalInputs
({
"pixel_values"
:
pixel_values
})
return
MultiModalInputs
({
"pixel_values"
:
pixel_values
})
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
ad58e9b3
...
@@ -1055,6 +1055,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -1055,6 +1055,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
if
weight_name
not
in
name
:
if
weight_name
not
in
name
:
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
...
@@ -1078,6 +1081,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
...
@@ -1078,6 +1081,9 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
loaded_weight
=
loaded_weight
.
reshape
(
-
1
)
try
:
try
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
except
KeyError
:
except
KeyError
:
print
(
params_dict
.
keys
())
print
(
params_dict
.
keys
())
...
...
vllm/outputs.py
View file @
ad58e9b3
...
@@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
...
@@ -5,6 +5,7 @@ from typing import Sequence as GenericSequence
from
typing
import
Union
from
typing
import
Union
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sampling_params
import
RequestOutputKind
from
vllm.sequence
import
(
PromptLogprobs
,
RequestMetrics
,
SampleLogprobs
,
from
vllm.sequence
import
(
PromptLogprobs
,
RequestMetrics
,
SampleLogprobs
,
SequenceGroup
,
SequenceStatus
)
SequenceGroup
,
SequenceStatus
)
...
@@ -92,7 +93,7 @@ class RequestOutput:
...
@@ -92,7 +93,7 @@ class RequestOutput:
self
,
self
,
request_id
:
str
,
request_id
:
str
,
prompt
:
Optional
[
str
],
prompt
:
Optional
[
str
],
prompt_token_ids
:
List
[
int
],
prompt_token_ids
:
Optional
[
List
[
int
]
]
,
prompt_logprobs
:
Optional
[
PromptLogprobs
],
prompt_logprobs
:
Optional
[
PromptLogprobs
],
outputs
:
List
[
CompletionOutput
],
outputs
:
List
[
CompletionOutput
],
finished
:
bool
,
finished
:
bool
,
...
@@ -113,19 +114,26 @@ class RequestOutput:
...
@@ -113,19 +114,26 @@ class RequestOutput:
self
.
encoder_prompt_token_ids
=
encoder_prompt_token_ids
self
.
encoder_prompt_token_ids
=
encoder_prompt_token_ids
@
classmethod
@
classmethod
def
from_seq_group
(
cls
,
seq_group
:
SequenceGroup
)
->
"RequestOutput"
:
def
from_seq_group
(
cls
,
if
seq_group
.
sampling_params
is
None
:
seq_group
:
SequenceGroup
)
->
Optional
[
"RequestOutput"
]:
sampling_params
=
seq_group
.
sampling_params
if
sampling_params
is
None
:
raise
ValueError
(
raise
ValueError
(
"Sampling parameters are missing for a CompletionRequest."
)
"Sampling parameters are missing for a CompletionRequest."
)
finished
=
seq_group
.
is_finished
()
if
sampling_params
.
output_kind
==
RequestOutputKind
.
FINAL_ONLY
and
(
not
finished
):
return
None
seqs
=
seq_group
.
get_seqs
()
seqs
=
seq_group
.
get_seqs
()
if
len
(
seqs
)
==
1
:
if
len
(
seqs
)
==
1
:
top_n_seqs
=
seqs
top_n_seqs
=
seqs
else
:
else
:
# Get the top-n sequences.
# Get the top-n sequences.
n
=
seq_group
.
sampling_params
.
n
n
=
sampling_params
.
n
if
seq_group
.
sampling_params
.
use_beam_search
:
if
sampling_params
.
use_beam_search
:
sorting_key
=
lambda
seq
:
seq
.
get_beam_search_score
(
sorting_key
=
lambda
seq
:
seq
.
get_beam_search_score
(
seq_group
.
sampling_params
.
length_penalty
)
sampling_params
.
length_penalty
)
else
:
else
:
sorting_key
=
lambda
seq
:
seq
.
get_cumulative_logprob
()
sorting_key
=
lambda
seq
:
seq
.
get_cumulative_logprob
()
sorted_seqs
=
sorted
(
seqs
,
key
=
sorting_key
,
reverse
=
True
)
sorted_seqs
=
sorted
(
seqs
,
key
=
sorting_key
,
reverse
=
True
)
...
@@ -135,26 +143,49 @@ class RequestOutput:
...
@@ -135,26 +143,49 @@ class RequestOutput:
# NOTE: We need omit logprobs here explicitly because the sequence
# NOTE: We need omit logprobs here explicitly because the sequence
# always has the logprobs of the sampled tokens even if the
# always has the logprobs of the sampled tokens even if the
# logprobs are not requested.
# logprobs are not requested.
include_logprobs
=
seq_group
.
sampling_params
.
logprobs
is
not
None
include_logprobs
=
sampling_params
.
logprobs
is
not
None
text_buffer_length
=
seq_group
.
sampling_params
.
output_text_buffer_length
text_buffer_length
=
sampling_params
.
output_text_buffer_length
outputs
=
[
delta
=
sampling_params
.
output_kind
==
RequestOutputKind
.
DELTA
CompletionOutput
(
seqs
.
index
(
seq
),
outputs
=
[]
seq
.
get_output_text_to_return
(
text_buffer_length
),
include_prompt
=
True
seq
.
data
.
_output_token_ids
,
for
seq
in
top_n_seqs
:
seq
.
get_cumulative_logprob
()
if
include_logprobs
else
None
,
output_text
=
seq
.
get_output_text_to_return
(
seq
.
output_logprobs
if
include_logprobs
else
None
,
text_buffer_length
,
delta
)
SequenceStatus
.
get_finished_reason
(
seq
.
status
),
output_token_ids
=
seq
.
get_output_token_ids_to_return
(
delta
)
seq
.
stop_reason
)
for
seq
in
top_n_seqs
output_logprobs
=
seq
.
output_logprobs
if
include_logprobs
else
None
]
if
delta
:
# Slice logprobs delta if applicable
if
output_logprobs
:
output_logprobs
=
output_logprobs
[
-
len
(
output_token_ids
):]
# Don't include prompt if this is after the first output
# containing decode token ids
if
include_prompt
and
seq
.
get_output_len
()
>
len
(
output_token_ids
):
include_prompt
=
False
outputs
.
append
(
CompletionOutput
(
seqs
.
index
(
seq
),
output_text
,
output_token_ids
,
seq
.
get_cumulative_logprob
()
if
include_logprobs
else
None
,
output_logprobs
,
SequenceStatus
.
get_finished_reason
(
seq
.
status
),
seq
.
stop_reason
))
# Every sequence in the sequence group should have the same prompt.
# Every sequence in the sequence group should have the same prompt.
prompt
=
seq_group
.
prompt
if
include_prompt
:
prompt_token_ids
=
seq_group
.
prompt_token_ids
prompt
=
seq_group
.
prompt
encoder_prompt
=
seq_group
.
encoder_prompt
prompt_token_ids
=
seq_group
.
prompt_token_ids
encoder_prompt_token_ids
=
seq_group
.
encoder_prompt_token_ids
encoder_prompt
=
seq_group
.
encoder_prompt
prompt_logprobs
=
seq_group
.
prompt_logprobs
encoder_prompt_token_ids
=
seq_group
.
encoder_prompt_token_ids
finished
=
seq_group
.
is_finished
()
prompt_logprobs
=
seq_group
.
prompt_logprobs
else
:
prompt
=
None
prompt_token_ids
=
None
encoder_prompt
=
None
encoder_prompt_token_ids
=
None
prompt_logprobs
=
None
finished_time
=
time
.
time
()
if
finished
else
None
finished_time
=
time
.
time
()
if
finished
else
None
seq_group
.
set_finished_time
(
finished_time
)
seq_group
.
set_finished_time
(
finished_time
)
return
cls
(
seq_group
.
request_id
,
return
cls
(
seq_group
.
request_id
,
...
...
vllm/plugins/__init__.py
View file @
ad58e9b3
import
logging
import
logging
from
typing
import
Callable
,
Optional
,
Union
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -29,3 +30,15 @@ def load_general_plugins():
...
@@ -29,3 +30,15 @@ def load_general_plugins():
except
Exception
:
except
Exception
:
logger
.
exception
(
"Failed to load general plugin: %s"
,
logger
.
exception
(
"Failed to load general plugin: %s"
,
plugin
.
name
)
plugin
.
name
)
_torch_compile_backend
:
Optional
[
Union
[
Callable
,
str
]]
=
None
def
set_torch_compile_backend
(
backend
:
Union
[
Callable
,
str
]):
global
_torch_compile_backend
_torch_compile_backend
=
backend
def
get_torch_compile_backend
()
->
Optional
[
Union
[
Callable
,
str
]]:
return
_torch_compile_backend
vllm/sampling_params.py
View file @
ad58e9b3
"""Sampling parameters for text generation."""
"""Sampling parameters for text generation."""
import
copy
import
copy
from
enum
import
IntEnum
from
enum
import
Enum
,
IntEnum
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Set
,
Union
...
@@ -33,6 +33,15 @@ first argument, and returns a modified tensor of logits
...
@@ -33,6 +33,15 @@ first argument, and returns a modified tensor of logits
to sample from."""
to sample from."""
class
RequestOutputKind
(
Enum
):
# Return entire output so far in every RequestOutput
CUMULATIVE
=
0
# Return only deltas in each RequestOutput
DELTA
=
1
# Do not return intermediate RequestOuputs
FINAL_ONLY
=
2
class
SamplingParams
(
class
SamplingParams
(
msgspec
.
Struct
,
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
omit_defaults
=
True
,
# type: ignore[call-arg]
...
@@ -147,6 +156,7 @@ class SamplingParams(
...
@@ -147,6 +156,7 @@ class SamplingParams(
logits_processors
:
Optional
[
Any
]
=
None
logits_processors
:
Optional
[
Any
]
=
None
include_stop_str_in_output
:
bool
=
False
include_stop_str_in_output
:
bool
=
False
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
msgspec
.
Meta
(
ge
=
1
)]]
=
None
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
msgspec
.
Meta
(
ge
=
1
)]]
=
None
output_kind
:
RequestOutputKind
=
RequestOutputKind
.
CUMULATIVE
# The below fields are not supposed to be used as an input.
# The below fields are not supposed to be used as an input.
# They are set in post_init.
# They are set in post_init.
...
@@ -182,6 +192,7 @@ class SamplingParams(
...
@@ -182,6 +192,7 @@ class SamplingParams(
logits_processors
:
Optional
[
List
[
LogitsProcessor
]]
=
None
,
logits_processors
:
Optional
[
List
[
LogitsProcessor
]]
=
None
,
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
truncate_prompt_tokens
:
Optional
[
Annotated
[
int
,
msgspec
.
Meta
(
ge
=
1
)]]
=
None
,
msgspec
.
Meta
(
ge
=
1
)]]
=
None
,
output_kind
:
RequestOutputKind
=
RequestOutputKind
.
CUMULATIVE
,
)
->
"SamplingParams"
:
)
->
"SamplingParams"
:
return
SamplingParams
(
return
SamplingParams
(
n
=
1
if
n
is
None
else
n
,
n
=
1
if
n
is
None
else
n
,
...
@@ -213,6 +224,7 @@ class SamplingParams(
...
@@ -213,6 +224,7 @@ class SamplingParams(
spaces_between_special_tokens
=
spaces_between_special_tokens
,
spaces_between_special_tokens
=
spaces_between_special_tokens
,
logits_processors
=
logits_processors
,
logits_processors
=
logits_processors
,
truncate_prompt_tokens
=
truncate_prompt_tokens
,
truncate_prompt_tokens
=
truncate_prompt_tokens
,
output_kind
=
output_kind
,
)
)
def
__post_init__
(
self
)
->
None
:
def
__post_init__
(
self
)
->
None
:
...
@@ -317,6 +329,9 @@ class SamplingParams(
...
@@ -317,6 +329,9 @@ class SamplingParams(
raise
ValueError
(
raise
ValueError
(
"stop strings are only supported when detokenize is True. "
"stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop."
)
"Set detokenize=True to use stop."
)
if
self
.
best_of
!=
self
.
n
and
self
.
output_kind
==
(
RequestOutputKind
.
DELTA
):
raise
ValueError
(
"best_of must equal n to use output_kind=DELTA"
)
def
_verify_beam_search
(
self
)
->
None
:
def
_verify_beam_search
(
self
)
->
None
:
if
self
.
best_of
==
1
:
if
self
.
best_of
==
1
:
...
...
vllm/sequence.py
View file @
ad58e9b3
...
@@ -5,8 +5,9 @@ from abc import ABC, abstractmethod
...
@@ -5,8 +5,9 @@ from abc import ABC, abstractmethod
from
array
import
array
from
array
import
array
from
collections
import
defaultdict
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Mapping
,
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Mapping
,
Optional
Optional
,
Set
,
Tuple
,
Union
,
cast
)
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Tuple
,
Union
,
cast
import
msgspec
import
msgspec
import
torch
import
torch
...
@@ -407,6 +408,10 @@ class Sequence:
...
@@ -407,6 +408,10 @@ class Sequence:
self
.
status
=
SequenceStatus
.
WAITING
self
.
status
=
SequenceStatus
.
WAITING
self
.
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
self
.
stop_reason
:
Union
[
int
,
str
,
None
]
=
None
# These are used to keep track of delta outputs
self
.
_last_token_ids_offset
:
int
=
0
self
.
_last_output_text_offset
:
int
=
0
# Used for incremental detokenization
# Used for incremental detokenization
self
.
prefix_offset
=
0
self
.
prefix_offset
=
0
self
.
read_offset
=
0
self
.
read_offset
=
0
...
@@ -462,11 +467,37 @@ class Sequence:
...
@@ -462,11 +467,37 @@ class Sequence:
return
self
.
prompt_adapter_request
.
prompt_adapter_id
\
return
self
.
prompt_adapter_request
.
prompt_adapter_id
\
if
self
.
prompt_adapter_request
else
0
if
self
.
prompt_adapter_request
else
0
def
get_output_text_to_return
(
self
,
buffer_length
:
int
):
def
get_output_text_to_return
(
self
,
buffer_length
:
int
,
delta
:
bool
)
->
str
:
"""If delta is True, only new text since the last call to
this method is returned"""
# We return the full output text if the sequence is finished.
# We return the full output text if the sequence is finished.
truncate
=
buffer_length
and
not
self
.
is_finished
()
truncate
=
buffer_length
and
not
self
.
is_finished
()
return
self
.
output_text
[:
-
buffer_length
]
if
truncate
else
(
if
not
delta
:
self
.
output_text
)
return
self
.
output_text
[:
-
buffer_length
]
if
truncate
else
(
self
.
output_text
)
length
=
len
(
self
.
output_text
)
if
truncate
:
length
-=
buffer_length
last_offset
=
self
.
_last_output_text_offset
if
last_offset
<
length
:
self
.
_last_output_text_offset
=
length
return
self
.
output_text
[
last_offset
:
length
]
return
""
def
get_output_token_ids_to_return
(
self
,
delta
:
bool
)
->
GenericSequence
[
int
]:
"""If delta is True, only new tokens since the last call to
this method are returned"""
if
not
delta
:
return
self
.
get_output_token_ids
()
length
=
self
.
get_output_len
()
last_offset
=
self
.
_last_token_ids_offset
if
last_offset
<
length
:
self
.
_last_token_ids_offset
=
length
return
self
.
data
.
_output_token_ids
[
last_offset
:]
return
()
def
hash_of_block
(
self
,
logical_idx
:
int
)
->
int
:
def
hash_of_block
(
self
,
logical_idx
:
int
)
->
int
:
# TODO This can produce incorrect hash when block size > prompt size
# TODO This can produce incorrect hash when block size > prompt size
...
...
vllm/transformers_utils/config.py
View file @
ad58e9b3
...
@@ -4,7 +4,9 @@ import json
...
@@ -4,7 +4,9 @@ import json
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Optional
,
Type
,
Union
from
typing
import
Any
,
Dict
,
Optional
,
Type
,
Union
from
huggingface_hub
import
file_exists
,
hf_hub_download
import
huggingface_hub
from
huggingface_hub
import
(
file_exists
,
hf_hub_download
,
try_to_load_from_cache
)
from
transformers
import
GenerationConfig
,
PretrainedConfig
from
transformers
import
GenerationConfig
,
PretrainedConfig
from
transformers.models.auto.image_processing_auto
import
(
from
transformers.models.auto.image_processing_auto
import
(
get_image_processor_config
)
get_image_processor_config
)
...
@@ -70,7 +72,22 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
...
@@ -70,7 +72,22 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
if
Path
(
model
).
exists
():
if
Path
(
model
).
exists
():
return
(
Path
(
model
)
/
config_name
).
is_file
()
return
(
Path
(
model
)
/
config_name
).
is_file
()
return
file_exists
(
model
,
config_name
,
revision
=
revision
,
token
=
token
)
# Offline mode support: Check if config file is cached already
cached_filepath
=
try_to_load_from_cache
(
repo_id
=
model
,
filename
=
config_name
,
revision
=
revision
)
if
isinstance
(
cached_filepath
,
str
):
# The config file exists in cache- we can continue trying to load
return
True
# NB: file_exists will only check for the existence of the config file on
# hf_hub. This will fail in offline mode.
try
:
return
file_exists
(
model
,
config_name
,
revision
=
revision
,
token
=
token
)
except
huggingface_hub
.
errors
.
OfflineModeIsEnabled
:
# Don't raise in offline mode, all we know is that we don't have this
# file cached.
return
False
def
get_config
(
def
get_config
(
...
@@ -102,6 +119,15 @@ def get_config(
...
@@ -102,6 +119,15 @@ def get_config(
token
=
kwargs
.
get
(
"token"
)):
token
=
kwargs
.
get
(
"token"
)):
config_format
=
ConfigFormat
.
MISTRAL
config_format
=
ConfigFormat
.
MISTRAL
else
:
else
:
# If we're in offline mode and found no valid config format, then
# raise an offline mode error to indicate to the user that they
# don't have files cached and may need to go online.
# This is conveniently triggered by calling file_exists().
file_exists
(
model
,
HF_CONFIG_NAME
,
revision
=
revision
,
token
=
kwargs
.
get
(
"token"
))
raise
ValueError
(
f
"No supported config format found in
{
model
}
"
)
raise
ValueError
(
f
"No supported config format found in
{
model
}
"
)
if
config_format
==
ConfigFormat
.
HF
:
if
config_format
==
ConfigFormat
.
HF
:
...
@@ -206,6 +232,8 @@ def load_params_config(model, revision) -> PretrainedConfig:
...
@@ -206,6 +232,8 @@ def load_params_config(model, revision) -> PretrainedConfig:
config_dict
[
"tie_word_embeddings"
]
=
config_dict
.
get
(
config_dict
[
"tie_word_embeddings"
]
=
config_dict
.
get
(
"tie_embeddings"
,
False
)
"tie_embeddings"
,
False
)
config_dict
[
"max_seq_len"
]
=
config_dict
.
get
(
"max_seq_len"
,
128_000
)
config_dict
[
"max_seq_len"
]
=
config_dict
.
get
(
"max_seq_len"
,
128_000
)
config_dict
[
"max_position_embeddings"
]
=
config_dict
.
get
(
"max_position_embeddings"
,
128_000
)
if
config_dict
.
get
(
"moe"
)
is
not
None
:
if
config_dict
.
get
(
"moe"
)
is
not
None
:
config_dict
[
"architectures"
]
=
[
"MixtralForCausalLM"
]
config_dict
[
"architectures"
]
=
[
"MixtralForCausalLM"
]
...
...
vllm/utils.py
View file @
ad58e9b3
...
@@ -82,6 +82,9 @@ STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
...
@@ -82,6 +82,9 @@ STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
"currently supported with encoder/"
"currently supported with encoder/"
"decoder models."
)
"decoder models."
)
STR_NOT_IMPL_ENC_DEC_CPU
=
(
"CPU is not currently supported with "
"encoder/decoder models."
)
# Efficiently import all enc/dec error strings
# Efficiently import all enc/dec error strings
# rather than having to import all of the above
# rather than having to import all of the above
STR_NOT_IMPL_ENC_DEC_ERR_STRS
=
{
STR_NOT_IMPL_ENC_DEC_ERR_STRS
=
{
...
@@ -97,6 +100,7 @@ STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
...
@@ -97,6 +100,7 @@ STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
"STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH"
:
STR_NOT_IMPL_ENC_DEC_CUDAGRAPH
,
"STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH"
:
STR_NOT_IMPL_ENC_DEC_CUDAGRAPH
,
"STR_NOT_IMPL_ENC_DEC_BACKEND"
:
STR_NOT_IMPL_ENC_DEC_BACKEND
,
"STR_NOT_IMPL_ENC_DEC_BACKEND"
:
STR_NOT_IMPL_ENC_DEC_BACKEND
,
"STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER"
:
STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER
,
"STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER"
:
STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER
,
"STR_NOT_IMPL_ENC_DEC_CPU"
:
STR_NOT_IMPL_ENC_DEC_CPU
}
}
# Constants related to forcing the attention backend selection
# Constants related to forcing the attention backend selection
...
...
vllm/version.py
View file @
ad58e9b3
...
@@ -2,6 +2,7 @@ import warnings
...
@@ -2,6 +2,7 @@ import warnings
try
:
try
:
import
vllm.commit_id
import
vllm.commit_id
__commit__
=
vllm
.
commit_id
.
__commit__
__commit__
=
vllm
.
commit_id
.
__commit__
except
Exception
as
e
:
except
Exception
as
e
:
warnings
.
warn
(
f
"Failed to read commit hash:
\n
{
e
}
"
,
warnings
.
warn
(
f
"Failed to read commit hash:
\n
{
e
}
"
,
...
@@ -9,4 +10,4 @@ except Exception as e:
...
@@ -9,4 +10,4 @@ except Exception as e:
stacklevel
=
2
)
stacklevel
=
2
)
__commit__
=
"COMMIT_HASH_PLACEHOLDER"
__commit__
=
"COMMIT_HASH_PLACEHOLDER"
__version__
=
"0.6.1"
__version__
=
"0.6.1
.post2
"
vllm/worker/cpu_model_runner.py
View file @
ad58e9b3
...
@@ -15,7 +15,7 @@ from vllm.model_executor.model_loader import get_model
...
@@ -15,7 +15,7 @@ from vllm.model_executor.model_loader import get_model
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
MultiModalInputs
)
MultiModalInputs
)
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
from
vllm.sequence
import
IntermediateTensors
,
SequenceGroupMetadata
from
vllm.utils
import
make_tensor_with_pad
from
vllm.utils
import
STR_NOT_IMPL_ENC_DEC_ERR_STRS
,
make_tensor_with_pad
from
vllm.worker.model_runner_base
import
(
from
vllm.worker.model_runner_base
import
(
ModelRunnerBase
,
ModelRunnerInputBase
,
ModelRunnerBase
,
ModelRunnerInputBase
,
_add_attn_metadata_broadcastable_dict
,
_add_attn_metadata_broadcastable_dict
,
...
@@ -121,6 +121,10 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
...
@@ -121,6 +121,10 @@ class CPUModelRunner(ModelRunnerBase[CPUModelInput]):
# Lazy initialization.
# Lazy initialization.
self
.
model
:
nn
.
Module
# Set after init_Model
self
.
model
:
nn
.
Module
# Set after init_Model
if
self
.
model_config
.
is_encoder_decoder_model
:
raise
NotImplementedError
(
STR_NOT_IMPL_ENC_DEC_ERR_STRS
[
'STR_NOT_IMPL_ENC_DEC_CPU'
])
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
self
.
model
=
get_model
(
model_config
=
self
.
model_config
,
self
.
model
=
get_model
(
model_config
=
self
.
model_config
,
load_config
=
self
.
load_config
,
load_config
=
self
.
load_config
,
...
...
vllm/worker/model_runner.py
View file @
ad58e9b3
...
@@ -53,7 +53,7 @@ from vllm.worker.model_runner_base import (
...
@@ -53,7 +53,7 @@ from vllm.worker.model_runner_base import (
_add_attn_metadata_broadcastable_dict
,
_add_attn_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
,
_add_sampling_metadata_broadcastable_dict
,
_init_attn_metadata_from_tensor_dict
,
_init_attn_metadata_from_tensor_dict
,
_init_sampling_metadata_from_tensor_dict
)
_init_sampling_metadata_from_tensor_dict
,
dump_input_when_exception
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.attention.backends.abstract
import
AttentionBackend
...
@@ -1064,10 +1064,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1064,10 +1064,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"This may lead to less accurate results!"
)
"This may lead to less accurate results!"
)
if
envs
.
VLLM_TEST_DYNAMO_GRAPH_CAPTURE
and
supports_dynamo
():
if
envs
.
VLLM_TEST_DYNAMO_GRAPH_CAPTURE
and
supports_dynamo
():
from
vllm.plugins
import
get_torch_compile_backend
backend
=
get_torch_compile_backend
()
or
"eager"
self
.
model
=
torch
.
compile
(
self
.
model
=
torch
.
compile
(
self
.
model
,
self
.
model
,
fullgraph
=
envs
.
VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE
,
fullgraph
=
envs
.
VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE
,
backend
=
"eager"
)
backend
=
backend
)
def
save_sharded_state
(
def
save_sharded_state
(
self
,
self
,
...
@@ -1489,6 +1491,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1489,6 +1491,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
virtual_engine
=
virtual_engine
)
virtual_engine
=
virtual_engine
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
@
dump_input_when_exception
(
exclude_args
=
[
0
],
exclude_kwargs
=
[
"self"
])
def
execute_model
(
def
execute_model
(
self
,
self
,
model_input
:
ModelInputForGPUWithSamplingMetadata
,
model_input
:
ModelInputForGPUWithSamplingMetadata
,
...
...
vllm/worker/model_runner_base.py
View file @
ad58e9b3
import
dataclasses
import
dataclasses
import
pickle
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
datetime
import
datetime
from
functools
import
wraps
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generic
,
List
,
Optional
,
Type
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Generic
,
List
,
Optional
,
Type
,
TypeVar
)
TypeVar
)
...
@@ -98,6 +101,37 @@ def _init_frozen_model_input_from_tensor_dict(
...
@@ -98,6 +101,37 @@ def _init_frozen_model_input_from_tensor_dict(
return
tensor_dict
return
tensor_dict
def
dump_input_when_exception
(
exclude_args
:
Optional
[
List
[
int
]]
=
None
,
exclude_kwargs
:
Optional
[
List
[
str
]]
=
None
):
def
_inner
(
func
):
@
wraps
(
func
)
def
_wrapper
(
*
args
,
**
kwargs
):
try
:
return
func
(
*
args
,
**
kwargs
)
except
Exception
as
err
:
timestamp
=
datetime
.
now
().
strftime
(
"%Y%m%d-%H%M%S"
)
filename
=
f
"/tmp/err_
{
func
.
__name__
}
_input_
{
timestamp
}
.pkl"
with
open
(
filename
,
"wb"
)
as
filep
:
dumped_inputs
=
{
k
:
v
for
k
,
v
in
kwargs
.
items
()
if
k
not
in
(
exclude_kwargs
or
[])
}
for
i
,
arg
in
enumerate
(
args
):
if
i
not
in
(
exclude_args
or
[]):
dumped_inputs
[
f
"arg_
{
i
}
"
]
=
arg
pickle
.
dump
(
dumped_inputs
,
filep
)
raise
type
(
err
)(
f
"Error in model execution (input dumped to
{
filename
}
): "
f
"
{
str
(
err
)
}
"
)
from
err
return
_wrapper
return
_inner
class
BroadcastableModelInput
(
ABC
):
class
BroadcastableModelInput
(
ABC
):
@
abstractmethod
@
abstractmethod
...
...
vllm/worker/multi_step_model_runner.py
View file @
ad58e9b3
...
@@ -4,13 +4,6 @@ from dataclasses import dataclass, field
...
@@ -4,13 +4,6 @@ from dataclasses import dataclass, field
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
)
Union
)
try
:
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
except
ModuleNotFoundError
:
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
from
vllm.attention.backends.rocm_flash_attn
import
(
ROCmFlashAttentionMetadata
as
FlashAttentionMetadata
)
import
torch
import
torch
from
vllm.distributed
import
get_pp_group
from
vllm.distributed
import
get_pp_group
...
@@ -36,6 +29,8 @@ if TYPE_CHECKING:
...
@@ -36,6 +29,8 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
MULTI_STEP_ATTENTION_BACKENDS
=
[
"flash-attn"
,
"flashinfer"
]
def
seq_output_builder
():
def
seq_output_builder
():
return
SequenceOutput
(
return
SequenceOutput
(
...
@@ -230,12 +225,15 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -230,12 +225,15 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
self
.
_base_model_runner
:
GPUModelRunnerBase
=
base_model_runner
self
.
_base_model_runner
:
GPUModelRunnerBase
=
base_model_runner
self
.
is_multi_step
=
self
.
scheduler_config
.
is_multi_step
self
.
is_multi_step
=
self
.
scheduler_config
.
is_multi_step
# used to copy tensors from GPU to CPU asynchronously
self
.
_copy_stream
=
torch
.
cuda
.
Stream
()
self
.
pinned_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
self
.
pinned_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
self
.
pythonization_cache
=
PythonizationCache
()
self
.
pythonization_cache
=
PythonizationCache
()
@
functools
.
cached_property
def
_copy_stream
(
self
):
# used to copy tensors from GPU to CPU asynchronously
return
torch
.
cuda
.
Stream
()
def
make_model_input_from_broadcasted_tensor_dict
(
def
make_model_input_from_broadcasted_tensor_dict
(
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
StatefulModelInput
:
self
,
tensor_dict
:
Dict
[
str
,
Any
])
->
StatefulModelInput
:
model_input
=
(
StatefulModelInput
.
from_broadcasted_tensor_dict
(
model_input
=
(
StatefulModelInput
.
from_broadcasted_tensor_dict
(
...
@@ -486,27 +484,27 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
...
@@ -486,27 +484,27 @@ class MultiStepModelRunner(GPUModelRunnerBase[StatefulModelInput]):
def
_advance_step
(
self
,
model_input
:
StatefulModelInput
,
def
_advance_step
(
self
,
model_input
:
StatefulModelInput
,
out
:
SamplerOutput
)
->
StatefulModelInput
:
out
:
SamplerOutput
)
->
StatefulModelInput
:
frozen_model_input
=
model_input
.
frozen_model_input
if
self
.
attn_backend
.
get_name
()
not
in
MULTI_STEP_ATTENTION_BACKENDS
:
assert
frozen_model_input
is
not
None
raise
ValueError
(
assert
frozen_model_input
.
attn_metadata
is
not
None
f
"Multi-step not supported for attention backend: "
f
"
{
self
.
attn_backend
.
get_name
()
}
. Set VLLM_ATTENTION_BACKEND "
f
"to a value from
{
MULTI_STEP_ATTENTION_BACKENDS
}
."
)
sampled_token_ids
=
model_input
.
cached_outputs
[
-
1
].
sampled_token_ids
num_seqs
=
model_input
.
num_seqs
num_seqs
=
model_input
.
num_seqs
num_queries
=
model_input
.
num_queries
num_queries
=
model_input
.
num_queries
assert
num_seqs
>
0
frozen_model_input
=
model_input
.
frozen_model_input
assert
num_queries
>
0
assert
frozen_model_input
is
not
None
assert
num_seqs
>=
num_queries
attn_metadata
=
frozen_model_input
.
attn_metadata
attn_metadata
=
frozen_model_input
.
attn_metadata
assert
isinstance
(
attn_metadata
,
FlashAttentionMetadata
)
assert
attn_metadata
is
not
None
attn_metadata
.
advance_step
(
attn_metadata
.
advance_step
(
frozen_model_input
,
frozen_model_input
,
model_input
.
cached_outputs
[
-
1
].
sampled_token_ids
,
self
.
block_size
,
sampled_token_ids
,
num_seqs
,
num_queries
)
self
.
block_size
,
num_seqs
,
if
frozen_model_input
.
seq_lens
is
not
None
:
num_queries
,
for
i
in
range
(
num_queries
):
)
frozen_model_input
.
seq_lens
[
i
]
=
attn_metadata
.
seq_lens
[
i
]
return
model_input
return
model_input
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment