Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
31f6b24f
Commit
31f6b24f
authored
Mar 26, 2025
by
zhuwenwen
Browse files
Merge remote-tracking branch 'mirror/v0.8.2' into v0.8.2-ori
parents
89d1dd57
25f560a6
Changes
88
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
498 additions
and
218 deletions
+498
-218
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+47
-0
vllm/model_executor/models/transformers.py
vllm/model_executor/models/transformers.py
+230
-86
vllm/model_executor/models/utils.py
vllm/model_executor/models/utils.py
+11
-1
vllm/platforms/cpu.py
vllm/platforms/cpu.py
+1
-1
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+3
-5
vllm/sampling_params.py
vllm/sampling_params.py
+2
-1
vllm/utils.py
vllm/utils.py
+3
-3
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+6
-4
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+1
-1
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+4
-2
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+14
-20
vllm/v1/engine/logprobs.py
vllm/v1/engine/logprobs.py
+0
-1
vllm/v1/engine/output_processor.py
vllm/v1/engine/output_processor.py
+47
-22
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+32
-7
vllm/v1/metrics/stats.py
vllm/v1/metrics/stats.py
+4
-15
vllm/v1/outputs.py
vllm/v1/outputs.py
+19
-0
vllm/v1/sample/ops/utils.py
vllm/v1/sample/ops/utils.py
+0
-30
vllm/v1/sample/rejection_sampler.py
vllm/v1/sample/rejection_sampler.py
+72
-14
vllm/v1/sample/sampler.py
vllm/v1/sample/sampler.py
+1
-1
vllm/v1/spec_decode/utils.py
vllm/v1/spec_decode/utils.py
+1
-4
No files found.
vllm/model_executor/model_loader/weight_utils.py
View file @
31f6b24f
...
...
@@ -38,6 +38,14 @@ except (ImportError, OSError):
SafetensorsStreamer
=
runai_model_streamer
.
placeholder_attr
(
"SafetensorsStreamer"
)
try
:
from
fastsafetensors
import
SafeTensorsFileLoader
,
SingleGroup
except
ImportError
:
fastsafetensors
=
PlaceholderModule
(
"fastsafetensors"
)
SafeTensorsFileLoader
=
fastsafetensors
.
placeholder_attr
(
"SafeTensorsFileLoader"
)
SingleGroup
=
fastsafetensors
.
placeholder_attr
(
"SingleGroup"
)
logger
=
init_logger
(
__name__
)
# use system-level temp directory for file locks, so that multiple users
...
...
@@ -452,6 +460,45 @@ def runai_safetensors_weights_iterator(
yield
from
streamer
.
get_tensors
()
def
fastsafetensors_weights_iterator
(
hf_weights_files
:
List
[
str
],
use_tqdm_on_load
:
bool
,
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Iterate over the weights in the model safetensor files
using fastsafetensor library."""
if
torch
.
distributed
.
is_initialized
():
pg
=
torch
.
distributed
.
group
.
WORLD
else
:
pg
=
SingleGroup
()
device
=
torch
.
device
(
f
'cuda:
{
pg
.
rank
()
}
'
)
weight_files_sub_lists
=
[
hf_weights_files
[
i
:
i
+
pg
.
size
()]
for
i
in
range
(
0
,
len
(
hf_weights_files
),
pg
.
size
())
]
for
f_list
in
tqdm
(
weight_files_sub_lists
,
desc
=
"Loading safetensors using Fastsafetensor loader"
,
disable
=
not
enable_tqdm
(
use_tqdm_on_load
),
bar_format
=
_BAR_FORMAT
,
):
loader
=
SafeTensorsFileLoader
(
pg
,
device
)
rank_file_map
=
{
i
:
[
f
]
for
i
,
f
in
enumerate
(
f_list
)}
loader
.
add_filenames
(
rank_file_map
)
try
:
fb
=
loader
.
copy_files_to_device
()
try
:
keys
=
list
(
fb
.
key_to_rank_lidx
.
keys
())
for
k
in
keys
:
t
=
fb
.
get_tensor
(
k
)
yield
k
,
t
finally
:
fb
.
close
()
finally
:
loader
.
close
()
def
pt_weights_iterator
(
hf_weights_files
:
List
[
str
],
use_tqdm_on_load
:
bool
,
...
...
vllm/model_executor/models/transformers.py
View file @
31f6b24f
...
...
@@ -15,21 +15,25 @@
# limitations under the License.
"""Wrapper around `transformers` models"""
import
re
from
itertools
import
chain
from
typing
import
Iterable
,
Literal
,
Optional
,
Union
import
torch
from
torch
import
nn
from
transformers
import
AutoModel
,
PreTrainedModel
from
transformers
import
AutoModel
,
PretrainedConfig
,
PreTrainedModel
from
transformers.modeling_utils
import
ALL_ATTENTION_FUNCTIONS
from
vllm.attention
import
Attention
from
vllm.config
import
VllmConfig
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
ModelConfig
,
ParallelConfig
,
VllmConfig
)
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed.utils
import
get_pp_indices
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
...
...
@@ -37,8 +41,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
.interfaces
import
SupportsLoRA
,
SupportsQuant
from
.utils
import
maybe_prefix
from
.interfaces
import
SupportsLoRA
,
SupportsPP
,
SupportsQuant
from
.utils
import
(
PPMissingLayer
,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
maybe_prefix
)
logger
=
init_logger
(
__name__
)
...
...
@@ -53,7 +58,7 @@ def vllm_flash_attention_forward(
# Transformers kwargs
scaling
:
Optional
[
float
]
=
None
,
# vLLM kwargs
attention_instances
:
Optional
[
lis
t
[
Attention
]]
=
None
,
attention_instances
:
Optional
[
dic
t
[
Attention
]]
=
None
,
**
kwargs
):
self_attn
=
attention_instances
[
module
.
layer_idx
]
if
scaling
is
not
None
:
...
...
@@ -72,13 +77,12 @@ def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
def
replace_linear_class
(
linear
:
nn
.
Linear
,
style
:
Literal
[
"colwise"
,
"rowwise"
],
quant_config
=
None
)
->
Union
[
ColumnParallelLinear
,
RowParallelLinear
]:
linear
:
nn
.
Linear
,
style
:
Literal
[
"colwise"
,
"rowwise"
],
quant_config
:
QuantizationConfig
)
->
Union
[
ColumnParallelLinear
,
RowParallelLinear
]:
"""
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
`quant_config` is not yet supported.
Args:
linear (nn.Linear): `nn.Linear` to be replaced.
style (str): Tensor parallel style of the new linear, e.g. "colwise".
...
...
@@ -105,7 +109,7 @@ def replace_linear_class(
)
class
TransformersModel
(
nn
.
Module
,
SupportsQuant
,
SupportsLoRA
):
class
TransformersModel
(
nn
.
Module
,
SupportsQuant
,
SupportsLoRA
,
SupportsPP
):
embedding_padding_modules
=
[
"lm_head"
]
embedding_modules
=
[
"embed_tokens"
]
# TODO transformers will have a util to get it
...
...
@@ -114,31 +118,175 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
super
().
__init__
()
logger
.
info
(
"Using Transformers backend."
)
config
=
vllm_config
.
model_config
.
hf_config
cache_config
=
vllm_config
.
cache_config
model_config
=
vllm_config
.
model_config
parallel_config
=
vllm_config
.
parallel_config
config
:
PretrainedConfig
=
vllm_config
.
model_config
.
hf_config
cache_config
:
CacheConfig
=
vllm_config
.
cache_config
device_config
:
DeviceConfig
=
vllm_config
.
device_config
model_config
:
ModelConfig
=
vllm_config
.
model_config
parallel_config
:
ParallelConfig
=
vllm_config
.
parallel_config
quant_config
:
QuantizationConfig
=
vllm_config
.
quant_config
self
.
config
=
config
self
.
cache_config
=
cache_config
self
.
device_config
=
device_config
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
quant_config
=
quant_config
self
.
vocab_size
=
model_config
.
get_vocab_size
()
self
.
unpadded_vocab_size
=
model_config
.
get_vocab_size
()
self
.
model
:
PreTrainedModel
=
AutoModel
.
from_config
(
self
.
config
,
attn_implementation
=
"vllm"
,
torch_dtype
=
vllm_config
.
model_config
.
dtype
,
trust_remote_code
=
vllm_config
.
model_config
.
trust_remote_code
,
)
self
.
pp_group
=
get_pp_group
()
self
.
pp_size
=
self
.
pp_group
.
world_size
self
.
pp_rank
=
self
.
pp_group
.
rank_in_group
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
# Use meta device to delay allocating GPU tensors
with
torch
.
device
(
"meta"
):
self
.
model
:
PreTrainedModel
=
AutoModel
.
from_config
(
config
,
attn_implementation
=
"vllm"
,
torch_dtype
=
model_config
.
dtype
,
trust_remote_code
=
model_config
.
trust_remote_code
,
)
prefix
=
self
.
model
.
base_model_prefix
# MLP modifications
self
.
apply_base_model_tp_plan
(
self
.
model
)
self
.
pipeline_parallel
()
self
.
tensor_parallel
()
# Input embeddings
if
not
isinstance
(
self
.
model
.
get_input_embeddings
(),
PPMissingLayer
):
self
.
model
.
set_input_embeddings
(
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
quant_config
=
quant_config
,
))
# Attention layers
self
.
attention_instances
=
self
.
create_attention_instances
()
# Output embeddings
if
not
isinstance
(
getattr
(
self
,
"lm_head"
,
None
),
PPMissingLayer
):
self
.
unpadded_vocab_size
=
config
.
vocab_size
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
),
)
if
config
.
tie_word_embeddings
:
self
.
lm_head
=
self
.
lm_head
.
tie_weights
(
self
.
model
.
get_input_embeddings
())
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
# Initialize buffers (e.g. rotary embedding inverse frequency)
self
.
init_buffers
(
self
.
model
)
# Move remaining meta tensors to device (should happen last)
self
.
meta_to_empty
(
self
.
model
)
self
.
sampler
=
get_sampler
()
self
.
make_empty_intermediate_tensors
=
(
make_empty_intermediate_tensors_factory
([
"hidden_states"
],
config
.
hidden_size
))
def
pipeline_parallel
(
self
):
"""
Apply the model's pipeline parallelization plan.
"""
if
self
.
pp_size
<=
1
:
return
# Attention modifications (assumes 1 attention op per hidden layer)
num_heads
=
model_config
.
get_num_attention_heads
(
parallel_config
)
head_size
=
model_config
.
get_head_size
()
num_kv_heads
=
model_config
.
get_num_kv_heads
(
parallel_config
)
self
.
attention_instances
=
[
if
not
self
.
model
.
supports_pp_plan
:
raise
ValueError
(
f
"
{
type
(
self
.
model
)
}
does not support pipeline parallel yet!"
)
module_lists
=
[]
module_list_idx
=
None
pp_plan
=
list
(
self
.
model
.
_pp_plan
.
keys
())
for
i
,
name
in
enumerate
(
pp_plan
):
if
isinstance
(
getattr
(
self
.
model
,
name
),
nn
.
ModuleList
):
module_lists
.
append
(
name
)
module_list_idx
=
i
if
len
(
module_lists
)
>
1
:
raise
ValueError
(
"Pipeline parallel of models with multiple `ModuleList`s "
"in the base model are not supported yet!"
)
if
module_list_idx
is
None
:
raise
ValueError
(
f
"Could not find `ModuleList` in
{
type
(
self
.
model
)
}
"
)
# Layers before module list
for
name
in
pp_plan
[:
module_list_idx
]:
if
self
.
pp_group
.
is_first_rank
or
(
self
.
config
.
tie_word_embeddings
and
self
.
pp_group
.
is_last_rank
):
continue
setattr
(
self
.
model
,
name
,
PPMissingLayer
())
# Module list
start_layer
,
end_layer
=
get_pp_indices
(
self
.
config
.
num_hidden_layers
,
self
.
pp_rank
,
self
.
pp_size
)
layers_name
=
pp_plan
[
module_list_idx
]
layers
=
getattr
(
self
.
model
,
layers_name
)
for
i
in
range
(
len
(
layers
)):
if
start_layer
<=
i
and
i
<
end_layer
:
continue
layers
[
i
]
=
PPMissingLayer
(
return_tuple
=
True
)
# Layers after module list
for
name
in
pp_plan
[
module_list_idx
+
1
:]:
# Modules that should be on last rank
if
not
self
.
pp_group
.
is_last_rank
:
setattr
(
self
.
model
,
name
,
PPMissingLayer
())
if
not
self
.
pp_group
.
is_last_rank
:
self
.
lm_head
=
PPMissingLayer
()
def
tensor_parallel
(
self
):
"""
Apply the model's tensor parallelization plan.
Currently only supports linear layers.
"""
if
self
.
tp_size
>
1
and
self
.
config
.
base_model_tp_plan
is
None
:
raise
ValueError
(
f
"
{
type
(
self
.
model
)
}
does not support tensor parallel yet!"
)
tp_plan
=
self
.
model
.
_tp_plan
def
_tensor_parallel
(
module
:
nn
.
Module
,
prefix
:
str
=
""
):
for
child_name
,
child_module
in
module
.
named_children
():
qual_name
=
maybe_prefix
(
prefix
,
child_name
)
for
pattern
,
style
in
tp_plan
.
items
():
if
re
.
match
(
pattern
,
qual_name
)
and
isinstance
(
child_module
,
nn
.
Linear
):
new_module
=
replace_linear_class
(
child_module
,
style
,
self
.
quant_config
)
setattr
(
module
,
child_name
,
new_module
)
log_replacement
(
qual_name
,
child_module
,
new_module
)
else
:
_tensor_parallel
(
child_module
,
prefix
=
qual_name
)
_tensor_parallel
(
self
.
model
)
def
create_attention_instances
(
self
)
->
dict
[
int
,
Attention
]:
"""
Create `Attention` instances to inform KV cache allocation.
"""
num_heads
=
self
.
model_config
.
get_num_attention_heads
(
self
.
parallel_config
)
head_size
=
self
.
model_config
.
get_head_size
()
num_kv_heads
=
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
)
start
,
end
=
get_pp_indices
(
self
.
config
.
num_hidden_layers
,
self
.
pp_rank
,
self
.
pp_size
)
return
{
i
:
Attention
(
num_heads
=
num_heads
,
head_size
=
head_size
,
...
...
@@ -146,77 +294,70 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
# Transformers, it's updated in vllm_flash_attention_forward
scale
=
head_size
**-
0.5
,
num_kv_heads
=
num_kv_heads
,
cache_config
=
cache_config
,
cache_config
=
self
.
cache_config
,
quant_config
=
self
.
quant_config
,
prefix
=
f
"
{
i
}
.attn"
)
for
i
in
range
(
config
.
num_hidden_layers
)
]
# Model modifications
self
.
replace_vocab_embed_class
(
self
.
model
)
# ForCausalLM modifications
self
.
lm_head
=
ParallelLMHead
(
self
.
vocab_size
,
config
.
hidden_size
,
quant_config
=
self
.
quant_config
,
prefix
=
maybe_prefix
(
prefix
,
"lm_head"
))
if
config
.
tie_word_embeddings
:
self
.
lm_head
.
weight
=
self
.
model
.
get_input_embeddings
().
weight
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
self
.
vocab_size
,
logit_scale
)
self
.
sampler
=
get_sampler
()
prefix
=
f
"
{
i
}
.attn"
)
for
i
in
range
(
start
,
end
)
}
def
apply_base_model_tp_plan
(
self
,
module
:
nn
.
Module
,
prefix
:
str
=
""
):
def
init_buffers
(
self
,
module
:
nn
.
Module
):
"""
Apply the base model tensor parallelization plan to a module.
Currently only supports linear layers.
If a `buffer` is on the `meta` device, then its parent
`module` is the original module created by:
```python
with torch.device("meta"):
self.model: PreTrainedModel = AutoModel.from_config(...)
```
This means that:
- `type(module)` is a class from `transformers`
- This class is constructed using a `PretrainedConfig`
"""
if
(
self
.
config
.
base_model_tp_plan
is
None
and
get_tensor_model_parallel_world_size
()
>
1
):
raise
ValueError
(
"Trying to run tensor parallelization but the model does not "
"support it yet!"
)
for
child_name
,
child_module
in
module
.
named_children
():
qual_name
=
maybe_prefix
(
prefix
,
child_name
)
for
pattern
,
style
in
self
.
config
.
base_model_tp_plan
.
items
():
if
re
.
match
(
pattern
,
qual_name
)
and
isinstance
(
child_module
,
nn
.
Linear
):
new_module
=
replace_linear_class
(
child_module
,
style
,
self
.
quant_config
)
setattr
(
module
,
child_name
,
new_module
)
log_replacement
(
qual_name
,
child_module
,
new_module
)
else
:
self
.
apply_base_model_tp_plan
(
child_module
,
prefix
=
qual_name
)
def
replace_vocab_embed_class
(
self
,
module
:
nn
.
Module
):
# Use native set input embeddings
new_module
=
VocabParallelEmbedding
(
self
.
vocab_size
,
self
.
config
.
hidden_size
,
org_num_embeddings
=
self
.
vocab_size
,
quant_config
=
None
,
)
log_replacement
(
"input embedding"
,
self
.
model
.
get_input_embeddings
(),
new_module
)
module
.
set_input_embeddings
(
new_module
)
for
name
,
buffer
in
module
.
named_buffers
(
recurse
=
False
):
if
buffer
.
device
==
torch
.
device
(
"meta"
):
new_buffer
=
getattr
(
type
(
module
)(
self
.
config
),
name
)
setattr
(
module
,
name
,
new_buffer
)
for
child
in
module
.
children
():
self
.
init_buffers
(
child
)
def
meta_to_empty
(
self
,
module
:
nn
.
Module
):
tensors
=
list
(
chain
(
module
.
buffers
(),
module
.
parameters
()))
if
tensors
and
all
(
t
.
device
==
torch
.
device
(
"meta"
)
for
t
in
tensors
):
module
.
to_empty
(
device
=
self
.
device_config
.
device
)
return
# We can stop recursing because to_empty is recursive
for
child
in
module
.
children
():
self
.
meta_to_empty
(
child
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
input_ids
:
Optional
[
torch
.
Tensor
]
,
positions
:
torch
.
Tensor
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
model_output
=
self
.
model
(
input_ids
[
None
,
...],
if
not
get_pp_group
().
is_first_rank
:
assert
intermediate_tensors
is
not
None
input_ids
=
None
inputs_embeds
=
intermediate_tensors
[
"hidden_states"
]
if
input_ids
is
not
None
:
input_ids
=
input_ids
[
None
,
...]
if
inputs_embeds
is
not
None
:
inputs_embeds
=
inputs_embeds
[
None
,
...]
hidden_states
=
self
.
model
(
input_ids
=
input_ids
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
False
,
position_ids
=
positions
[
None
,
...],
intermediate_tensors
=
intermediate_tensors
,
attention_instances
=
self
.
attention_instances
,
return_dict
=
False
)[
0
][
0
,
...]
# we remove batch dimension for now
return
model_output
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
return
hidden_states
def
compute_logits
(
self
,
...
...
@@ -238,8 +379,11 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
=
set
[
str
]()
for
name
,
loaded_weight
in
weights
:
if
name
not
in
params_dict
:
name
=
f
"
{
self
.
model
.
base_model_prefix
}
.
{
name
}
"
# Necessary for some models which use remote code
if
not
name
.
startswith
(
prefix
:
=
self
.
model
.
base_model_prefix
):
name
=
maybe_prefix
(
prefix
,
name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
if
name
in
params_dict
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
...
...
vllm/model_executor/models/utils.py
View file @
31f6b24f
...
...
@@ -472,6 +472,16 @@ class PPMissingLayer(torch.nn.Identity):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
()
self
.
return_tuple
=
kwargs
.
get
(
"return_tuple"
,
False
)
def
forward
(
self
,
*
args
,
**
kwargs
):
"""
Return the first arg from args or the first value from kwargs.
Wraps the input in a tuple if `self.return_tuple` is True.
"""
input
=
args
[
0
]
if
args
else
next
(
iter
(
kwargs
.
values
()))
return
(
input
,
)
if
self
.
return_tuple
else
input
_CPU_OFFLOAD_BYTES
=
0
...
...
@@ -650,4 +660,4 @@ def cast_overflow_tensors(
if
tensors
.
isinf
().
any
()
or
tensors
.
isnan
().
any
():
clamp_value
=
torch
.
finfo
(
tensors
.
dtype
).
max
-
offset
tensors
=
torch
.
clamp
(
tensors
,
min
=-
clamp_value
,
max
=
clamp_value
)
return
tensors
\ No newline at end of file
return
tensors
vllm/platforms/cpu.py
View file @
31f6b24f
...
...
@@ -92,7 +92,7 @@ class CpuPlatform(Platform):
if
kv_cache_space
==
0
:
cache_config
.
cpu_kvcache_space_bytes
=
4
*
GiB_bytes
# type: ignore
logger
.
warning
(
"Environment variable VLLM_CPU_KVCACHE_SPACE (GB) "
"Environment variable VLLM_CPU_KVCACHE_SPACE (G
i
B) "
"for CPU backend is not set, using 4 by default."
)
else
:
cache_config
.
cpu_kvcache_space_bytes
=
kv_cache_space
*
GiB_bytes
# type: ignore # noqa
...
...
vllm/platforms/cuda.py
View file @
31f6b24f
...
...
@@ -14,7 +14,6 @@ from typing_extensions import ParamSpec
# import custom ops, trigger op registration
import
vllm._C
# noqa
import
vllm.envs
as
envs
from
vllm.fa_utils
import
get_flash_attn_version
from
vllm.logger
import
init_logger
from
vllm.utils
import
import_pynvml
...
...
@@ -258,7 +257,7 @@ class CudaPlatformBase(Platform):
try
:
import
vllm.vllm_flash_attn
# noqa: F401
from
vllm.attention.backends.flash_attn
import
(
# noqa: F401
FlashAttentionBackend
)
FlashAttentionBackend
,
flash_attn_supports_fp8
)
supported_sizes
=
\
FlashAttentionBackend
.
get_supported_head_sizes
()
...
...
@@ -269,10 +268,9 @@ class CudaPlatformBase(Platform):
target_backend
=
_Backend
.
XFORMERS
fp8_kv_cache
=
(
kv_cache_dtype
is
not
None
and
kv_cache_dtype
.
startswith
(
"fp8"
))
if
(
fp8_kv_cache
and
get_
flash_attn_
version
()
!=
3
):
if
(
fp8_kv_cache
and
not
flash_attn_
supports_fp8
()
):
logger
.
info
(
"Cannot use FlashAttention-2 backend for FP8 KV cache."
)
"Cannot use FlashAttention backend for FP8 KV cache."
)
logger
.
warning
(
"Please use FlashInfer backend with FP8 KV Cache for "
"better performance by setting environment variable "
...
...
vllm/sampling_params.py
View file @
31f6b24f
...
...
@@ -369,8 +369,9 @@ class SamplingParams(
self
.
top_k
=
-
1
self
.
min_p
=
0.0
self
.
_verify_greedy_sampling
()
# eos_token_id is added to this by the engine
self
.
_all_stop_token_ids
=
set
(
self
.
stop_token_ids
)
self
.
_all_stop_token_ids
.
update
(
self
.
stop_token_ids
)
def
_verify_args
(
self
)
->
None
:
if
not
isinstance
(
self
.
n
,
int
):
...
...
vllm/utils.py
View file @
31f6b24f
...
...
@@ -37,7 +37,7 @@ from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
from
dataclasses
import
dataclass
,
field
from
functools
import
cache
,
lru_cache
,
partial
,
wraps
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Generic
,
Literal
,
NamedTuple
,
Optional
,
TypeVar
,
Union
)
Optional
,
Type
,
TypeVar
,
Union
)
from
uuid
import
uuid4
import
cloudpickle
...
...
@@ -1544,9 +1544,9 @@ class LazyDict(Mapping[str, T], Generic[T]):
return
len
(
self
.
_factory
)
class
ClassRegistry
(
UserDict
[
t
ype
[
T
],
_V
]):
class
ClassRegistry
(
UserDict
[
T
ype
[
T
],
_V
]):
def
__getitem__
(
self
,
key
:
t
ype
[
T
])
->
_V
:
def
__getitem__
(
self
,
key
:
T
ype
[
T
])
->
_V
:
for
cls
in
key
.
mro
():
if
cls
in
self
.
data
:
return
self
.
data
[
cls
]
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
31f6b24f
...
...
@@ -11,10 +11,11 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata
,
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.ops.triton_merge_attn_states
import
merge_attn_states
from
vllm.fa_utils
import
get_flash_attn_version
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
from
vllm.vllm_flash_attn.fa_utils
import
(
flash_attn_supports_fp8
,
get_flash_attn_version
)
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
...
@@ -182,9 +183,6 @@ class FlashAttentionImpl(AttentionImpl):
else
:
self
.
sliding_window
=
(
sliding_window
-
1
,
0
)
self
.
kv_cache_dtype
=
kv_cache_dtype
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"FlashAttention V1 with FP8 KV cache not yet supported"
)
if
logits_soft_cap
is
None
:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap
=
0
...
...
@@ -206,6 +204,10 @@ class FlashAttentionImpl(AttentionImpl):
"are not implemented for "
"FlashAttentionImpl"
)
self
.
vllm_flash_attn_version
=
get_flash_attn_version
()
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
)
\
and
not
flash_attn_supports_fp8
():
raise
NotImplementedError
(
"FlashAttention does not support fp8 kv-cache on this device."
)
def
forward
(
self
,
...
...
vllm/v1/attention/backends/mla/common.py
View file @
31f6b24f
...
...
@@ -196,7 +196,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata
,
MLAAttentionImpl
)
from
vllm.attention.ops.triton_merge_attn_states
import
merge_attn_states
from
vllm.fa_utils
import
get_flash_attn_version
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
RowParallelLinear
,
...
...
@@ -204,6 +203,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
,
round_down
from
vllm.vllm_flash_attn.fa_utils
import
get_flash_attn_version
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
...
vllm/v1/core/sched/scheduler.py
View file @
31f6b24f
...
...
@@ -627,8 +627,7 @@ class Scheduler(SchedulerInterface):
# Get prompt logprobs for this request.
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
# Transmit partial if chunked prefill & prompt logprobs is enabled
if
new_token_ids
or
prompt_logprobs_tensors
is
not
None
:
if
new_token_ids
:
# Add EngineCoreOutput for this Request.
outputs
.
append
(
EngineCoreOutput
(
...
...
@@ -639,6 +638,9 @@ class Scheduler(SchedulerInterface):
new_prompt_logprobs_tensors
=
prompt_logprobs_tensors
,
stop_reason
=
request
.
stop_reason
,
events
=
request
.
take_events
()))
else
:
# Invariant: EngineCore returns no partial prefill outputs.
assert
not
prompt_logprobs_tensors
self
.
scheduled_req_ids
.
remove
(
request
.
request_id
)
if
not
stopped
:
...
...
vllm/v1/engine/async_llm.py
View file @
31f6b24f
...
...
@@ -21,14 +21,15 @@ from vllm.lora.request import LoRARequest
from
vllm.outputs
import
RequestOutput
from
vllm.pooling_params
import
PoolingParams
from
vllm.prompt_adapter.request
import
PromptAdapterRequest
from
vllm.sampling_params
import
RequestOutputKind
,
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer_group
import
init_tokenizer_from_configs
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
Device
,
cdiv
,
kill_process_tree
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine.core_client
import
EngineCoreClient
from
vllm.v1.engine.output_processor
import
OutputProcessor
from
vllm.v1.engine.output_processor
import
(
OutputProcessor
,
RequestOutputCollector
)
from
vllm.v1.engine.parallel_sampling
import
ParentRequest
from
vllm.v1.engine.processor
import
Processor
from
vllm.v1.executor.abstract
import
Executor
...
...
@@ -176,11 +177,14 @@ class AsyncLLM(EngineClient):
trace_headers
:
Optional
[
Mapping
[
str
,
str
]]
=
None
,
prompt_adapter_request
:
Optional
[
PromptAdapterRequest
]
=
None
,
priority
:
int
=
0
,
)
->
asyncio
.
Queue
[
RequestOutput
]
:
)
->
RequestOutput
Collector
:
"""Add new request to the AsyncLLM."""
# Create a new output queue for the request.
queue
:
asyncio
.
Queue
[
RequestOutput
]
=
asyncio
.
Queue
()
assert
isinstance
(
params
,
SamplingParams
),
\
"Pooling is not supported in V1"
# Create a new output collector for the request.
queue
=
RequestOutputCollector
(
output_kind
=
params
.
output_kind
)
# Convert Input --> Request.
request
=
self
.
processor
.
process_inputs
(
request_id
,
prompt
,
params
,
...
...
@@ -189,17 +193,15 @@ class AsyncLLM(EngineClient):
prompt_adapter_request
,
priority
)
n
=
params
.
n
if
isinstance
(
params
,
SamplingParams
)
else
1
if
n
==
1
:
if
params
.
n
==
1
:
await
self
.
_add_request
(
request
,
None
,
0
,
queue
)
return
queue
# Fan out child requests (for n>1).
parent_request
=
ParentRequest
(
request_id
,
params
)
for
idx
in
range
(
n
):
for
idx
in
range
(
params
.
n
):
request_id
,
params
=
parent_request
.
get_child_info
(
idx
)
child_request
=
request
if
idx
==
n
-
1
else
copy
(
request
)
child_request
=
request
if
idx
==
params
.
n
-
1
else
copy
(
request
)
child_request
.
request_id
=
request_id
child_request
.
sampling_params
=
params
await
self
.
_add_request
(
child_request
,
parent_request
,
idx
,
queue
)
...
...
@@ -207,7 +209,7 @@ class AsyncLLM(EngineClient):
async
def
_add_request
(
self
,
request
:
EngineCoreRequest
,
parent_req
:
Optional
[
ParentRequest
],
index
:
int
,
queue
:
asyncio
.
Queue
[
RequestOutput
]
):
queue
:
RequestOutput
Collector
):
# Add the request to OutputProcessor (this process).
self
.
output_processor
.
add_request
(
request
,
parent_req
,
index
,
queue
)
...
...
@@ -272,15 +274,7 @@ class AsyncLLM(EngineClient):
while
not
finished
:
# Note: drain queue without await if possible (avoids
# task switching under load which helps performance).
out
=
q
.
get_nowait
()
if
not
q
.
empty
()
else
await
q
.
get
()
# Coalesce any additional queued outputs
while
not
q
.
empty
():
next_out
=
q
.
get_nowait
()
if
sampling_params
.
output_kind
==
RequestOutputKind
.
DELTA
:
out
.
add
(
next_out
)
else
:
out
=
next_out
out
=
q
.
get_nowait
()
or
await
q
.
get
()
# Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished.
...
...
vllm/v1/engine/logprobs.py
View file @
31f6b24f
...
...
@@ -115,7 +115,6 @@ class LogprobsProcessor:
num_prompt_tokens
,
num_logprobs
=
logprobs
.
shape
# Pythonize the torch tensors.
# TODO(rob): experiment with doing this in EngineCore?
prompt_token_ranks
=
ranks
.
tolist
()
prompt_logprobs
=
logprobs
.
tolist
()
token_ids
=
token_ids
.
tolist
()
...
...
vllm/v1/engine/output_processor.py
View file @
31f6b24f
...
...
@@ -17,6 +17,46 @@ from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
RequestStateStats
)
class
RequestOutputCollector
:
"""
Collects streamed RequestOutputs per individual request,
for hand-off to the consuming asyncio generate task.
When streaming deltas, RequestOutputs are merged if the
producer gets ahead of the consumer.
"""
def
__init__
(
self
,
output_kind
:
RequestOutputKind
):
self
.
aggregate
=
output_kind
==
RequestOutputKind
.
DELTA
self
.
output
:
Optional
[
RequestOutput
]
=
None
self
.
ready
=
asyncio
.
Event
()
def
put
(
self
,
output
:
RequestOutput
)
->
None
:
if
self
.
output
is
None
:
self
.
output
=
output
self
.
ready
.
set
()
elif
self
.
aggregate
:
# Coalesce the outputs in delta case.
self
.
output
.
add
(
output
)
else
:
# Just replace latest in non-delta case.
self
.
output
=
output
async
def
get
(
self
)
->
RequestOutput
:
while
(
output
:
=
self
.
output
)
is
None
:
await
self
.
ready
.
wait
()
self
.
output
=
None
self
.
ready
.
clear
()
return
output
def
get_nowait
(
self
)
->
Optional
[
RequestOutput
]:
output
=
self
.
output
if
output
is
not
None
:
self
.
output
=
None
self
.
ready
.
clear
()
return
output
@
dataclass
class
OutputProcessorOutput
:
...
...
@@ -39,7 +79,7 @@ class RequestState:
detokenizer
:
IncrementalDetokenizer
,
max_tokens_param
:
Optional
[
int
],
arrival_time
:
float
,
queue
:
Optional
[
asyncio
.
Queue
[
RequestOutput
]
],
queue
:
Optional
[
RequestOutput
Collector
],
log_stats
:
bool
,
):
self
.
request_id
=
request_id
...
...
@@ -66,7 +106,7 @@ class RequestState:
request
:
EngineCoreRequest
,
parent_req
:
Optional
[
ParentRequest
],
request_index
:
int
,
queue
:
Optional
[
asyncio
.
Queue
[
RequestOutput
]
],
queue
:
Optional
[
RequestOutput
Collector
],
log_stats
:
bool
,
)
->
"RequestState"
:
if
not
request
.
sampling_params
.
detokenize
:
...
...
@@ -105,9 +145,7 @@ class RequestState:
finished
=
finish_reason
is
not
None
final_only
=
self
.
output_kind
==
RequestOutputKind
.
FINAL_ONLY
# In follow up, we will switch to invariant where EngineCore
# does not stream partial prefills.
if
not
finished
and
(
self
.
is_prefilling
or
final_only
):
if
not
finished
and
final_only
:
# Only the final output is required in FINAL_ONLY mode.
return
None
...
...
@@ -219,7 +257,7 @@ class OutputProcessor:
request
:
EngineCoreRequest
,
parent_req
:
Optional
[
ParentRequest
]
=
None
,
request_index
:
int
=
0
,
queue
:
Optional
[
asyncio
.
Queue
[
RequestOutput
]
]
=
None
,
queue
:
Optional
[
RequestOutput
Collector
]
=
None
,
)
->
None
:
request_id
=
request
.
request_id
if
request_id
in
self
.
request_states
:
...
...
@@ -285,19 +323,7 @@ class OutputProcessor:
finish_reason
=
engine_core_output
.
finish_reason
stop_reason
=
engine_core_output
.
stop_reason
# TODO(andy): prompt logprobs + chunked prefill can
# result in engine core returning an output for a
# partial prefill (in order to send back partial
# prompt logprobs.) This breaks the invariant that
# process_outputs is only operating on engine core
# outputs associated with non-partial completions.
# Currently this is handled by having `is_prefilling`
# check for new decoded tokens, indicating that
# the completion is not partial.
#
# Follow up will aggregate partial prompt logprobs
# in the EngineCore.
req_state
.
is_prefilling
=
not
new_token_ids
req_state
.
is_prefilling
=
False
# 2) Detokenize the token ids into text and perform stop checks.
stop_string
=
req_state
.
detokenizer
.
update
(
...
...
@@ -306,8 +332,7 @@ class OutputProcessor:
finish_reason
=
FinishReason
.
STOP
stop_reason
=
stop_string
# 3) Compute sample and prompt logprobs for request,
# if required.
# 3) Compute sample and prompt logprobs for request, if required.
req_state
.
logprobs_processor
.
update_from_output
(
engine_core_output
)
# 4) Create and handle RequestOutput objects.
...
...
@@ -315,7 +340,7 @@ class OutputProcessor:
new_token_ids
,
finish_reason
,
stop_reason
):
if
req_state
.
queue
is
not
None
:
# AsyncLLM: put into queue for handling by generate().
req_state
.
queue
.
put
_nowait
(
request_output
)
req_state
.
queue
.
put
(
request_output
)
else
:
# LLMEngine: return list of RequestOutputs.
request_outputs
.
append
(
request_output
)
...
...
vllm/v1/engine/processor.py
View file @
31f6b24f
...
...
@@ -4,7 +4,6 @@ import time
from
collections.abc
import
Mapping
from
typing
import
Optional
,
Union
import
vllm.platforms
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
InputRegistry
,
ProcessorInputs
,
PromptType
,
SingletonInputsAdapter
)
...
...
@@ -20,7 +19,10 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from
vllm.sampling_params
import
SamplingParams
from
vllm.transformers_utils.tokenizer_group
import
BaseTokenizerGroup
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.structured_output.utils
import
validate_structured_output_request
from
vllm.v1.structured_output.backend_guidance
import
(
validate_guidance_grammar
)
from
vllm.v1.structured_output.utils
import
(
validate_structured_output_request_xgrammar
)
class
Processor
:
...
...
@@ -120,7 +122,9 @@ class Processor:
if
not
params
.
guided_decoding
or
not
self
.
decoding_config
:
return
supported_backends
=
[
"xgrammar"
,
"xgrammar:disable-any-whitespace"
]
supported_backends
=
[
"xgrammar"
,
"xgrammar:disable-any-whitespace"
,
"guidance"
,
"auto"
]
engine_level_backend
=
self
.
decoding_config
.
guided_decoding_backend
if
engine_level_backend
not
in
supported_backends
:
raise
ValueError
(
f
"Only
{
supported_backends
}
structured output is "
...
...
@@ -134,10 +138,31 @@ class Processor:
else
:
params
.
guided_decoding
.
backend
=
engine_level_backend
if
vllm
.
platforms
.
current_platform
.
is_tpu
():
raise
ValueError
(
"Structured output is not supported on TPU."
)
validate_structured_output_request
(
params
)
# Request content validation
if
engine_level_backend
==
"xgrammar"
:
# xgrammar with no fallback
validate_structured_output_request_xgrammar
(
params
)
params
.
guided_decoding
.
backend
=
"xgrammar"
elif
engine_level_backend
==
"auto"
:
# "auto" is an opt-in to opinionated behavior where we try to
# choose a backend based on request contents. This is not the
# default as it is less predictable and subject to change
# between releases as feature support changes.
try
:
validate_structured_output_request_xgrammar
(
params
)
params
.
guided_decoding
.
backend
=
"xgrammar"
except
ValueError
:
# The request includes some jsonschema feature(s) that
# are not supported in xgrammar. Fall back to guidance.
params
.
guided_decoding
.
backend
=
"guidance"
if
params
.
guided_decoding
.
backend
==
"guidance"
:
# TODO ideally we would have the LLTokenizer here as Lark syntax
# allows <|special_token|> and similar, see
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
# Without tokenizer these are disallowed in grammars.
validate_guidance_grammar
(
params
,
tokenizer
=
None
)
def
process_inputs
(
self
,
...
...
vllm/v1/metrics/stats.py
View file @
31f6b24f
...
...
@@ -100,15 +100,8 @@ class IterationStats:
num_new_generation_tokens
=
len
(
output
.
new_token_ids
)
self
.
num_generation_tokens
+=
num_new_generation_tokens
if
is_prefilling
and
num_new_generation_tokens
>
0
:
# TODO(andy): we used to assert that num_new_generation_tokens
# > 0 with an invariant that EngineCore does not stream outputs
# for partially completed prefills (scheduler.update_from_output
# makes EngineCoreOutput iff num_computed_tokens == num_tokens).
# When prompt logprobs are enabled, we currently stream out the
# partially completed prompt.
# This will be reverted in a follow up PR and we should re-enable
# this assertion / invariant.
if
is_prefilling
:
assert
num_new_generation_tokens
>
0
self
.
num_prompt_tokens
+=
prompt_len
first_token_latency
=
self
.
_time_since
(
req_stats
.
arrival_time
)
...
...
@@ -123,16 +116,12 @@ class IterationStats:
# Process the batch-level "new tokens" engine core event
if
is_prefilling
:
# TODO: re-enable no-output-for-partial-prefills invariant as above
if
num_new_generation_tokens
>
0
:
req_stats
.
first_token_ts
=
engine_core_timestamp
req_stats
.
first_token_ts
=
engine_core_timestamp
else
:
tpot
=
engine_core_timestamp
-
req_stats
.
last_token_ts
self
.
time_per_output_tokens_iter
.
append
(
tpot
)
# TODO: re-enable no-output-for-partial-prefills invariant as above
if
num_new_generation_tokens
>
0
:
req_stats
.
last_token_ts
=
engine_core_timestamp
req_stats
.
last_token_ts
=
engine_core_timestamp
def
update_from_events
(
self
,
req_id
:
str
,
events
:
list
[
"EngineCoreEvent"
],
is_prefilling
:
bool
,
req_stats
:
RequestStateStats
,
...
...
vllm/v1/outputs.py
View file @
31f6b24f
...
...
@@ -39,6 +39,25 @@ class LogprobsTensors(NamedTuple):
self
.
selected_token_ranks
.
tolist
(),
)
@
staticmethod
def
empty_cpu
(
num_positions
:
int
,
num_tokens_per_position
:
int
)
->
"LogprobsTensors"
:
"""Create empty LogprobsTensors on CPU."""
logprob_token_ids
=
torch
.
empty
(
(
num_positions
,
num_tokens_per_position
),
dtype
=
torch
.
int32
,
device
=
"cpu"
)
logprobs
=
torch
.
empty_like
(
logprob_token_ids
,
dtype
=
torch
.
float32
)
selected_token_ranks
=
torch
.
empty
(
num_positions
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
return
LogprobsTensors
(
logprob_token_ids
=
logprob_token_ids
,
logprobs
=
logprobs
,
selected_token_ranks
=
selected_token_ranks
,
)
@
dataclass
class
SamplerOutput
:
...
...
vllm/v1/sample/ops/utils.py
deleted
100644 → 0
View file @
89d1dd57
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Union
import
torch
def
compiled_softmax
(
logits
:
torch
.
Tensor
,
temperature
:
Union
[
float
,
torch
.
Tensor
]
=
1.0
,
)
->
torch
.
Tensor
:
"""Faster softmax kernel generated by torch.compile.
Args:
logits: [n, vocab_size]
temperature: [n] or float
"""
# NOTE(woosuk): Avoid recompilation by marking the first dim as dynamic.
torch
.
_dynamo
.
mark_dynamic
(
logits
,
index
=
0
)
if
isinstance
(
temperature
,
torch
.
Tensor
):
torch
.
_dynamo
.
mark_dynamic
(
temperature
,
index
=
0
)
return
_softmax
(
logits
,
temperature
)
@
torch
.
compile
def
_softmax
(
logits
:
torch
.
Tensor
,
temperature
:
Union
[
float
,
torch
.
Tensor
],
)
->
torch
.
Tensor
:
logits
=
logits
/
temperature
return
torch
.
softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
vllm/v1/sample/rejection_sampler.py
View file @
31f6b24f
...
...
@@ -8,7 +8,7 @@ import triton.language as tl
from
vllm.logger
import
init_logger
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.sample.ops.
utils
import
compiled_softmax
from
vllm.v1.sample.ops.
topk_topp_sampler
import
apply_top_k_top_p
from
vllm.v1.spec_decode.metadata
import
SpecDecodeMetadata
logger
=
init_logger
(
__name__
)
...
...
@@ -67,6 +67,7 @@ class RejectionSampler(nn.Module):
Shape is [num_tokens, vocab_size]. Here, probabilities from
different requests are flattened into a single tensor because
this is the shape of the output logits.
NOTE: `target_logits` can be updated in place to save memory.
bonus_token_ids_tensor (torch.Tensor):
A tensor containing bonus tokens. Shape is [batch_size, 1].
Bonus tokens are added to the end of the sequence if all
...
...
@@ -83,6 +84,8 @@ class RejectionSampler(nn.Module):
'''
assert
metadata
.
max_spec_len
<=
MAX_SPEC_LEN
# [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the
# `compute_probs` function.
target_probs
=
compute_probs
(
target_logits
,
metadata
.
cu_num_draft_tokens
,
...
...
@@ -245,25 +248,80 @@ def compute_probs(
return
logits
num_tokens
=
logits
.
shape
[
0
]
batch_size
=
cu_num_draft_tokens
.
shape
[
0
]
expanded_temperature
=
torch
.
empty
(
(
num_tokens
,
1
),
dtype
=
torch
.
float32
,
device
=
logits
.
device
,
)
expand_kernel
[(
batch_size
,
)](
expanded_temperature
,
temperature
=
expand_batch_to_tokens
(
sampling_metadata
.
temperature
,
cu_num_draft_tokens
,
GREEDY_TEMPERATURE
,
# replace_from
1
,
# replace_to
MAX_NUM_TOKENS
=
MAX_SPEC_LEN
,
num_warps
=
1
,
num_tokens
,
replace_from
=
GREEDY_TEMPERATURE
,
replace_to
=
1
,
)
output_prob
=
compiled_softmax
(
logits
,
expanded_temperature
)
# NOTE(woosuk): Update `logits` in place to avoid allocating a new tensor.
logits
.
div_
(
temperature
.
unsqueeze
(
-
1
))
# Get expanded top_k and top_p tensors.
top_k
=
None
if
sampling_metadata
.
top_k
is
not
None
:
top_k
=
expand_batch_to_tokens
(
sampling_metadata
.
top_k
,
cu_num_draft_tokens
,
num_tokens
,
)
top_p
=
None
if
sampling_metadata
.
top_p
is
not
None
:
top_p
=
expand_batch_to_tokens
(
sampling_metadata
.
top_p
,
cu_num_draft_tokens
,
num_tokens
,
)
# NOTE(woosuk): `apply_top_k_top_p` uses sorting to calculate the mask,
# which is slow for large vocab sizes. This may cause performance issues.
logits
=
apply_top_k_top_p
(
logits
,
top_k
,
top_p
)
output_prob
=
logits
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
)
return
output_prob
def
expand_batch_to_tokens
(
x
:
torch
.
Tensor
,
# [batch_size]
cu_num_tokens
:
torch
.
Tensor
,
# [batch_size]
num_tokens
:
int
,
replace_from
:
int
=
0
,
replace_to
:
int
=
0
,
)
->
torch
.
Tensor
:
"""Expand [batch_size] tensor to [num_tokens] tensor based on the number of
tokens per batch in cu_num_tokens.
For example, if x = [a, b, c] and cu_num_tokens = [2, 5, 6], then
num_tokens = 6, and expanded_x = [a, a, b, b, b, c].
Args:
x: [batch_size] tensor to expand.
cu_num_tokens: [batch_size] tensor containing the cumulative number of
tokens per batch. Each element represents the total number of
tokens up to and including that batch.
num_tokens: Total number of tokens.
replace_from: int = 0
Value to be replaced if it is found in x.
replace_to: int = 0
Value to replace with when replace_from is found.
Returns:
expanded_x: [num_tokens] tensor.
"""
batch_size
=
x
.
shape
[
0
]
assert
cu_num_tokens
.
shape
[
0
]
==
batch_size
expanded_x
=
x
.
new_empty
(
num_tokens
)
expand_kernel
[(
batch_size
,
)](
expanded_x
,
x
,
cu_num_tokens
,
replace_from
,
replace_to
,
MAX_NUM_TOKENS
=
MAX_SPEC_LEN
,
# To avoid recompilation.
num_warps
=
1
,
)
return
expanded_x
def
generate_uniform_probs
(
num_tokens
:
int
,
num_draft_tokens
:
list
[
int
],
...
...
vllm/v1/sample/sampler.py
View file @
31f6b24f
...
...
@@ -137,7 +137,7 @@ class Sampler(nn.Module):
Gather logprobs for topk and sampled/prompt token.
Args:
log
it
s: (num tokens) x (vocab) tensor
log
prob
s: (num tokens) x (vocab) tensor
num_logprobs: minimum number of logprobs to
retain per token
token_ids: prompt tokens (if prompt logprobs)
...
...
vllm/v1/spec_decode/utils.py
View file @
31f6b24f
...
...
@@ -3,10 +3,7 @@ from vllm.v1.worker.gpu_input_batch import InputBatch
def
is_spec_decode_supported
(
req_id
:
str
,
input_batch
:
InputBatch
)
->
bool
:
if
req_id
in
input_batch
.
top_k_reqs
or
req_id
in
input_batch
.
top_p_reqs
:
# Spec decode doesn't support top_p/top_k sampling.
return
False
elif
req_id
in
input_batch
.
min_p_reqs
:
if
req_id
in
input_batch
.
min_p_reqs
:
# Spec decode doesn't support min_p sampling.
return
False
elif
(
req_id
in
input_batch
.
frequency_penalties_reqs
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment