Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0640f227
Commit
0640f227
authored
Sep 09, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.0' into v0.6.0-dev
parents
82f1ffdf
32e7db25
Changes
335
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
471 additions
and
485 deletions
+471
-485
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+2
-2
vllm/model_executor/models/qwen2_moe.py
vllm/model_executor/models/qwen2_moe.py
+2
-2
vllm/model_executor/models/siglip.py
vllm/model_executor/models/siglip.py
+73
-182
vllm/model_executor/models/stablelm.py
vllm/model_executor/models/stablelm.py
+2
-2
vllm/model_executor/models/starcoder2.py
vllm/model_executor/models/starcoder2.py
+2
-2
vllm/model_executor/models/ultravox.py
vllm/model_executor/models/ultravox.py
+94
-58
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+78
-26
vllm/model_executor/models/xverse.py
vllm/model_executor/models/xverse.py
+2
-2
vllm/model_executor/parameter.py
vllm/model_executor/parameter.py
+3
-2
vllm/multimodal/__init__.py
vllm/multimodal/__init__.py
+1
-2
vllm/multimodal/base.py
vllm/multimodal/base.py
+28
-32
vllm/multimodal/utils.py
vllm/multimodal/utils.py
+59
-22
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+6
-1
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+11
-0
vllm/scripts.py
vllm/scripts.py
+9
-0
vllm/sequence.py
vllm/sequence.py
+12
-79
vllm/spec_decode/batch_expansion.py
vllm/spec_decode/batch_expansion.py
+81
-65
vllm/spec_decode/draft_model_runner.py
vllm/spec_decode/draft_model_runner.py
+2
-2
vllm/spec_decode/medusa_worker.py
vllm/spec_decode/medusa_worker.py
+2
-2
vllm/spec_decode/mlp_speculator_worker.py
vllm/spec_decode/mlp_speculator_worker.py
+2
-2
No files found.
vllm/model_executor/models/qwen2.py
View file @
0640f227
...
@@ -42,13 +42,13 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
...
@@ -42,13 +42,13 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
from
.interfaces
import
SupportsLoRA
from
.utils
import
is_pp_missing_parameter
,
make_layers
from
.utils
import
is_pp_missing_parameter
,
make_layers
...
...
vllm/model_executor/models/qwen2_moe.py
View file @
0640f227
...
@@ -45,12 +45,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
...
@@ -45,12 +45,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
print_warning_once
from
vllm.utils
import
print_warning_once
from
.utils
import
is_pp_missing_parameter
,
make_layers
from
.utils
import
is_pp_missing_parameter
,
make_layers
...
...
vllm/model_executor/models/siglip.py
View file @
0640f227
...
@@ -3,18 +3,16 @@ within a vision language model."""
...
@@ -3,18 +3,16 @@ within a vision language model."""
import
math
import
math
from
array
import
array
from
array
import
array
from
typing
import
Iterable
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
PIL
import
Image
from
PIL
import
Image
from
torch
import
nn
from
torch
import
nn
from
transformers
import
SiglipVisionConfig
from
transformers
import
SiglipVisionConfig
from
transformers.models.siglip.modeling_siglip
import
SiglipAttention
from
transformers.models.siglip.modeling_siglip
import
SiglipSdpaAttention
from
vllm_flash_attn
import
flash_attn_func
from
xformers.ops
import
memory_efficient_attention
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
divide
,
get_tensor_model_parallel_world_size
from
vllm.inputs
import
LLMInputs
from
vllm.inputs
import
LLMInputs
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.activation
import
get_act_fn
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
...
@@ -28,6 +26,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
...
@@ -28,6 +26,12 @@ from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens
)
repeat_and_pad_placeholder_tokens
)
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
try
:
from
xformers
import
ops
as
xops
USE_XFORMERS_OPS
=
True
except
ImportError
:
USE_XFORMERS_OPS
=
False
def
get_siglip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
def
get_siglip_patch_grid_length
(
*
,
image_size
:
int
,
patch_size
:
int
)
->
int
:
# Since interpolation is applied, the image size need not be divisible
# Since interpolation is applied, the image size need not be divisible
...
@@ -93,7 +97,7 @@ def input_processor_for_siglip(
...
@@ -93,7 +97,7 @@ def input_processor_for_siglip(
llm_inputs
:
LLMInputs
,
llm_inputs
:
LLMInputs
,
*
,
*
,
image_token_id
:
int
,
image_token_id
:
int
,
image_feature_size_override
:
Optional
[
int
]
=
None
,
image_feature_size_override
:
Optional
[
Union
[
int
,
List
[
int
]]
]
=
None
,
):
):
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
multi_modal_data
=
llm_inputs
.
get
(
"multi_modal_data"
)
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
if
multi_modal_data
is
None
or
"image"
not
in
multi_modal_data
:
...
@@ -221,9 +225,7 @@ class SiglipVisionEmbeddings(nn.Module):
...
@@ -221,9 +225,7 @@ class SiglipVisionEmbeddings(nn.Module):
return
embeddings
return
embeddings
# NOTE: Not used - kept for later when we TP the ViT
class
SiglipParallelAttention
(
nn
.
Module
):
# TODO(ChristopherCho): Implement TP version of Attention
class
SiglipTPAttention
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -233,38 +235,30 @@ class SiglipTPAttention(nn.Module):
...
@@ -233,38 +235,30 @@ class SiglipTPAttention(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
head_dim
=
self
.
embed_dim
//
self
.
num_heads
self
.
total_num_heads
=
config
.
num_attention_heads
if
self
.
head_dim
*
self
.
num_heads
!=
self
.
embed_dim
:
if
self
.
total_num_heads
%
tp_size
!=
0
:
raise
ValueError
(
f
"Number of attention heads (
{
self
.
total_num_heads
}
) "
"must be divisible by the tensor model parallel size"
f
" (
{
tp_size
}
)."
)
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
head_dim
=
self
.
embed_dim
//
self
.
total_num_heads
if
self
.
head_dim
*
self
.
total_num_heads
!=
self
.
embed_dim
:
raise
ValueError
(
f
"embed_dim must be divisible by num_heads (got "
raise
ValueError
(
f
"embed_dim must be divisible by num_heads (got "
"`embed_dim`: {self.embed_dim} and `num_heads`:"
"`embed_dim`: {self.embed_dim} and `num_heads`:"
f
"
{
self
.
num_heads
}
)."
)
f
"
{
self
.
num_heads
}
)."
)
self
.
qkv_size
=
self
.
num_heads
*
self
.
head_dim
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
dropout
=
config
.
attention_dropout
self
.
dropout
=
config
.
attention_dropout
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
self
.
embed_dim
,
hidden_size
=
self
.
embed_dim
,
head_size
=
self
.
head_dim
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
total_
num_heads
,
total_num_heads
=
self
.
num_heads
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
self
.
out_proj
=
RowParallelLinear
(
self
.
out_proj
=
RowParallelLinear
(
input_size
=
self
.
embed_dim
,
input_size
=
self
.
embed_dim
,
output_size
=
self
.
embed_dim
,
output_size
=
self
.
embed_dim
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
self
.
attn_fn
=
self
.
_basic_attention_forward
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -274,161 +268,27 @@ class SiglipTPAttention(nn.Module):
...
@@ -274,161 +268,27 @@ class SiglipTPAttention(nn.Module):
batch_size
,
q_len
,
_
=
hidden_states
.
size
()
batch_size
,
q_len
,
_
=
hidden_states
.
size
()
qkv_states
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv_states
,
_
=
self
.
qkv_proj
(
hidden_states
)
query_states
,
key_states
,
value_states
=
qkv_states
.
split
(
query_states
,
key_states
,
value_states
=
qkv_states
.
chunk
(
3
,
dim
=-
1
)
[
self
.
qkv_size
]
*
3
,
dim
=-
1
)
attn_output
=
self
.
attn_fn
(
q
=
query_states
,
k
=
key_states
,
v
=
value_states
,
batch_size
=
batch_size
,
q_len
=
q_len
,
)
attn_output
,
_
=
self
.
out_proj
(
attn_output
)
return
attn_output
def
_basic_attention_forward
(
self
,
q
,
k
,
v
,
batch_size
,
q_len
):
q
=
q
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
k
=
k
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
v
=
v
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
k_v_seq_len
=
k
.
shape
[
-
2
]
attn_weights
=
torch
.
matmul
(
q
,
k
.
transpose
(
2
,
3
))
*
self
.
scale
if
attn_weights
.
size
()
!=
(
batch_size
,
self
.
num_heads
,
q_len
,
k_v_seq_len
,
):
raise
ValueError
(
"Attention weights should be of size "
f
"
{
(
batch_size
,
self
.
num_heads
,
q_len
,
k_v_seq_len
)
}
, but is"
f
"
{
attn_weights
.
size
()
}
"
)
# upcast attention to fp32
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
,
dim
=-
1
,
dtype
=
torch
.
float32
).
to
(
q
.
dtype
)
attn_weights
=
nn
.
functional
.
dropout
(
attn_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
attn_output
=
torch
.
matmul
(
attn_weights
,
v
)
if
attn_output
.
size
()
!=
(
batch_size
,
self
.
num_heads
,
q_len
,
self
.
head_dim
,
):
raise
ValueError
(
"`attn_output` should be of size "
f
"
{
(
batch_size
,
self
.
num_heads
,
q_len
,
self
.
head_dim
)
}
, but is"
f
"
{
attn_output
.
size
()
}
"
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
batch_size
,
q_len
,
self
.
embed_dim
)
return
attn_output
# NOTE: Not used - kept for later when we TP the ViT
# TODO(ChristopherCho): flash_attn_func is not working properly.
# It constantly throws a CUDA error.
class
SiglipFlashAttention2
(
SiglipTPAttention
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
attn_fn
=
self
.
_flash_attention_forward
# Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449
# and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133
def
_flash_attention_forward
(
self
,
q
,
k
,
v
,
batch_size
,
q_len
,
*
args
,
**
kwargs
):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the
query, key, and value. (B, S, H, D)
"""
q
=
q
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
k
=
k
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
v
=
v
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
attn_output
=
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
self
.
dropout
,
causal
=
False
,
)
attn_output
=
attn_output
.
reshape
(
batch_size
,
q_len
,
self
.
embed_dim
).
contiguous
()
return
attn_output
# NOTE: Not used - kept for later when we TP the ViT
class
SiglipSdpaAttention
(
SiglipTPAttention
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
is_causal
=
False
self
.
attn_fn
=
self
.
_sdpa_attention_forward
def
_sdpa_attention_forward
(
self
,
q
,
k
,
v
,
batch_size
,
q_len
):
query_states
=
query_states
.
view
(
batch_size
,
q_len
,
q
=
q
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
num_heads_per_partition
,
self
.
head_dim
).
transpose
(
1
,
2
)
self
.
head_dim
)
k
=
k
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
key_states
=
key_states
.
view
(
batch_size
,
q_len
,
self
.
head_dim
).
transpose
(
1
,
2
)
self
.
num_heads_per_partition
,
v
=
v
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
batch_size
,
q_len
,
self
.
num_heads_per_partition
,
self
.
head_dim
)
attn_output
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
out
=
xops
.
memory_efficient_attention_forward
(
query_states
,
q
,
k
,
v
,
dropout_p
=
self
.
dropout
,
is_causal
=
False
,
scale
=
self
.
scale
)
key_states
,
value_states
,
p
=
self
.
dropout
,
scale
=
self
.
scale
)
out
=
out
.
view
(
batch_size
,
q_len
,
-
1
)
attn_output
,
_
=
self
.
out_proj
(
out
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
return
attn_output
,
None
attn_output
=
attn_output
.
view
(
batch_size
,
q_len
,
self
.
embed_dim
)
return
attn_output
# NOTE: Not used - kept for later when we TP the ViT
class
SiglipxFormersAttention
(
SiglipTPAttention
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
attn_fn
=
self
.
_xformers_attention_forward
def
_xformers_attention_forward
(
self
,
q
,
k
,
v
,
batch_size
,
q_len
):
q
=
q
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
k
=
k
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
v
=
v
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
head_dim
)
attn_output
=
memory_efficient_attention
(
q
,
k
,
v
,
p
=
0.0
,
scale
=
self
.
scale
)
attn_output
=
attn_output
.
reshape
(
batch_size
,
q_len
,
self
.
embed_dim
).
contiguous
()
return
attn_output
# NOTE: Not used - kept for later when we TP the ViT
SIGLIP_ATTENTION_CLASSES
=
{
"eager"
:
SiglipTPAttention
,
"flash_attention_2"
:
SiglipFlashAttention2
,
"sdpa"
:
SiglipSdpaAttention
,
"xformers"
:
SiglipxFormersAttention
,
}
class
SiglipMLP
(
nn
.
Module
):
class
SiglipMLP
(
nn
.
Module
):
...
@@ -473,8 +333,14 @@ class SiglipEncoderLayer(nn.Module):
...
@@ -473,8 +333,14 @@ class SiglipEncoderLayer(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
embed_dim
=
config
.
hidden_size
self
.
embed_dim
=
config
.
hidden_size
# TODO(ChristopherCho): use TP'ed Attention block
num_heads
=
config
.
num_attention_heads
self
.
self_attn
=
SiglipAttention
(
config
)
tp_size
=
get_tensor_model_parallel_world_size
()
if
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
:
self
.
self_attn
=
SiglipParallelAttention
(
config
,
quant_config
=
quant_config
)
else
:
self
.
self_attn
=
SiglipSdpaAttention
(
config
)
self
.
layer_norm1
=
nn
.
LayerNorm
(
self
.
embed_dim
,
self
.
layer_norm1
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_eps
)
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
SiglipMLP
(
self
.
mlp
=
SiglipMLP
(
...
@@ -577,14 +443,27 @@ class SiglipVisionTransformer(nn.Module):
...
@@ -577,14 +443,27 @@ class SiglipVisionTransformer(nn.Module):
self
.
config
=
config
self
.
config
=
config
embed_dim
=
config
.
hidden_size
embed_dim
=
config
.
hidden_size
if
(
num_hidden_layers_override
is
None
or
num_hidden_layers_override
==
config
.
num_hidden_layers
):
self
.
need_post_layernorm
=
True
elif
num_hidden_layers_override
>
config
.
num_hidden_layers
:
raise
ValueError
(
"num_hidden_layers_override cannot be greater than "
"num_hidden_layers"
)
else
:
self
.
need_post_layernorm
=
False
self
.
embeddings
=
SiglipVisionEmbeddings
(
config
)
self
.
embeddings
=
SiglipVisionEmbeddings
(
config
)
self
.
encoder
=
SiglipEncoder
(
self
.
encoder
=
SiglipEncoder
(
config
,
config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
num_hidden_layers_override
=
num_hidden_layers_override
,
num_hidden_layers_override
=
num_hidden_layers_override
,
)
)
self
.
post_layernorm
=
nn
.
LayerNorm
(
embed_dim
,
if
self
.
need_post_layernorm
:
eps
=
config
.
layer_norm_eps
)
self
.
post_layernorm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
else
:
self
.
post_layernorm
=
nn
.
Identity
()
self
.
use_head
=
(
True
if
not
hasattr
(
config
,
"vision_use_head"
)
else
self
.
use_head
=
(
True
if
not
hasattr
(
config
,
"vision_use_head"
)
else
config
.
vision_use_head
)
config
.
vision_use_head
)
if
self
.
use_head
:
if
self
.
use_head
:
...
@@ -604,7 +483,6 @@ class SiglipVisionTransformer(nn.Module):
...
@@ -604,7 +483,6 @@ class SiglipVisionTransformer(nn.Module):
encoder_outputs
=
self
.
encoder
(
inputs_embeds
=
hidden_states
)
encoder_outputs
=
self
.
encoder
(
inputs_embeds
=
hidden_states
)
last_hidden_state
=
self
.
post_layernorm
(
encoder_outputs
)
last_hidden_state
=
self
.
post_layernorm
(
encoder_outputs
)
# TODO: add this back when pooled_output is used in inference
# TODO: add this back when pooled_output is used in inference
# if self.use_head:
# if self.use_head:
# pooled_output = self.head(last_hidden_state)
# pooled_output = self.head(last_hidden_state)
...
@@ -623,12 +501,20 @@ class SiglipVisionModel(nn.Module):
...
@@ -623,12 +501,20 @@ class SiglipVisionModel(nn.Module):
num_hidden_layers_override
:
Optional
[
int
]
=
None
,
num_hidden_layers_override
:
Optional
[
int
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
num_heads
=
config
.
num_attention_heads
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
shard_weight
=
USE_XFORMERS_OPS
and
num_heads
%
tp_size
==
0
self
.
vision_model
=
SiglipVisionTransformer
(
self
.
vision_model
=
SiglipVisionTransformer
(
config
,
config
,
quant_config
,
quant_config
,
num_hidden_layers_override
=
num_hidden_layers_override
,
num_hidden_layers_override
=
num_hidden_layers_override
,
)
)
@
property
def
need_post_layernorm
(
self
):
return
self
.
vision_model
.
need_post_layernorm
def
get_input_embeddings
(
self
)
->
nn
.
Module
:
def
get_input_embeddings
(
self
)
->
nn
.
Module
:
return
self
.
vision_model
.
embeddings
.
patch_embedding
return
self
.
vision_model
.
embeddings
.
patch_embedding
...
@@ -647,6 +533,11 @@ class SiglipVisionModel(nn.Module):
...
@@ -647,6 +533,11 @@ class SiglipVisionModel(nn.Module):
layer_count
=
len
(
self
.
vision_model
.
encoder
.
layers
)
layer_count
=
len
(
self
.
vision_model
.
encoder
.
layers
)
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
# post_layernorm is optional in SiglipVisionModel
if
(
"vision_model.post_layernorm"
in
name
and
not
self
.
need_post_layernorm
):
continue
# omit layers when num_hidden_layers_override is set
# omit layers when num_hidden_layers_override is set
if
"vision_model.encoder.layers."
in
name
:
if
"vision_model.encoder.layers."
in
name
:
layer_idx
=
int
(
name
.
split
(
"."
)[
3
])
layer_idx
=
int
(
name
.
split
(
"."
)[
3
])
...
...
vllm/model_executor/models/stablelm.py
View file @
0640f227
...
@@ -36,12 +36,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
...
@@ -36,12 +36,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
class
StablelmMLP
(
nn
.
Module
):
class
StablelmMLP
(
nn
.
Module
):
...
...
vllm/model_executor/models/starcoder2.py
View file @
0640f227
...
@@ -35,12 +35,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
...
@@ -35,12 +35,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
class
Starcoder2Attention
(
nn
.
Module
):
class
Starcoder2Attention
(
nn
.
Module
):
...
...
vllm/model_executor/models/ultravox.py
View file @
0640f227
...
@@ -8,7 +8,6 @@ from functools import lru_cache
...
@@ -8,7 +8,6 @@ from functools import lru_cache
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
from
typing
import
(
Iterable
,
List
,
Literal
,
Mapping
,
Optional
,
Tuple
,
TypedDict
,
Union
,
cast
)
TypedDict
,
Union
,
cast
)
import
librosa
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.utils.checkpoint
import
torch.utils.checkpoint
...
@@ -27,17 +26,18 @@ from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
...
@@ -27,17 +26,18 @@ from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.model_executor.models.interfaces
import
SupportsMultiModal
from
vllm.model_executor.models.utils
import
(
filter_weights
,
from
vllm.model_executor.models.utils
import
(
filter_weights
,
flatten_bn
,
init_vllm_registered_model
,
init_vllm_registered_model
,
merge_multimodal_embeddings
)
merge_multimodal_embeddings
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.base
import
MultiModalInputs
,
NestedTensors
from
vllm.multimodal.utils
import
(
cached_get_tokenizer
,
from
vllm.multimodal.utils
import
(
cached_get_tokenizer
,
repeat_and_pad_placeholder_tokens
)
repeat_and_pad_placeholder_tokens
)
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplerOutput
,
SequenceData
from
vllm.sequence
import
VLLM_TOKEN_ID_ARRAY_TYPE
,
SequenceData
from
vllm.transformers_utils.configs.ultravox
import
UltravoxConfig
from
vllm.transformers_utils.configs.ultravox
import
UltravoxConfig
_AUDIO_PLACEHOLDER_TOKEN
=
128002
_AUDIO_PLACEHOLDER_TOKEN
=
128002
...
@@ -48,13 +48,14 @@ logger = init_logger(__name__)
...
@@ -48,13 +48,14 @@ logger = init_logger(__name__)
class
UltravoxAudioFeatureInputs
(
TypedDict
):
class
UltravoxAudioFeatureInputs
(
TypedDict
):
type
:
Literal
[
"audio_features"
]
type
:
Literal
[
"audio_features"
]
data
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]
data
:
Nested
Tensor
s
"""Shape: `(batch_size, 80, M)"""
"""Shape: `(batch_size,
num_audios,
80, M)"""
class
UltravoxAudioEmbeddingInputs
(
TypedDict
):
class
UltravoxAudioEmbeddingInputs
(
TypedDict
):
type
:
Literal
[
"audio_embeds"
]
type
:
Literal
[
"audio_embeds"
]
data
:
torch
.
Tensor
data
:
NestedTensors
"""Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)"""
UltravoxAudioInputs
=
Union
[
UltravoxAudioFeatureInputs
,
UltravoxAudioInputs
=
Union
[
UltravoxAudioFeatureInputs
,
...
@@ -85,27 +86,41 @@ def dummy_data_for_ultravox(
...
@@ -85,27 +86,41 @@ def dummy_data_for_ultravox(
audio_count
=
mm_counts
[
"audio"
]
audio_count
=
mm_counts
[
"audio"
]
audio_token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
audio_placeholder
=
array
(
_AUDIO_PLACEHOLDER_TOKEN
VLLM_TOKEN_ID_ARRAY_TYPE
,
])
*
get_ultravox_max_audio_tokens
(
ctx
)
*
audio_count
[
_AUDIO_PLACEHOLDER_TOKEN
])
*
get_ultravox_max_audio_tokens
(
ctx
)
# Add a separator between each chunk.
audio_token_ids
=
(
audio_placeholder
+
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
]))
*
audio_count
other_token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
other_token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
(
seq_len
-
len
(
audio_token_ids
))
[
0
])
*
(
seq_len
-
len
(
audio_token_ids
))
audio_and_sr
=
(
np
.
array
([
0.0
]
*
feature_extractor
.
chunk_length
),
1
)
audio_and_sr
=
(
np
.
array
([
0.0
]
*
feature_extractor
.
chunk_length
),
1
)
mm_dict
=
{
mm_dict
=
{
"audio"
:
[
audio_and_sr
]
*
audio_count
}
"audio"
:
audio_and_sr
if
audio_count
==
1
else
[
audio_and_sr
]
*
audio_count
}
return
(
SequenceData
(
audio_token_ids
+
other_token_ids
),
mm_dict
)
return
(
SequenceData
(
audio_token_ids
+
other_token_ids
),
mm_dict
)
def
input_mapper_for_ultravox
(
ctx
:
InputContext
,
data
:
object
):
def
input_mapper_for_ultravox
(
ctx
:
InputContext
,
data
:
object
):
if
isinstance
(
data
,
tuple
):
if
not
isinstance
(
data
,
list
):
(
audio
,
sr
)
=
cast
(
Tuple
[
np
.
ndarray
,
Union
[
float
,
int
]],
data
)
data
=
[
data
]
audio_features
=
[]
for
audio_input
in
data
:
if
not
isinstance
(
audio_input
,
tuple
):
raise
NotImplementedError
(
f
"Unsupported data type:
{
type
(
audio_input
)
}
"
)
(
audio
,
sr
)
=
cast
(
Tuple
[
np
.
ndarray
,
Union
[
float
,
int
]],
audio_input
)
feature_extractor
=
whisper_feature_extractor
(
ctx
)
feature_extractor
=
whisper_feature_extractor
(
ctx
)
if
sr
!=
feature_extractor
.
sampling_rate
:
if
sr
!=
feature_extractor
.
sampling_rate
:
try
:
import
librosa
except
ImportError
:
raise
ImportError
(
"Please install vllm[audio] for audio support."
)
from
None
audio
=
librosa
.
resample
(
audio
,
audio
=
librosa
.
resample
(
audio
,
orig_sr
=
sr
,
orig_sr
=
sr
,
target_sr
=
feature_extractor
.
sampling_rate
)
target_sr
=
feature_extractor
.
sampling_rate
)
...
@@ -116,15 +131,14 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
...
@@ -116,15 +131,14 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object):
# Not enough audio; pad it.
# Not enough audio; pad it.
audio
=
np
.
pad
(
audio
,
(
0
,
minimum_audio_length
-
len
(
audio
)))
audio
=
np
.
pad
(
audio
,
(
0
,
minimum_audio_length
-
len
(
audio
)))
return
MultiModalInputs
({
single_audio_features
=
feature_extractor
(
"audio_features"
:
audio
,
sampling_rate
=
sr
,
padding
=
"longest"
,
feature_extractor
(
audio
,
return_tensors
=
"pt"
)[
"input_features"
]
sampling_rate
=
sr
,
padding
=
"longest"
,
return_tensors
=
"pt"
)[
"input_features"
]
})
raise
NotImplementedError
(
f
"Unsupported data type:
{
type
(
data
)
}
"
)
# Remove the batch dimension because we're wrapping it in a list.
audio_features
.
append
(
single_audio_features
.
squeeze
(
0
))
return
MultiModalInputs
({
"audio_features"
:
audio_features
})
def
input_processor_for_ultravox
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
def
input_processor_for_ultravox
(
ctx
:
InputContext
,
llm_inputs
:
LLMInputs
):
...
@@ -133,25 +147,31 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -133,25 +147,31 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
return
llm_inputs
return
llm_inputs
feature_extractor
=
whisper_feature_extractor
(
ctx
)
feature_extractor
=
whisper_feature_extractor
(
ctx
)
audio_data
,
sample_rate
=
multi_modal_data
[
"audio"
]
audios
=
multi_modal_data
[
"audio"
]
if
not
isinstance
(
audios
,
list
):
audio_length
=
audio_data
.
shape
[
0
]
audios
=
[
audios
]
if
sample_rate
!=
feature_extractor
.
sampling_rate
:
# Account for resampling.
audio_token_counts
=
[]
adjustment
=
feature_extractor
.
sampling_rate
/
sample_rate
for
audio_data
,
sample_rate
in
audios
:
audio_length
=
math
.
ceil
(
adjustment
*
audio_length
)
audio_length
=
audio_data
.
shape
[
0
]
if
sample_rate
!=
feature_extractor
.
sampling_rate
:
feature_extractor_output_length
=
math
.
ceil
(
# Account for resampling.
(
audio_length
-
adjustment
=
feature_extractor
.
sampling_rate
/
sample_rate
(
feature_extractor
.
hop_length
-
1
))
/
feature_extractor
.
hop_length
)
audio_length
=
math
.
ceil
(
adjustment
*
audio_length
)
uv_config
=
ctx
.
get_hf_config
(
UltravoxConfig
)
feature_extractor_output_length
=
math
.
ceil
(
audio_num_tokens
=
min
(
(
audio_length
-
(
feature_extractor
.
hop_length
-
1
))
/
max
(
feature_extractor
.
hop_length
)
1
,
math
.
ceil
(
feature_extractor_output_length
/
uv_config
=
ctx
.
get_hf_config
(
UltravoxConfig
)
(
uv_config
.
stack_factor
*
2
))),
audio_num_tokens
=
min
(
get_ultravox_max_audio_tokens
(
ctx
))
max
(
1
,
math
.
ceil
(
feature_extractor_output_length
/
(
uv_config
.
stack_factor
*
2
))),
get_ultravox_max_audio_tokens
(
ctx
))
audio_token_counts
.
append
(
audio_num_tokens
)
tokenizer
=
cached_get_tokenizer
(
ctx
.
model_config
.
tokenizer
)
tokenizer
=
cached_get_tokenizer
(
ctx
.
model_config
.
tokenizer
)
new_prompt
,
new_token_ids
=
repeat_and_pad_placeholder_tokens
(
new_prompt
,
new_token_ids
=
repeat_and_pad_placeholder_tokens
(
...
@@ -159,7 +179,7 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
...
@@ -159,7 +179,7 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs):
llm_inputs
.
get
(
"prompt"
),
llm_inputs
.
get
(
"prompt"
),
llm_inputs
[
"prompt_token_ids"
],
llm_inputs
[
"prompt_token_ids"
],
placeholder_token_id
=
_AUDIO_PLACEHOLDER_TOKEN
,
placeholder_token_id
=
_AUDIO_PLACEHOLDER_TOKEN
,
repeat_count
=
audio_
num_
tokens
,
repeat_count
=
audio_token
_count
s
,
)
)
# NOTE: Create a defensive copy of the original inputs
# NOTE: Create a defensive copy of the original inputs
...
@@ -337,7 +357,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
...
@@ -337,7 +357,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
data
=
audio_features
)
data
=
audio_features
)
if
audio_embeds
is
not
None
:
if
audio_embeds
is
not
None
:
if
not
isinstance
(
audio_embeds
,
torch
.
Tensor
):
if
not
isinstance
(
audio_embeds
,
(
torch
.
Tensor
,
list
)
):
raise
ValueError
(
"Incorrect type of audio embeds. "
raise
ValueError
(
"Incorrect type of audio embeds. "
f
"Got type:
{
type
(
audio_embeds
)
}
"
)
f
"Got type:
{
type
(
audio_embeds
)
}
"
)
...
@@ -347,22 +367,38 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
...
@@ -347,22 +367,38 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
raise
AssertionError
(
"This line should be unreachable."
)
raise
AssertionError
(
"This line should be unreachable."
)
def
_process_audio_input
(
def
_process_audio_input
(
self
,
audio_input
:
UltravoxAudioInputs
self
,
audio_input
:
UltravoxAudioInputs
)
->
NestedTensors
:
)
->
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
if
audio_input
[
"type"
]
==
"audio_embeds"
:
if
audio_input
[
"type"
]
==
"audio_embeds"
:
return
audio_input
[
"data"
]
return
audio_input
[
"data"
]
audio_features
=
audio_input
[
"data"
]
audio_features
=
audio_input
[
"data"
]
if
isinstance
(
audio_features
,
list
):
if
isinstance
(
audio_features
,
torch
.
Tensor
):
# TODO: Batch these through the encoder/projector instead of
# Combine the B and N dimensions for the encoder/projector
# serializing them.
flattened
=
flatten_bn
(
audio_features
)
return
[
flattened_embeddings
=
self
.
_audio_features_to_embeddings
(
self
.
_audio_features_to_embeddings
(
flattened
)
features
.
unsqueeze
(
0
)).
squeeze
(
0
)
for
features
in
audio_features
# Restore the original dimensions
]
embeddings
=
flattened_embeddings
.
unflatten
(
else
:
0
,
audio_features
.
shape
[:
2
])
return
self
.
_audio_features_to_embeddings
(
audio_features
)
return
embeddings
result
=
[]
# TODO: Batch heterogeneous tensors through the encoder/projector
for
audio_features_item
in
audio_features
:
if
isinstance
(
audio_features_item
,
torch
.
Tensor
):
result
.
append
(
self
.
_audio_features_to_embeddings
(
audio_features_item
))
else
:
embeddings
=
[
# Add a batch dimension to embed it, then remove it.
self
.
_audio_features_to_embeddings
(
tensor
.
unsqueeze
(
0
)
).
squeeze
(
0
)
for
tensor
in
audio_features_item
]
result
.
append
(
embeddings
)
return
result
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
...
@@ -379,7 +415,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
...
@@ -379,7 +415,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
with the `input_ids`.
with the `input_ids`.
Args:
Args:
input
_features: A batch of audio inputs
,
[
1
, 80, M].
audio
_features: A batch of audio inputs [
B, N
, 80, M].
"""
"""
audio_input
=
self
.
_parse_and_validate_audio_input
(
**
kwargs
)
audio_input
=
self
.
_parse_and_validate_audio_input
(
**
kwargs
)
if
audio_input
is
not
None
:
if
audio_input
is
not
None
:
...
...
vllm/model_executor/models/utils.py
View file @
0640f227
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Protocol
,
Tuple
from
typing
import
(
Dict
,
Iterable
,
List
,
Literal
,
Optional
,
Protocol
,
Tuple
,
Union
,
overload
)
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -10,7 +11,7 @@ from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
...
@@ -10,7 +11,7 @@ from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.model_loader.loader
import
build_model
from
vllm.model_executor.model_loader.loader
import
build_model
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.multimodal
import
Batch
edTensors
from
vllm.multimodal
.base
import
Nest
edTensors
from
vllm.utils
import
is_pin_memory_available
from
vllm.utils
import
is_pin_memory_available
...
@@ -54,9 +55,73 @@ def init_vllm_registered_model(
...
@@ -54,9 +55,73 @@ def init_vllm_registered_model(
)
)
@
overload
def
flatten_bn
(
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@
overload
def
flatten_bn
(
x
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
...
@
overload
def
flatten_bn
(
x
:
Union
[
List
[
torch
.
Tensor
],
torch
.
Tensor
],
*
,
concat
:
Literal
[
True
],
)
->
torch
.
Tensor
:
...
def
flatten_bn
(
x
:
Union
[
List
[
torch
.
Tensor
],
torch
.
Tensor
],
*
,
concat
:
bool
=
False
,
)
->
Union
[
List
[
torch
.
Tensor
],
torch
.
Tensor
]:
"""
Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs.
The input tensor should have shape ``(B, N, ...)```.
"""
if
isinstance
(
x
,
torch
.
Tensor
):
return
x
.
flatten
(
0
,
1
)
if
concat
:
return
torch
.
cat
(
x
)
return
[
x_n
for
x_b
in
x
for
x_n
in
x_b
]
def
_flatten_embeddings
(
embeddings
:
NestedTensors
)
->
torch
.
Tensor
:
"""
Recursively flattens and concatenates NestedTensors on all but the last
dimension.
"""
if
isinstance
(
embeddings
,
torch
.
Tensor
):
# Flatten all but the last dimension.
return
embeddings
.
flatten
(
0
,
-
2
)
return
torch
.
cat
(
tuple
(
_flatten_embeddings
(
t
)
for
t
in
embeddings
))
def
_embedding_count_expression
(
embeddings
:
NestedTensors
)
->
str
:
"""
Constructs a debugging representation of the number of embeddings in the
NestedTensors.
"""
if
isinstance
(
embeddings
,
torch
.
Tensor
):
return
" x "
.
join
([
str
(
dim
)
for
dim
in
embeddings
.
shape
[:
-
1
]])
return
" + "
.
join
(
_embedding_count_expression
(
inner
)
for
inner
in
embeddings
)
def
merge_multimodal_embeddings
(
input_ids
:
torch
.
Tensor
,
def
merge_multimodal_embeddings
(
input_ids
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
inputs_embeds
:
torch
.
Tensor
,
multimodal_embeddings
:
Batch
edTensors
,
multimodal_embeddings
:
Nest
edTensors
,
placeholder_token_id
:
int
)
->
torch
.
Tensor
:
placeholder_token_id
:
int
)
->
torch
.
Tensor
:
"""
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
...
@@ -67,30 +132,17 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
...
@@ -67,30 +132,17 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor,
This updates ``inputs_embeds`` in place.
This updates ``inputs_embeds`` in place.
"""
"""
mask
=
(
input_ids
==
placeholder_token_id
)
mask
=
(
input_ids
==
placeholder_token_id
)
num_expected_tokens
=
mask
.
sum
()
num_expected_tokens
=
mask
.
sum
().
item
()
assert
isinstance
(
num_expected_tokens
,
int
)
if
isinstance
(
multimodal_embeddings
,
torch
.
Tensor
):
batch_size
,
batch_tokens
,
*
_
,
embed_dim
=
multimodal_embeddings
.
shape
flattened
=
_flatten_embeddings
(
multimodal_embeddings
)
total_tokens
=
batch_size
*
batch_tokens
if
flattened
.
shape
[
0
]
!=
num_expected_tokens
:
if
num_expected_tokens
!=
total_tokens
:
expr
=
_embedding_count_expression
(
multimodal_embeddings
)
expr
=
f
"
{
batch_size
}
x
{
batch_tokens
}
"
raise
ValueError
(
raise
ValueError
(
f
"Attempted to assign
{
expr
}
=
{
flattened
.
shape
[
0
]
}
"
f
"Attempted to assign
{
expr
}
=
{
total_tokens
}
"
f
"multimodal tokens to
{
num_expected_tokens
}
placeholders"
)
f
"multimodal tokens to
{
num_expected_tokens
}
placeholders"
)
inputs_embeds
[
mask
]
=
multimodal_embeddings
.
view
(
total_tokens
,
embed_dim
)
else
:
size_per_batch
=
[
t
.
shape
[
0
]
for
t
in
multimodal_embeddings
]
total_tokens
=
sum
(
size_per_batch
)
if
num_expected_tokens
!=
total_tokens
:
expr
=
' + '
.
join
(
map
(
str
,
size_per_batch
))
raise
ValueError
(
f
"Attempted to assign
{
expr
}
=
{
total_tokens
}
"
f
"multimodal tokens to
{
num_expected_tokens
}
placeholders"
)
inputs_embeds
[
mask
]
=
torch
.
cat
(
multimodal_embeddings
)
inputs_embeds
[
mask
]
=
flattened
return
inputs_embeds
return
inputs_embeds
...
...
vllm/model_executor/models/xverse.py
View file @
0640f227
...
@@ -38,12 +38,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
...
@@ -38,12 +38,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
,
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
from
.interfaces
import
SupportsLoRA
...
...
vllm/model_executor/parameter.py
View file @
0640f227
from
fractions
import
Fraction
from
typing
import
Callable
,
Optional
,
Union
from
typing
import
Callable
,
Optional
,
Union
import
torch
import
torch
...
@@ -257,7 +258,7 @@ class PackedColumnParameter(_ColumnvLLMParameter):
...
@@ -257,7 +258,7 @@ class PackedColumnParameter(_ColumnvLLMParameter):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
packed_factor
:
int
,
packed_factor
:
Union
[
int
,
Fraction
]
,
packed_dim
:
int
,
packed_dim
:
int
,
marlin_tile_size
:
Optional
[
int
]
=
None
,
marlin_tile_size
:
Optional
[
int
]
=
None
,
**
kwargs
):
**
kwargs
):
...
@@ -298,7 +299,7 @@ class PackedvLLMParameter(ModelWeightParameter):
...
@@ -298,7 +299,7 @@ class PackedvLLMParameter(ModelWeightParameter):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
packed_factor
:
int
,
packed_factor
:
Union
[
int
,
Fraction
]
,
packed_dim
:
int
,
packed_dim
:
int
,
marlin_tile_size
:
Optional
[
int
]
=
None
,
marlin_tile_size
:
Optional
[
int
]
=
None
,
**
kwargs
):
**
kwargs
):
...
...
vllm/multimodal/__init__.py
View file @
0640f227
from
.base
import
(
BatchedTensorInputs
,
BatchedTensors
,
MultiModalDataBuiltins
,
from
.base
import
(
BatchedTensorInputs
,
MultiModalDataBuiltins
,
MultiModalDataDict
,
MultiModalInputs
,
MultiModalPlugin
,
MultiModalDataDict
,
MultiModalInputs
,
MultiModalPlugin
,
NestedTensors
)
NestedTensors
)
from
.registry
import
MultiModalRegistry
from
.registry
import
MultiModalRegistry
...
@@ -14,7 +14,6 @@ See also:
...
@@ -14,7 +14,6 @@ See also:
__all__
=
[
__all__
=
[
"BatchedTensorInputs"
,
"BatchedTensorInputs"
,
"BatchedTensors"
,
"MultiModalDataBuiltins"
,
"MultiModalDataBuiltins"
,
"MultiModalDataDict"
,
"MultiModalDataDict"
,
"MultiModalInputs"
,
"MultiModalInputs"
,
...
...
vllm/multimodal/base.py
View file @
0640f227
import
sys
import
sys
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections
import
UserDict
,
defaultdict
from
collections
import
UserDict
,
defaultdict
from
typing
import
Callable
,
Dict
,
List
,
Mapping
,
Optional
from
typing
import
(
Callable
,
Dict
,
List
,
Mapping
,
Optional
,
Tuple
,
Type
,
from
typing
import
Sequence
as
GenericSequence
TypedDict
,
TypeVar
,
Union
,
cast
,
final
)
from
typing
import
Tuple
,
Type
,
TypedDict
,
TypeVar
,
Union
,
cast
,
final
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -15,23 +14,16 @@ from typing_extensions import TypeAlias
...
@@ -15,23 +14,16 @@ from typing_extensions import TypeAlias
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.inputs
import
InputContext
from
vllm.inputs
import
InputContext
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
JSONTree
,
json_map_leaves
from
vllm.utils
import
JSONTree
,
is_list_of
,
json_map_leaves
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
NestedTensors
=
Union
[
GenericSequence
[
torch
.
Tensor
],
torch
.
Tensor
]
NestedTensors
=
Union
[
List
[
"NestedTensors"
],
List
[
torch
.
Tensor
],
torch
.
Tensor
]
"""
"""
Use a list instead of a tensor if the dimensions of each element do not match.
Uses a list instead of a tensor if the dimensions of each element do not match.
Currently only supports up to singly nested list of tensors.
"""
"""
BatchedTensors
:
TypeAlias
=
JSONTree
[
torch
.
Tensor
]
BatchedTensorInputs
:
TypeAlias
=
Dict
[
str
,
NestedTensors
]
"""
A nested JSON structure of tensors which have been batched via
:meth:`MultiModalInputs.batch`.
"""
BatchedTensorInputs
:
TypeAlias
=
Dict
[
str
,
JSONTree
[
torch
.
Tensor
]]
"""
"""
A dictionary containing nested tensors which have been batched via
A dictionary containing nested tensors which have been batched via
:meth:`MultiModalInputs.batch`.
:meth:`MultiModalInputs.batch`.
...
@@ -54,26 +46,24 @@ class MultiModalInputs(_MultiModalInputsBase):
...
@@ -54,26 +46,24 @@ class MultiModalInputs(_MultiModalInputsBase):
"""
"""
@
staticmethod
@
staticmethod
def
_try_
concat
(
tensors
:
List
[
NestedTensors
]
)
->
Batch
edTensors
:
def
_try_
stack
(
nested_
tensors
:
NestedTensors
)
->
Nest
edTensors
:
"""
"""
If each input tensor in the batch has the same shape, return a single
Recursively stacks lists of tensors when they all have the same shape.
batched tensor; otherwise, return a list of :class:`NestedTensors` with
one element per item in the batch.
"""
"""
# may be list rather than tensors
if
isinstance
(
nested_tensors
,
torch
.
Tensor
):
if
isinstance
(
tensors
[
0
],
list
):
return
nested_tensors
return
[[
t
for
t
in
tensor
[
0
]]
for
tensor
in
cast
(
List
[
List
[
torch
.
Tensor
]],
tensors
)]
tensors_
=
cast
(
List
[
torch
.
Tensor
],
tensors
)
stacked
=
[
MultiModalInputs
.
_try_stack
(
t
)
for
t
in
nested_tensors
]
if
not
is_list_of
(
stacked
,
torch
.
Tensor
,
check
=
"all"
):
# Only tensors (not lists) can be stacked.
return
stacked
unbatched_shape
=
tensors_
[
0
].
shape
[
1
:]
tensors_
=
cast
(
List
[
torch
.
Tensor
],
stacked
)
if
any
(
t
.
shape
!=
tensors_
[
0
].
shape
for
t
in
tensors_
):
# The tensors have incompatible shapes and can't be stacked.
return
tensors_
for
tensor
in
tensors_
:
return
torch
.
stack
(
tensors_
)
if
tensor
.
shape
[
1
:]
!=
unbatched_shape
:
return
[
tensor
.
squeeze
(
0
)
for
tensor
in
tensors_
]
return
torch
.
cat
(
tensors_
,
dim
=
0
)
@
staticmethod
@
staticmethod
def
batch
(
inputs_list
:
List
[
"MultiModalInputs"
])
->
BatchedTensorInputs
:
def
batch
(
inputs_list
:
List
[
"MultiModalInputs"
])
->
BatchedTensorInputs
:
...
@@ -102,7 +92,7 @@ class MultiModalInputs(_MultiModalInputsBase):
...
@@ -102,7 +92,7 @@ class MultiModalInputs(_MultiModalInputsBase):
item_lists
[
k
].
append
(
v
)
item_lists
[
k
].
append
(
v
)
return
{
return
{
k
:
MultiModalInputs
.
_try_
concat
(
item_list
)
k
:
MultiModalInputs
.
_try_
stack
(
item_list
)
for
k
,
item_list
in
item_lists
.
items
()
for
k
,
item_list
in
item_lists
.
items
()
}
}
...
@@ -112,8 +102,14 @@ class MultiModalInputs(_MultiModalInputsBase):
...
@@ -112,8 +102,14 @@ class MultiModalInputs(_MultiModalInputsBase):
*
,
*
,
device
:
torch
.
types
.
Device
,
device
:
torch
.
types
.
Device
,
)
->
BatchedTensorInputs
:
)
->
BatchedTensorInputs
:
return
json_map_leaves
(
lambda
x
:
x
.
to
(
device
,
non_blocking
=
True
),
json_inputs
=
cast
(
JSONTree
[
torch
.
Tensor
],
batched_inputs
)
batched_inputs
)
json_mapped
=
json_map_leaves
(
lambda
x
:
x
.
to
(
device
,
non_blocking
=
True
),
json_inputs
,
)
return
cast
(
BatchedTensorInputs
,
json_mapped
)
_T
=
TypeVar
(
"_T"
)
_T
=
TypeVar
(
"_T"
)
...
...
vllm/multimodal/utils.py
View file @
0640f227
import
base64
import
base64
from
functools
import
lru_cache
from
functools
import
lru_cache
from
io
import
BytesIO
from
io
import
BytesIO
from
typing
import
List
,
Optional
,
Tuple
,
TypeVar
,
Union
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
TypeVar
,
Union
import
librosa
import
numpy
as
np
import
numpy
as
np
import
soundfile
from
PIL
import
Image
from
PIL
import
Image
from
vllm.connections
import
global_http_connection
from
vllm.connections
import
global_http_connection
...
@@ -73,10 +71,22 @@ async def async_fetch_image(image_url: str,
...
@@ -73,10 +71,22 @@ async def async_fetch_image(image_url: str,
return
image
.
convert
(
image_mode
)
return
image
.
convert
(
image_mode
)
def
try_import_audio_packages
()
->
Tuple
[
Any
,
Any
]:
try
:
import
librosa
import
soundfile
except
ImportError
:
raise
ImportError
(
"Please install vllm[audio] for audio support."
)
from
None
return
librosa
,
soundfile
def
fetch_audio
(
audio_url
:
str
)
->
Tuple
[
np
.
ndarray
,
Union
[
int
,
float
]]:
def
fetch_audio
(
audio_url
:
str
)
->
Tuple
[
np
.
ndarray
,
Union
[
int
,
float
]]:
"""
"""
Load audio from a URL.
Load audio from a URL.
"""
"""
librosa
,
_
=
try_import_audio_packages
()
if
audio_url
.
startswith
(
"http"
):
if
audio_url
.
startswith
(
"http"
):
audio_bytes
=
global_http_connection
.
get_bytes
(
audio_bytes
=
global_http_connection
.
get_bytes
(
audio_url
,
timeout
=
VLLM_AUDIO_FETCH_TIMEOUT
)
audio_url
,
timeout
=
VLLM_AUDIO_FETCH_TIMEOUT
)
...
@@ -95,6 +105,8 @@ async def async_fetch_audio(
...
@@ -95,6 +105,8 @@ async def async_fetch_audio(
"""
"""
Asynchronously fetch audio from a URL.
Asynchronously fetch audio from a URL.
"""
"""
librosa
,
_
=
try_import_audio_packages
()
if
audio_url
.
startswith
(
"http"
):
if
audio_url
.
startswith
(
"http"
):
audio_bytes
=
await
global_http_connection
.
async_get_bytes
(
audio_bytes
=
await
global_http_connection
.
async_get_bytes
(
audio_url
,
timeout
=
VLLM_AUDIO_FETCH_TIMEOUT
)
audio_url
,
timeout
=
VLLM_AUDIO_FETCH_TIMEOUT
)
...
@@ -108,6 +120,16 @@ async def async_fetch_audio(
...
@@ -108,6 +120,16 @@ async def async_fetch_audio(
return
librosa
.
load
(
BytesIO
(
audio_bytes
),
sr
=
None
)
return
librosa
.
load
(
BytesIO
(
audio_bytes
),
sr
=
None
)
def
get_and_parse_audio
(
audio_url
:
str
)
->
MultiModalDataDict
:
audio
,
sr
=
fetch_audio
(
audio_url
)
return
{
"audio"
:
(
audio
,
sr
)}
def
get_and_parse_image
(
image_url
:
str
)
->
MultiModalDataDict
:
image
=
fetch_image
(
image_url
)
return
{
"image"
:
image
}
async
def
async_get_and_parse_audio
(
audio_url
:
str
)
->
MultiModalDataDict
:
async
def
async_get_and_parse_audio
(
audio_url
:
str
)
->
MultiModalDataDict
:
audio
,
sr
=
await
async_fetch_audio
(
audio_url
)
audio
,
sr
=
await
async_fetch_audio
(
audio_url
)
return
{
"audio"
:
(
audio
,
sr
)}
return
{
"audio"
:
(
audio
,
sr
)}
...
@@ -123,6 +145,8 @@ def encode_audio_base64(
...
@@ -123,6 +145,8 @@ def encode_audio_base64(
sampling_rate
:
int
,
sampling_rate
:
int
,
)
->
str
:
)
->
str
:
"""Encode audio as base64."""
"""Encode audio as base64."""
_
,
soundfile
=
try_import_audio_packages
()
buffered
=
BytesIO
()
buffered
=
BytesIO
()
soundfile
.
write
(
buffered
,
audio
,
sampling_rate
,
format
=
"WAV"
)
soundfile
.
write
(
buffered
,
audio
,
sampling_rate
,
format
=
"WAV"
)
...
@@ -189,10 +213,13 @@ def repeat_and_pad_placeholder_tokens(
...
@@ -189,10 +213,13 @@ def repeat_and_pad_placeholder_tokens(
prompt_token_ids
:
List
[
int
],
prompt_token_ids
:
List
[
int
],
*
,
*
,
placeholder_token_id
:
int
,
placeholder_token_id
:
int
,
repeat_count
:
int
=
1
,
repeat_count
:
Union
[
int
,
List
[
int
]]
,
pad_token_left
:
Optional
[
int
]
=
None
,
pad_token_left
:
Optional
[
int
]
=
None
,
pad_token_right
:
Optional
[
int
]
=
None
,
pad_token_right
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
Optional
[
str
],
List
[
int
]]:
)
->
Tuple
[
Optional
[
str
],
List
[
int
]]:
if
isinstance
(
repeat_count
,
int
):
repeat_count
=
[
repeat_count
]
if
prompt
is
None
:
if
prompt
is
None
:
new_prompt
=
None
new_prompt
=
None
else
:
else
:
...
@@ -201,13 +228,6 @@ def repeat_and_pad_placeholder_tokens(
...
@@ -201,13 +228,6 @@ def repeat_and_pad_placeholder_tokens(
tokenizer
.
decode
(
pad_token_left
))
tokenizer
.
decode
(
pad_token_left
))
pad_token_str_right
=
(
None
if
pad_token_right
is
None
else
pad_token_str_right
=
(
None
if
pad_token_right
is
None
else
tokenizer
.
decode
(
pad_token_right
))
tokenizer
.
decode
(
pad_token_right
))
replacement_str
=
""
.
join
(
repeat_and_pad_token
(
placeholder_token_str
,
repeat_count
=
repeat_count
,
pad_token_left
=
pad_token_str_left
,
pad_token_right
=
pad_token_str_right
,
))
placeholder_token_count
=
prompt
.
count
(
placeholder_token_str
)
placeholder_token_count
=
prompt
.
count
(
placeholder_token_str
)
# This is an arbitrary number to distinguish between the two cases
# This is an arbitrary number to distinguish between the two cases
...
@@ -216,28 +236,45 @@ def repeat_and_pad_placeholder_tokens(
...
@@ -216,28 +236,45 @@ def repeat_and_pad_placeholder_tokens(
"Please follow the prompt format that is "
"Please follow the prompt format that is "
"documented on HuggingFace which does not involve "
"documented on HuggingFace which does not involve "
"repeating %s tokens."
,
placeholder_token_str
)
"repeating %s tokens."
,
placeholder_token_str
)
elif
placeholder_token_count
>
1
:
if
placeholder_token_count
<
len
(
repeat_count
):
logger
.
warning
(
"Multiple multi-modal input is not supported yet, "
logger
.
warning
(
"so any extra placeholder tokens will be treated "
"The number of multi-modal placeholder tokens in the prompt "
"as plain text."
)
"is less than the number of multi-modal inputs. Extra "
"placeholder tokens will be treated as plain text"
)
# The image tokens are removed to be consistent with HuggingFace
repeat_count
=
repeat_count
[:
placeholder_token_count
]
new_prompt
=
prompt
.
replace
(
placeholder_token_str
,
replacement_str
,
1
)
prompt_parts
=
prompt
.
split
(
placeholder_token_str
,
maxsplit
=
len
(
repeat_count
))
new_prompt
=
""
for
i
,
repeat_count_item
in
enumerate
(
repeat_count
):
replacement_str
=
""
.
join
(
repeat_and_pad_token
(
placeholder_token_str
,
repeat_count
=
repeat_count_item
,
pad_token_left
=
pad_token_str_left
,
pad_token_right
=
pad_token_str_right
,
))
# The image tokens are removed to be consistent with HuggingFace
new_prompt
+=
prompt_parts
[
i
]
+
replacement_str
new_prompt
+=
prompt_parts
[
-
1
]
new_token_ids
:
List
[
int
]
=
[]
new_token_ids
:
List
[
int
]
=
[]
placeholder_token_idx
=
0
for
i
,
token
in
enumerate
(
prompt_token_ids
):
for
i
,
token
in
enumerate
(
prompt_token_ids
):
if
token
==
placeholder_token_id
:
if
token
==
placeholder_token_id
:
replacement_ids
=
repeat_and_pad_token
(
replacement_ids
=
repeat_and_pad_token
(
placeholder_token_id
,
placeholder_token_id
,
repeat_count
=
repeat_count
,
repeat_count
=
repeat_count
[
placeholder_token_idx
]
,
pad_token_left
=
pad_token_left
,
pad_token_left
=
pad_token_left
,
pad_token_right
=
pad_token_right
,
pad_token_right
=
pad_token_right
,
)
)
new_token_ids
.
extend
(
replacement_ids
)
new_token_ids
.
extend
(
replacement_ids
)
placeholder_token_idx
+=
1
# No need to further scan the list since we only replace once
# No need to further scan the list since we replaced all tokens
new_token_ids
.
extend
(
prompt_token_ids
[
i
+
1
:])
if
placeholder_token_idx
>=
len
(
repeat_count
):
break
new_token_ids
.
extend
(
prompt_token_ids
[
i
+
1
:])
break
else
:
else
:
new_token_ids
.
append
(
token
)
new_token_ids
.
append
(
token
)
...
...
vllm/platforms/cuda.py
View file @
0640f227
...
@@ -21,7 +21,9 @@ _R = TypeVar("_R")
...
@@ -21,7 +21,9 @@ _R = TypeVar("_R")
if
pynvml
.
__file__
.
endswith
(
"__init__.py"
):
if
pynvml
.
__file__
.
endswith
(
"__init__.py"
):
logger
.
warning
(
logger
.
warning
(
"You are using a deprecated `pynvml` package. Please install"
"You are using a deprecated `pynvml` package. Please install"
" `nvidia-ml-py` instead. See https://pypi.org/project/pynvml "
" `nvidia-ml-py` instead, and make sure to uninstall `pynvml`."
" When both of them are installed, `pynvml` will take precedence"
" and cause errors. See https://pypi.org/project/pynvml "
"for more information."
)
"for more information."
)
# NVML utils
# NVML utils
...
@@ -82,6 +84,9 @@ except ModuleNotFoundError:
...
@@ -82,6 +84,9 @@ except ModuleNotFoundError:
def
device_id_to_physical_device_id
(
device_id
:
int
)
->
int
:
def
device_id_to_physical_device_id
(
device_id
:
int
)
->
int
:
if
"CUDA_VISIBLE_DEVICES"
in
os
.
environ
:
if
"CUDA_VISIBLE_DEVICES"
in
os
.
environ
:
device_ids
=
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
].
split
(
","
)
device_ids
=
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
].
split
(
","
)
if
device_ids
==
[
""
]:
raise
RuntimeError
(
"CUDA_VISIBLE_DEVICES is set to empty string,"
" which means GPU support is disabled."
)
physical_device_id
=
device_ids
[
device_id
]
physical_device_id
=
device_ids
[
device_id
]
return
int
(
physical_device_id
)
return
int
(
physical_device_id
)
else
:
else
:
...
...
vllm/platforms/rocm.py
View file @
0640f227
import
os
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Tuple
from
typing
import
Tuple
import
torch
import
torch
from
vllm.logger
import
init_logger
from
.interface
import
Platform
,
PlatformEnum
from
.interface
import
Platform
,
PlatformEnum
logger
=
init_logger
(
__name__
)
if
os
.
environ
.
get
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
None
)
in
[
"fork"
,
None
]:
logger
.
warning
(
"`fork` method is not supported by ROCm. "
"VLLM_WORKER_MULTIPROC_METHOD is overridden to"
" `spawn` instead."
)
os
.
environ
[
"VLLM_WORKER_MULTIPROC_METHOD"
]
=
"spawn"
class
RocmPlatform
(
Platform
):
class
RocmPlatform
(
Platform
):
_enum
=
PlatformEnum
.
ROCM
_enum
=
PlatformEnum
.
ROCM
...
...
vllm/scripts.py
View file @
0640f227
...
@@ -125,6 +125,15 @@ def main():
...
@@ -125,6 +125,15 @@ def main():
serve_parser
.
add_argument
(
"model_tag"
,
serve_parser
.
add_argument
(
"model_tag"
,
type
=
str
,
type
=
str
,
help
=
"The model tag to serve"
)
help
=
"The model tag to serve"
)
serve_parser
.
add_argument
(
"--config"
,
type
=
str
,
default
=
''
,
required
=
False
,
help
=
"Read CLI options from a config file."
"Must be a YAML with the following options:"
"https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#command-line-arguments-for-the-server"
)
serve_parser
=
make_arg_parser
(
serve_parser
)
serve_parser
=
make_arg_parser
(
serve_parser
)
serve_parser
.
set_defaults
(
dispatch_function
=
serve
)
serve_parser
.
set_defaults
(
dispatch_function
=
serve
)
...
...
vllm/sequence.py
View file @
0640f227
...
@@ -5,8 +5,8 @@ from abc import ABC, abstractmethod
...
@@ -5,8 +5,8 @@ from abc import ABC, abstractmethod
from
array
import
array
from
array
import
array
from
collections
import
defaultdict
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Mapping
,
Optional
,
Set
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Mapping
,
Tuple
,
Union
,
cast
)
Optional
,
Set
,
Tuple
,
Union
,
cast
)
import
msgspec
import
msgspec
import
torch
import
torch
...
@@ -474,11 +474,8 @@ class Sequence:
...
@@ -474,11 +474,8 @@ class Sequence:
"""Reset the sequence states for recomputation."""
"""Reset the sequence states for recomputation."""
self
.
data
.
reset_state_for_recompute
()
self
.
data
.
reset_state_for_recompute
()
def
append_token_id
(
def
append_token_id
(
self
,
token_id
:
int
,
logprobs
:
Dict
[
int
,
self
,
Logprob
])
->
None
:
token_id
:
int
,
logprobs
:
Dict
[
int
,
Logprob
],
)
->
None
:
assert
token_id
in
logprobs
assert
token_id
in
logprobs
self
.
output_logprobs
.
append
(
logprobs
)
self
.
output_logprobs
.
append
(
logprobs
)
self
.
data
.
append_token_id
(
token_id
,
logprobs
[
token_id
].
logprob
)
self
.
data
.
append_token_id
(
token_id
,
logprobs
[
token_id
].
logprob
)
...
@@ -814,6 +811,9 @@ class SequenceGroup:
...
@@ -814,6 +811,9 @@ class SequenceGroup:
self
.
is_single_seq
=
len
(
self
.
seqs
)
==
1
self
.
is_single_seq
=
len
(
self
.
seqs
)
==
1
def
is_finished
(
self
)
->
bool
:
def
is_finished
(
self
)
->
bool
:
if
self
.
is_single_seq
:
return
self
.
seqs
[
0
].
is_finished
()
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
seqs
)
return
all
(
seq
.
is_finished
()
for
seq
in
self
.
seqs
)
def
is_prefill
(
self
)
->
bool
:
def
is_prefill
(
self
)
->
bool
:
...
@@ -886,7 +886,7 @@ class SequenceGroupMetadata(
...
@@ -886,7 +886,7 @@ class SequenceGroupMetadata(
request_id
:
str
request_id
:
str
is_prompt
:
bool
is_prompt
:
bool
seq_data
:
Dict
[
int
,
SequenceData
]
seq_data
:
Dict
[
int
,
SequenceData
]
sampling_params
:
SamplingParams
sampling_params
:
Optional
[
SamplingParams
]
block_tables
:
Dict
[
int
,
List
[
int
]]
block_tables
:
Dict
[
int
,
List
[
int
]]
do_sample
:
bool
=
True
do_sample
:
bool
=
True
pooling_params
:
Optional
[
PoolingParams
]
=
None
pooling_params
:
Optional
[
PoolingParams
]
=
None
...
@@ -1060,76 +1060,6 @@ class IntermediateTensors(
...
@@ -1060,76 +1060,6 @@ class IntermediateTensors(
return
f
"IntermediateTensors(tensors=
{
self
.
tensors
}
)"
return
f
"IntermediateTensors(tensors=
{
self
.
tensors
}
)"
class
SamplerOutput
(
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
array_like
=
True
):
# type: ignore[call-arg]
"""For each sequence group, we generate a list of SequenceOutput object,
each of which contains one possible candidate for the next token.
This data structure implements methods, so it can be used like a list, but
also has optional fields for device tensors.
"""
outputs
:
List
[
CompletionSequenceGroupOutput
]
# On-device tensor containing probabilities of each token.
sampled_token_probs
:
Optional
[
torch
.
Tensor
]
=
None
# On-device tensor containing the logprobs of each token.
logprobs
:
Optional
[
"torch.Tensor"
]
=
None
# On-device tensor containing the sampled token ids.
sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
# CPU tensor containing the sampled token ids. Used during multi-step to
# return the sampled token ids from last rank to AsyncLLMEngine to be
# 'broadcasted' to all other PP ranks for next step.
sampled_token_ids_cpu
:
Optional
[
torch
.
Tensor
]
=
None
# Spec decode metrics populated by workers.
spec_decode_worker_metrics
:
Optional
[
SpecDecodeWorkerMetrics
]
=
None
# Optional last hidden states from the model.
hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
# Optional prefill hidden states from the model
# (used for models like EAGLE).
prefill_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
# Time taken in the forward pass for this across all workers
model_forward_time
:
Optional
[
float
]
=
None
# Time taken in the model execute function. This will include model forward,
# block/sync across workers, cpu-gpu sync time and sampling time.
model_execute_time
:
Optional
[
float
]
=
None
def
__getitem__
(
self
,
idx
:
int
):
return
self
.
outputs
[
idx
]
def
__setitem__
(
self
,
idx
:
int
,
value
):
self
.
outputs
[
idx
]
=
value
def
__len__
(
self
):
return
len
(
self
.
outputs
)
def
__eq__
(
self
,
other
:
object
):
return
isinstance
(
other
,
self
.
__class__
)
and
self
.
outputs
==
other
.
outputs
def
__repr__
(
self
)
->
str
:
"""Show the shape of a tensor instead of its values to reduce noise.
"""
sampled_token_probs_repr
=
(
"None"
if
self
.
sampled_token_probs
is
None
else
self
.
sampled_token_probs
.
shape
)
sampled_token_ids_repr
=
(
"None"
if
self
.
sampled_token_ids
is
None
else
self
.
sampled_token_ids
.
shape
)
return
(
f
"SamplerOutput(outputs=
{
self
.
outputs
}
, "
f
"sampled_token_probs=
{
sampled_token_probs_repr
}
, "
f
"sampled_token_ids=
{
sampled_token_ids_repr
}
, "
f
"spec_decode_worker_metrics=
{
self
.
spec_decode_worker_metrics
}
)"
)
class
PoolerOutput
(
class
PoolerOutput
(
msgspec
.
Struct
,
msgspec
.
Struct
,
omit_defaults
=
True
,
# type: ignore[call-arg]
omit_defaults
=
True
,
# type: ignore[call-arg]
...
@@ -1293,6 +1223,8 @@ class ExecuteModelRequest(
...
@@ -1293,6 +1223,8 @@ class ExecuteModelRequest(
finished_requests_ids
:
List
[
str
]
=
msgspec
.
field
(
default_factory
=
list
)
finished_requests_ids
:
List
[
str
]
=
msgspec
.
field
(
default_factory
=
list
)
# The last sampled token ids for multi step decoding.
# The last sampled token ids for multi step decoding.
last_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
last_sampled_token_ids
:
Optional
[
torch
.
Tensor
]
=
None
# Async callback
async_callback
:
Optional
[
Callable
]
=
None
@
property
@
property
def
is_first_multi_step
(
self
)
->
bool
:
def
is_first_multi_step
(
self
)
->
bool
:
...
@@ -1338,4 +1270,5 @@ class ExecuteModelRequest(
...
@@ -1338,4 +1270,5 @@ class ExecuteModelRequest(
num_steps
=
self
.
num_steps
,
num_steps
=
self
.
num_steps
,
finished_requests_ids
=
self
.
finished_requests_ids
,
finished_requests_ids
=
self
.
finished_requests_ids
,
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
clone
()
last_sampled_token_ids
=
self
.
last_sampled_token_ids
.
clone
()
if
self
.
last_sampled_token_ids
is
not
None
else
None
)
if
self
.
last_sampled_token_ids
is
not
None
else
None
,
async_callback
=
self
.
async_callback
)
vllm/spec_decode/batch_expansion.py
View file @
0640f227
...
@@ -5,13 +5,13 @@ from typing import Iterator, List, Optional, Tuple
...
@@ -5,13 +5,13 @@ from typing import Iterator, List, Optional, Tuple
import
torch
import
torch
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
ExecuteModelRequest
,
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
ExecuteModelRequest
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
,
SequenceData
,
SequenceGroupMetadata
,
get_all_seq_ids
)
get_all_seq_ids
)
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
from
vllm.spec_decode.interfaces
import
(
SpeculativeProposals
,
SpeculativeScorer
,
SpeculativeScores
)
SpeculativeScorer
,
SpeculativeScores
)
from
vllm.spec_decode.util
import
(
nvtx_range
,
sampler_output_to_torch
,
from
vllm.spec_decode.util
import
nvtx_range
,
split_batch_by_proposal_len
split_batch_by_proposal_len
)
from
vllm.worker.worker_base
import
WorkerBase
from
vllm.worker.worker_base
import
WorkerBase
SeqId
=
int
SeqId
=
int
...
@@ -88,17 +88,25 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -88,17 +88,25 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
assert
len
(
target_sampler_output
)
==
1
,
"expected single-step output"
target_sampler_output
=
target_sampler_output
[
0
]
target_sampler_output
=
target_sampler_output
[
0
]
(
all_tokens
,
all_probs
,
spec_logprobs
,
if
not
non_spec_indices
:
all_hidden_states
)
=
self
.
_contract_batch
(
# All sequence groups in batch have spec decoding enabled
contracted_bs
=
len
(
execute_model_req
.
seq_group_metadata_list
),
contracted
=
self
.
_contract_batch_all_spec
(
target_sampler_output
=
target_sampler_output
,
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
proposals
=
proposals
,
num_scoring_tokens
=
num_scoring_tokens
,
)
non_spec_indices
=
non_spec_indices
,
else
:
spec_indices
=
spec_indices
,
# Batch has a mix of spec decode enabled and disabled seq groups
k
=
execute_model_req
.
num_lookahead_slots
,
contracted
=
self
.
_contract_batch
(
)
contracted_bs
=
len
(
execute_model_req
.
seq_group_metadata_list
),
target_sampler_output
=
target_sampler_output
,
proposals
=
proposals
,
num_scoring_tokens
=
num_scoring_tokens
,
non_spec_indices
=
non_spec_indices
,
spec_indices
=
spec_indices
,
k
=
execute_model_req
.
num_lookahead_slots
,
)
all_tokens
,
all_probs
,
spec_logprobs
,
all_hidden_states
=
contracted
return
SpeculativeScores
(
return
SpeculativeScores
(
probs
=
all_probs
,
probs
=
all_probs
,
token_ids
=
all_tokens
,
token_ids
=
all_tokens
,
...
@@ -121,14 +129,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -121,14 +129,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# proposal len. This adds some complexity (splitting the batch into spec
# proposal len. This adds some complexity (splitting the batch into spec
# and non spec sequences) and should be removed in the future. It can be
# and non spec sequences) and should be removed in the future. It can be
# done by supporting per-sequence proposal lens.
# done by supporting per-sequence proposal lens.
spec_seqs
,
spec_indices
=
split_batch_by_proposal_len
(
(
spec_seqs
,
spec_indices
),
(
non_spec_seqs
,
non_spec_indices
)
=
\
seq_group_metadata_list
,
split_batch_by_proposal_len
(
proposal_lens_list
,
seq_group_metadata_list
,
proposal_lens_list
)
select_proposal_len_zero
=
False
)
non_spec_seqs
,
non_spec_indices
=
split_batch_by_proposal_len
(
seq_group_metadata_list
,
proposal_lens_list
,
select_proposal_len_zero
=
True
)
target_seq_group_metadata_list
=
self
.
_create_scoring_model_input
(
target_seq_group_metadata_list
=
self
.
_create_scoring_model_input
(
seq_group_metadata_list
=
spec_seqs
,
seq_group_metadata_list
=
spec_seqs
,
...
@@ -171,7 +174,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -171,7 +174,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
# The number of tokens in the expanded batch used for speculation is
# The number of tokens in the expanded batch used for speculation is
# equal to the total expanded batch size minus the number of samples for
# equal to the total expanded batch size minus the number of samples for
# non-speculative sequences.
# non-speculative sequences.
non_spec_expanded_bs
,
_
=
non_spec_target_token_ids
.
shape
non_spec_expanded_bs
=
len
(
non_spec_target_token_ids
)
spec_expanded_bs
=
expanded_batch_size
-
non_spec_expanded_bs
spec_expanded_bs
=
expanded_batch_size
-
non_spec_expanded_bs
target_token_ids
=
target_token_ids
.
reshape
(
spec_expanded_bs
,
k
+
1
)
target_token_ids
=
target_token_ids
.
reshape
(
spec_expanded_bs
,
k
+
1
)
...
@@ -181,7 +184,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -181,7 +184,7 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
if
target_hidden_states
is
not
None
:
if
target_hidden_states
is
not
None
:
target_hidden_states
=
target_hidden_states
.
reshape
(
target_hidden_states
=
target_hidden_states
.
reshape
(
spec_expanded_bs
,
k
+
1
,
target_hidden_states
.
shape
[
-
1
])
*
target_token_ids
.
shape
,
target_hidden_states
.
shape
[
-
1
])
all_tokens
=
target_token_ids
.
new_full
(
size
=
(
contracted_bs
,
k
+
1
),
all_tokens
=
target_token_ids
.
new_full
(
size
=
(
contracted_bs
,
k
+
1
),
fill_value
=-
1
)
fill_value
=-
1
)
...
@@ -196,24 +199,58 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -196,24 +199,58 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
all_hidden_states
=
None
all_hidden_states
=
None
if
non_spec_indices
:
if
non_spec_indices
:
all_tokens
[
non_spec_indices
,
:
1
]
=
non_spec_target_token_ids
all_tokens
[
non_spec_indices
,
:
1
]
=
\
all_probs
[
non_spec_indices
,
:
1
,
:]
=
non_spec_target_probs
non_spec_target_token_ids
.
unsqueeze
(
1
)
all_logprobs
[
non_spec_indices
,
:
1
,
:]
=
non_spec_target_logprobs
all_probs
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_target_probs
.
unsqueeze
(
1
)
all_logprobs
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_target_logprobs
.
unsqueeze
(
1
)
if
all_hidden_states
is
not
None
:
if
all_hidden_states
is
not
None
:
all_hidden_states
[
assert
non_spec_target_hidden_states
is
not
None
non_spec_indices
,
:
1
,
:]
=
non_spec_target_hidden_states
all_hidden_states
[
non_spec_indices
,
:
1
,
:]
=
\
non_spec_target_hidden_states
.
unsqueeze
(
1
)
if
spec_indices
:
if
spec_indices
:
all_tokens
[
spec_indices
]
=
target_token_ids
all_tokens
[
spec_indices
]
=
target_token_ids
all_probs
[
spec_indices
]
=
target_probs
all_probs
[
spec_indices
]
=
target_probs
all_logprobs
[
spec_indices
]
=
target_logprobs
all_logprobs
[
spec_indices
]
=
target_logprobs
if
all_hidden_states
is
not
None
:
if
all_hidden_states
is
not
None
:
all_hidden_states
[
spec_indices
]
=
target_hidden_states
all_hidden_states
[
spec_indices
]
=
target_hidden_states
return
all_tokens
,
all_probs
,
all_logprobs
,
all_hidden_states
return
all_tokens
,
all_probs
,
all_logprobs
,
all_hidden_states
def
_contract_batch_all_spec
(
self
,
target_sampler_output
:
SamplerOutput
,
proposals
:
SpeculativeProposals
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
It assumes all sequences in the batch were previously expanded.
"""
# Map distinct sequences used to score each token
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
contracted_bs
,
k
=
proposals
.
proposal_token_ids
.
shape
# Reshape tensors to original batch size
target_token_ids
=
target_sampler_output
.
sampled_token_ids
.
reshape
(
contracted_bs
,
k
+
1
)
target_probs
=
target_sampler_output
.
sampled_token_probs
.
reshape
(
*
target_token_ids
.
shape
,
self
.
_vocab_size
)
target_logprobs
=
target_sampler_output
.
logprobs
.
reshape
(
target_probs
.
shape
)
target_hidden_states
=
target_sampler_output
.
hidden_states
if
target_hidden_states
is
not
None
:
target_hidden_states
=
target_hidden_states
.
reshape
(
*
target_token_ids
.
shape
,
target_hidden_states
.
shape
[
-
1
])
return
(
target_token_ids
,
target_probs
,
target_logprobs
,
target_hidden_states
)
def
_create_scoring_model_input
(
def
_create_scoring_model_input
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
...
@@ -345,8 +382,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -345,8 +382,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
token_chunk_size
=
1
,
token_chunk_size
=
1
,
)
)
@
staticmethod
def
_split_scoring_output
(
def
_split_scoring_output
(
self
,
sampler_output
:
SamplerOutput
,
num_scoring_tokens
:
int
sampler_output
:
SamplerOutput
,
num_scoring_tokens
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
...
@@ -361,10 +399,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -361,10 +399,9 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
#
#
# First samples are from speculative scoring, latter samples are non-
# First samples are from speculative scoring, latter samples are non-
# speculative samples.
# speculative samples.
split_sizes
=
[
split_sizes
=
(
num_scoring_tokens
,
num_scoring_tokens
,
sampler_output
.
sampled_token_ids
.
numel
()
-
sampler_output
.
sampled_token_ids
.
numel
()
-
num_scoring_tokens
num_scoring_tokens
)
]
(
spec_probs
,
non_spec_probs
(
spec_probs
,
non_spec_probs
)
=
sampler_output
.
sampled_token_probs
.
split
(
split_sizes
)
)
=
sampler_output
.
sampled_token_probs
.
split
(
split_sizes
)
(
spec_sampled_tokens
,
non_spec_sampled_tokens
(
spec_sampled_tokens
,
non_spec_sampled_tokens
...
@@ -382,32 +419,13 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -382,32 +419,13 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
else
:
else
:
spec_hidden_states
,
non_spec_hidden_states
=
None
,
None
spec_hidden_states
,
non_spec_hidden_states
=
None
,
None
# Convert scores to tensors.
return
(
spec_sampled_tokens
,
spec_probs
,
spec_logprobs
,
sampler_output
.
sampled_token_probs
=
spec_probs
spec_hidden_states
,
non_spec_sampled_tokens
,
non_spec_probs
,
sampler_output
.
sampled_token_ids
=
spec_sampled_tokens
non_spec_logprobs
,
non_spec_hidden_states
)
sampler_output
.
logprobs
=
spec_logprobs
sampler_output
.
hidden_states
=
spec_hidden_states
(
target_token_ids
,
target_probs
,
target_logprobs
,
target_hidden_states
)
=
sampler_output_to_torch
([
sampler_output
],
True
)
# Convert non-speculative output tokens to tensors.
sampler_output
.
sampled_token_probs
=
non_spec_probs
sampler_output
.
sampled_token_ids
=
non_spec_sampled_tokens
sampler_output
.
logprobs
=
non_spec_logprobs
sampler_output
.
hidden_states
=
non_spec_hidden_states
(
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_logprobs
,
non_spec_target_hidden_states
)
=
sampler_output_to_torch
(
[
sampler_output
],
True
)
return
(
target_token_ids
,
target_probs
,
target_logprobs
,
target_hidden_states
,
non_spec_target_token_ids
,
non_spec_target_probs
,
non_spec_target_logprobs
,
non_spec_target_hidden_states
)
@
staticmethod
def
_create_target_seq_id_iterator
(
def
_create_target_seq_id_iterator
(
self
,
seq_ids
:
List
[
SeqId
])
->
Iterator
[
TargetSeqId
]:
seq_ids
:
List
[
SeqId
])
->
Iterator
[
TargetSeqId
]:
"""Create an iterator for creating target sequence ids.
"""Create an iterator for creating target sequence ids.
Target sequence ids are distinct from sequence ids because we create a
Target sequence ids are distinct from sequence ids because we create a
distinct target sequence id for each proposal token to be scored.
distinct target sequence id for each proposal token to be scored.
...
@@ -417,8 +435,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -417,8 +435,8 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
"""
"""
return
count
(
start
=
max
(
seq_ids
)
+
1
)
return
count
(
start
=
max
(
seq_ids
)
+
1
)
@
staticmethod
def
_get_token_ids_to_score
(
def
_get_token_ids_to_score
(
self
,
full_spec_token_ids
:
List
[
TokenId
]
# shape: [k]
full_spec_token_ids
:
List
[
TokenId
]
# shape: [k]
)
->
List
[
List
[
TokenId
]]:
)
->
List
[
List
[
TokenId
]]:
"""Given an int tensor of proposal token ids, return a list of
"""Given an int tensor of proposal token ids, return a list of
...
@@ -439,8 +457,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
...
@@ -439,8 +457,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer):
empty_token_ids
:
List
[
TokenId
]
=
[]
empty_token_ids
:
List
[
TokenId
]
=
[]
token_ids_to_score
=
[
empty_token_ids
]
token_ids_to_score
=
[
empty_token_ids
]
token_ids_to_score
.
extend
([
token_ids_to_score
.
extend
(
full_spec_token_ids
[:
i
+
1
]
full_spec_token_ids
[:
i
+
1
]
for
i
in
range
(
len
(
full_spec_token_ids
)))
for
i
in
range
(
len
(
full_spec_token_ids
))
])
return
token_ids_to_score
return
token_ids_to_score
vllm/spec_decode/draft_model_runner.py
View file @
0640f227
...
@@ -3,6 +3,7 @@ from typing import List, Optional
...
@@ -3,6 +3,7 @@ from typing import List, Optional
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.sampler
import
SamplerOutput
try
:
try
:
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.attention.backends.flash_attn
import
FlashAttentionMetadata
...
@@ -16,8 +17,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
...
@@ -16,8 +17,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
PromptAdapterConfig
,
SchedulerConfig
)
PromptAdapterConfig
,
SchedulerConfig
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
MultiModalInputs
from
vllm.multimodal
import
MultiModalInputs
from
vllm.sequence
import
(
ExecuteModelRequest
,
IntermediateTensors
,
from
vllm.sequence
import
ExecuteModelRequest
,
IntermediateTensors
SamplerOutput
)
from
vllm.worker.model_runner
import
(
ModelInputForGPUWithSamplingMetadata
,
from
vllm.worker.model_runner
import
(
ModelInputForGPUWithSamplingMetadata
,
ModelRunner
)
ModelRunner
)
...
...
vllm/spec_decode/medusa_worker.py
View file @
0640f227
...
@@ -4,8 +4,8 @@ from typing import List, Optional, Set, Tuple
...
@@ -4,8 +4,8 @@ from typing import List, Optional, Set, Tuple
import
torch
import
torch
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.
sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
from
vllm.
model_executor.layers.sampler
import
SamplerOutput
SequenceGroupMetadata
)
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceGroupMetadata
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.interfaces
import
SpeculativeProposals
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
from
vllm.spec_decode.top1_proposer
import
Top1Proposer
...
...
vllm/spec_decode/mlp_speculator_worker.py
View file @
0640f227
...
@@ -3,8 +3,8 @@ from typing import List, Optional, Set, Tuple
...
@@ -3,8 +3,8 @@ from typing import List, Optional, Set, Tuple
import
torch
import
torch
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.
sequence
import
(
ExecuteModelRequest
,
SamplerOutput
,
from
vllm.
model_executor.layers.sampler
import
SamplerOutput
SequenceGroupMetadata
)
from
vllm.sequence
import
ExecuteModelRequest
,
SequenceGroupMetadata
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.multi_step_worker
import
MultiStepWorker
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
from
vllm.spec_decode.proposer_worker_base
import
NonLLMProposerWorkerBase
...
...
Prev
1
…
11
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment