Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
76572db3
Commit
76572db3
authored
Aug 19, 2025
by
zhuwenwen
Browse files
Merge branch 'v0.9.2-dev' of
http://10.16.6.30/dcutoolkit/deeplearing/vllm
into v0.9.2-dev
parents
864c718a
f3e13c54
Changes
15
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
1337 additions
and
319 deletions
+1337
-319
vllm/attention/backends/flashmla.py
vllm/attention/backends/flashmla.py
+7
-2
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+1
-1
vllm/attention/ops/flashmla.py
vllm/attention/ops/flashmla.py
+18
-0
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+2
-1
vllm/envs.py
vllm/envs.py
+2
-2
vllm/forward_context.py
vllm/forward_context.py
+13
-0
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+14
-11
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+344
-14
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+17
-5
vllm/v1/attention/backends/mla/flashmla.py
vllm/v1/attention/backends/mla/flashmla.py
+7
-2
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+26
-29
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+8
-2
vllm/zero_overhead/v1/core.py
vllm/zero_overhead/v1/core.py
+153
-110
vllm/zero_overhead/v1/gpu_model_runner.py
vllm/zero_overhead/v1/gpu_model_runner.py
+721
-139
vllm/zero_overhead/v1/outputs.py
vllm/zero_overhead/v1/outputs.py
+4
-1
No files found.
vllm/attention/backends/flashmla.py
View file @
76572db3
...
@@ -211,8 +211,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -211,8 +211,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl"
)
"FlashMLAImpl"
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
if
self
.
kv_cache_dtype
!=
"fp8"
:
"FlashMLA with FP8 KV cache not yet supported"
)
raise
NotImplementedError
(
"FlashMLA with other KV cache not yet supported"
)
def
_forward_decode
(
def
_forward_decode
(
self
,
self
,
...
@@ -220,6 +221,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -220,6 +221,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q_pe
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashMLAMetadata
,
attn_metadata
:
FlashMLAMetadata
,
k_scale
=
None
,
kv_cache_dtype
=
"auto"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
...
@@ -239,6 +242,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -239,6 +242,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits
=
decode_meta
.
decode_num_splits
,
num_splits
=
decode_meta
.
decode_num_splits
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
k_scale
=
k_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
)
)
return
self
.
_v_up_proj
(
o
)
return
self
.
_v_up_proj
(
o
)
vllm/attention/backends/mla/common.py
View file @
76572db3
...
@@ -1397,6 +1397,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
...
@@ -1397,6 +1397,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
output
[
num_prefill_tokens
:]
=
self
.
_forward_decode
(
output
[
num_prefill_tokens
:]
=
self
.
_forward_decode
(
decode_ql_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
)
decode_ql_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
,
layer
.
_k_scale
,
self
.
kv_cache_dtype
)
return
output
return
output
\ No newline at end of file
vllm/attention/ops/flashmla.py
View file @
76572db3
...
@@ -75,6 +75,8 @@ def flash_mla_with_kvcache(
...
@@ -75,6 +75,8 @@ def flash_mla_with_kvcache(
num_splits
:
torch
.
Tensor
,
num_splits
:
torch
.
Tensor
,
softmax_scale
:
Optional
[
float
]
=
None
,
softmax_scale
:
Optional
[
float
]
=
None
,
causal
:
bool
=
False
,
causal
:
bool
=
False
,
k_scale
=
None
,
kv_cache_dtype
=
"auto"
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
Arguments:
Arguments:
...
@@ -97,6 +99,22 @@ def flash_mla_with_kvcache(
...
@@ -97,6 +99,22 @@ def flash_mla_with_kvcache(
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
if
kv_cache_dtype
==
"fp8"
:
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla
(
q
,
k_cache
,
None
,
head_dim_v
,
cache_seqlens
,
block_table
,
softmax_scale
,
causal
,
tile_scheduler_metadata
,
num_splits
,
k_scale
,
"fp8_e4m3"
,
)
return
out
,
softmax_lse
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla
(
out
,
softmax_lse
=
flash_mla_cuda
.
fwd_kvcache_mla
(
q
,
q
,
k_cache
,
k_cache
,
...
...
vllm/compilation/decorators.py
View file @
76572db3
...
@@ -11,6 +11,7 @@ from torch._dynamo.symbolic_convert import InliningInstructionTranslator
...
@@ -11,6 +11,7 @@ from torch._dynamo.symbolic_convert import InliningInstructionTranslator
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.forward_context
import
get_profilling
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
...
@@ -169,7 +170,7 @@ def _support_torch_compile(
...
@@ -169,7 +170,7 @@ def _support_torch_compile(
# torch.compiler.is_compiling() means we are inside the compilation
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
# need to compile the model inside.
if
self
.
do_not_compile
or
torch
.
compiler
.
is_compiling
():
if
self
.
do_not_compile
or
torch
.
compiler
.
is_compiling
()
or
get_profilling
()
:
return
self
.
forward
(
*
args
,
**
kwargs
)
return
self
.
forward
(
*
args
,
**
kwargs
)
# the first compilation needs to have dynamic shapes marked
# the first compilation needs to have dynamic shapes marked
...
...
vllm/envs.py
View file @
76572db3
...
@@ -1087,7 +1087,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1087,7 +1087,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# vLLM will use global cache for moe
# vLLM will use global cache for moe
"VLLM_USE_GLOBAL_CACHE13"
:
"VLLM_USE_GLOBAL_CACHE13"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_GLOBAL_CACHE13"
,
"
Tru
e"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_GLOBAL_CACHE13"
,
"
Fals
e"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
}
}
...
@@ -1162,4 +1162,4 @@ def compute_hash() -> str:
...
@@ -1162,4 +1162,4 @@ def compute_hash() -> str:
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
(),
hash_str
=
hashlib
.
md5
(
str
(
factors
).
encode
(),
usedforsecurity
=
False
).
hexdigest
()
usedforsecurity
=
False
).
hexdigest
()
return
hash_str
return
hash_str
\ No newline at end of file
vllm/forward_context.py
View file @
76572db3
...
@@ -196,3 +196,16 @@ def set_forward_context(
...
@@ -196,3 +196,16 @@ def set_forward_context(
_forward_context
=
prev_context
_forward_context
=
prev_context
if
envs
.
VLLM_ENABLE_TBO
:
if
envs
.
VLLM_ENABLE_TBO
:
set_tbo_forward_context
(
_forward_context
)
set_tbo_forward_context
(
_forward_context
)
_profiling
:
bool
=
False
@
contextmanager
def
set_profilling
(
profiling
):
global
_profiling
_profiling
=
profiling
def
get_profilling
()
->
bool
:
global
_profiling
return
_profiling
\ No newline at end of file
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
76572db3
...
@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
...
@@ -18,7 +18,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_make_workspace_new
,
maybe_warn_marlin_atomic_add
)
marlin_make_workspace_new
,
maybe_warn_marlin_atomic_add
)
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.scalar_type
import
ScalarType
,
scalar_types
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
from
vllm.model_executor.layers.fused_moe.fused_moe
import
get_moe_cache
def
get_scalar_type
(
num_bits
:
int
,
has_zp
:
bool
):
def
get_scalar_type
(
num_bits
:
int
,
has_zp
:
bool
):
if
has_zp
:
if
has_zp
:
return
scalar_types
.
uint4
if
num_bits
==
4
else
scalar_types
.
uint8
return
scalar_types
.
uint4
if
num_bits
==
4
else
scalar_types
.
uint8
...
@@ -104,8 +104,8 @@ def fused_marlin_moe(
...
@@ -104,8 +104,8 @@ def fused_marlin_moe(
topk
=
topk_ids
.
shape
[
1
]
# 8
topk
=
topk_ids
.
shape
[
1
]
# 8
#暂时固定为16384
#暂时固定为16384
CHUNK_SIZE
=
16384
#
CHUNK_SIZE = 16384
CHUNK_SIZE
=
envs
.
VLLM_FUSED_MOE_CHUNK_SIZE
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
M
=
min
(
num_tokens
,
CHUNK_SIZE
)
if
workspace
is
None
:
if
workspace
is
None
:
...
@@ -120,18 +120,21 @@ def fused_marlin_moe(
...
@@ -120,18 +120,21 @@ def fused_marlin_moe(
if
global_num_experts
==
-
1
:
if
global_num_experts
==
-
1
:
global_num_experts
=
E
global_num_experts
=
E
intermediate_cache2
=
torch
.
empty
(
intermediate_cache2
=
torch
.
empty
(
(
M
*
topk_ids
.
shape
[
1
],
N
),
(
M
*
topk
,
N
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache13
=
torch
.
empty
(
(
M
*
topk_ids
.
shape
[
1
]
*
max
(
2
*
N
,
K
),
),
device
=
hidden_states
.
device
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
dtype
=
hidden_states
.
dtype
,
)
)
intermediate_cache1
=
intermediate_cache13
[:
M
*
topk_ids
.
shape
[
1
]
*
2
*
N
]
if
envs
.
VLLM_USE_GLOBAL_CACHE13
:
intermediate_cache13
=
get_moe_cache
(
topk
,
N
,
K
,
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
else
:
intermediate_cache13
=
torch
.
empty
(
(
M
*
topk
*
max
(
2
*
N
,
K
),
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
intermediate_cache1
=
intermediate_cache13
[:
M
*
topk
*
2
*
N
]
intermediate_cache1
=
intermediate_cache1
.
view
(
-
1
,
2
*
N
)
intermediate_cache1
=
intermediate_cache1
.
view
(
-
1
,
2
*
N
)
intermediate_cache3
=
intermediate_cache13
[:
M
*
topk
_ids
.
shape
[
1
]
*
K
]
intermediate_cache3
=
intermediate_cache13
[:
M
*
topk
*
K
]
intermediate_cache3
=
intermediate_cache3
.
view
(
-
1
,
K
)
intermediate_cache3
=
intermediate_cache3
.
view
(
-
1
,
K
)
use_atomic_add
=
hidden_states
.
dtype
==
torch
.
half
or
\
use_atomic_add
=
hidden_states
.
dtype
==
torch
.
half
or
\
...
...
vllm/model_executor/models/deepseek_mtp.py
View file @
76572db3
...
@@ -58,6 +58,11 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
...
@@ -58,6 +58,11 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
enorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
enorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
hnorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
hnorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
eh_proj
=
nn
.
Linear
(
config
.
hidden_size
*
2
,
self
.
eh_proj
=
nn
.
Linear
(
config
.
hidden_size
*
2
,
...
@@ -75,6 +80,8 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
...
@@ -75,6 +80,8 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
spec_step_index
:
int
=
0
,
spec_step_index
:
int
=
0
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
assert
inputs_embeds
is
not
None
assert
inputs_embeds
is
not
None
# masking inputs at position 0, as not needed by MTP
# masking inputs at position 0, as not needed by MTP
inputs_embeds
[
positions
==
0
]
=
0
inputs_embeds
[
positions
==
0
]
=
0
...
@@ -111,10 +118,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
...
@@ -111,10 +118,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
for
idx
in
range
(
self
.
mtp_start_layer_idx
,
for
idx
in
range
(
self
.
mtp_start_layer_idx
,
self
.
mtp_start_layer_idx
+
self
.
num_mtp_layers
)
self
.
mtp_start_layer_idx
+
self
.
num_mtp_layers
)
})
})
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
def
forward
(
def
forward
(
...
@@ -125,8 +129,6 @@ class DeepSeekMultiTokenPredictor(nn.Module):
...
@@ -125,8 +129,6 @@ class DeepSeekMultiTokenPredictor(nn.Module):
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
spec_step_idx
:
int
=
0
,
spec_step_idx
:
int
=
0
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
current_step_idx
=
(
spec_step_idx
%
self
.
num_mtp_layers
)
current_step_idx
=
(
spec_step_idx
%
self
.
num_mtp_layers
)
return
self
.
layers
[
str
(
self
.
mtp_start_layer_idx
+
current_step_idx
)](
return
self
.
layers
[
str
(
self
.
mtp_start_layer_idx
+
current_step_idx
)](
input_ids
,
input_ids
,
...
@@ -308,25 +310,353 @@ class DeepSeekMTP(nn.Module, SupportsPP):
...
@@ -308,25 +310,353 @@ class DeepSeekMTP(nn.Module, SupportsPP):
"""
"""
Rewrite the weight name to match the format of the original model.
Rewrite the weight name to match the format of the original model.
Add .mtp_block for modules in transformer layer block for spec layer
Add .mtp_block for modules in transformer layer block for spec layer
and rename shared layer weights to be top level.
"""
"""
spec_layer_weight_names
=
[
spec_layer_weight_names
=
[
"embed_tokens"
,
"enorm"
,
"hnorm"
,
"eh_proj"
,
"shared_head"
"embed_tokens"
,
"enorm"
,
"hnorm"
,
"eh_proj"
,
"shared_head"
]
]
shared_weight_names
=
[
"embed_tokens"
]
spec_layer_weight
=
False
spec_layer_weight
=
False
shared_weight
=
False
for
weight_name
in
spec_layer_weight_names
:
for
weight_name
in
spec_layer_weight_names
:
if
weight_name
in
name
:
if
weight_name
in
name
:
spec_layer_weight
=
True
spec_layer_weight
=
True
if
weight_name
in
shared_weight_names
:
shared_weight
=
True
break
break
if
not
spec_layer_weight
:
if
not
spec_layer_weight
:
# treat rest weights as weights for transformer layer block
# treat rest weights as weights for transformer layer block
name
=
name
.
replace
(
f
"model.layers.
{
spec_layer
}
."
,
name
=
name
.
replace
(
f
"model.layers.
{
spec_layer
}
."
,
f
"model.layers.
{
spec_layer
}
.mtp_block."
)
f
"model.layers.
{
spec_layer
}
.mtp_block."
)
elif
shared_weight
:
# treat shared weights as top level weights
name
=
name
.
replace
(
f
"model.layers.
{
spec_layer
}
."
,
"model."
)
return
name
return
name
# # SPDX-License-Identifier: Apache-2.0
# # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# import os
# import re
# from collections.abc import Iterable
# from typing import Iterable, Optional
# import torch
# import torch.nn as nn
# from transformers import PretrainedConfig
# from vllm.config import CacheConfig, ModelConfig, VllmConfig
# from vllm.model_executor.layers.fused_moe import FusedMoE
# from vllm.model_executor.layers.layernorm import RMSNorm
# from vllm.model_executor.layers.logits_processor import LogitsProcessor
# from vllm.model_executor.layers.quantization import QuantizationConfig
# from vllm.model_executor.layers.vocab_parallel_embedding import (
# ParallelLMHead, VocabParallelEmbedding)
# from vllm.model_executor.model_loader.weight_utils import default_weight_loader
# from vllm.model_executor.sampling_metadata import SamplingMetadata
# from vllm.sequence import IntermediateTensors
# from vllm.compilation.decorators import support_torch_compile
# from .deepseek_v2 import (DeepseekV2DecoderLayer,
# get_spec_layer_idx_from_weight_name)
# from .interfaces import SupportsPP
# from .utils import maybe_prefix
# from vllm import _custom_ops as ops
# from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
# class SharedHead(nn.Module):
# def __init__(
# self,
# config: PretrainedConfig,
# quant_config: Optional[QuantizationConfig] = None,
# ) -> None:
# super().__init__()
# self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.head = ParallelLMHead(config.vocab_size,
# config.hidden_size,
# quant_config=quant_config)
# def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# return self.norm(hidden_states)
# class DeepSeekMultiTokenPredictorLayer(nn.Module):
# def __init__(
# self,
# config: PretrainedConfig,
# prefix: str,
# model_config: ModelConfig,
# cache_config: Optional[CacheConfig] = None,
# quant_config: Optional[QuantizationConfig] = None,
# ) -> None:
# super().__init__()
# self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# self.eh_proj = nn.Linear(config.hidden_size * 2,
# config.hidden_size,
# bias=False)
# self.shared_head = SharedHead(config=config, quant_config=quant_config)
# self.mtp_block = DeepseekV2DecoderLayer(config, prefix, model_config,
# cache_config, quant_config)
# def forward(
# self,
# input_ids: torch.Tensor,
# positions: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# inputs_embeds: Optional[torch.Tensor] = None,
# spec_step_index: int = 0,
# ) -> torch.Tensor:
# assert inputs_embeds is not None
# # masking inputs at position 0, as not needed by MTP
# inputs_embeds[positions == 0] = 0
# inputs_embeds = self.enorm(inputs_embeds)
# previous_hidden_states = self.hnorm(previous_hidden_states)
# hidden_states = self.eh_proj(
# torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
# hidden_states, residual = self.mtp_block(positions=positions,
# hidden_states=hidden_states,
# residual=None)
# hidden_states = residual + hidden_states
# return hidden_states
# class DeepSeekMultiTokenPredictor(nn.Module):
# def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# super().__init__()
# config = vllm_config.model_config.hf_config
# self.mtp_start_layer_idx = config.num_hidden_layers
# self.num_mtp_layers = config.num_nextn_predict_layers
# # to map the exact layer index from weights
# self.layers = torch.nn.ModuleDict({
# str(idx):
# DeepSeekMultiTokenPredictorLayer(
# config,
# f"{prefix}.layers.{idx}",
# model_config=vllm_config.model_config,
# cache_config=vllm_config.cache_config,
# quant_config=vllm_config.quant_config,
# )
# for idx in range(self.mtp_start_layer_idx,
# self.mtp_start_layer_idx + self.num_mtp_layers)
# })
# self.embed_tokens = VocabParallelEmbedding(
# config.vocab_size,
# config.hidden_size,
# )
# self.logits_processor = LogitsProcessor(config.vocab_size)
# def forward(
# self,
# input_ids: torch.Tensor,
# positions: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# inputs_embeds: Optional[torch.Tensor] = None,
# spec_step_idx: int = 0,
# ) -> torch.Tensor:
# if inputs_embeds is None:
# inputs_embeds = self.embed_tokens(input_ids)
# current_step_idx = (spec_step_idx % self.num_mtp_layers)
# return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
# input_ids,
# positions,
# previous_hidden_states,
# inputs_embeds,
# current_step_idx,
# )
# def compute_logits(
# self,
# hidden_states: torch.Tensor,
# sampling_metadata: SamplingMetadata,
# spec_step_idx: int = 0,
# ) -> torch.Tensor:
# current_step_idx = (spec_step_idx % self.num_mtp_layers)
# mtp_layer = self.layers[str(self.mtp_start_layer_idx +
# current_step_idx)]
# logits = self.logits_processor(mtp_layer.shared_head.head,
# mtp_layer.shared_head(hidden_states),
# sampling_metadata)
# return logits
# @support_torch_compile
# class DeepSeekMTP(nn.Module, SupportsPP):
# def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# super().__init__()
# self.config = vllm_config.model_config.hf_config
# quant_config = vllm_config.quant_config
# self.quant_method = None
# if quant_config is not None:
# self.quant_method = quant_config.get_name()
# os.environ['LLAMA_NN'] = '0'
# os.environ['LM_NN'] = '0'
# # The AWQ layer of MTP uses BlockInt8W8A8.
# if self.quant_method == "moe_wna16" or self.quant_method == "awq_marlin":
# vllm_config.quant_config = BlockInt8Config(is_checkpoint_int8_serialized=True, weight_block_size=[128,128])
# self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config,
# prefix=maybe_prefix(
# prefix, "model"))
# self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
# def forward(
# self,
# input_ids: torch.Tensor,
# positions: torch.Tensor,
# previous_hidden_states: torch.Tensor,
# intermediate_tensors: Optional[IntermediateTensors] = None,
# inputs_embeds: Optional[torch.Tensor] = None,
# spec_step_idx: int = 0,
# ) -> torch.Tensor:
# hidden_states = self.model(input_ids, positions,
# previous_hidden_states, inputs_embeds,
# spec_step_idx)
# return hidden_states
# def compute_logits(
# self,
# hidden_states: torch.Tensor,
# sampling_metadata: SamplingMetadata,
# spec_step_idx: int = 0,
# ) -> Optional[torch.Tensor]:
# return self.model.compute_logits(hidden_states, sampling_metadata,
# spec_step_idx)
# def load_weights(self, weights: Iterable[tuple[str,
# torch.Tensor]]) -> set[str]:
# stacked_params_mapping = [
# ("gate_up_proj", "gate_proj", 0),
# ("gate_up_proj", "up_proj", 1),
# ]
# expert_params_mapping = FusedMoE.make_expert_params_mapping(
# ckpt_gate_proj_name="gate_proj",
# ckpt_down_proj_name="down_proj",
# ckpt_up_proj_name="up_proj",
# num_experts=self.config.n_routed_experts)
# params_dict = dict(self.named_parameters())
# loaded_params: set[str] = set()
# for name, loaded_weight in weights:
# if "rotary_emb.inv_freq" in name:
# continue
# spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
# if spec_layer is None:
# continue
# name = self._rewrite_spec_layer_name(spec_layer, name)
# for (param_name, weight_name, shard_id) in stacked_params_mapping:
# # Skip non-stacked layers and experts (experts handled below).
# if weight_name not in name:
# continue
# # We have mlp.experts[0].gate_proj in the checkpoint.
# # Since we handle the experts below in expert_params_mapping,
# # we need to skip here BEFORE we update the name, otherwise
# # name will be updated to mlp.experts[0].gate_up_proj, which
# # will then be updated below in expert_params_mapping
# # for mlp.experts[0].gate_gate_up_proj, which breaks load.
# if (("mlp.experts." in name) and name not in params_dict):
# continue
# name = name.replace(weight_name, param_name)
# # Skip loading extra bias for GPTQ models.
# if name.endswith(".bias") and name not in params_dict:
# continue
# param = params_dict[name]
# weight_loader = param.weight_loader
# weight_loader(param, loaded_weight, shard_id)
# break
# else:
# for mapping in expert_params_mapping:
# param_name, weight_name, expert_id, shard_id = mapping
# if weight_name not in name:
# continue
# name = name.replace(weight_name, param_name)
# param = params_dict[name]
# weight_loader = param.weight_loader
# weight_loader(param,
# loaded_weight,
# name,
# shard_id=shard_id,
# expert_id=expert_id)
# break
# else:
# # Skip loading extra bias for GPTQ models.
# if name.endswith(".bias") and name not in params_dict:
# continue
# # According to DeepSeek-V3 Technical Report, MTP modules
# # shares embedding layer. We only load the first weights.
# if (spec_layer != self.model.mtp_start_layer_idx
# and ".layers" not in name):
# continue
# param = params_dict[name]
# weight_loader = getattr(param, "weight_loader",
# default_weight_loader)
# weight_loader(param, loaded_weight)
# loaded_params.add(name)
# if self.use_llama_nn and self.quant_method is None:
# lay_key_words = [
# "self_attn.eh_proj.weight",
# "self_attn.q_proj.weight",
# "self_attn.q_a_proj.weight",
# "self_attn.q_b_proj.weight",
# "self_attn.kv_a_proj_with_mqa.weight",
# "self_attn.kv_b_proj.weight",
# "self_attn.o_proj.weight",
# "mlp.gate_up_proj.weight",
# "mlp.down_proj.weight",
# "mlp.gate.weight",
# "shared_experts.gate_up_proj.weight",
# "shared_experts.down_proj.weight",
# "shared_head.head.weight",
# ]
# combined_words = "|".join(lay_key_words)
# for layername in loaded_params:
# weight = params_dict[layername]
# matches = re.findall(combined_words, layername)
# if matches:
# _weight = torch.zeros_like(weight.data)
# ori_shape =_weight.shape
# ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
# weight.data.copy_(_weight)
# weight.data=weight.data.reshape(ori_shape[1],-1)
# return loaded_params
# def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
# """
# Rewrite the weight name to match the format of the original model.
# Add .mtp_block for modules in transformer layer block for spec layer
# and rename shared layer weights to be top level.
# """
# spec_layer_weight_names = [
# "embed_tokens", "enorm", "hnorm", "eh_proj", "shared_head"
# ]
# shared_weight_names = ["embed_tokens"]
# spec_layer_weight = False
# shared_weight = False
# for weight_name in spec_layer_weight_names:
# if weight_name in name:
# spec_layer_weight = True
# if weight_name in shared_weight_names:
# shared_weight = True
# break
# if not spec_layer_weight:
# # treat rest weights as weights for transformer layer block
# name = name.replace(f"model.layers.{spec_layer}.",
# f"model.layers.{spec_layer}.mtp_block.")
# elif shared_weight:
# # treat shared weights as top level weights
# name = name.replace(f"model.layers.{spec_layer}.", "model.")
# return name
vllm/v1/attention/backends/mla/common.py
View file @
76572db3
...
@@ -647,10 +647,22 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
...
@@ -647,10 +647,22 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
repeats
=
torch
.
from_numpy
(
query_lens
).
pin_memory
().
to
(
repeats
=
torch
.
from_numpy
(
query_lens
).
pin_memory
().
to
(
block_table_tensor
.
device
,
non_blocking
=
True
).
contiguous
()
block_table_tensor
.
device
,
non_blocking
=
True
).
contiguous
()
decode_block_table_tensor
=
torch
.
repeat_interleave
(
block_table_tensor
[:
self
.
_num_decodes
,
...],
if
envs
.
VLLM_ZERO_OVERHEAD
:
repeats
,
dim
=
0
).
contiguous
()
decode_block_table_tensor
=
torch
.
empty
((
self
.
_num_decode_tokens
,
block_table_tensor
.
shape
[
1
]),
decode_seq_lens
=
torch
.
repeat_interleave
(
seq_lens
[:
self
.
_num_decodes
],
repeats
,
dim
=
0
).
contiguous
()
device
=
block_table_tensor
.
device
)
arange_np
=
np
.
arange
(
self
.
_num_decodes
)
indices_np
=
np
.
repeat
(
arange_np
,
query_lens
)
indices
=
torch
.
from_numpy
(
indices_np
).
pin_memory
().
to
(
block_table_tensor
.
device
,
non_blocking
=
True
)
decode_block_table_tensor
=
block_table_tensor
[
indices
].
contiguous
()
decode_seq_lens
=
seq_lens
[
indices
].
contiguous
()
else
:
decode_block_table_tensor
=
torch
.
repeat_interleave
(
block_table_tensor
[:
self
.
_num_decodes
,
...],
repeats
,
dim
=
0
).
contiguous
()
decode_seq_lens
=
torch
.
repeat_interleave
(
seq_lens
[:
self
.
_num_decodes
],
repeats
,
dim
=
0
).
contiguous
()
seq_lens_minus
=
torch
.
from_numpy
(
rarange
).
to
(
torch
.
int32
).
pin_memory
().
to
(
seq_lens_minus
=
torch
.
from_numpy
(
rarange
).
to
(
torch
.
int32
).
pin_memory
().
to
(
seq_lens
.
device
,
non_blocking
=
True
).
contiguous
()
seq_lens
.
device
,
non_blocking
=
True
).
contiguous
()
decode_seq_lens
=
decode_seq_lens
-
seq_lens_minus
decode_seq_lens
=
decode_seq_lens
-
seq_lens_minus
...
@@ -1086,6 +1098,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
...
@@ -1086,6 +1098,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
decode_ql_nope
=
decode_ql_nope
.
transpose
(
0
,
1
)
output
[:
num_decode_tokens
]
=
self
.
_forward_decode
(
output
[:
num_decode_tokens
]
=
self
.
_forward_decode
(
decode_ql_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
)
decode_ql_nope
,
decode_q_pe
,
kv_cache
,
attn_metadata
,
layer
.
_k_scale
,
self
.
kv_cache_dtype
)
return
output_padded
return
output_padded
\ No newline at end of file
vllm/v1/attention/backends/mla/flashmla.py
View file @
76572db3
...
@@ -148,8 +148,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -148,8 +148,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl"
)
"FlashMLAImpl"
)
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
if
self
.
kv_cache_dtype
!=
"fp8"
:
"FlashMLA V1 with FP8 KV cache not yet supported"
)
raise
NotImplementedError
(
"FlashMLA with other KV cache not yet supported"
)
def
_forward_decode
(
def
_forward_decode
(
self
,
self
,
...
@@ -157,6 +158,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -157,6 +158,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q_pe
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
FlashMLAMetadata
,
attn_metadata
:
FlashMLAMetadata
,
k_scale
=
None
,
kv_cache_dtype
=
"auto"
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
assert
attn_metadata
.
decode
is
not
None
assert
attn_metadata
.
decode
is
not
None
...
@@ -175,6 +178,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
...
@@ -175,6 +178,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits
=
attn_metadata
.
decode
.
num_splits
,
num_splits
=
attn_metadata
.
decode
.
num_splits
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
causal
=
True
,
k_scale
=
k_scale
,
kv_cache_dtype
=
kv_cache_dtype
,
)
)
return
self
.
_v_up_proj
(
o
)
return
self
.
_v_up_proj
(
o
)
vllm/v1/worker/gpu_model_runner.py
View file @
76572db3
...
@@ -29,7 +29,7 @@ from vllm.distributed.parallel_state import (
...
@@ -29,7 +29,7 @@ from vllm.distributed.parallel_state import (
get_pp_group
,
get_tp_group
,
graph_capture
,
is_global_first_rank
,
get_pp_group
,
get_tp_group
,
graph_capture
,
is_global_first_rank
,
prepare_communication_buffer_for_model
)
prepare_communication_buffer_for_model
)
from
vllm.forward_context
import
(
DPMetadata
,
get_forward_context
,
from
vllm.forward_context
import
(
DPMetadata
,
get_forward_context
,
set_forward_context
)
set_forward_context
,
set_profilling
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.mamba.mamba_mixer2
import
MambaMixer2
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
...
@@ -69,7 +69,6 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
...
@@ -69,7 +69,6 @@ from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.two_batch_overlap.v1.model_input_split_v1
import
tbo_split_and_execute_model
from
vllm.two_batch_overlap.v1.model_input_split_v1
import
tbo_split_and_execute_model
from
vllm.zero_overhead.v1.gpu_model_runner
import
execute_model_sampled
,
zero_prepare_inputs
from
..sample.logits_processor
import
LogitsProcessorManager
from
..sample.logits_processor
import
LogitsProcessorManager
from
.utils
import
(
gather_mm_placeholders
,
initialize_kv_cache_for_kv_sharing
,
from
.utils
import
(
gather_mm_placeholders
,
initialize_kv_cache_for_kv_sharing
,
...
@@ -955,15 +954,25 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -955,15 +954,25 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# [0, 1, 2, 5, 6, 9]
# [0, 1, 2, 5, 6, 9]
target_logits_indices
+=
arange
target_logits_indices
+=
arange
# TODO: Optimize the CPU -> GPU copy.
if
envs
.
VLLM_ZERO_OVERHEAD
:
cu_num_draft_tokens
=
torch
.
from_numpy
(
cu_num_draft_tokens
).
to
(
cu_num_draft_tokens
=
torch
.
from_numpy
(
cu_num_draft_tokens
).
pin_memory
().
to
(
self
.
device
,
non_blocking
=
True
)
self
.
device
,
non_blocking
=
True
)
logits_indices
=
torch
.
from_numpy
(
logits_indices
).
to
(
self
.
device
,
logits_indices
=
torch
.
from_numpy
(
logits_indices
).
pin_memory
().
to
(
self
.
device
,
non_blocking
=
True
)
non_blocking
=
True
)
target_logits_indices
=
torch
.
from_numpy
(
target_logits_indices
).
to
(
target_logits_indices
=
torch
.
from_numpy
(
target_logits_indices
).
pin_memory
().
to
(
self
.
device
,
non_blocking
=
True
)
self
.
device
,
non_blocking
=
True
)
bonus_logits_indices
=
torch
.
from_numpy
(
bonus_logits_indices
).
to
(
bonus_logits_indices
=
torch
.
from_numpy
(
bonus_logits_indices
).
pin_memory
().
to
(
self
.
device
,
non_blocking
=
True
)
self
.
device
,
non_blocking
=
True
)
else
:
# TODO: Optimize the CPU -> GPU copy.
cu_num_draft_tokens
=
torch
.
from_numpy
(
cu_num_draft_tokens
).
to
(
self
.
device
,
non_blocking
=
True
)
logits_indices
=
torch
.
from_numpy
(
logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
target_logits_indices
=
torch
.
from_numpy
(
target_logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
bonus_logits_indices
=
torch
.
from_numpy
(
bonus_logits_indices
).
to
(
self
.
device
,
non_blocking
=
True
)
# Compute the draft token ids.
# Compute the draft token ids.
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
# draft_token_indices: [ 1, 2, 3, 105, 106, 208]
...
@@ -1364,8 +1373,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1364,8 +1373,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# compiled with full CUDA graphs, we have to skip them entirely.
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs
=
self
.
full_cuda_graph
and
not
attention_cuda_graphs
skip_cuda_graphs
=
self
.
full_cuda_graph
and
not
attention_cuda_graphs
if
envs
.
VLLM_ZERO_OVERHEAD
:
zero_prepare_inputs
(
self
,
scheduler_output
,
input_ids
)
if
envs
.
VLLM_ENABLE_TBO
and
not
self
.
use_cuda_graph
:
if
envs
.
VLLM_ENABLE_TBO
and
not
self
.
use_cuda_graph
:
model_output
,
finished_sending
,
finished_recving
=
\
model_output
,
finished_sending
,
finished_recving
=
\
tbo_split_and_execute_model
(
self
,
attn_metadata
,
num_input_tokens
,
tbo_split_and_execute_model
(
self
,
attn_metadata
,
num_input_tokens
,
...
@@ -1507,21 +1514,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1507,21 +1514,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
sampled_token_ids
=
sampler_output
.
sampled_token_ids
sampled_token_ids
=
sampler_output
.
sampled_token_ids
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
max_gen_len
=
sampled_token_ids
.
shape
[
-
1
]
if
envs
.
VLLM_ZERO_OVERHEAD
:
return
execute_model_sampled
(
self
,
max_gen_len
,
sampled_token_ids
,
discard_sampled_tokens_req_indices
,
scheduler_output
,
sampling_metadata
,
hidden_states
,
sample_hidden_states
,
aux_hidden_states
,
spec_decode_metadata
,
attn_metadata
,
logprobs_lists
,
prompt_logprobs_dict
,
finished_sending
,
finished_recving
,
num_nans_in_logits
)
if
max_gen_len
==
1
:
if
max_gen_len
==
1
:
# No spec decode tokens.
# No spec decode tokens.
valid_sampled_token_ids
=
sampled_token_ids
.
tolist
()
valid_sampled_token_ids
=
sampled_token_ids
.
tolist
()
...
@@ -2095,7 +2087,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2095,7 +2087,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else
:
else
:
hidden_states
=
outputs
hidden_states
=
outputs
if
self
.
speculative_config
and
self
.
speculative_config
.
use_eagle
():
if
self
.
speculative_config
and
self
.
speculative_config
.
use_eagle
()
and
not
is_profile
:
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
assert
isinstance
(
self
.
drafter
,
EagleProposer
)
self
.
drafter
.
dummy_run
(
num_tokens
,
attn_metadata
)
self
.
drafter
.
dummy_run
(
num_tokens
,
attn_metadata
)
...
@@ -2230,6 +2222,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2230,6 +2222,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
return
pooler_output
return
pooler_output
def
profile_run
(
self
)
->
None
:
def
profile_run
(
self
)
->
None
:
# set profiling flag to avoid torch compile
set_profilling
(
True
)
self
.
_sync_device
()
# Profile with multimodal encoder & encoder cache.
# Profile with multimodal encoder & encoder cache.
# TODO: handle encoder-decoder models once we support them.
# TODO: handle encoder-decoder models once we support them.
if
(
self
.
is_multimodal_model
and
self
.
max_num_encoder_input_tokens
>
0
if
(
self
.
is_multimodal_model
and
self
.
max_num_encoder_input_tokens
>
0
...
@@ -2313,6 +2309,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2313,6 +2309,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
del
hidden_states
,
output
del
hidden_states
,
output
self
.
encoder_cache
.
clear
()
self
.
encoder_cache
.
clear
()
gc
.
collect
()
gc
.
collect
()
set_profilling
(
False
)
def
capture_model
(
self
)
->
None
:
def
capture_model
(
self
)
->
None
:
if
not
self
.
use_cuda_graph
:
if
not
self
.
use_cuda_graph
:
...
...
vllm/v1/worker/gpu_worker.py
View file @
76572db3
...
@@ -29,6 +29,7 @@ from vllm.v1.utils import report_usage_stats
...
@@ -29,6 +29,7 @@ from vllm.v1.utils import report_usage_stats
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.worker_base
import
WorkerBase
from
vllm.v1.worker.worker_base
import
WorkerBase
from
vllm.zero_overhead.utils
import
zero_overhead_stream
from
vllm.zero_overhead.utils
import
zero_overhead_stream
from
vllm.zero_overhead.v1.gpu_model_runner
import
V1ZeroModelRunner
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -163,8 +164,13 @@ class Worker(WorkerBase):
...
@@ -163,8 +164,13 @@ class Worker(WorkerBase):
set_random_seed
(
self
.
model_config
.
seed
)
set_random_seed
(
self
.
model_config
.
seed
)
# Construct the model runner
# Construct the model runner
self
.
model_runner
:
GPUModelRunner
=
GPUModelRunner
(
if
envs
.
VLLM_ZERO_OVERHEAD
:
self
.
vllm_config
,
self
.
device
)
logger
.
info
(
'use zero overhead model_runner'
)
self
.
model_runner
:
GPUModelRunner
=
V1ZeroModelRunner
(
self
.
vllm_config
,
self
.
device
)
else
:
self
.
model_runner
:
GPUModelRunner
=
GPUModelRunner
(
self
.
vllm_config
,
self
.
device
)
if
self
.
rank
==
0
:
if
self
.
rank
==
0
:
# If usage stat is enabled, collect relevant info.
# If usage stat is enabled, collect relevant info.
...
...
vllm/zero_overhead/v1/core.py
View file @
76572db3
...
@@ -14,11 +14,15 @@ requsets_valid_token_len = {}
...
@@ -14,11 +14,15 @@ requsets_valid_token_len = {}
def
check_stop
(
request
:
Request
,
def
check_stop
(
request
:
Request
,
max_model_len
:
int
,
max_model_len
:
int
,
pooler_output
:
Optional
[
torch
.
Tensor
]
=
None
)
->
bool
:
pooler_output
:
Optional
[
torch
.
Tensor
]
=
None
,
if
request
.
request_id
not
in
requsets_valid_token_len
:
use_valid_token_len
:
bool
=
False
)
->
bool
:
requsets_valid_token_len
[
request
.
request_id
]
=
0
if
use_valid_token_len
:
return
False
if
request
.
request_id
not
in
requsets_valid_token_len
:
valid_output_len
=
requsets_valid_token_len
[
request
.
request_id
]
requsets_valid_token_len
[
request
.
request_id
]
=
0
return
False
valid_output_len
=
requsets_valid_token_len
[
request
.
request_id
]
else
:
valid_output_len
=
request
.
num_output_tokens
valid_num_tokens
=
request
.
num_prompt_tokens
+
valid_output_len
valid_num_tokens
=
request
.
num_prompt_tokens
+
valid_output_len
if
(
valid_num_tokens
>=
max_model_len
if
(
valid_num_tokens
>=
max_model_len
or
valid_output_len
>=
request
.
max_tokens
):
or
valid_output_len
>=
request
.
max_tokens
):
...
@@ -62,110 +66,121 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
...
@@ -62,110 +66,121 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
spec_decoding_stats
:
Optional
[
SpecDecodingStats
]
=
None
# fix last model out in zero overhead
# fix last model out in zero overhead
for
req_idx
,
req_id
in
enumerate
(
model_runner_output
.
fix_req_ids
):
if
model_runner_output
.
fix_req_ids
is
not
None
:
if
req_id
not
in
scheduler
.
requests
:
for
req_idx
,
req_id
in
enumerate
(
model_runner_output
.
fix_req_ids
):
continue
if
req_id
not
in
scheduler
.
requests
:
request
=
scheduler
.
requests
[
req_id
]
continue
generated_token_ids
=
model_runner_output
.
fix_sampled_token_ids
[
req_idx
]
request
=
scheduler
.
requests
[
req_id
]
if
req_id
not
in
requsets_valid_token_len
:
generated_token_ids
=
model_runner_output
.
fix_sampled_token_ids
[
req_idx
]
requsets_valid_token_len
[
req_id
]
=
0
if
req_id
not
in
requsets_valid_token_len
:
valid_output_len
=
requsets_valid_token_len
[
req_id
]
requsets_valid_token_len
[
req_id
]
=
0
fix_offset
=
valid_output_len
-
request
.
num_output_tokens
valid_output_len
=
requsets_valid_token_len
[
req_id
]
if
isinstance
(
generated_token_ids
,
int
):
fix_offset
=
valid_output_len
-
request
.
num_output_tokens
request
.
_output_token_ids
[
fix_offset
]
=
generated_token_ids
if
isinstance
(
generated_token_ids
,
int
):
request
.
_all_token_ids
[
fix_offset
]
=
generated_token_ids
request
.
_output_token_ids
[
fix_offset
]
=
generated_token_ids
requsets_valid_token_len
[
req_id
]
+=
1
request
.
_all_token_ids
[
fix_offset
]
=
generated_token_ids
else
:
requsets_valid_token_len
[
req_id
]
+=
1
valid_output_end
=
valid_output_len
+
len
(
generated_token_ids
)
-
request
.
num_output_tokens
if
valid_output_end
==
0
:
request
.
_output_token_ids
[
fix_offset
:
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
:
]
=
generated_token_ids
else
:
else
:
request
.
_output_token_ids
[
fix_offset
:
valid_output_end
]
=
generated_token_ids
valid_output_end
=
valid_output_len
+
len
(
generated_token_ids
)
-
request
.
num_output_tokens
request
.
_all_token_ids
[
fix_offset
:
valid_output_end
]
=
generated_token_ids
if
valid_output_end
==
0
:
requsets_valid_token_len
[
req_id
]
+=
len
(
generated_token_ids
)
request
.
_output_token_ids
[
fix_offset
:
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
:
]
=
generated_token_ids
else
:
request
.
_output_token_ids
[
fix_offset
:
valid_output_end
]
=
generated_token_ids
request
.
_all_token_ids
[
fix_offset
:
valid_output_end
]
=
generated_token_ids
requsets_valid_token_len
[
req_id
]
+=
len
(
generated_token_ids
)
stopped
=
False
new_logprobs
=
None
new_token_ids
=
generated_token_ids
kv_transfer_params
=
None
stopped
=
False
# Check for stop and update request state.
new_logprobs
=
None
# This must be called before we make the EngineCoreOutput.
new_token_ids
=
generated_token_ids
for
num_new
,
output_token_id
in
enumerate
(
new_token_ids
,
1
):
kv_transfer_params
=
None
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
True
)
if
stopped
:
# Check for stop and update request state.
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
# This must be called before we make the EngineCoreOutput.
del
new_token_ids
[
num_new
:]
# Trim new tokens if needed.
for
num_new
,
output_token_id
in
enumerate
(
new_token_ids
,
1
):
break
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
)
if
stopped
:
pooler_output
=
None
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
if
pooler_outputs
:
del
new_token_ids
[
num_new
:]
# Trim new tokens if needed.
pooler_output
=
pooler_outputs
[
req_index
]
break
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
pooler_output
,
True
)
pooler_output
=
None
if
stopped
:
if
pooler_outputs
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
pooler_output
=
pooler_outputs
[
req_index
]
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
# Extract sample logprobs if needed.
pooler_output
)
if
request
.
sampling_params
is
not
None
\
if
stopped
:
and
request
.
sampling_params
.
logprobs
is
not
None
and
logprobs
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
# Extract sample logprobs if needed.
new_logprobs
=
logprobs
.
slice
(
req_index
,
req_index
+
1
)
if
request
.
sampling_params
is
not
None
\
and
request
.
sampling_params
.
logprobs
is
not
None
and
logprobs
:
if
new_token_ids
and
scheduler
.
structured_output_manager
.
should_advance
(
# NOTE: once we support N tokens per step (spec decode),
request
):
# the outer lists can be of length > 1.
# NOTE: structured_output_request
new_logprobs
=
logprobs
.
slice
(
req_index
,
req_index
+
1
)
# should not be None if use_structured_output, we have
# check above, so safe to ignore type warning
if
new_token_ids
and
scheduler
.
structured_output_manager
.
should_advance
(
request
.
structured_output_request
.
grammar
.
accept_tokens
(
# type: ignore[union-attr]
request
):
req_id
,
new_token_ids
)
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# spec_token_ids comes from the model runner output
# check above, so safe to ignore type warning
if
num_nans_in_logits
is
not
None
and
req_id
in
num_nans_in_logits
:
request
.
structured_output_request
.
grammar
.
accept_tokens
(
# type: ignore[union-attr]
request
.
num_nans_in_logits
=
num_nans_in_logits
[
req_id
]
req_id
,
new_token_ids
)
# Get prompt logprobs for this request.
# spec_token_ids comes from the model runner output
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
if
num_nans_in_logits
is
not
None
and
req_id
in
num_nans_in_logits
:
if
new_token_ids
or
pooler_output
is
not
None
\
request
.
num_nans_in_logits
=
num_nans_in_logits
[
req_id
]
or
kv_transfer_params
:
# Add EngineCoreOutput for this Request.
outputs
[
request
.
client_index
].
append
(
EngineCoreOutput
(
request_id
=
req_id
,
new_token_ids
=
new_token_ids
,
finish_reason
=
request
.
get_finished_reason
(),
new_logprobs
=
new_logprobs
,
new_prompt_logprobs_tensors
=
prompt_logprobs_tensors
,
pooling_output
=
pooler_output
,
stop_reason
=
request
.
stop_reason
,
events
=
request
.
take_events
(),
kv_transfer_params
=
kv_transfer_params
,
num_cached_tokens
=
request
.
num_cached_tokens
,
))
# Add newly generated spec token ids to the request.
if
spec_token_ids
is
not
None
:
if
scheduler
.
structured_output_manager
.
should_advance
(
request
):
metadata
=
request
.
structured_output_request
# Needs to happen after new_token_ids are accepted.
request
.
spec_token_ids
=
metadata
.
grammar
.
validate_tokens
(
# type: ignore[union-attr]
spec_token_ids
[
req_index
])
else
:
else
:
request
.
spec_token_ids
=
spec_token_ids
[
req_index
]
# Invariant: EngineCore returns no partial prefill outputs.
assert
not
prompt_logprobs_tensors
# Get prompt logprobs for this request.
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
if
new_token_ids
or
pooler_output
is
not
None
\
or
kv_transfer_params
:
# Add EngineCoreOutput for this Request.
outputs
[
request
.
client_index
].
append
(
EngineCoreOutput
(
request_id
=
req_id
,
new_token_ids
=
new_token_ids
,
finish_reason
=
request
.
get_finished_reason
(),
new_logprobs
=
new_logprobs
,
new_prompt_logprobs_tensors
=
prompt_logprobs_tensors
,
pooling_output
=
pooler_output
,
stop_reason
=
request
.
stop_reason
,
events
=
request
.
take_events
(),
kv_transfer_params
=
kv_transfer_params
,
num_cached_tokens
=
request
.
num_cached_tokens
,
))
else
:
# Invariant: EngineCore returns no partial prefill outputs.
assert
not
prompt_logprobs_tensors
# fix last model out in zero overhead
if
model_runner_output
.
fix_draft_req_ids
is
not
None
:
for
req_idx
,
req_id
in
enumerate
(
model_runner_output
.
fix_draft_req_ids
):
if
req_id
not
in
scheduler
.
requests
:
continue
request
=
scheduler
.
requests
[
req_id
]
# Add newly generated spec token ids to the request.
if
model_runner_output
.
fix_draft_tokens_ids
is
not
None
:
if
scheduler
.
structured_output_manager
.
should_advance
(
request
):
metadata
=
request
.
structured_output_request
# Needs to happen after new_token_ids are accepted.
request
.
spec_token_ids
=
metadata
.
grammar
.
validate_tokens
(
# type: ignore[union-attr]
model_runner_output
.
fix_draft_tokens_ids
[
req_idx
])
else
:
request
.
spec_token_ids
=
model_runner_output
.
fix_draft_tokens_ids
[
req_idx
]
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# NOTE(woosuk): As len(self.running) can be up to 1K or more, the below
# loop can be a performance bottleneck. We should do our best to avoid
# loop can be a performance bottleneck. We should do our best to avoid
# expensive operations inside the loop.
# expensive operations inside the loop.
for
request
in
scheduler
.
running
:
for
request
in
scheduler
.
running
:
if
request
.
is_finished
():
if
req_id
in
requsets_valid_token_len
:
requsets_valid_token_len
.
pop
(
req_id
)
continue
req_id
=
request
.
request_id
req_id
=
request
.
request_id
num_tokens_scheduled
=
num_scheduled_tokens
.
get
(
req_id
,
0
)
num_tokens_scheduled
=
num_scheduled_tokens
.
get
(
req_id
,
0
)
if
num_tokens_scheduled
==
0
:
if
num_tokens_scheduled
==
0
:
...
@@ -212,19 +227,24 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
...
@@ -212,19 +227,24 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# Check for stop and update request state.
# Check for stop and update request state.
# This must be called before we make the EngineCoreOutput.
# This must be called before we make the EngineCoreOutput.
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
)
# if stopped:
if
model_runner_output
.
is_output_valid
:
# kv_transfer_params = scheduler._free_request(request)
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
# del new_token_ids[num_new:] # Trim new tokens if needed.
False
)
# break
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
del
new_token_ids
[
num_new
:]
# Trim new tokens if needed.
break
pooler_output
=
None
pooler_output
=
None
if
pooler_outputs
:
if
pooler_outputs
:
pooler_output
=
pooler_outputs
[
req_index
]
if
model_runner_output
.
is_output_valid
:
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
pooler_output
=
pooler_outputs
[
req_index
]
pooler_output
)
stopped
=
check_stop
(
request
,
scheduler
.
max_model_len
,
# if stopped:
pooler_output
,
# kv_transfer_params = scheduler._free_request(request)
False
)
if
stopped
:
kv_transfer_params
=
scheduler
.
_free_request
(
request
)
# Extract sample logprobs if needed.
# Extract sample logprobs if needed.
if
request
.
sampling_params
is
not
None
\
if
request
.
sampling_params
is
not
None
\
...
@@ -255,7 +275,30 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
...
@@ -255,7 +275,30 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
else
:
else
:
request
.
spec_token_ids
=
spec_token_ids
[
req_index
]
request
.
spec_token_ids
=
spec_token_ids
[
req_index
]
if
not
stopped
:
if
model_runner_output
.
is_output_valid
:
# # Get prompt logprobs for this request.
prompt_logprobs_tensors
=
prompt_logprobs_dict
.
get
(
req_id
)
if
new_token_ids
or
pooler_output
is
not
None
\
or
kv_transfer_params
:
# Add EngineCoreOutput for this Request.
outputs
[
request
.
client_index
].
append
(
EngineCoreOutput
(
request_id
=
req_id
,
new_token_ids
=
new_token_ids
,
finish_reason
=
request
.
get_finished_reason
(),
new_logprobs
=
new_logprobs
,
new_prompt_logprobs_tensors
=
prompt_logprobs_tensors
,
pooling_output
=
pooler_output
,
stop_reason
=
request
.
stop_reason
,
events
=
request
.
take_events
(),
kv_transfer_params
=
kv_transfer_params
,
num_cached_tokens
=
request
.
num_cached_tokens
,
))
if
stopped
:
if
req_id
in
requsets_valid_token_len
:
requsets_valid_token_len
.
pop
(
req_id
)
else
:
new_running
.
append
(
request
)
new_running
.
append
(
request
)
scheduler
.
running
=
new_running
scheduler
.
running
=
new_running
...
...
vllm/zero_overhead/v1/gpu_model_runner.py
View file @
76572db3
This diff is collapsed.
Click to expand it.
vllm/zero_overhead/v1/outputs.py
View file @
76572db3
...
@@ -8,4 +8,7 @@ from vllm.v1.outputs import ModelRunnerOutput
...
@@ -8,4 +8,7 @@ from vllm.v1.outputs import ModelRunnerOutput
class
ZeroV1ModelRunnerOutput
(
ModelRunnerOutput
):
class
ZeroV1ModelRunnerOutput
(
ModelRunnerOutput
):
# [num_reqs]
# [num_reqs]
fix_req_ids
:
list
[
str
]
=
None
fix_req_ids
:
list
[
str
]
=
None
fix_sampled_token_ids
:
list
[
list
[
int
]]
=
None
fix_sampled_token_ids
:
list
[
list
[
int
]]
=
None
\ No newline at end of file
fix_draft_req_ids
:
list
[
list
[
int
]]
=
None
fix_draft_tokens_ids
:
list
[
list
[
int
]]
=
None
is_output_valid
:
bool
=
True
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment