Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cfabf125
Commit
cfabf125
authored
Aug 27, 2025
by
王敏
Browse files
Merge remote-tracking branch 'origin/v0.9.2-dev' into v0.9.2-dev
parents
dbd0bda6
645fcfd9
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
438 additions
and
1632 deletions
+438
-1632
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+6
-1
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+0
-1544
vllm/two_batch_overlap/two_batch_overlap.py
vllm/two_batch_overlap/two_batch_overlap.py
+2
-1
vllm/two_batch_overlap/v1/model_input_split_v1.py
vllm/two_batch_overlap/v1/model_input_split_v1.py
+38
-15
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
+4
-8
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-1
vllm/zero_overhead/v1/eagle.py
vllm/zero_overhead/v1/eagle.py
+317
-0
vllm/zero_overhead/v1/gpu_model_runner.py
vllm/zero_overhead/v1/gpu_model_runner.py
+70
-62
No files found.
vllm/compilation/decorators.py
View file @
cfabf125
...
@@ -9,9 +9,10 @@ import torch
...
@@ -9,9 +9,10 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch._dynamo.symbolic_convert
import
InliningInstructionTranslator
from
torch._dynamo.symbolic_convert
import
InliningInstructionTranslator
from
vllm
import
envs
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.forward_context
import
get_profilling
from
vllm.forward_context
import
get_forward_context
,
get_profilling
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -170,6 +171,10 @@ def _support_torch_compile(
...
@@ -170,6 +171,10 @@ def _support_torch_compile(
# torch.compiler.is_compiling() means we are inside the compilation
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
# need to compile the model inside.
skip_cuda_graphs
=
get_forward_context
().
skip_cuda_graphs
if
envs
.
VLLM_ENABLE_TBO
and
skip_cuda_graphs
:
return
self
.
forward
(
*
args
,
**
kwargs
)
if
self
.
do_not_compile
or
torch
.
compiler
.
is_compiling
()
or
get_profilling
():
if
self
.
do_not_compile
or
torch
.
compiler
.
is_compiling
()
or
get_profilling
():
return
self
.
forward
(
*
args
,
**
kwargs
)
return
self
.
forward
(
*
args
,
**
kwargs
)
...
...
vllm/model_executor/model_loader/loader.py
deleted
100644 → 0
View file @
dbd0bda6
# SPDX-License-Identifier: Apache-2.0
# ruff: noqa: SIM117
import
collections
import
copy
import
dataclasses
import
fnmatch
import
glob
import
inspect
import
itertools
import
math
import
os
import
time
import
warnings
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
typing
import
(
Any
,
Callable
,
Dict
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
cast
)
import
gguf
import
huggingface_hub
import
numpy
as
np
import
torch
from
huggingface_hub
import
HfApi
from
torch
import
nn
from
transformers
import
AutoModelForCausalLM
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm.attention
import
Attention
from
vllm.config
import
(
LoadConfig
,
LoadFormat
,
ModelConfig
,
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
)
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.logger
import
init_logger
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
MergedColumnParallelLinear
,
QKVCrossParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
# yapf: enable
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizeMethodBase
)
from
vllm.model_executor.model_loader.tensorizer
import
(
TensorizerConfig
,
is_vllm_tensorized
,
load_with_tensorizer
,
serialize_vllm_model
,
tensorizer_weights_iterator
)
from
vllm.model_executor.model_loader.utils
import
(
ParamMapping
,
configure_quant_config
,
get_model_architecture
,
set_default_torch_dtype
)
from
vllm.model_executor.model_loader.weight_utils
import
(
download_safetensors_index_file_from_hf
,
download_weights_from_hf
,
fastsafetensors_weights_iterator
,
filter_duplicate_safetensors_files
,
filter_files_not_needed_for_inference
,
get_gguf_extra_tensor_names
,
get_lock
,
gguf_quant_weights_iterator
,
initialize_dummy_weights
,
np_cache_weights_iterator
,
pt_weights_iterator
,
runai_safetensors_weights_iterator
,
safetensors_weights_iterator
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.transformers_utils.s3_utils
import
glob
as
s3_glob
from
vllm.transformers_utils.utils
import
is_s3
from
vllm.utils
import
is_pin_memory_available
@
contextmanager
def
device_loading_context
(
module
:
torch
.
nn
.
Module
,
target_device
:
torch
.
device
):
if
target_device
.
type
==
"cpu"
:
# If target is CPU, no need to move anything
yield
module
return
original_device_states
:
Dict
[
str
,
torch
.
device
]
=
{}
# Store original device states and move parameters to GPU if they're on CPU
for
name
,
p
in
module
.
named_parameters
():
if
p
.
device
.
type
==
"cpu"
:
original_device_states
[
name
]
=
p
.
device
p
.
data
=
p
.
data
.
to
(
target_device
)
# Parameters already on target device are not touched
try
:
yield
module
finally
:
# Restore parameters to their original devices, ignoring new parameters
pin_memory
=
is_pin_memory_available
()
for
name
,
p
in
module
.
named_parameters
():
if
name
in
original_device_states
:
original_device
:
torch
.
device
=
original_device_states
[
name
]
if
original_device
.
type
==
"cpu"
:
# `torch.empty_like` does not support `pin_memory` argument
cpu_data
=
torch
.
empty_strided
(
size
=
p
.
data
.
size
(),
stride
=
p
.
data
.
stride
(),
dtype
=
p
.
data
.
dtype
,
layout
=
p
.
data
.
layout
,
device
=
"cpu"
,
pin_memory
=
pin_memory
,
)
cpu_data
.
copy_
(
p
.
data
)
p
.
data
=
cpu_data
else
:
p
.
data
=
p
.
data
.
to
(
original_device
)
# New parameters or parameters already on target device are untouched
logger
=
init_logger
(
__name__
)
def
_initialize_model
(
vllm_config
:
VllmConfig
,
*
,
prefix
:
str
=
""
,
model_class
:
Optional
[
type
[
nn
.
Module
]]
=
None
,
)
->
nn
.
Module
:
"""Initialize a model with the given configurations."""
model_config
=
vllm_config
.
model_config
if
model_class
is
None
:
model_class
,
_
=
get_model_architecture
(
model_config
)
if
vllm_config
.
quant_config
is
not
None
:
configure_quant_config
(
vllm_config
.
quant_config
,
model_class
)
signatures
=
inspect
.
signature
(
model_class
.
__init__
)
all_params
=
[
param
.
name
for
param
in
signatures
.
parameters
.
values
()]
if
"vllm_config"
in
all_params
and
"prefix"
in
all_params
:
# new-style model class
with
set_current_vllm_config
(
vllm_config
,
check_compile
=
True
):
return
model_class
(
vllm_config
=
vllm_config
,
prefix
=
prefix
)
msg
=
(
"vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class"
" registered from out of tree and it is used for new vLLM version. "
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
"for the design and update the model class accordingly."
)
warnings
.
warn
(
msg
,
DeprecationWarning
,
stacklevel
=
2
)
logger
.
warning
(
"Trying to guess the arguments for old-style model class %s"
,
model_class
,
)
# try to be compatible with old-style model class
kwargs
=
{}
if
"prefix"
in
all_params
:
kwargs
[
"prefix"
]
=
prefix
if
"config"
in
all_params
:
kwargs
[
"config"
]
=
model_config
.
hf_config
if
"cache_config"
in
all_params
:
kwargs
[
"cache_config"
]
=
vllm_config
.
cache_config
if
"quant_config"
in
all_params
:
kwargs
[
"quant_config"
]
=
vllm_config
.
quant_config
if
"lora_config"
in
all_params
:
kwargs
[
"lora_config"
]
=
vllm_config
.
lora_config
if
"scheduler_config"
in
all_params
:
kwargs
[
"scheduler_config"
]
=
vllm_config
.
scheduler_config
if
"parallel_config"
in
all_params
:
kwargs
[
"parallel_config"
]
=
vllm_config
.
parallel_config
with
set_current_vllm_config
(
vllm_config
,
check_compile
=
True
):
return
model_class
(
**
kwargs
)
def
_process_weights_after_loading
(
model
:
nn
.
Module
,
model_config
:
ModelConfig
,
target_device
:
torch
.
device
)
->
None
:
for
_
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
QKVCrossParallelLinear
):
# NOTE(Isotr0py): special case for cross QKV layer because
# q and kv proj aren't registered as submodules intentionally
module
.
process_weights_after_loading
()
continue
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
isinstance
(
quant_method
,
QuantizeMethodBase
):
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with
device_loading_context
(
module
,
target_device
):
quant_method
.
process_weights_after_loading
(
module
)
# Currently only used by MLA.
# NOTE: This intentionally happens after other modules so we can easily
# decompress the weights for MLA.
for
_
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
Attention
)
and
\
hasattr
(
module
,
"process_weights_after_loading"
):
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module
.
process_weights_after_loading
(
model_config
.
dtype
)
class
BaseModelLoader
(
ABC
):
"""Base class for model loaders."""
def
__init__
(
self
,
load_config
:
LoadConfig
):
self
.
load_config
=
load_config
@
abstractmethod
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
"""Download a model so that it can be immediately loaded."""
raise
NotImplementedError
@
abstractmethod
def
load_model
(
self
,
*
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
"""Load a model with the given configurations."""
raise
NotImplementedError
class
DefaultModelLoader
(
BaseModelLoader
):
"""Model loader that can load different file types from disk."""
@
dataclasses
.
dataclass
class
Source
:
"""A source for weights."""
model_or_path
:
str
"""The model ID or path."""
revision
:
Optional
[
str
]
"""The optional model revision."""
prefix
:
str
=
""
"""A prefix to prepend to all weights."""
fall_back_to_pt
:
bool
=
True
"""Whether .pt weights can be used."""
allow_patterns_overrides
:
Optional
[
list
[
str
]]
=
None
"""If defined, weights will load exclusively using these patterns."""
counter_before_loading_weights
:
float
=
0.0
counter_after_loading_weights
:
float
=
0.0
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
load_config
.
model_loader_extra_config
:
raise
ValueError
(
f
"Model loader extra config is not supported for "
f
"load format
{
load_config
.
load_format
}
"
)
def
_maybe_download_from_modelscope
(
self
,
model
:
str
,
revision
:
Optional
[
str
])
->
Optional
[
str
]:
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
Returns the path to the downloaded model, or None if the model is not
downloaded from ModelScope."""
if
VLLM_USE_MODELSCOPE
:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
from
modelscope.hub.snapshot_download
import
snapshot_download
if
not
os
.
path
.
exists
(
model
):
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with
get_lock
(
model
,
self
.
load_config
.
download_dir
):
model_path
=
snapshot_download
(
model_id
=
model
,
cache_dir
=
self
.
load_config
.
download_dir
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
revision
=
revision
,
ignore_file_pattern
=
self
.
load_config
.
ignore_patterns
,
)
else
:
model_path
=
model
return
model_path
return
None
def
_prepare_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
fall_back_to_pt
:
bool
,
allow_patterns_overrides
:
Optional
[
list
[
str
]],
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
model_name_or_path
=
(
self
.
_maybe_download_from_modelscope
(
model_name_or_path
,
revision
)
or
model_name_or_path
)
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
load_format
=
self
.
load_config
.
load_format
use_safetensors
=
False
index_file
=
SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights.
if
load_format
==
LoadFormat
.
AUTO
:
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
elif
(
load_format
==
LoadFormat
.
SAFETENSORS
or
load_format
==
LoadFormat
.
FASTSAFETENSORS
):
use_safetensors
=
True
allow_patterns
=
[
"*.safetensors"
]
elif
load_format
==
LoadFormat
.
MISTRAL
:
use_safetensors
=
True
allow_patterns
=
[
"consolidated*.safetensors"
]
index_file
=
"consolidated.safetensors.index.json"
elif
load_format
==
LoadFormat
.
PT
:
allow_patterns
=
[
"*.pt"
]
elif
load_format
==
LoadFormat
.
NPCACHE
:
allow_patterns
=
[
"*.bin"
]
else
:
raise
ValueError
(
f
"Unknown load_format:
{
load_format
}
"
)
if
fall_back_to_pt
:
allow_patterns
+=
[
"*.pt"
]
if
allow_patterns_overrides
is
not
None
:
allow_patterns
=
allow_patterns_overrides
if
not
is_local
:
hf_folder
=
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
allow_patterns
,
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
else
:
hf_folder
=
model_name_or_path
hf_weights_files
:
List
[
str
]
=
[]
for
pattern
in
allow_patterns
:
hf_weights_files
+=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
))
if
len
(
hf_weights_files
)
>
0
:
if
pattern
==
"*.safetensors"
:
use_safetensors
=
True
break
if
use_safetensors
:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if
not
is_local
:
download_safetensors_index_file_from_hf
(
model_name_or_path
,
index_file
,
self
.
load_config
.
download_dir
,
revision
,
)
hf_weights_files
=
filter_duplicate_safetensors_files
(
hf_weights_files
,
hf_folder
,
index_file
)
else
:
hf_weights_files
=
filter_files_not_needed_for_inference
(
hf_weights_files
)
if
len
(
hf_weights_files
)
==
0
:
raise
RuntimeError
(
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
return
hf_folder
,
hf_weights_files
,
use_safetensors
def
_get_weights_iterator
(
self
,
source
:
"Source"
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Get an iterator for the model weights based on the load format."""
hf_folder
,
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
source
.
model_or_path
,
source
.
revision
,
source
.
fall_back_to_pt
,
source
.
allow_patterns_overrides
)
if
self
.
load_config
.
load_format
==
LoadFormat
.
NPCACHE
:
# Currently np_cache only support *.bin checkpoints
assert
use_safetensors
is
False
weights_iterator
=
np_cache_weights_iterator
(
source
.
model_or_path
,
self
.
load_config
.
download_dir
,
hf_folder
,
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
)
elif
use_safetensors
:
if
self
.
load_config
.
load_format
==
LoadFormat
.
FASTSAFETENSORS
:
weights_iterator
=
fastsafetensors_weights_iterator
(
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
)
else
:
weights_iterator
=
safetensors_weights_iterator
(
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
)
else
:
weights_iterator
=
pt_weights_iterator
(
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
)
if
current_platform
.
is_tpu
():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
# not too many ops are accumulated in the XLA program.
import
torch_xla.core.xla_model
as
xm
def
_xla_weights_iterator
(
iterator
:
Generator
):
for
weights
in
iterator
:
yield
weights
xm
.
mark_step
()
weights_iterator
=
_xla_weights_iterator
(
weights_iterator
)
elif
current_platform
.
is_hpu
():
import
habana_frameworks.torch.core
as
htcore
def
_hpu_weights_iterator
(
iterator
:
Generator
):
for
weights
in
iterator
:
yield
weights
htcore
.
mark_step
()
weights_iterator
=
_hpu_weights_iterator
(
weights_iterator
)
if
self
.
counter_before_loading_weights
==
0.0
:
self
.
counter_before_loading_weights
=
time
.
perf_counter
()
# Apply the prefix.
return
((
source
.
prefix
+
name
,
tensor
)
for
(
name
,
tensor
)
in
weights_iterator
)
def
get_all_weights
(
self
,
model_config
:
ModelConfig
,
model
:
nn
.
Module
,
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
primary_weights
=
DefaultModelLoader
.
Source
(
model_config
.
model
,
model_config
.
revision
,
prefix
=
""
,
fall_back_to_pt
=
getattr
(
model
,
"fall_back_to_pt_during_load"
,
True
),
allow_patterns_overrides
=
getattr
(
model
,
"allow_patterns_overrides"
,
None
),
)
yield
from
self
.
_get_weights_iterator
(
primary_weights
)
secondary_weights
=
cast
(
Iterable
[
DefaultModelLoader
.
Source
],
getattr
(
model
,
"secondary_weights"
,
()),
)
for
source
in
secondary_weights
:
yield
from
self
.
_get_weights_iterator
(
source
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
,
fall_back_to_pt
=
True
,
allow_patterns_overrides
=
None
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
model
=
_initialize_model
(
vllm_config
=
vllm_config
)
weights_to_load
=
{
name
for
name
,
_
in
model
.
named_parameters
()}
loaded_weights
=
model
.
load_weights
(
self
.
get_all_weights
(
model_config
,
model
))
self
.
counter_after_loading_weights
=
time
.
perf_counter
()
logger
.
info
(
"Loading weights took %.2f seconds"
,
self
.
counter_after_loading_weights
-
self
.
counter_before_loading_weights
)
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
if
model_config
.
quantization
is
None
and
loaded_weights
is
not
None
:
weights_not_loaded
=
weights_to_load
-
loaded_weights
if
weights_not_loaded
:
raise
ValueError
(
"Following weights were not initialized from "
f
"checkpoint:
{
weights_not_loaded
}
"
)
_process_weights_after_loading
(
model
,
model_config
,
target_device
)
return
model
.
eval
()
class
DummyModelLoader
(
BaseModelLoader
):
"""Model loader that will set model weights to random values."""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
load_config
.
model_loader_extra_config
:
raise
ValueError
(
f
"Model loader extra config is not supported for "
f
"load format
{
load_config
.
load_format
}
"
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
pass
# Nothing to download
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
model
=
_initialize_model
(
vllm_config
=
vllm_config
)
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights
(
model
)
_process_weights_after_loading
(
model
,
model_config
,
target_device
)
return
model
.
eval
()
class
TensorizerLoader
(
BaseModelLoader
):
"""Model loader using CoreWeave's tensorizer library."""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
isinstance
(
load_config
.
model_loader_extra_config
,
TensorizerConfig
):
self
.
tensorizer_config
=
load_config
.
model_loader_extra_config
else
:
self
.
tensorizer_config
=
TensorizerConfig
(
**
load_config
.
model_loader_extra_config
)
def
_verify_config
(
self
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
):
self
.
tensorizer_config
.
verify_with_model_config
(
model_config
)
self
.
tensorizer_config
.
verify_with_parallel_config
(
parallel_config
)
def
_get_weights_iterator
(
self
,
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
tensorizer_args
=
self
.
tensorizer_config
.
_construct_tensorizer_args
()
return
tensorizer_weights_iterator
(
tensorizer_args
)
def
_load_model_serialized_cpu
(
self
,
vllm_config
:
VllmConfig
,
)
->
nn
.
Module
:
"""Load a serialized model with tensorizer to the CPU.
This is only necessary when the model isn't vLLM-tensorized (see
examples/other/tensorize_vllm_model.py) This should still
be faster than default HuggingFace loading, but will be slower than
loading a vLLM-tensorized model.
"""
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
vllm_config
=
vllm_config
)
model
.
load_weights
(
self
.
_get_weights_iterator
())
return
model
.
eval
()
def
_load_model_serialized
(
self
,
vllm_config
:
VllmConfig
,
)
->
nn
.
Module
:
"""Load a serialized model with tensorizer.
Expects a vLLM-tensorized model. See the
examples/other/tensorize_vllm_model.py example script
for serializing vLLM models."""
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model_class
=
get_model_architecture
(
model_config
)[
0
]
tensorizer_config
=
copy
.
copy
(
self
.
tensorizer_config
)
tensorizer_config
.
model_class
=
model_class
tensorizer_config
.
hf_config
=
model_config
.
hf_config
tensorizer_config
.
dtype
=
model_config
.
dtype
model
=
load_with_tensorizer
(
tensorizer_config
,
vllm_config
=
vllm_config
)
return
model
.
eval
()
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
tensorizer_config
.
verify_with_model_config
(
model_config
)
with
self
.
tensorizer_config
.
open_stream
():
pass
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
model_config
=
vllm_config
.
model_config
parallel_config
=
vllm_config
.
parallel_config
self
.
_verify_config
(
model_config
,
parallel_config
)
if
parallel_config
.
tensor_parallel_size
>
1
:
from
vllm.distributed
import
get_tensor_model_parallel_rank
self
.
tensorizer_config
.
tensorizer_uri
=
(
self
.
tensorizer_config
.
tensorizer_uri
%
get_tensor_model_parallel_rank
())
if
is_vllm_tensorized
(
self
.
tensorizer_config
):
return
self
.
_load_model_serialized
(
vllm_config
=
vllm_config
)
return
self
.
_load_model_serialized_cpu
(
vllm_config
=
vllm_config
)
@
staticmethod
def
save_model
(
model
:
torch
.
nn
.
Module
,
tensorizer_config
:
TensorizerConfig
,
)
->
None
:
serialize_vllm_model
(
model
=
model
,
tensorizer_config
=
tensorizer_config
,
)
class
ShardedStateLoader
(
BaseModelLoader
):
"""
Model loader that directly loads each worker's model state dict, which
enables a fast load path for large tensor-parallel models where each worker
only needs to read its own shard rather than the entire checkpoint. See
`examples/offline_inference/save_sharded_state.py` for creating a sharded
checkpoint.
"""
DEFAULT_PATTERN
=
"model-rank-{rank}-part-{part}.safetensors"
def
__init__
(
self
,
load_config
:
LoadConfig
,
runai_model_streamer
:
bool
=
False
):
super
().
__init__
(
load_config
)
self
.
runai_model_streamer
=
runai_model_streamer
extra_config
=
({}
if
load_config
.
model_loader_extra_config
is
None
else
load_config
.
model_loader_extra_config
.
copy
())
self
.
pattern
=
extra_config
.
pop
(
"pattern"
,
self
.
DEFAULT_PATTERN
)
if
extra_config
:
raise
ValueError
(
f
"Unexpected extra config keys for load format "
f
"
{
load_config
.
load_format
}
: "
f
"
{
load_config
.
model_loader_extra_config
.
keys
()
}
"
)
@
staticmethod
def
_filter_subtensors
(
tensors
:
Dict
[
str
,
torch
.
Tensor
],
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""
Filter out all tensors that share the same memory or a subset of the
memory of another tensor.
"""
same_storage_groups
:
Dict
[
Any
,
List
[
Tuple
[
str
,
torch
.
Tensor
]]]
=
(
collections
.
defaultdict
(
list
))
for
key
,
tensor
in
tensors
.
items
():
if
tensor
.
numel
():
ptr
=
tensor
.
untyped_storage
().
data_ptr
()
same_storage_groups
[
tensor
.
device
,
ptr
].
append
((
key
,
tensor
))
def
get_end_ptr
(
tensor
:
torch
.
Tensor
)
->
int
:
return
tensor
.
view
(
-
1
)[
-
1
].
data_ptr
()
+
tensor
.
element_size
()
result
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
for
group
in
same_storage_groups
.
values
():
for
k
,
t
in
group
:
a
,
b
=
t
.
data_ptr
(),
get_end_ptr
(
t
)
for
k2
,
t2
in
group
:
if
not
t2
.
is_contiguous
():
continue
a2
,
b2
=
t2
.
data_ptr
(),
get_end_ptr
(
t2
)
if
a
<
a2
or
b2
<
b
:
continue
if
a2
<
a
or
b
<
b2
or
not
t
.
is_contiguous
():
break
# t2 covers strictly more memory than t.
if
k2
<
k
:
# Same tensors, keep the one with the smaller key.
break
else
:
result
[
k
]
=
t
return
result
def
_prepare_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
]):
if
is_s3
(
model_name_or_path
)
or
os
.
path
.
isdir
(
model_name_or_path
):
return
model_name_or_path
else
:
allow_patterns
=
[
"*.safetensors"
]
return
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
allow_patterns
,
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
from
vllm.distributed
import
get_tensor_model_parallel_rank
model_weights
=
model_config
.
model
if
hasattr
(
model_config
,
"model_weights"
):
model_weights
=
model_config
.
model_weights
local_model_path
=
model_weights
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
model
=
_initialize_model
(
vllm_config
=
vllm_config
)
_process_weights_after_loading
(
model
,
model_config
,
target_device
)
rank
=
get_tensor_model_parallel_rank
()
pattern
=
os
.
path
.
join
(
local_model_path
,
self
.
pattern
.
format
(
rank
=
rank
,
part
=
"*"
),
)
filepaths
=
[]
if
is_s3
(
local_model_path
):
file_pattern
=
f
"*
{
self
.
pattern
.
format
(
rank
=
rank
,
part
=
' * '
)
}
"
filepaths
=
s3_glob
(
path
=
local_model_path
,
allow_pattern
=
[
file_pattern
])
else
:
filepaths
=
glob
.
glob
(
pattern
)
if
not
filepaths
:
# TODO: support un-sharded checkpoints too
raise
ValueError
(
f
"Could not find checkpoint files '
{
pattern
}
', only "
f
"pre-sharded checkpoints are currently supported!"
)
state_dict
=
self
.
_filter_subtensors
(
model
.
state_dict
())
for
key
,
tensor
in
self
.
iterate_over_files
(
filepaths
):
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data
=
state_dict
[
key
].
data
param_shape
=
state_dict
[
key
].
shape
for
dim
,
size
in
enumerate
(
tensor
.
shape
):
if
size
<
param_shape
[
dim
]:
param_data
=
param_data
.
narrow
(
dim
,
0
,
size
)
if
tensor
.
shape
!=
param_shape
:
logger
.
warning
(
"loading tensor of shape %s into "
"parameter '%s' of shape %s"
,
tensor
.
shape
,
key
,
param_shape
,
)
param_data
.
copy_
(
tensor
)
state_dict
.
pop
(
key
)
if
state_dict
:
raise
ValueError
(
f
"Missing keys
{
tuple
(
state_dict
)
}
in loaded state!"
)
return
model
.
eval
()
def
iterate_over_files
(
self
,
paths
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
if
self
.
runai_model_streamer
:
yield
from
runai_safetensors_weights_iterator
(
paths
,
True
)
else
:
from
safetensors.torch
import
safe_open
for
path
in
paths
:
with
safe_open
(
path
,
framework
=
"pt"
)
as
f
:
for
key
in
f
.
keys
():
# noqa: SIM118
tensor
=
f
.
get_tensor
(
key
)
yield
key
,
tensor
@
staticmethod
def
save_model
(
model
:
torch
.
nn
.
Module
,
path
:
str
,
pattern
:
Optional
[
str
]
=
None
,
max_size
:
Optional
[
int
]
=
None
,
)
->
None
:
from
safetensors.torch
import
save_file
from
vllm.distributed
import
get_tensor_model_parallel_rank
if
pattern
is
None
:
pattern
=
ShardedStateLoader
.
DEFAULT_PATTERN
rank
=
get_tensor_model_parallel_rank
()
part_idx
=
0
total_size
=
0
state_dict
=
ShardedStateLoader
.
_filter_subtensors
(
model
.
state_dict
())
state_dict_part
:
Dict
[
str
,
torch
.
Tensor
]
=
{}
for
key
,
tensor
in
state_dict
.
items
():
param_size
=
tensor
.
nelement
()
*
tensor
.
element_size
()
if
max_size
is
not
None
and
total_size
+
param_size
>
max_size
:
filename
=
pattern
.
format
(
rank
=
rank
,
part
=
part_idx
)
save_file
(
state_dict_part
,
os
.
path
.
join
(
path
,
filename
),
)
part_idx
+=
1
total_size
=
0
state_dict_part
=
{}
state_dict_part
[
key
]
=
tensor
total_size
+=
param_size
if
len
(
state_dict_part
)
>
0
:
filename
=
pattern
.
format
(
rank
=
rank
,
part
=
part_idx
)
save_file
(
state_dict_part
,
os
.
path
.
join
(
path
,
filename
),
)
class
BitsAndBytesModelLoader
(
BaseModelLoader
):
"""Model loader to load model weights with BitAndBytes quantization."""
possible_config_file_names
=
[
"adapter_config.json"
]
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
# Save the module names without sharding.
self
.
unsharded_weights_modules
:
List
[
str
]
=
[]
# Save the module names that are sharded by column.
self
.
column_sharded_weights_modules
:
List
[
str
]
=
[]
# Store all module names (from transformers) that support
# BNB quantization.
self
.
target_modules
:
List
[
str
]
=
[]
# mapping weight names from transformers to vllm.
self
.
weight_mapper
:
Callable
=
lambda
name
:
name
def
_get_weight_files
(
self
,
model_name_or_path
:
str
,
allowed_patterns
:
List
[
str
],
revision
:
Optional
[
str
]
=
None
,
)
->
Tuple
[
str
,
List
[
str
],
str
]:
"""Retrieve weight files. Download the files if necessary.
Return the weight files and the file pattern."""
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
if
is_local
:
for
pattern
in
allowed_patterns
:
weight_files
=
glob
.
glob
(
os
.
path
.
join
(
model_name_or_path
,
pattern
))
if
weight_files
:
return
model_name_or_path
,
weight_files
,
pattern
else
:
hf_api
=
HfApi
()
repo_files
=
hf_api
.
list_repo_files
(
repo_id
=
model_name_or_path
)
for
pattern
in
allowed_patterns
:
matching_files
=
fnmatch
.
filter
(
repo_files
,
pattern
)
if
matching_files
:
hf_folder
=
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
[
pattern
],
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
)
return
hf_folder
,
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
)),
pattern
raise
RuntimeError
(
f
"No model weights found in: `
{
model_name_or_path
}
`"
)
def
_prepare_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
])
->
Tuple
[
List
[
str
],
bool
]:
"""Prepare weight files for the model."""
allowed_patterns
=
[
"*.safetensors"
,
"*.bin"
,
"*.pt"
]
hf_folder
,
hf_weights_files
,
matched_pattern
=
self
.
_get_weight_files
(
model_name_or_path
,
allowed_patterns
,
revision
)
use_safetensors
=
matched_pattern
==
"*.safetensors"
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
index_file
=
SAFE_WEIGHTS_INDEX_NAME
if
use_safetensors
:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if
not
is_local
:
download_safetensors_index_file_from_hf
(
model_name_or_path
,
index_file
,
self
.
load_config
.
download_dir
,
revision
,
)
hf_weights_files
=
filter_duplicate_safetensors_files
(
hf_weights_files
,
hf_folder
,
index_file
)
else
:
hf_weights_files
=
filter_files_not_needed_for_inference
(
hf_weights_files
)
if
len
(
hf_weights_files
)
==
0
:
raise
RuntimeError
(
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
return
hf_weights_files
,
use_safetensors
def
_hf_weight_iter
(
self
,
hf_weights_files
,
use_safetensors
:
bool
):
if
use_safetensors
:
iterator
=
safetensors_weights_iterator
(
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
)
else
:
iterator
=
pt_weights_iterator
(
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
)
for
org_name
,
param
in
iterator
:
# mapping weight names from transformers to vllm while preserving
# original names.
mapped_name
=
self
.
weight_mapper
(
org_name
)
yield
org_name
,
mapped_name
,
param
def
_get_quantized_weights_iterator
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
],
pre_quant
:
bool
,
load_8bit
:
bool
,
)
->
Tuple
[
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
],
Dict
[
str
,
Any
]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
as well as the quantization state dictionary."""
# only load the bitsandbytes module when needed
try
:
import
bitsandbytes
if
bitsandbytes
.
__version__
<
"0.45.3"
:
raise
ImportError
(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.45.3."
)
except
ImportError
as
err
:
raise
ImportError
(
"Please install bitsandbytes>=0.45.3 via "
"`pip install bitsandbytes>=0.45.3` to use "
"bitsandbytes quantizer."
)
from
err
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
model_name_or_path
,
revision
)
quant_state_dict
:
Dict
[
str
,
Any
]
=
{}
if
pre_quant
:
if
load_8bit
:
return
self
.
_quantized_8bit_generator
(
hf_weights_files
,
use_safetensors
,
quant_state_dict
),
quant_state_dict
else
:
return
self
.
_quantized_4bit_generator
(
hf_weights_files
,
use_safetensors
,
quant_state_dict
),
quant_state_dict
return
self
.
_unquantized_generator
(
hf_weights_files
,
use_safetensors
,
quant_state_dict
),
quant_state_dict
def
_is_8bit_weight_name
(
self
,
weight_name
:
str
):
quantized_suffix
=
{
".scb"
,
".weight_format"
}
return
any
(
weight_name
.
lower
().
endswith
(
suffix
)
for
suffix
in
quantized_suffix
)
def
_is_4bit_weight_name
(
self
,
weight_name
:
str
):
quantized_suffix
=
{
"absmax"
,
"quant_map"
,
"nested_absmax"
,
"nested_quant_map"
,
"bitsandbytes"
,
}
suffix
=
weight_name
.
split
(
"."
)[
-
1
]
return
any
(
q_suffix
in
suffix
for
q_suffix
in
quantized_suffix
)
def
_quantized_8bit_generator
(
self
,
hf_weights_files
,
use_safetensors
,
quant_state_dict
)
->
Generator
:
for
(
org_weight_name
,
mapped_weight_name
,
weight_tensor
,
)
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
not
mapped_weight_name
.
lower
().
endswith
(
".scb"
):
continue
weight_key
=
mapped_weight_name
.
lower
().
replace
(
".scb"
,
".weight"
)
quant_state_dict
[
weight_key
]
=
weight_tensor
for
(
org_weight_name
,
mapped_weight_name
,
weight_tensor
,
)
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
self
.
_is_8bit_weight_name
(
mapped_weight_name
):
continue
if
mapped_weight_name
in
quant_state_dict
:
set_weight_attrs
(
weight_tensor
,
{
"load_in_8bit"
:
True
})
yield
org_weight_name
,
weight_tensor
else
:
yield
org_weight_name
,
weight_tensor
def
_quantized_4bit_generator
(
self
,
hf_weights_files
,
use_safetensors
,
quant_state_dict
)
->
Generator
:
from
bitsandbytes.functional
import
QuantState
# First iterate over all quant state weights
weight_iterator
=
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
)
temp_state_dict
=
{}
for
(
org_weight_name
,
mapped_weight_name
,
weight_tensor
,
)
in
weight_iterator
:
if
not
self
.
_is_4bit_weight_name
(
mapped_weight_name
):
continue
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
if
"quant_state.bitsandbytes"
in
mapped_weight_name
:
temp_state_dict
[
mapped_weight_name
]
=
weight_tensor
.
cpu
().
data
else
:
temp_state_dict
[
mapped_weight_name
]
=
weight_tensor
# Closure to parse quant_state for each prequant weight
def
_parse_quant_state
(
param_name
:
str
,
temp_state_dict
:
Dict
)
->
QuantState
:
quant_state
=
{}
for
k
in
temp_state_dict
:
if
param_name
+
"."
in
k
:
quant_state
[
k
]
=
temp_state_dict
[
k
]
return
QuantState
.
from_dict
(
quant_state
,
device
=
current_platform
.
device_type
)
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for
(
org_weight_name
,
mapped_weight_name
,
weight_tensor
,
)
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
self
.
_is_4bit_weight_name
(
mapped_weight_name
):
continue
if
(
f
"
{
mapped_weight_name
}
.quant_state.bitsandbytes__nf4"
in
temp_state_dict
)
or
(
f
"
{
mapped_weight_name
}
.quant_state.bitsandbytes__fp4"
in
temp_state_dict
):
quant_state
=
_parse_quant_state
(
mapped_weight_name
,
temp_state_dict
)
quant_state_dict
[
mapped_weight_name
]
=
quant_state
yield
org_weight_name
,
weight_tensor
else
:
yield
org_weight_name
,
weight_tensor
def
_unquantized_generator
(
self
,
hf_weights_files
,
use_safetensors
,
quant_state_dict
)
->
Generator
:
from
bitsandbytes.functional
import
quantize_4bit
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
for
(
org_weight_name
,
mapped_weight_name
,
weight_tensor
,
)
in
self
.
_hf_weight_iter
(
hf_weights_files
,
use_safetensors
):
if
any
(
target_module
in
mapped_weight_name
for
target_module
in
self
.
target_modules
)
and
mapped_weight_name
.
endswith
(
".weight"
):
# Without sharding
if
any
(
mapped_weight_name
.
startswith
(
module
)
for
module
in
self
.
unsharded_weights_modules
):
weight_sub_tensor
=
weight_tensor
# Shard by column
elif
any
(
mapped_weight_name
.
startswith
(
module
)
for
module
in
self
.
column_sharded_weights_modules
):
total_size
=
weight_tensor
.
size
(
-
1
)
start_index
=
total_size
//
tp_size
*
tp_rank
end_index
=
total_size
//
tp_size
*
(
tp_rank
+
1
)
weight_sub_tensor
=
weight_tensor
[...,
start_index
:
end_index
]
# Weights have fused on disk. In this case, we assume that the
# weight and module use same name.
elif
any
(
mapped_weight_name
.
startswith
(
module
)
for
module
in
self
.
maybe_fused_weights_modules
):
# special case for fused weights
# get the size of each shard weight tensor
total_shard_sizes
=
next
(
(
sizes
for
module
,
sizes
in
self
.
maybe_fused_weights_modules
.
items
()
if
mapped_weight_name
.
startswith
(
module
)))
total_size
=
weight_tensor
.
size
(
0
)
assert
total_size
==
sum
(
total_shard_sizes
)
# get the start/end index of each shard weight tensor
total_start_index
=
list
(
itertools
.
accumulate
([
0
]
+
total_shard_sizes
))[:
-
1
]
shard_weights_index
=
[(
idx
+
size
//
tp_size
*
tp_rank
,
idx
+
size
//
tp_size
*
(
tp_rank
+
1
),
)
for
idx
,
size
in
zip
(
total_start_index
,
total_shard_sizes
)]
# slice and reorder the weight tensor
weight_tensor
=
[
weight_tensor
[
start_index
:
end_index
,
...]
for
start_index
,
end_index
in
shard_weights_index
]
weight_sub_tensor
=
torch
.
cat
(
weight_tensor
,
dim
=
0
)
# Shard by row
else
:
total_size
=
weight_tensor
.
size
(
0
)
start_index
=
total_size
//
tp_size
*
tp_rank
end_index
=
total_size
//
tp_size
*
(
tp_rank
+
1
)
weight_sub_tensor
=
weight_tensor
[
start_index
:
end_index
,
...]
# bitsandbytes requires data in GPU
if
weight_sub_tensor
.
is_cuda
:
loaded_weight
=
weight_sub_tensor
else
:
loaded_weight
=
weight_sub_tensor
.
cuda
()
# remove the following after the issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
if
loaded_weight
.
is_contiguous
()
is
False
:
loaded_weight
=
loaded_weight
.
contiguous
()
with
set_default_torch_dtype
(
torch
.
float32
):
processed_weight
,
quant_state
=
quantize_4bit
(
loaded_weight
,
compress_statistics
=
True
,
quant_type
=
"nf4"
,
)
quant_state_dict
[
mapped_weight_name
]
=
quant_state
else
:
processed_weight
=
weight_tensor
yield
org_weight_name
,
processed_weight
def
_get_bnb_target_modules
(
self
,
model
:
nn
.
Module
)
->
None
:
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
(
LinearBase
,
)):
if
modules_info
:
=
self
.
modules_mapping
.
get_sub_modules
(
name
):
# Map vllm's names to transformers's names.
rep_name
,
sub_modules
=
modules_info
for
sub_name
in
sub_modules
:
self
.
target_modules
.
append
(
name
.
replace
(
rep_name
,
sub_name
))
# Add original module name even if the module has stacked map,
# in case model has a mixture of disk-merged and disk-splitted
# weights with same last name.
self
.
target_modules
.
append
(
name
)
assert
(
self
.
target_modules
),
"vllm currently does not support BNB quantization for"
f
"
{
type
(
model
).
__name__
}
"
def
_load_weights
(
self
,
model_config
:
ModelConfig
,
model
:
nn
.
Module
)
->
None
:
if
not
hasattr
(
model
,
"load_weights"
):
raise
AttributeError
(
"The required method 'load_weights' is not defined in class"
f
"
{
type
(
model
).
__name__
}
."
)
if
not
hasattr
(
model
,
"packed_modules_mapping"
):
raise
AttributeError
(
f
"Model
{
type
(
model
).
__name__
}
does not support BitsAndBytes "
"quantization yet. No 'packed_modules_mapping' found."
)
self
.
modules_mapping
=
ParamMapping
(
copy
.
deepcopy
(
model
.
packed_modules_mapping
))
# For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights.
if
hf_to_vllm_mapper
:
=
getattr
(
model
,
"hf_to_vllm_mapper"
,
None
):
self
.
weight_mapper
=
lambda
name
:
hf_to_vllm_mapper
.
_map_name
(
name
)
# Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP
self
.
maybe_fused_weights_modules
:
Dict
[
str
,
List
[
int
]]
=
{}
self
.
_get_bnb_target_modules
(
model
)
for
name
,
module
in
model
.
named_modules
():
# Some modules like `ReplicatedLinear` should not have their weights
# sharded. The reason for implementing it this way is to avoid new
# static variable in the model implementation.
if
isinstance
(
module
,
(
ReplicatedLinear
,
)):
self
.
unsharded_weights_modules
.
append
(
name
)
# `QKVParallelLinear` and `MergedColumnParallelLinear` might have
# fused weights on disk. We need to use the output sizes of these
# modules to shard the weights correctly.
elif
isinstance
(
module
,
(
QKVParallelLinear
,
MergedColumnParallelLinear
)):
self
.
maybe_fused_weights_modules
[
name
]
=
module
.
output_sizes
# In TP, these weights are partitioned along the column
# dimension (dim=-1)
elif
isinstance
(
module
,
(
RowParallelLinear
,
)):
self
.
column_sharded_weights_modules
.
append
(
name
)
self
.
model_type
=
type
(
model
).
__name__
logger
.
info
(
"Loading weights with BitsAndBytes quantization. "
"May take a while ..."
)
quant_config
=
getattr
(
model_config
.
hf_config
,
"quantization_config"
,
None
)
pre_quant
=
False
if
quant_config
is
not
None
:
quant_method
=
quant_config
.
get
(
"quant_method"
)
if
quant_method
==
"bitsandbytes"
:
pre_quant
=
True
else
:
raise
ValueError
(
f
"BitsAndBytes loader does not support
{
quant_method
}
"
"quantization"
)
# The quant_states in pre_quantized models cannot work with a split
# weight tensor. So TP does not work with pre_quantized bnb models.
if
pre_quant
and
get_tensor_model_parallel_world_size
()
>
1
:
raise
ValueError
(
"Prequant BitsAndBytes models with tensor parallelism is not "
"supported. Please try with pipeline parallelism."
)
load_8bit
=
False
if
pre_quant
:
load_8bit
=
quant_config
.
get
(
"load_in_8bit"
,
False
)
qweight_iterator
,
quant_state_dict
=
(
self
.
_get_quantized_weights_iterator
(
model_config
.
model
,
model_config
.
revision
,
pre_quant
,
load_8bit
))
weights_to_load
=
{
name
for
name
,
_
in
model
.
named_parameters
()}
loaded_weights
=
model
.
load_weights
(
qweight_iterator
)
# Some models may have weights loading tracker unimplemented.
if
loaded_weights
is
not
None
:
weights_not_loaded
=
weights_to_load
-
loaded_weights
if
weights_not_loaded
:
raise
ValueError
(
"Following weights were not initialized from "
f
"checkpoint:
{
weights_not_loaded
}
"
)
torch
.
cuda
.
empty_cache
()
param_dict
=
dict
(
model
.
named_parameters
())
stacked_quant_state_dict
:
Dict
[
str
,
Dict
[
int
,
Any
]]
=
{}
# TODO: Change this lazy import to normal import
# after the checks are updated to run on a new version
from
vllm.model_executor.models.utils
import
is_pp_missing_parameter
for
quant_param_name
in
quant_state_dict
:
if
is_pp_missing_parameter
(
quant_param_name
,
model
):
continue
non_stacked_param_name
=
quant_param_name
shard_index
=
0
for
shard_name
,
(
weight_name
,
index
,
)
in
self
.
modules_mapping
.
inverse_packed_mapping
.
items
():
# Some models, such as MiniCPM V2.5/2.6, contain both
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
# from being incorrectly identified as being present in
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
shard_pos
=
quant_param_name
.
find
(
shard_name
)
can_correct_rename
=
(
shard_pos
>
0
)
and
(
quant_param_name
[
shard_pos
-
1
]
==
"."
)
# If the quant_param_name is packed, it won't occur in the
# param_dict before renaming.
new_quant_param_name
=
quant_param_name
.
replace
(
shard_name
,
weight_name
)
need_rename
=
(
quant_param_name
not
in
param_dict
)
\
and
(
new_quant_param_name
in
param_dict
)
if
can_correct_rename
and
need_rename
:
shard_index
=
index
quant_param_name
=
new_quant_param_name
break
# Models like Clip/Siglip may skip some layers in initialization,
# causing unused quant_param_name in state_dict.
if
quant_param_name
not
in
param_dict
:
continue
if
quant_param_name
not
in
stacked_quant_state_dict
:
stacked_quant_state_dict
[
quant_param_name
]
=
{}
stacked_quant_state_dict
[
quant_param_name
][
shard_index
]
=
(
quant_state_dict
[
non_stacked_param_name
])
# save quant_states and offsets as the attributes of the parameters
for
param_name
,
param
in
param_dict
.
items
():
if
param_name
in
stacked_quant_state_dict
:
quant_states
=
stacked_quant_state_dict
[
param_name
]
set_weight_attrs
(
param
,
{
"bnb_quant_state"
:
quant_states
})
pack_ratio
=
getattr
(
param
,
"pack_factor"
,
-
1
)
if
pack_ratio
==
-
1
:
raise
ValueError
(
f
"pack_factor not set for parameter
{
param_name
}
."
)
num_elements
=
[
0
]
*
len
(
quant_states
)
for
seq
,
quant_state
in
quant_states
.
items
():
num_elements
[
seq
]
=
(
math
.
prod
(
quant_state
.
shape
)
//
pack_ratio
)
offsets
=
np
.
concatenate
(([
0
],
np
.
cumsum
(
num_elements
)))
# Make torch infer_schema happy
offsets
=
torch
.
tensor
(
offsets
).
cpu
()
set_weight_attrs
(
param
,
{
"bnb_shard_offsets"
:
offsets
})
if
load_8bit
:
set_weight_attrs
(
param
,
{
"matmul_state"
:
[
None
]
*
len
(
quant_states
)})
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
torch
.
device
(
device_config
.
device
):
model
=
_initialize_model
(
vllm_config
=
vllm_config
)
self
.
_load_weights
(
model_config
,
model
)
return
model
.
eval
()
class
GGUFModelLoader
(
BaseModelLoader
):
"""
Model loader that can load GGUF files. This is useful for loading models
that are quantized with GGUF and saved in the GGUF format. This loader
supports loading both full models and sharded models.
"""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
load_config
.
model_loader_extra_config
:
raise
ValueError
(
f
"Model loader extra config is not supported for "
f
"load format
{
load_config
.
load_format
}
"
)
def
_prepare_weights
(
self
,
model_name_or_path
:
str
):
if
os
.
path
.
isfile
(
model_name_or_path
):
return
model_name_or_path
else
:
raise
ValueError
(
f
"
{
model_name_or_path
}
is not a file."
)
def
_get_gguf_weights_map
(
self
,
model_config
:
ModelConfig
):
"""
GGUF uses this naming convention for their tensors from HF checkpoint:
`blk.N.BB.weight` and `blk.N.BB.bias`
where N signifies the block number of a layer, and BB signifies the
attention/mlp layer components.
See "Standardized tensor names" in
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
"""
config
=
model_config
.
hf_config
model_type
=
config
.
model_type
gguf_to_hf_name_map
=
{}
# hack: ggufs have a different name than transformers
if
model_type
==
"cohere"
:
model_type
=
"command-r"
if
model_type
in
(
"deepseek_v3"
,
"deepseek_v2"
):
model_type
=
"deepseek2"
# GGUF layer map assumes that we will have a merged expert weights
# so we need to map them manually
for
idx
in
range
(
config
.
num_hidden_layers
):
gguf_to_hf_name_map
[
f
"blk.
{
idx
}
.exp_probs_b.bias"
]
=
\
f
"model.layers.
{
idx
}
.mlp.gate.e_score_correction_bias"
gguf_to_hf_name_map
[
f
"blk.
{
idx
}
.ffn_down_exps.weight"
]
=
\
f
"model.layers.
{
idx
}
.mlp.experts.0.down_proj.weight"
gguf_to_hf_name_map
[
f
"blk.
{
idx
}
.ffn_gate_exps.weight"
]
=
\
f
"model.layers.
{
idx
}
.mlp.experts.0.gate_proj.weight"
gguf_to_hf_name_map
[
f
"blk.
{
idx
}
.ffn_up_exps.weight"
]
=
\
f
"model.layers.
{
idx
}
.mlp.experts.0.up_proj.weight"
arch
=
None
for
key
,
value
in
gguf
.
MODEL_ARCH_NAMES
.
items
():
if
value
==
model_type
:
arch
=
key
break
if
arch
is
None
:
raise
RuntimeError
(
f
"Unknown gguf model_type:
{
model_type
}
"
)
num_layers
=
config
.
num_hidden_layers
name_map
=
gguf
.
get_tensor_name_map
(
arch
,
num_layers
)
with
torch
.
device
(
"meta"
):
dummy_model
=
AutoModelForCausalLM
.
from_config
(
config
,
trust_remote_code
=
model_config
.
trust_remote_code
)
state_dict
=
dummy_model
.
state_dict
()
for
hf_name
in
state_dict
:
name
,
suffix
=
hf_name
.
rsplit
(
"."
,
1
)
gguf_name
=
name_map
.
get_name
(
name
)
gguf_to_hf_name_map
[
f
"
{
gguf_name
}
.
{
suffix
}
"
]
=
hf_name
return
gguf_to_hf_name_map
def
_get_weights_iterator
(
self
,
model_name_or_path
:
str
,
gguf_to_hf_name_map
:
Dict
[
str
,
str
]
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
return
gguf_quant_weights_iterator
(
model_name_or_path
,
gguf_to_hf_name_map
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
self
.
_prepare_weights
(
model_config
.
model
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
local_model_path
=
self
.
_prepare_weights
(
model_config
.
model
)
gguf_weights_map
=
self
.
_get_gguf_weights_map
(
model_config
)
# we can only know if tie word embeddings after mapping weights
if
"lm_head.weight"
in
get_gguf_extra_tensor_names
(
local_model_path
,
gguf_weights_map
):
model_config
.
hf_config
.
update
({
"tie_word_embeddings"
:
True
})
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
model
=
_initialize_model
(
vllm_config
=
vllm_config
)
model
.
load_weights
(
self
.
_get_weights_iterator
(
local_model_path
,
gguf_weights_map
))
_process_weights_after_loading
(
model
,
model_config
,
target_device
)
return
model
class
RunaiModelStreamerLoader
(
BaseModelLoader
):
"""
Model loader that can load safetensors
files from local FS or S3 bucket.
"""
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
if
load_config
.
model_loader_extra_config
:
extra_config
=
load_config
.
model_loader_extra_config
if
(
"concurrency"
in
extra_config
and
isinstance
(
extra_config
.
get
(
"concurrency"
),
int
)):
os
.
environ
[
"RUNAI_STREAMER_CONCURRENCY"
]
=
str
(
extra_config
.
get
(
"concurrency"
))
if
(
"memory_limit"
in
extra_config
and
isinstance
(
extra_config
.
get
(
"memory_limit"
),
int
)):
os
.
environ
[
"RUNAI_STREAMER_MEMORY_LIMIT"
]
=
str
(
extra_config
.
get
(
"memory_limit"
))
runai_streamer_s3_endpoint
=
os
.
getenv
(
'RUNAI_STREAMER_S3_ENDPOINT'
)
aws_endpoint_url
=
os
.
getenv
(
'AWS_ENDPOINT_URL'
)
if
(
runai_streamer_s3_endpoint
is
None
and
aws_endpoint_url
is
not
None
):
os
.
environ
[
"RUNAI_STREAMER_S3_ENDPOINT"
]
=
aws_endpoint_url
def
_prepare_weights
(
self
,
model_name_or_path
:
str
,
revision
:
Optional
[
str
])
->
List
[
str
]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
is_s3_path
=
is_s3
(
model_name_or_path
)
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
safetensors_pattern
=
"*.safetensors"
index_file
=
SAFE_WEIGHTS_INDEX_NAME
hf_folder
=
(
model_name_or_path
if
(
is_local
or
is_s3_path
)
else
download_weights_from_hf
(
model_name_or_path
,
self
.
load_config
.
download_dir
,
[
safetensors_pattern
],
revision
,
ignore_patterns
=
self
.
load_config
.
ignore_patterns
,
))
if
is_s3_path
:
hf_weights_files
=
s3_glob
(
path
=
hf_folder
,
allow_pattern
=
[
safetensors_pattern
])
else
:
hf_weights_files
=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
safetensors_pattern
))
if
not
is_local
and
not
is_s3_path
:
download_safetensors_index_file_from_hf
(
model_name_or_path
,
index_file
,
self
.
load_config
.
download_dir
,
revision
)
if
not
hf_weights_files
:
raise
RuntimeError
(
f
"Cannot find any safetensors model weights with "
f
"`
{
model_name_or_path
}
`"
)
return
hf_weights_files
def
_get_weights_iterator
(
self
,
model_or_path
:
str
,
revision
:
str
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Get an iterator for the model weights based on the load format."""
hf_weights_files
=
self
.
_prepare_weights
(
model_or_path
,
revision
)
return
runai_safetensors_weights_iterator
(
hf_weights_files
,
self
.
load_config
.
use_tqdm_on_load
,
)
def
download_model
(
self
,
model_config
:
ModelConfig
)
->
None
:
"""Download model if necessary"""
self
.
_prepare_weights
(
model_config
.
model
,
model_config
.
revision
)
def
load_model
(
self
,
vllm_config
:
VllmConfig
)
->
nn
.
Module
:
"""Perform streaming of the model to destination"""
device_config
=
vllm_config
.
device_config
model_config
=
vllm_config
.
model_config
target_device
=
torch
.
device
(
device_config
.
device
)
with
set_default_torch_dtype
(
model_config
.
dtype
):
with
target_device
:
model
=
_initialize_model
(
vllm_config
=
vllm_config
)
model_weights
=
model_config
.
model
if
hasattr
(
model_config
,
"model_weights"
):
model_weights
=
model_config
.
model_weights
model
.
load_weights
(
self
.
_get_weights_iterator
(
model_weights
,
model_config
.
revision
))
_process_weights_after_loading
(
model
,
model_config
,
target_device
)
return
model
.
eval
()
def
get_model_loader
(
load_config
:
LoadConfig
)
->
BaseModelLoader
:
"""Get a model loader based on the load format."""
if
isinstance
(
load_config
.
load_format
,
type
):
return
load_config
.
load_format
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
DUMMY
:
return
DummyModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
TENSORIZER
:
return
TensorizerLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
SHARDED_STATE
:
return
ShardedStateLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
BITSANDBYTES
:
return
BitsAndBytesModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
GGUF
:
return
GGUFModelLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
RUNAI_STREAMER
:
return
RunaiModelStreamerLoader
(
load_config
)
if
load_config
.
load_format
==
LoadFormat
.
RUNAI_STREAMER_SHARDED
:
return
ShardedStateLoader
(
load_config
,
runai_model_streamer
=
True
)
return
DefaultModelLoader
(
load_config
)
vllm/two_batch_overlap/two_batch_overlap.py
View file @
cfabf125
...
@@ -58,7 +58,8 @@ class TwoBatchOverlap():
...
@@ -58,7 +58,8 @@ class TwoBatchOverlap():
self
.
left_thread
.
start
()
self
.
left_thread
.
start
()
self
.
right_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_right_queue
,))
self
.
right_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_right_queue
,))
self
.
right_thread
.
start
()
self
.
right_thread
.
start
()
logger
.
info
(
'tbo:two batch overlap start'
)
if
get_tp_group
().
rank
==
0
:
logger
.
info
(
'tbo:two batch overlap start'
)
def
finish_thread
(
self
):
def
finish_thread
(
self
):
self
.
left_thread
.
join
()
self
.
left_thread
.
join
()
...
...
vllm/two_batch_overlap/v1/model_input_split_v1.py
View file @
cfabf125
...
@@ -9,6 +9,7 @@ from vllm.forward_context import set_forward_context
...
@@ -9,6 +9,7 @@ from vllm.forward_context import set_forward_context
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.two_batch_overlap.v1.two_batch_overlap_v1
import
tbo_model_executable_v1
from
vllm.two_batch_overlap.v1.two_batch_overlap_v1
import
tbo_model_executable_v1
from
vllm.utils
import
async_tensor_h2d
from
vllm.utils
import
async_tensor_h2d
from
vllm.v1.attention.backends.mla.common
import
MLACommonMetadataBuilder
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.core.sched.output
import
CachedRequestData
,
SchedulerOutput
from
vllm.v1.core.sched.output
import
CachedRequestData
,
SchedulerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
...
@@ -224,28 +225,45 @@ def prepare_tbo_atten_metadata(
...
@@ -224,28 +225,45 @@ def prepare_tbo_atten_metadata(
# Prepare for cascade attention if enabled & beneficial.
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len
=
0
common_prefix_len
=
0
metadata_builder
=
runner
.
attn_metadata_builders
[
kv_cache_group_id
]
if
runner
.
cascade_attn_enabled
:
if
runner
.
cascade_attn_enabled
:
common_prefix_len
=
runner
.
_compute_cascade_attn_prefix_len
(
common_prefix_len
=
runner
.
_compute_cascade_attn_prefix_len
(
num_scheduled_tokens
,
num_scheduled_tokens
,
scheduler_output
.
scheduler_output
.
num_common_prefix_blocks
[
kv_cache_group_id
],
num_common_prefix_blocks
[
kv_cache_group_id
],
kv_cache_group_spec
.
kv_cache_spec
,
kv_cache_group_spec
.
kv_cache_spec
,
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
,
metadata_builder
,
)
)
if
req_offset
>
0
:
if
req_offset
>
0
:
origin_block_table
=
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
.
block_table
.
block_table
origin_block_table
=
metadata_builder
.
block_table
.
block_table
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
.
block_table
.
block_table
=
origin_block_table
[
req_offset
:,
:]
metadata_builder
.
block_table
.
block_table
=
origin_block_table
[
req_offset
:,
:]
origin_slot_mapping
=
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
.
block_table
.
slot_mapping
origin_slot_mapping
=
metadata_builder
.
block_table
.
slot_mapping
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
.
block_table
.
slot_mapping
=
\
metadata_builder
.
block_table
.
slot_mapping
=
\
origin_slot_mapping
[
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
:]
origin_slot_mapping
[
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
:]
if
isinstance
(
metadata_builder
,
MLACommonMetadataBuilder
):
# now support prefill only
_num_decodes_record
=
metadata_builder
.
_num_decodes
_num_prefills_record
=
metadata_builder
.
_num_prefills
_num_decode_tokens_record
=
metadata_builder
.
_num_decode_tokens
_num_prefill_tokens_record
=
metadata_builder
.
_num_prefill_tokens
metadata_builder
.
_num_decodes
=
0
metadata_builder
.
_num_prefills
=
num_reqs
metadata_builder
.
_num_decode_tokens
=
0
metadata_builder
.
_num_prefill_tokens
=
total_num_scheduled_tokens
attn_metadata_i
=
(
attn_metadata_i
=
(
runner
.
attn_
metadata_builder
s
[
kv_cache_group_id
]
.
build
(
metadata_builder
.
build
(
common_prefix_len
=
common_prefix_len
,
common_prefix_len
=
common_prefix_len
,
common_attn_metadata
=
common_attn_metadata
))
# maybe FlashAttentionMetadata
common_attn_metadata
=
common_attn_metadata
))
# maybe FlashAttentionMetadata
if
req_offset
>
0
:
if
req_offset
>
0
:
runner
.
attn_metadata_builders
[
kv_cache_group_id
].
block_table
.
block_table
=
origin_block_table
metadata_builder
.
block_table
.
block_table
=
origin_block_table
runner
.
attn_metadata_builders
[
kv_cache_group_id
].
block_table
.
slot_mapping
=
origin_slot_mapping
metadata_builder
.
block_table
.
slot_mapping
=
origin_slot_mapping
if
isinstance
(
metadata_builder
,
MLACommonMetadataBuilder
):
# now support prefill only
metadata_builder
.
_num_decodes
=
_num_decodes_record
metadata_builder
.
_num_prefills
=
_num_prefills_record
metadata_builder
.
_num_decode_tokens
=
_num_decode_tokens_record
metadata_builder
.
_num_prefill_tokens
=
_num_prefill_tokens_record
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
attn_metadata
[
layer_name
]
=
attn_metadata_i
attn_metadata
[
layer_name
]
=
attn_metadata_i
...
@@ -288,12 +306,16 @@ def tbo_split_and_execute_model(
...
@@ -288,12 +306,16 @@ def tbo_split_and_execute_model(
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Union
[
ModelRunnerOutput
,
IntermediateTensors
]:
)
->
Union
[
ModelRunnerOutput
,
IntermediateTensors
]:
use_tbo
=
False
use_tbo
=
False
if
len
(
scheduler_output
.
num_scheduled_tokens
)
>
1
:
split_scheduler_output
(
runner
,
scheduler_output
)
if
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
>=
envs
.
VLLM_TBO_MIN_TOKENS
and
\
input_split
.
scheduler_output_right
.
total_num_scheduled_tokens
>=
envs
.
VLLM_TBO_MIN_TOKENS
:
use_tbo
=
True
if
isinstance
(
runner
.
attn_metadata_builders
[
0
],
MLACommonMetadataBuilder
)
and
\
runner
.
attn_metadata_builders
[
0
].
_num_decodes
>
0
:
#is mla decode
use_tbo
=
False
else
:
if
len
(
scheduler_output
.
num_scheduled_tokens
)
>
1
:
split_scheduler_output
(
runner
,
scheduler_output
)
if
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
>=
envs
.
VLLM_TBO_MIN_TOKENS
and
\
input_split
.
scheduler_output_right
.
total_num_scheduled_tokens
>=
envs
.
VLLM_TBO_MIN_TOKENS
:
use_tbo
=
True
if
use_tbo
:
if
use_tbo
:
num_input_tokens_left
=
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
num_input_tokens_left
=
input_split
.
scheduler_output_left
.
total_num_scheduled_tokens
num_input_tokens_right
=
num_input_tokens
-
num_input_tokens_left
num_input_tokens_right
=
num_input_tokens
-
num_input_tokens_left
...
@@ -319,7 +341,8 @@ def tbo_split_and_execute_model(
...
@@ -319,7 +341,8 @@ def tbo_split_and_execute_model(
with
set_forward_context
(
attn_metadata
,
with
set_forward_context
(
attn_metadata
,
runner
.
vllm_config
,
runner
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
num_tokens_across_dp
):
num_tokens_across_dp
=
num_tokens_across_dp
,
skip_cuda_graphs
=
True
):
runner
.
maybe_setup_kv_connector
(
scheduler_output
)
runner
.
maybe_setup_kv_connector
(
scheduler_output
)
model_output
=
runner
.
model
(
model_output
=
runner
.
model
(
...
...
vllm/two_batch_overlap/v1/two_batch_overlap_v1.py
View file @
cfabf125
...
@@ -50,7 +50,8 @@ class TwoBatchOverlap():
...
@@ -50,7 +50,8 @@ class TwoBatchOverlap():
self
.
left_thread
.
start
()
self
.
left_thread
.
start
()
self
.
right_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_right_queue
,))
self
.
right_thread
=
threading
.
Thread
(
target
=
self
.
thread_two_batch_overlap
,
args
=
(
self
.
model_input_right_queue
,))
self
.
right_thread
.
start
()
self
.
right_thread
.
start
()
logger
.
info
(
'tbo:two batch overlap start'
)
if
get_tp_group
().
rank
==
0
:
logger
.
info
(
'tbo:two batch overlap start'
)
def
finish_thread
(
self
):
def
finish_thread
(
self
):
self
.
left_thread
.
join
()
self
.
left_thread
.
join
()
...
@@ -71,7 +72,6 @@ class TwoBatchOverlap():
...
@@ -71,7 +72,6 @@ class TwoBatchOverlap():
init_tbo_forward_context
(
False
,
self
.
right_tid
)
init_tbo_forward_context
(
False
,
self
.
right_tid
)
with
torch
.
cuda
.
stream
(
tbo_step_stream
):
with
torch
.
cuda
.
stream
(
tbo_step_stream
):
queue
.
get
()
queue
.
get
()
profile
.
ProfRangePush
(
'start'
)
self
.
tbo_thread_synchronize
(
tid
)
self
.
tbo_thread_synchronize
(
tid
)
if
is_left_thread
:
if
is_left_thread
:
attn_metadata
=
self
.
attn_metadata_left
attn_metadata
=
self
.
attn_metadata_left
...
@@ -90,7 +90,8 @@ class TwoBatchOverlap():
...
@@ -90,7 +90,8 @@ class TwoBatchOverlap():
with
set_forward_context
(
attn_metadata
,
with
set_forward_context
(
attn_metadata
,
self
.
model_runner
.
vllm_config
,
self
.
model_runner
.
vllm_config
,
num_tokens
=
num_input_tokens
,
num_tokens
=
num_input_tokens
,
num_tokens_across_dp
=
self
.
num_tokens_across_dp
):
num_tokens_across_dp
=
self
.
num_tokens_across_dp
,
skip_cuda_graphs
=
True
):
model_output
=
self
.
model_runner
.
model
(
model_output
=
self
.
model_runner
.
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
...
@@ -102,22 +103,17 @@ class TwoBatchOverlap():
...
@@ -102,22 +103,17 @@ class TwoBatchOverlap():
self
.
states_left_queue
.
put
(
model_output
)
self
.
states_left_queue
.
put
(
model_output
)
else
:
else
:
self
.
states_right_queue
.
put
(
model_output
)
self
.
states_right_queue
.
put
(
model_output
)
profile
.
ProfRangePop
()
def
tbo_thread_synchronize
(
self
,
tid
):
def
tbo_thread_synchronize
(
self
,
tid
):
if
tid
==
self
.
left_tid
:
if
tid
==
self
.
left_tid
:
if
not
self
.
left_first
:
if
not
self
.
left_first
:
self
.
sem_right
.
release
()
self
.
sem_right
.
release
()
self
.
left_first
=
False
self
.
left_first
=
False
profile
.
ProfRangePop
()
self
.
sem_left
.
acquire
()
self
.
sem_left
.
acquire
()
profile
.
ProfRangePush
(
'left'
)
return
self
.
event_left_c2t
,
self
.
event_left_t2c
return
self
.
event_left_c2t
,
self
.
event_left_t2c
else
:
else
:
self
.
sem_left
.
release
()
self
.
sem_left
.
release
()
profile
.
ProfRangePop
()
self
.
sem_right
.
acquire
()
self
.
sem_right
.
acquire
()
profile
.
ProfRangePush
(
'right'
)
return
self
.
event_right_c2t
,
self
.
event_right_t2c
return
self
.
event_right_c2t
,
self
.
event_right_t2c
def
set_model_input
(
self
,
def
set_model_input
(
self
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
cfabf125
...
@@ -1373,7 +1373,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1373,7 +1373,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# compiled with full CUDA graphs, we have to skip them entirely.
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs
=
self
.
full_cuda_graph
and
not
attention_cuda_graphs
skip_cuda_graphs
=
self
.
full_cuda_graph
and
not
attention_cuda_graphs
if
envs
.
VLLM_ENABLE_TBO
and
not
self
.
use_cuda_graph
:
if
envs
.
VLLM_ENABLE_TBO
and
(
not
self
.
use_cuda_graph
or
skip_cuda_graphs
)
:
model_output
,
finished_sending
,
finished_recving
=
\
model_output
,
finished_sending
,
finished_recving
=
\
tbo_split_and_execute_model
(
self
,
attn_metadata
,
num_input_tokens
,
tbo_split_and_execute_model
(
self
,
attn_metadata
,
num_input_tokens
,
num_tokens_across_dp
,
input_ids
,
positions
,
num_tokens_across_dp
,
input_ids
,
positions
,
...
...
vllm/zero_overhead/v1/eagle.py
0 → 100644
View file @
cfabf125
import
torch
from
vllm.forward_context
import
set_forward_context
from
vllm.model_executor.models.llama_eagle3
import
Eagle3LlamaForCausalLM
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.mla.common
import
MLACommonMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.sample.metadata
import
SamplingMetadata
from
vllm.v1.spec_decode.eagle
import
PADDING_SLOT_ID
,
EagleProposer
class
V1ZeroEagleProposer
(
EagleProposer
):
def
__init__
(
self
,
vllm_config
,
device
,
runner
=
None
):
super
().
__init__
(
vllm_config
,
device
,
runner
)
self
.
spec_scheduler_max_num_tokens
=
0
def
propose
(
self
,
# [num_tokens]
target_token_ids
:
torch
.
Tensor
,
# [num_tokens]
target_positions
:
torch
.
Tensor
,
# [num_tokens, hidden_size]
target_hidden_states
:
torch
.
Tensor
,
# [num_tokens]
target_slot_mapping
:
torch
.
Tensor
,
# [batch_size]
next_token_ids
:
torch
.
Tensor
,
# [batch_size + 1] starting with 0
cu_num_tokens
:
torch
.
Tensor
,
# [batch_size, max_num_blocks_per_req]
block_table
:
torch
.
Tensor
,
# [batch_size]
sampling_metadata
:
SamplingMetadata
,
decoding
:
bool
=
False
,
)
->
torch
.
Tensor
:
num_tokens
=
target_token_ids
.
shape
[
0
]
batch_size
=
next_token_ids
.
shape
[
0
]
last_token_indices
=
cu_num_tokens
[
1
:]
-
1
if
self
.
method
==
"eagle3"
:
assert
isinstance
(
self
.
model
,
Eagle3LlamaForCausalLM
)
target_hidden_states
=
self
.
model
.
combine_hidden_states
(
target_hidden_states
)
assert
target_hidden_states
.
shape
[
-
1
]
==
self
.
hidden_size
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self
.
input_ids
[:
num_tokens
-
1
]
=
target_token_ids
[
1
:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self
.
input_ids
[
last_token_indices
]
=
next_token_ids
# FA requires seq_len to have dtype int32.
seq_lens
=
(
target_positions
[
last_token_indices
]
+
1
).
int
()
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len
=
seq_lens
.
max
().
item
()
max_num_tokens
=
(
cu_num_tokens
[
1
:]
-
cu_num_tokens
[:
-
1
]).
max
().
item
()
attn_metadata
=
FlashAttentionMetadata
(
num_actual_tokens
=
num_tokens
,
max_query_len
=
max_num_tokens
,
query_start_loc
=
cu_num_tokens
,
max_seq_len
=
max_seq_len
,
seq_lens
=
seq_lens
,
block_table
=
block_table
,
slot_mapping
=
target_slot_mapping
,
# TODO(woosuk): Support cascade attention.
use_cascade
=
False
,
common_prefix_len
=
0
,
cu_prefix_query_lens
=
None
,
prefix_kv_lens
=
None
,
suffix_kv_lens
=
None
,
)
elif
self
.
method
==
"deepseek_mtp"
:
max_query_len
=
self
.
spec_scheduler_max_num_tokens
common_attn_metadata
=
CommonAttentionMetadata
(
query_start_loc
=
cu_num_tokens
,
seq_lens
=
seq_lens
,
num_reqs
=
batch_size
,
num_actual_tokens
=
num_tokens
,
max_query_len
=
max_query_len
,
slot_mapping
=
target_slot_mapping
,
spec_layer_decoding
=
decoding
)
assert
self
.
runner
is
not
None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata
=
self
.
runner
.
attn_metadata_builders
[
0
].
build
(
common_prefix_len
=
0
,
common_attn_metadata
=
common_attn_metadata
)
else
:
raise
ValueError
(
f
"Unsupported method:
{
self
.
method
}
"
)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata
=
{}
for
layer_name
in
self
.
attn_layer_names
:
per_layer_attn_metadata
[
layer_name
]
=
attn_metadata
if
self
.
use_cuda_graph
and
\
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
num_input_tokens
=
self
.
vllm_config
.
pad_for_cudagraph
(
num_tokens
)
else
:
num_input_tokens
=
num_tokens
# copy inputs to buffer for cudagraph
self
.
positions
[:
num_tokens
]
=
target_positions
self
.
hidden_states
[:
num_tokens
]
=
target_hidden_states
if
(
decoding
and
self
.
use_full_cuda_graph
and
num_tokens
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
assert
self
.
attn_metadata_cudagraph
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
self
.
attn_metadata_cudagraph
.
seq_lens
[:
batch_size
]
=
(
attn_metadata
.
seq_lens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
num_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
block_table
[:
batch_size
]
=
(
attn_metadata
.
block_table
)
elif
self
.
method
==
"deepseek_mtp"
:
self
.
attn_metadata_cudagraph
.
num_actual_tokens
=
(
attn_metadata
.
num_actual_tokens
)
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
num_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
num_decodes
=
(
attn_metadata
.
num_decodes
)
self
.
attn_metadata_cudagraph
.
num_decode_tokens
=
(
attn_metadata
.
num_decode_tokens
)
self
.
attn_metadata_cudagraph
.
num_prefills
=
(
attn_metadata
.
num_prefills
)
if
attn_metadata
.
decode
is
not
None
:
self
.
attn_metadata_cudagraph
.
decode
.
block_table
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
block_table
)
self
.
attn_metadata_cudagraph
.
decode
.
seq_lens
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
seq_lens
)
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
num_tokens
=
num_input_tokens
,
skip_cuda_graphs
=
not
decoding
):
ret_hidden_states
=
self
.
model
(
self
.
input_ids
[:
num_input_tokens
],
self
.
positions
[:
num_input_tokens
],
self
.
hidden_states
[:
num_input_tokens
],
)
if
self
.
method
==
"deepseek_mtp"
:
last_hidden_states
=
ret_hidden_states
else
:
last_hidden_states
,
hidden_states
=
ret_hidden_states
sample_hidden_states
=
last_hidden_states
[
last_token_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
# Early exit if there is only one draft token to be generated.
if
self
.
num_speculative_tokens
==
1
:
# [batch_size, 1]
return
draft_token_ids
.
view
(
-
1
,
1
)
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.
# Generate the remaining draft tokens.
draft_token_ids_list
=
[
draft_token_ids
]
positions
=
target_positions
[
last_token_indices
]
if
self
.
method
==
"deepseek_mtp"
:
hidden_states
=
last_hidden_states
[
last_token_indices
]
else
:
hidden_states
=
hidden_states
[
last_token_indices
]
if
self
.
use_cuda_graph
and
\
batch_size
<=
self
.
cudagraph_batch_sizes
[
-
1
]:
input_batch_size
=
self
.
vllm_config
.
pad_for_cudagraph
(
batch_size
)
else
:
input_batch_size
=
batch_size
attn_metadata
.
num_actual_tokens
=
batch_size
attn_metadata
.
max_query_len
=
1
attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
if
isinstance
(
attn_metadata
,
MLACommonMetadata
):
attn_metadata
.
num_decodes
=
batch_size
attn_metadata
.
num_decode_tokens
=
batch_size
attn_metadata
.
num_prefills
=
0
block_table
=
self
.
runner
.
attn_metadata_builders
[
0
].
block_table
.
get_device_tensor
()[:
batch_size
,
...]
attn_metadata
.
decode
=
self
.
runner
.
attn_metadata_builders
[
0
].
_build_decode
(
block_table_tensor
=
block_table
,
seq_lens
=
seq_lens
,
)
for
i
in
range
(
self
.
num_speculative_tokens
-
1
):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
input_ids
=
draft_token_ids_list
[
-
1
].
int
()
positions
+=
1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len
=
positions
>=
self
.
max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions
=
torch
.
where
(
exceeds_max_model_len
,
0
,
positions
)
if
isinstance
(
attn_metadata
,
MLACommonMetadata
):
attn_metadata
.
decode
.
seq_lens
+=
1
else
:
attn_metadata
.
seq_lens
+=
1
# Increment the sequence lengths.
attn_metadata
.
max_seq_len
+=
1
# Consider max model length.
attn_metadata
.
max_seq_len
=
min
(
attn_metadata
.
max_seq_len
,
self
.
max_model_len
)
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
attn_metadata
.
seq_lens
.
masked_fill_
(
exceeds_max_model_len
,
1
)
# Compute the slot mapping.
block_numbers
=
clamped_positions
//
self
.
block_size
block_ids
=
block_table
.
gather
(
dim
=
1
,
index
=
block_numbers
.
view
(
-
1
,
1
))
block_ids
=
block_ids
.
view
(
-
1
)
attn_metadata
.
slot_mapping
=
(
block_ids
*
self
.
block_size
+
clamped_positions
%
self
.
block_size
)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
attn_metadata
.
slot_mapping
.
masked_fill_
(
exceeds_max_model_len
,
PADDING_SLOT_ID
)
# copy inputs to buffer for cudagraph
self
.
input_ids
[:
batch_size
]
=
input_ids
self
.
positions
[:
batch_size
]
=
clamped_positions
self
.
hidden_states
[:
batch_size
]
=
hidden_states
if
(
self
.
use_full_cuda_graph
and
batch_size
<=
self
.
cudagraph_batch_sizes
[
-
1
]):
assert
self
.
attn_metadata_cudagraph
if
self
.
method
in
[
"eagle"
,
"eagle3"
]:
self
.
attn_metadata_cudagraph
.
seq_lens
[:
batch_size
]
=
(
attn_metadata
.
seq_lens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
batch_size
]
=
(
attn_metadata
.
slot_mapping
)
if
i
==
0
:
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
block_table
[:
batch_size
]
=
(
attn_metadata
.
block_table
)
elif
self
.
method
==
"deepseek_mtp"
:
self
.
attn_metadata_cudagraph
.
num_actual_tokens
=
(
attn_metadata
.
num_actual_tokens
)
self
.
attn_metadata_cudagraph
.
slot_mapping
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
slot_mapping
)
self
.
attn_metadata_cudagraph
.
num_decodes
=
(
attn_metadata
.
num_decodes
)
self
.
attn_metadata_cudagraph
.
num_decode_tokens
=
(
attn_metadata
.
num_decode_tokens
)
self
.
attn_metadata_cudagraph
.
num_prefills
=
(
attn_metadata
.
num_prefills
)
self
.
attn_metadata_cudagraph
.
decode
.
seq_lens
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
seq_lens
)
if
i
==
0
:
self
.
attn_metadata_cudagraph
.
query_start_loc
[:
batch_size
+
1
]
=
(
attn_metadata
.
query_start_loc
)
self
.
attn_metadata_cudagraph
.
decode
.
block_table
[:
attn_metadata
.
num_decode_tokens
]
=
(
attn_metadata
.
decode
.
block_table
)
# Run the model.
with
set_forward_context
(
per_layer_attn_metadata
,
self
.
vllm_config
,
num_tokens
=
input_batch_size
):
ret_hidden_states
=
self
.
model
(
self
.
input_ids
[:
input_batch_size
],
self
.
positions
[:
input_batch_size
],
self
.
hidden_states
[:
input_batch_size
],
)
if
self
.
method
==
"deepseek_mtp"
:
last_hidden_states
=
ret_hidden_states
hidden_states
=
last_hidden_states
[:
batch_size
]
else
:
last_hidden_states
,
hidden_states
=
ret_hidden_states
hidden_states
=
hidden_states
[:
batch_size
]
logits
=
self
.
model
.
compute_logits
(
last_hidden_states
[:
batch_size
],
None
)
# TODO(wenlong): get more than one token for tree attention
draft_token_ids
=
logits
.
argmax
(
dim
=-
1
)
draft_token_ids_list
.
append
(
draft_token_ids
)
# [batch_size, num_speculative_tokens]
draft_token_ids
=
torch
.
stack
(
draft_token_ids_list
,
dim
=
1
)
return
draft_token_ids
vllm/zero_overhead/v1/gpu_model_runner.py
View file @
cfabf125
...
@@ -18,6 +18,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
...
@@ -18,6 +18,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.spec_decode.ngram_proposer
import
NgramProposer
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.zero_overhead.v1.eagle
import
V1ZeroEagleProposer
from
vllm.zero_overhead.v1.outputs
import
ZeroV1ModelRunnerOutput
from
vllm.zero_overhead.v1.outputs
import
ZeroV1ModelRunnerOutput
from
vllm.profiler.prof
import
profile
from
vllm.profiler.prof
import
profile
from
vllm.two_batch_overlap.v1.model_input_split_v1
import
tbo_split_and_execute_model
from
vllm.two_batch_overlap.v1.model_input_split_v1
import
tbo_split_and_execute_model
...
@@ -31,10 +32,15 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -31,10 +32,15 @@ class V1ZeroModelRunner(GPUModelRunner):
self
.
last_sampled_token_lens
=
[]
self
.
last_sampled_token_lens
=
[]
self
.
last_sampler_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
last_sampler_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
last_sampler_host_tokens
=
None
self
.
last_sampler_host_tokens
=
None
self
.
token_ids_cpu_fix_recod
e
=
[]
self
.
token_ids_cpu_fix_reco
r
d
=
[]
self
.
last_draft_token_ids
=
None
self
.
last_draft_token_ids
=
None
self
.
last_draft_host_tokens
=
None
self
.
last_draft_host_tokens
=
None
self
.
last_draft_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
last_draft_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
spec_sampler_event
=
torch
.
cuda
.
Event
(
enable_timing
=
False
)
self
.
spec_scheduler_max_num_tokens
=
0
if
hasattr
(
self
,
'drafter'
)
and
isinstance
(
self
.
drafter
,
EagleProposer
):
self
.
drafter
=
V1ZeroEagleProposer
(
self
.
vllm_config
,
self
.
device
,
self
)
def
_prepare_inputs
(
def
_prepare_inputs
(
self
,
self
,
...
@@ -62,6 +68,7 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -62,6 +68,7 @@ class V1ZeroModelRunner(GPUModelRunner):
tokens
=
[
scheduler_output
.
num_scheduled_tokens
[
i
]
for
i
in
req_ids
]
tokens
=
[
scheduler_output
.
num_scheduled_tokens
[
i
]
for
i
in
req_ids
]
num_scheduled_tokens
=
np
.
array
(
tokens
,
dtype
=
np
.
int32
)
num_scheduled_tokens
=
np
.
array
(
tokens
,
dtype
=
np
.
int32
)
max_num_scheduled_tokens
=
max
(
tokens
)
max_num_scheduled_tokens
=
max
(
tokens
)
self
.
spec_scheduler_max_num_tokens
=
max_num_scheduled_tokens
# Get request indices.
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
...
@@ -281,7 +288,8 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -281,7 +288,8 @@ class V1ZeroModelRunner(GPUModelRunner):
def
propose_draft_token_ids
(
def
propose_draft_token_ids
(
self
,
self
,
scheduler_output
:
"SchedulerOutput"
,
scheduler_output
:
"SchedulerOutput"
,
sampled_token_ids
:
list
[
list
[
int
]],
num_accepted_tokens_tensor
:
torch
.
Tensor
,
sampled_token_ids
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sample_hidden_states
:
torch
.
Tensor
,
sample_hidden_states
:
torch
.
Tensor
,
...
@@ -317,26 +325,8 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -317,26 +325,8 @@ class V1ZeroModelRunner(GPUModelRunner):
elif
self
.
speculative_config
.
use_eagle
():
elif
self
.
speculative_config
.
use_eagle
():
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
# TODO(woosuk): Refactor the loop.
# TODO(woosuk): Refactor the loop.
if
self
.
last_sampled_token_ids
is
not
None
:
row_indices
=
torch
.
arange
(
sampled_token_ids
.
size
(
0
),
device
=
sampled_token_ids
.
device
)
next_token_ids
=
self
.
last_sampled_token_ids
.
flatten
()
next_token_ids
=
sampled_token_ids
[
row_indices
,
num_accepted_tokens_tensor
].
flatten
()
else
:
next_token_ids
:
list
[
int
]
=
[]
for
i
,
token_ids
in
enumerate
(
sampled_token_ids
):
if
token_ids
:
# Common case.
next_token_id
=
token_ids
[
-
1
]
else
:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id
=
self
.
input_batch
.
req_ids
[
i
]
req_state
=
self
.
requests
[
req_id
]
seq_len
=
(
req_state
.
num_computed_tokens
+
scheduler_output
.
num_scheduled_tokens
[
req_id
])
next_token_id
=
req_state
.
get_token_id
(
seq_len
)
next_token_ids
.
append
(
next_token_id
)
next_token_ids
=
torch
.
tensor
(
next_token_ids
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
# At this moment, we assume all eagle layers belong to the same KV
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
# cache group, thus using the same attention metadata.
eagle_attn_metadata
=
attn_metadata
[
eagle_attn_metadata
=
attn_metadata
[
...
@@ -348,6 +338,7 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -348,6 +338,7 @@ class V1ZeroModelRunner(GPUModelRunner):
else
:
else
:
block_table
=
None
block_table
=
None
spec_scheduler_max_num_tokens
=
self
.
spec_scheduler_max_num_tokens
if
spec_decode_metadata
is
None
:
if
spec_decode_metadata
is
None
:
# input_ids can be None for multimodal models.
# input_ids can be None for multimodal models.
target_token_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
target_token_ids
=
self
.
input_ids
[:
num_scheduled_tokens
]
...
@@ -363,16 +354,11 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -363,16 +354,11 @@ class V1ZeroModelRunner(GPUModelRunner):
cu_num_tokens
=
eagle_attn_metadata
.
query_start_loc
cu_num_tokens
=
eagle_attn_metadata
.
query_start_loc
else
:
else
:
# TODO(woosuk): Refactor this.
# TODO(woosuk): Refactor this.
num_accepted_tokens
=
[
len
(
s
)
-
1
for
s
in
sampled_token_ids
]
num_accepted_tokens_tensor
=
async_tensor_h2d
(
num_accepted_tokens
,
dtype
=
torch
.
int32
,
target_device
=
self
.
device
,
pin_memory
=
True
)
cu_num_tokens
,
token_indices
=
self
.
drafter
.
prepare_inputs
(
cu_num_tokens
,
token_indices
=
self
.
drafter
.
prepare_inputs
(
eagle_attn_metadata
.
query_start_loc
,
eagle_attn_metadata
.
query_start_loc
,
num_accepted_tokens_tensor
,
num_accepted_tokens_tensor
,
)
)
spec_scheduler_max_num_tokens
=
1
target_token_ids
=
self
.
input_ids
[
token_indices
]
target_token_ids
=
self
.
input_ids
[
token_indices
]
# TODO(woosuk): Support M-RoPE.
# TODO(woosuk): Support M-RoPE.
target_positions
=
self
.
positions
[
token_indices
]
target_positions
=
self
.
positions
[
token_indices
]
...
@@ -383,6 +369,7 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -383,6 +369,7 @@ class V1ZeroModelRunner(GPUModelRunner):
target_hidden_states
=
hidden_states
[
token_indices
]
target_hidden_states
=
hidden_states
[
token_indices
]
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
[
target_slot_mapping
=
eagle_attn_metadata
.
slot_mapping
[
token_indices
]
token_indices
]
self
.
drafter
.
spec_scheduler_max_num_tokens
=
spec_scheduler_max_num_tokens
draft_token_ids
=
self
.
drafter
.
propose
(
draft_token_ids
=
self
.
drafter
.
propose
(
target_token_ids
=
target_token_ids
,
target_token_ids
=
target_token_ids
,
target_positions
=
target_positions
,
target_positions
=
target_positions
,
...
@@ -392,7 +379,7 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -392,7 +379,7 @@ class V1ZeroModelRunner(GPUModelRunner):
cu_num_tokens
=
cu_num_tokens
,
cu_num_tokens
=
cu_num_tokens
,
block_table
=
block_table
,
block_table
=
block_table
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
decoding
=
spec_decode_metadata
is
not
None
decoding
=
spec_decode_metadata
is
not
None
,
)
)
spec_token_ids
=
np
.
ones
(
draft_token_ids
.
shape
,
dtype
=
int
).
tolist
()
spec_token_ids
=
np
.
ones
(
draft_token_ids
.
shape
,
dtype
=
int
).
tolist
()
self
.
last_draft_token_ids
=
draft_token_ids
self
.
last_draft_token_ids
=
draft_token_ids
...
@@ -486,7 +473,7 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -486,7 +473,7 @@ class V1ZeroModelRunner(GPUModelRunner):
# compiled with full CUDA graphs, we have to skip them entirely.
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs
=
self
.
full_cuda_graph
and
not
attention_cuda_graphs
skip_cuda_graphs
=
self
.
full_cuda_graph
and
not
attention_cuda_graphs
if
envs
.
VLLM_ENABLE_TBO
and
not
self
.
use_cuda_graph
:
if
envs
.
VLLM_ENABLE_TBO
and
(
not
self
.
use_cuda_graph
or
skip_cuda_graphs
)
:
model_output
,
finished_sending
,
finished_recving
=
\
model_output
,
finished_sending
,
finished_recving
=
\
tbo_split_and_execute_model
(
self
,
attn_metadata
,
num_input_tokens
,
tbo_split_and_execute_model
(
self
,
attn_metadata
,
num_input_tokens
,
num_tokens_across_dp
,
input_ids
,
positions
,
num_tokens_across_dp
,
input_ids
,
positions
,
...
@@ -622,22 +609,49 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -622,22 +609,49 @@ class V1ZeroModelRunner(GPUModelRunner):
scheduler_output
,
scheduler_output
,
)
)
# Get the valid generated tokens.
sampled_token_ids
=
sampler_output
.
sampled_token_ids
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
fix_req_ids
=
None
fix_req_ids
=
None
fix_sampled_token_ids
=
None
fix_sampled_token_ids
=
None
fix_draft_token_ids
=
None
fix_draft_token_ids
=
None
fix_draft_req_ids
=
self
.
last_sampled_req_ids
fix_draft_req_ids
=
self
.
last_sampled_req_ids
is_output_valid
=
False
is_output_valid
=
False
# Get the valid generated tokens.
sampled_token_ids
=
sampler_output
.
sampled_token_ids
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
if
not
self
.
speculative_config
:
# Speculative decoding is not enabled.
spec_token_ids
=
None
fix_draft_req_ids
=
None
else
:
sampled_token_ids_cpu
=
sampled_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
spec_sampler_event
.
record
()
if
self
.
last_draft_host_tokens
is
not
None
:
self
.
last_draft_event
.
synchronize
()
fix_draft_token_ids
=
self
.
last_draft_host_tokens
.
tolist
()
mask
=
(
sampled_token_ids
==
-
1
)
mask_int
=
mask
.
int
()
first_neg_one_indices
=
torch
.
argmax
(
mask_int
,
dim
=
1
)
num_accepted_tokens_tensor
=
torch
.
where
(
torch
.
any
(
mask
,
dim
=
1
),
first_neg_one_indices
,
sampled_token_ids
.
size
(
1
))
-
1
spec_token_ids
=
self
.
propose_draft_token_ids
(
scheduler_output
,
num_accepted_tokens_tensor
,
sampled_token_ids
,
sampling_metadata
,
hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
attn_metadata
,
)
if
self
.
speculative_config
:
if
self
.
speculative_config
:
self
.
spec_sampler_event
.
synchronize
()
if
max_gen_len
==
1
:
if
max_gen_len
==
1
:
valid_sampled_token_ids
=
sampled_token_ids
.
tolist
()
valid_sampled_token_ids
=
sampled_token_ids
_cpu
.
tolist
()
else
:
else
:
# Includes spec decode tokens.
# Includes spec decode tokens.
valid_sampled_token_ids
=
self
.
rejection_sampler
.
parse_output
(
valid_sampled_token_ids
=
self
.
rejection_sampler
.
parse_output
(
sampled_token_ids
,
sampled_token_ids
_cpu
,
self
.
input_batch
.
vocab_size
,
self
.
input_batch
.
vocab_size
,
)
)
self
.
last_sampler_host_tokens
=
None
self
.
last_sampler_host_tokens
=
None
...
@@ -649,13 +663,21 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -649,13 +663,21 @@ class V1ZeroModelRunner(GPUModelRunner):
if
self
.
last_sampler_host_tokens
!=
None
:
if
self
.
last_sampler_host_tokens
!=
None
:
self
.
last_sampler_event
.
synchronize
()
self
.
last_sampler_event
.
synchronize
()
fix_sampled_token_ids
=
self
.
last_sampler_host_tokens
.
tolist
()
fix_sampled_token_ids
=
self
.
last_sampler_host_tokens
.
tolist
()
for
req_idx
,
start_idx
,
end_idx
in
self
.
token_ids_cpu_fix_recode
:
for
req_idx
,
start_idx
,
end_idx
in
self
.
token_ids_cpu_fix_record
:
self
.
input_batch
.
token_ids_cpu
[
req_idx
,
start_idx
:
end_idx
]
=
fix_sampled_token_ids
[
req_idx
]
if
start_idx
==
-
1
:
continue
req_id
=
fix_req_ids
[
req_idx
]
if
req_id
in
self
.
input_batch
.
req_ids
:
new_req_idx
=
self
.
input_batch
.
req_ids
.
index
(
req_id
)
self
.
input_batch
.
token_ids_cpu
[
new_req_idx
,
start_idx
:
end_idx
]
=
fix_sampled_token_ids
[
req_idx
]
for
req_idx
,
req_id
in
enumerate
(
fix_req_ids
):
for
req_idx
,
req_id
in
enumerate
(
fix_req_ids
):
if
req_id
in
self
.
requests
:
if
req_id
in
self
.
requests
:
req_state
=
self
.
requests
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
token_idx
=
self
.
last_sampled_token_lens
[
req_idx
]
token_idx
=
self
.
last_sampled_token_lens
[
req_idx
]
req_state
.
output_token_ids
[
token_idx
]
=
fix_sampled_token_ids
[
req_idx
][
0
]
if
token_idx
==
-
1
:
continue
fix_len
=
len
(
fix_sampled_token_ids
[
req_idx
])
req_state
.
output_token_ids
[
token_idx
:
token_idx
+
fix_len
]
=
fix_sampled_token_ids
[
req_idx
]
self
.
last_sampler_host_tokens
=
sampled_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
last_sampler_host_tokens
=
sampled_token_ids
.
to
(
'cpu'
,
non_blocking
=
True
)
self
.
last_sampler_event
.
record
()
self
.
last_sampler_event
.
record
()
self
.
last_sampled_token_ids
=
sampled_token_ids
self
.
last_sampled_token_ids
=
sampled_token_ids
...
@@ -670,11 +692,16 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -670,11 +692,16 @@ class V1ZeroModelRunner(GPUModelRunner):
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# NOTE(woosuk): As an exception, when using PP, the scheduler sends
# the sampled tokens back, because there's no direct communication
# the sampled tokens back, because there's no direct communication
# between the first-stage worker and the last-stage worker.
# between the first-stage worker and the last-stage worker.
self
.
token_ids_cpu_fix_recod
e
.
clear
()
self
.
token_ids_cpu_fix_reco
r
d
.
clear
()
self
.
last_sampled_req_ids
=
[]
self
.
last_sampled_req_ids
=
[]
self
.
last_sampled_token_lens
=
[]
self
.
last_sampled_token_lens
=
[]
for
req_idx
,
sampled_ids
in
enumerate
(
valid_sampled_token_ids
):
for
req_idx
,
sampled_ids
in
enumerate
(
valid_sampled_token_ids
):
req_id
=
self
.
input_batch
.
req_ids
[
req_idx
]
self
.
last_sampled_req_ids
.
append
(
req_id
)
cache_output_len
=
-
1
if
not
sampled_ids
:
if
not
sampled_ids
:
self
.
last_sampled_token_lens
.
append
(
-
1
)
self
.
token_ids_cpu_fix_record
.
append
([
req_idx
,
-
1
,
-
1
])
continue
continue
start_idx
=
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
start_idx
=
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
...
@@ -686,34 +713,15 @@ class V1ZeroModelRunner(GPUModelRunner):
...
@@ -686,34 +713,15 @@ class V1ZeroModelRunner(GPUModelRunner):
self
.
input_batch
.
token_ids_cpu
[
req_idx
,
self
.
input_batch
.
token_ids_cpu
[
req_idx
,
start_idx
:
end_idx
]
=
sampled_ids
start_idx
:
end_idx
]
=
sampled_ids
self
.
token_ids_cpu_fix_recod
e
.
append
([
req_idx
,
start_idx
,
end_idx
])
self
.
token_ids_cpu_fix_reco
r
d
.
append
([
req_idx
,
start_idx
,
end_idx
])
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens_no_spec
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens
[
req_idx
]
=
end_idx
self
.
input_batch
.
num_tokens
[
req_idx
]
=
end_idx
req_id
=
self
.
input_batch
.
req_ids
[
req_idx
]
if
req_id
in
self
.
requests
:
if
req_id
in
self
.
requests
:
req_state
=
self
.
requests
[
req_id
]
req_state
=
self
.
requests
[
req_id
]
self
.
last_sampled_req_ids
.
append
(
req_id
)
cache_output_len
=
len
(
req_state
.
output_token_ids
)
self
.
last_sampled_token_lens
.
append
(
len
(
req_state
.
output_token_ids
))
req_state
.
output_token_ids
.
extend
(
sampled_ids
)
req_state
.
output_token_ids
.
extend
(
sampled_ids
)
self
.
last_sampled_token_lens
.
append
(
cache_output_len
)
if
not
self
.
speculative_config
:
# Speculative decoding is not enabled.
spec_token_ids
=
None
fix_draft_req_ids
=
None
else
:
if
self
.
last_draft_host_tokens
is
not
None
:
self
.
last_draft_event
.
synchronize
()
fix_draft_token_ids
=
self
.
last_draft_host_tokens
.
tolist
()
spec_token_ids
=
self
.
propose_draft_token_ids
(
scheduler_output
,
valid_sampled_token_ids
,
sampling_metadata
,
hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
attn_metadata
,
)
# Clear KVConnector state after all KVs are generated.
# Clear KVConnector state after all KVs are generated.
if
has_kv_transfer_group
():
if
has_kv_transfer_group
():
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment