Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b58514dd
Commit
b58514dd
authored
Apr 18, 2026
by
王敏
Browse files
[perf]1.优化pcp代码 2.优化ep低延迟模式调度,消除空泡
parent
c462f3a0
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
149 additions
and
187 deletions
+149
-187
vllm/config/parallel.py
vllm/config/parallel.py
+7
-0
vllm/config/vllm.py
vllm/config/vllm.py
+7
-2
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+14
-14
vllm/envs.py
vllm/envs.py
+7
-9
vllm/forward_context.py
vllm/forward_context.py
+12
-6
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+7
-20
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+1
-1
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+13
-0
vllm/model_executor/layers/mla.py
vllm/model_executor/layers/mla.py
+7
-6
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+3
-4
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+27
-92
vllm/v1/attention/backend.py
vllm/v1/attention/backend.py
+1
-2
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+11
-6
vllm/v1/worker/dp_utils.py
vllm/v1/worker/dp_utils.py
+2
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+30
-24
No files found.
vllm/config/parallel.py
View file @
b58514dd
...
@@ -299,6 +299,13 @@ class ParallelConfig:
...
@@ -299,6 +299,13 @@ class ParallelConfig:
should only be set by API server scale-out.
should only be set by API server scale-out.
"""
"""
enable_lightly_cp
:
bool
=
False
"""Use lightly context parallel."""
enable_lightly_cplb
:
bool
=
False
"""Use lightly context parallel load balancing."""
@
field_validator
(
"disable_nccl_for_dp_synchronization"
,
mode
=
"wrap"
)
@
field_validator
(
"disable_nccl_for_dp_synchronization"
,
mode
=
"wrap"
)
@
classmethod
@
classmethod
def
_skip_none_validation
(
cls
,
value
:
Any
,
handler
:
Callable
)
->
Any
:
def
_skip_none_validation
(
cls
,
value
:
Any
,
handler
:
Callable
)
->
Any
:
...
...
vllm/config/vllm.py
View file @
b58514dd
...
@@ -1061,6 +1061,12 @@ class VllmConfig:
...
@@ -1061,6 +1061,12 @@ class VllmConfig:
# Handle the KV connector configs
# Handle the KV connector configs
self
.
_post_init_kv_transfer_config
()
self
.
_post_init_kv_transfer_config
()
if
self
.
parallel_config
.
enable_lightly_cp
and
not
self
.
model_config
.
enforce_eager
:
raise
ValueError
(
"Lightly context parallel currently only supports the eager mode!!!"
)
def
update_sizes_for_sequence_parallelism
(
self
,
possible_sizes
:
list
)
->
list
:
def
update_sizes_for_sequence_parallelism
(
self
,
possible_sizes
:
list
)
->
list
:
# remove the sizes that not multiple of tp_size when
# remove the sizes that not multiple of tp_size when
# enable sequence parallelism
# enable sequence parallelism
...
@@ -1186,8 +1192,7 @@ class VllmConfig:
...
@@ -1186,8 +1192,7 @@ class VllmConfig:
if
(
if
(
self
.
parallel_config
.
tensor_parallel_size
>
1
self
.
parallel_config
.
tensor_parallel_size
>
1
and
(
self
.
compilation_config
.
pass_config
.
enable_sp
)
and
self
.
compilation_config
.
pass_config
.
enable_sp
#or envs.VLLM_MLA_CP)
):
):
cudagraph_capture_sizes
=
self
.
update_sizes_for_sequence_parallelism
(
cudagraph_capture_sizes
=
self
.
update_sizes_for_sequence_parallelism
(
cudagraph_capture_sizes
cudagraph_capture_sizes
...
...
vllm/engine/arg_utils.py
View file @
b58514dd
...
@@ -582,6 +582,9 @@ class EngineArgs:
...
@@ -582,6 +582,9 @@ class EngineArgs:
kv_offloading_backend
:
KVOffloadingBackend
=
CacheConfig
.
kv_offloading_backend
kv_offloading_backend
:
KVOffloadingBackend
=
CacheConfig
.
kv_offloading_backend
tokens_only
:
bool
=
False
tokens_only
:
bool
=
False
enable_lightly_cp
:
bool
=
ParallelConfig
.
enable_lightly_cp
enable_lightly_cplb
:
bool
=
ParallelConfig
.
enable_lightly_cplb
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# support `EngineArgs(compilation_config={...})`
# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
# without having to manually construct a
...
@@ -899,6 +902,15 @@ class EngineArgs:
...
@@ -899,6 +902,15 @@ class EngineArgs:
"--worker-extension-cls"
,
**
parallel_kwargs
[
"worker_extension_cls"
]
"--worker-extension-cls"
,
**
parallel_kwargs
[
"worker_extension_cls"
]
)
)
parallel_group
.
add_argument
(
"--enable-lightly-cp"
,
**
parallel_kwargs
[
"enable_lightly_cp"
],
)
parallel_group
.
add_argument
(
"--enable-lightly-cplb"
,
**
parallel_kwargs
[
"enable_lightly_cplb"
],
)
# KV cache arguments
# KV cache arguments
cache_kwargs
=
get_kwargs
(
CacheConfig
)
cache_kwargs
=
get_kwargs
(
CacheConfig
)
cache_group
=
parser
.
add_argument_group
(
cache_group
=
parser
.
add_argument_group
(
...
@@ -1500,20 +1512,6 @@ class EngineArgs:
...
@@ -1500,20 +1512,6 @@ class EngineArgs:
data_parallel_external_lb
=
(
data_parallel_external_lb
=
(
self
.
data_parallel_external_lb
or
self
.
data_parallel_rank
is
not
None
self
.
data_parallel_external_lb
or
self
.
data_parallel_rank
is
not
None
)
)
if
(
envs
.
VLLM_MLA_CP
and
self
.
max_num_batched_tokens
is
not
None
and
self
.
max_num_batched_tokens
<
self
.
tensor_parallel_size
**
3
):
raise
ValueError
(
"max_num_batched_tokens should be larger than "
"tensor_parallel_size ** 3 when enabled VLLM_MLA_CP"
)
logger
.
info
(
"[MLACP] VLLM_MLA_CP is %s"
,
envs
.
VLLM_MLA_CP
)
logger
.
info
(
"[MLACP] VLLM_MLA_CPLB is %s"
,
envs
.
VLLM_MLA_CPLB
)
# Local DP rank = 1, use pure-external LB.
# Local DP rank = 1, use pure-external LB.
if
data_parallel_external_lb
:
if
data_parallel_external_lb
:
assert
self
.
data_parallel_rank
is
not
None
,
(
assert
self
.
data_parallel_rank
is
not
None
,
(
...
@@ -1644,6 +1642,8 @@ class EngineArgs:
...
@@ -1644,6 +1642,8 @@ class EngineArgs:
cp_kv_cache_interleave_size
=
self
.
cp_kv_cache_interleave_size
,
cp_kv_cache_interleave_size
=
self
.
cp_kv_cache_interleave_size
,
_api_process_count
=
self
.
_api_process_count
,
_api_process_count
=
self
.
_api_process_count
,
_api_process_rank
=
self
.
_api_process_rank
,
_api_process_rank
=
self
.
_api_process_rank
,
enable_lightly_cp
=
self
.
enable_lightly_cp
,
enable_lightly_cplb
=
self
.
enable_lightly_cplb
,
)
)
speculative_config
=
self
.
create_speculative_config
(
speculative_config
=
self
.
create_speculative_config
(
...
...
vllm/envs.py
View file @
b58514dd
...
@@ -324,8 +324,9 @@ if TYPE_CHECKING:
...
@@ -324,8 +324,9 @@ if TYPE_CHECKING:
USE_LIGHTOP_TOPK
:
bool
=
False
USE_LIGHTOP_TOPK
:
bool
=
False
USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX
:
bool
=
False
USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX
:
bool
=
False
VLLM_DISABLE_DSA
:
bool
=
False
VLLM_DISABLE_DSA
:
bool
=
False
VLLM_MLA_CP
:
bool
=
False
VLLM_LIGHTLY_CP_THRESHOULD
:
int
=
2048
VLLM_MLA_CPLB
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
return
os
.
getenv
(
return
os
.
getenv
(
"XDG_CACHE_HOME"
,
"XDG_CACHE_HOME"
,
...
@@ -2012,13 +2013,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -2012,13 +2013,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DISABLE_DSA"
:
"VLLM_DISABLE_DSA"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_DISABLE_DSA"
,
"False"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_DISABLE_DSA"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# If set to 1/True, enable mla context parallel
"VLLM_MLA_CP"
:
# MLA_CP open threshold
lambda
:
(
os
.
environ
.
get
(
"VLLM_MLA_CP"
,
"False"
).
lower
()
in
"VLLM_LIGHTLY_CP_THRESHOULD"
:
(
"true"
,
"1"
)),
lambda
:
int
(
os
.
getenv
(
"VLLM_LIGHTLY_CP_THRESHOULD"
,
"2048"
)),
"VLLM_MLA_CPLB"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_MLA_CPLB"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
}
}
# --8<-- [end:env-vars-definition]
# --8<-- [end:env-vars-definition]
...
...
vllm/forward_context.py
View file @
b58514dd
...
@@ -242,7 +242,8 @@ class ForwardContext:
...
@@ -242,7 +242,8 @@ class ForwardContext:
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
enable_mla_cp
:
bool
=
False
enable_lightly_cp
:
bool
=
False
enable_lightly_cplb
:
bool
=
False
def
__post_init__
(
self
):
def
__post_init__
(
self
):
assert
self
.
cudagraph_runtime_mode
.
valid_runtime_modes
(),
(
assert
self
.
cudagraph_runtime_mode
.
valid_runtime_modes
(),
(
...
@@ -279,7 +280,8 @@ def create_forward_context(
...
@@ -279,7 +280,8 @@ def create_forward_context(
skip_compiled
:
bool
=
False
,
skip_compiled
:
bool
=
False
,
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
enable_mla_cp
:
bool
=
False
enable_lightly_cp
:
bool
=
False
,
enable_lightly_cplb
:
bool
=
False
):
):
if
vllm_config
.
compilation_config
.
fast_moe_cold_start
:
if
vllm_config
.
compilation_config
.
fast_moe_cold_start
:
if
vllm_config
.
speculative_config
is
None
:
if
vllm_config
.
speculative_config
is
None
:
...
@@ -307,7 +309,8 @@ def create_forward_context(
...
@@ -307,7 +309,8 @@ def create_forward_context(
skip_compiled
=
skip_compiled
,
skip_compiled
=
skip_compiled
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
enable_mla_cp
=
enable_mla_cp
,
enable_lightly_cp
=
enable_lightly_cp
,
enable_lightly_cplb
=
enable_lightly_cplb
,
additional_kwargs
=
additional_kwargs
or
{},
additional_kwargs
=
additional_kwargs
or
{},
)
)
...
@@ -341,7 +344,8 @@ def set_forward_context(
...
@@ -341,7 +344,8 @@ def set_forward_context(
skip_compiled
:
bool
=
False
,
skip_compiled
:
bool
=
False
,
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
enable_mla_cp
:
bool
=
False
,
enable_lightly_cp
:
bool
=
False
,
enable_lightly_cplb
:
bool
=
False
,
):
):
"""A context manager that stores the current forward context,
"""A context manager that stores the current forward context,
can be attention metadata, etc.
can be attention metadata, etc.
...
@@ -353,7 +357,8 @@ def set_forward_context(
...
@@ -353,7 +357,8 @@ def set_forward_context(
forward_start_time
=
time
.
perf_counter
()
forward_start_time
=
time
.
perf_counter
()
dp_metadata
:
DPMetadata
|
None
=
None
dp_metadata
:
DPMetadata
|
None
=
None
if
vllm_config
.
parallel_config
.
data_parallel_size
>
1
and
(
if
vllm_config
.
parallel_config
.
data_parallel_size
>
1
and
\
envs
.
VLLM_ALL2ALL_BACKEND
!=
"deepep_low_latency"
and
(
attn_metadata
is
not
None
or
num_tokens
is
not
None
attn_metadata
is
not
None
or
num_tokens
is
not
None
):
):
# If num_tokens_across_dp hasn't already been initialized, then
# If num_tokens_across_dp hasn't already been initialized, then
...
@@ -404,7 +409,8 @@ def set_forward_context(
...
@@ -404,7 +409,8 @@ def set_forward_context(
skip_compiled
,
skip_compiled
,
scatter_indexes_tensor
,
scatter_indexes_tensor
,
gather_indexes_tensor
,
gather_indexes_tensor
,
enable_mla_cp
enable_lightly_cp
,
enable_lightly_cplb
)
)
try
:
try
:
...
...
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
b58514dd
...
@@ -205,20 +205,6 @@ def moe_grouped_gemm(
...
@@ -205,20 +205,6 @@ def moe_grouped_gemm(
return
output
return
output
def
native_w8a8_perChannel_batch_matmul
(
q_a1_all
,
weight13
,
qa1_scale_all
,
w13_scale
,
output_dtype
):
A
=
q_a1_all
.
to
(
torch
.
float32
)
B
=
weight13
.
to
(
torch
.
float32
)
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
],
"Dimension mismatch"
C
=
torch
.
bmm
(
A
,
B
.
transpose
(
1
,
2
))
# [E, M, K]
C
=
qa1_scale_all
*
C
*
w13_scale
.
transpose
(
1
,
2
)
# Broadcast per-column scale
C
=
C
.
to
(
output_dtype
)
return
C
def
scales_shape_stride_dtype
(
def
scales_shape_stride_dtype
(
E
:
int
,
T
:
int
,
G
:
int
,
quant_scale_fmt
:
DeepGemmQuantScaleFMT
E
:
int
,
T
:
int
,
G
:
int
,
quant_scale_fmt
:
DeepGemmQuantScaleFMT
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
...
@@ -589,7 +575,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -589,7 +575,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
apply_router_weight_on_input
:
bool
,
apply_router_weight_on_input
:
bool
,
use_nn_moe
:
bool
|
None
=
False
,
**
_
**
_
):
):
assert
expert_tokens_meta
is
not
None
assert
expert_tokens_meta
is
not
None
...
@@ -612,11 +597,13 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -612,11 +597,13 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace1
=
_resize_cache
(
workspace13
,
(
E
,
max_num_tokens
,
N
))
workspace1
=
_resize_cache
(
workspace13
,
(
E
,
max_num_tokens
,
N
))
expected_m
=
self
.
estimate_expected_m
(
# expected_m = self.estimate_expected_m(
global_num_experts
=
global_num_experts
,
# global_num_experts=global_num_experts,
max_tokens_per_expert
=
max_num_tokens
,
# max_tokens_per_expert=max_num_tokens,
topk
=
topk_ids
.
size
(
-
1
),
# topk=topk_ids.size(-1),
)
# )
expected_m
=
self
.
get_expected_m
()
if
self
.
quant_config
.
use_fp8_w8a16
or
self
.
quant_config
.
use_fp8_w8a8
:
if
self
.
quant_config
.
use_fp8_w8a16
or
self
.
quant_config
.
use_fp8_w8a8
:
fp8_m_grouped_gemm_nt_masked
(
fp8_m_grouped_gemm_nt_masked
(
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
b58514dd
...
@@ -854,7 +854,7 @@ class FusedMoE(CustomOp):
...
@@ -854,7 +854,7 @@ class FusedMoE(CustomOp):
def
use_dp_chunking
(
self
)
->
bool
:
def
use_dp_chunking
(
self
)
->
bool
:
return
(
return
(
self
.
moe_parallel_config
.
use_pplx_kernels
self
.
moe_parallel_config
.
use_pplx_kernels
or
self
.
moe_parallel_config
.
use_deepep_ll_kernels
#
or self.moe_parallel_config.use_deepep_ll_kernels
or
self
.
moe_parallel_config
.
use_mori_kernels
or
self
.
moe_parallel_config
.
use_mori_kernels
or
(
self
.
dp_size
>
1
and
self
.
use_flashinfer_cutlass_kernels
)
or
(
self
.
dp_size
>
1
and
self
.
use_flashinfer_cutlass_kernels
)
)
and
envs
.
VLLM_ENABLE_MOE_DP_CHUNK
)
and
envs
.
VLLM_ENABLE_MOE_DP_CHUNK
...
...
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
b58514dd
...
@@ -406,6 +406,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
...
@@ -406,6 +406,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
max_num_tokens
=
max_num_tokens
self
.
max_num_tokens
=
max_num_tokens
self
.
num_dispatchers
=
num_dispatchers
self
.
num_dispatchers
=
num_dispatchers
self
.
expected_m
=
max_num_tokens
@
staticmethod
@
staticmethod
def
expects_unquantized_inputs
(
def
expects_unquantized_inputs
(
...
@@ -775,6 +776,12 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
...
@@ -775,6 +776,12 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
set_expected_m
(
self
,
expected_m
):
self
.
expected_m
=
expected_m
def
get_expected_m
(
self
):
return
self
.
expected_m
def
_slice_scales
(
def
_slice_scales
(
scales
:
torch
.
Tensor
|
None
,
start
:
int
,
end
:
int
scales
:
torch
.
Tensor
|
None
,
start
:
int
,
end
:
int
...
@@ -1074,6 +1081,12 @@ class FusedMoEModularKernel(torch.nn.Module):
...
@@ -1074,6 +1081,12 @@ class FusedMoEModularKernel(torch.nn.Module):
The _prepare method is a wrapper around self.prepare_finalize.prepare
The _prepare method is a wrapper around self.prepare_finalize.prepare
that handles DBO and async.
that handles DBO and async.
"""
"""
expected_m
=
(
hidden_states
.
shape
[
0
]
*
self
.
fused_experts
.
num_dispatchers
*
topk_ids
.
shape
[
1
]
+
global_num_experts
)
//
global_num_experts
self
.
fused_experts
.
set_expected_m
(
expected_m
)
if
not
self
.
prepare_finalize
.
supports_async
():
if
not
self
.
prepare_finalize
.
supports_async
():
# We shouldn't be running an a2a kernel that doesn't
# We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize
# support async prepare/finalize
...
...
vllm/model_executor/layers/mla.py
View file @
b58514dd
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
import
torch
import
torch
from
vllm.attention.layer
import
MLAAttention
from
vllm.attention.layer
import
MLAAttention
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -115,6 +114,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
...
@@ -115,6 +114,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
self
.
prefix
=
prefix
self
.
prefix
=
prefix
def
forward
(
def
forward
(
self
,
self
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
...
@@ -189,11 +189,12 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
...
@@ -189,11 +189,12 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
if
llama_4_scaling
is
not
None
:
if
llama_4_scaling
is
not
None
:
q
*=
llama_4_scaling
q
*=
llama_4_scaling
enable_mla_cp
=
get_forward_context
().
enable_mla_cp
#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
enable_lightly_cp
=
get_forward_context
().
enable_lightly_cp
enable_lightly_cplb
=
get_forward_context
().
enable_lightly_cplb
# if not use_fused_rms_rope_concat:
# if not use_fused_rms_rope_concat:
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
if
enable_
mla
_cp
:
if
enable_
lightly
_cp
:
kv_c_normed
=
tensor_model_parallel_all_gather
(
kv_c_normed
=
tensor_model_parallel_all_gather
(
kv_c_normed
.
contiguous
(),
0
kv_c_normed
.
contiguous
(),
0
)
)
...
@@ -202,7 +203,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
...
@@ -202,7 +203,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
)
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
en
vs
.
VLLM_MLA_CPLB
and
gather_indexes_tensor
is
not
None
:
if
en
able_lightly_cplb
and
gather_indexes_tensor
is
not
None
:
# Reorder kv after pcp allgather.
# Reorder kv after pcp allgather.
kv_c_normed
=
torch
.
index_select
(
kv_c_normed
,
0
,
gather_indexes_tensor
)
kv_c_normed
=
torch
.
index_select
(
kv_c_normed
,
0
,
gather_indexes_tensor
)
k_pe
=
torch
.
index_select
(
k_pe
,
0
,
gather_indexes_tensor
)
k_pe
=
torch
.
index_select
(
k_pe
,
0
,
gather_indexes_tensor
)
...
@@ -243,7 +244,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
...
@@ -243,7 +244,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
"expose 'cos_sin_cache'."
"expose 'cos_sin_cache'."
)
)
if
enable_
mla
_cp
:
if
enable_
lightly
_cp
:
kv_c
=
tensor_model_parallel_all_gather
(
kv_c
=
tensor_model_parallel_all_gather
(
kv_c
.
contiguous
(),
0
kv_c
.
contiguous
(),
0
)
)
...
@@ -251,7 +252,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
...
@@ -251,7 +252,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
k_pe
.
contiguous
(),
0
k_pe
.
contiguous
(),
0
)
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
en
vs
.
VLLM_MLA_CPLB
and
gather_indexes_tensor
is
not
None
:
if
en
able_lightly_cplb
and
gather_indexes_tensor
is
not
None
:
# Reorder kv after pcp allgather.
# Reorder kv after pcp allgather.
kv_c
=
torch
.
index_select
(
kv_c
,
0
,
gather_indexes_tensor
)
kv_c
=
torch
.
index_select
(
kv_c
,
0
,
gather_indexes_tensor
)
k_pe
=
torch
.
index_select
(
k_pe
,
0
,
gather_indexes_tensor
)
k_pe
=
torch
.
index_select
(
k_pe
,
0
,
gather_indexes_tensor
)
...
...
vllm/model_executor/models/deepseek_mtp.py
View file @
b58514dd
...
@@ -198,8 +198,8 @@ class DeepSeekMultiTokenPredictor(nn.Module):
...
@@ -198,8 +198,8 @@ class DeepSeekMultiTokenPredictor(nn.Module):
if
inputs_embeds
is
None
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
current_step_idx
=
spec_step_idx
%
self
.
num_mtp_layers
current_step_idx
=
spec_step_idx
%
self
.
num_mtp_layers
enable_
mla
_cp
=
get_forward_context
().
enable_
mla_cp
#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
enable_
lightly
_cp
=
get_forward_context
().
enable_
lightly_cp
if
enable_
mla
_cp
:
if
enable_
lightly
_cp
:
scatter_indexes_tensor
=
get_forward_context
().
scatter_indexes_tensor
scatter_indexes_tensor
=
get_forward_context
().
scatter_indexes_tensor
if
scatter_indexes_tensor
is
None
:
if
scatter_indexes_tensor
is
None
:
inputs_embeds_per_rank
=
torch
.
chunk
(
inputs_embeds
,
chunks
=
self
.
tp_size
,
dim
=
0
)
inputs_embeds_per_rank
=
torch
.
chunk
(
inputs_embeds
,
chunks
=
self
.
tp_size
,
dim
=
0
)
...
@@ -212,7 +212,6 @@ class DeepSeekMultiTokenPredictor(nn.Module):
...
@@ -212,7 +212,6 @@ class DeepSeekMultiTokenPredictor(nn.Module):
positions_per_rank
=
torch
.
chunk
(
positions
,
chunks
=
self
.
tp_size
,
dim
=
0
)
positions_per_rank
=
torch
.
chunk
(
positions
,
chunks
=
self
.
tp_size
,
dim
=
0
)
positions
=
positions_per_rank
[
self
.
tp_rank
].
contiguous
()
positions
=
positions_per_rank
[
self
.
tp_rank
].
contiguous
()
else
:
else
:
#scatter_indexes_tensor = scatter_indexes_tensor[scatter_indexes_tensor != -1]
scatter_indexes_tensor
=
torch
.
where
(
scatter_indexes_tensor
==
-
1
,
0
,
scatter_indexes_tensor
)
scatter_indexes_tensor
=
torch
.
where
(
scatter_indexes_tensor
==
-
1
,
0
,
scatter_indexes_tensor
)
inputs_embeds
=
torch
.
index_select
(
inputs_embeds
,
0
,
scatter_indexes_tensor
)
inputs_embeds
=
torch
.
index_select
(
inputs_embeds
,
0
,
scatter_indexes_tensor
)
...
@@ -228,7 +227,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
...
@@ -228,7 +227,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
current_step_idx
,
current_step_idx
,
)
)
if
enable_
mla
_cp
:
if
enable_
lightly
_cp
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
dim
=
0
)
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
dim
=
0
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
gather_indexes_tensor
is
not
None
:
if
gather_indexes_tensor
is
not
None
:
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
b58514dd
...
@@ -183,10 +183,9 @@ class DeepseekAttention(nn.Module):
...
@@ -183,10 +183,9 @@ class DeepseekAttention(nn.Module):
return
output
return
output
def
eff_2d_
iqis_all_gather
(
def
iqis_all_gather
(
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
tp_size
:
int
|
None
=
None
,
tp_size
:
int
|
None
=
None
tp_rank
:
int
|
None
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
iqis
is
not
None
assert
iqis
is
not
None
iq_tensor
,
is_tensor
=
iqis
iq_tensor
,
is_tensor
=
iqis
...
@@ -221,6 +220,7 @@ def eff_2d_iqis_all_gather(
...
@@ -221,6 +220,7 @@ def eff_2d_iqis_all_gather(
is_gathered
=
is_gathered_int8
.
view
(
torch
.
float32
)
is_gathered
=
is_gathered_int8
.
view
(
torch
.
float32
)
return
(
iq_gathered
,
is_gathered
)
return
(
iq_gathered
,
is_gathered
)
class
DeepseekV2MLP
(
nn
.
Module
):
class
DeepseekV2MLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -267,15 +267,10 @@ class DeepseekV2MLP(nn.Module):
...
@@ -267,15 +267,10 @@ class DeepseekV2MLP(nn.Module):
x
,
x
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
):
):
enable_
mla
_cp
=
get_forward_context
().
enable_
mla_cp
#envs.VLLM_MLA_CP# and not get_forward_context().draft_model
enable_
lightly
_cp
=
get_forward_context
().
enable_
lightly_cp
if
enable_
mla
_cp
:
if
enable_
lightly
_cp
:
if
iqis
is
not
None
and
iqis
[
0
]
is
not
None
and
iqis
[
1
]
is
not
None
:
if
iqis
is
not
None
and
iqis
[
0
]
is
not
None
and
iqis
[
1
]
is
not
None
:
if
False
:
iqis
=
iqis_all_gather
(
iqis
,
tp_size
=
self
.
tp_size
)
i_q_gahter
=
tensor_model_parallel_all_gather
(
iqis
[
0
].
contiguous
(),
0
)
i_s_gather
=
tensor_model_parallel_all_gather
(
iqis
[
1
].
contiguous
(),
0
)
iqis
=
(
i_q_gahter
,
i_s_gather
)
else
:
iqis
=
eff_2d_iqis_all_gather
(
iqis
,
tp_size
=
self
.
tp_size
,
tp_rank
=
get_tensor_model_parallel_rank
())
else
:
else
:
x
=
tensor_model_parallel_all_gather
(
x
.
contiguous
(),
0
)
x
=
tensor_model_parallel_all_gather
(
x
.
contiguous
(),
0
)
...
@@ -293,7 +288,7 @@ class DeepseekV2MLP(nn.Module):
...
@@ -293,7 +288,7 @@ class DeepseekV2MLP(nn.Module):
x
=
self
.
act_fn
(
gate_up
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
x
,
_
=
self
.
down_proj
(
x
)
if
enable_
mla
_cp
:
if
enable_
lightly
_cp
:
x
=
tensor_model_parallel_reduce_scatter
(
x
.
contiguous
(),
dim
=
0
)
x
=
tensor_model_parallel_reduce_scatter
(
x
.
contiguous
(),
dim
=
0
)
return
x
return
x
elif
self
.
tp_size
>
1
:
elif
self
.
tp_size
>
1
:
...
@@ -301,66 +296,6 @@ class DeepseekV2MLP(nn.Module):
...
@@ -301,66 +296,6 @@ class DeepseekV2MLP(nn.Module):
return
x
return
x
class
DeepseekV2SharedMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
QuantizationConfig
|
None
=
None
,
reduce_results
:
bool
=
True
,
is_sequence_parallel
=
False
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
# If is_sequence_parallel, the input and output tensors are sharded
# across the ranks within the tp_group. In this case the weights are
# replicated and no collective ops are needed.
# Otherwise we use standard TP with an allreduce at the end.
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
disable_tp
=
is_sequence_parallel
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
reduce_results
=
reduce_results
,
disable_tp
=
is_sequence_parallel
,
prefix
=
f
"
{
prefix
}
.down_proj"
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
):
if
envs
.
USE_FUSED_RMS_QUANT
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
,
iqis
=
iqis
)
if
envs
.
USE_FUSED_SILU_MUL_QUANT
:
from
lmslim.quantize.quant_ops
import
lm_fuse_silu_mul_quant
xq
,
xs
=
lm_fuse_silu_mul_quant
(
gate_up
)
x
,
_
=
self
.
down_proj
(
gate_up
,
iqis
=
(
xq
,
xs
))
else
:
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
else
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
DeepseekV2MoE
(
nn
.
Module
):
class
DeepseekV2MoE
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -431,7 +366,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -431,7 +366,7 @@ class DeepseekV2MoE(nn.Module):
else
:
else
:
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
self
.
shared_experts
=
DeepseekV2
Shared
MLP
(
self
.
shared_experts
=
DeepseekV2MLP
(
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
hidden_act
=
config
.
hidden_act
,
...
@@ -477,8 +412,8 @@ class DeepseekV2MoE(nn.Module):
...
@@ -477,8 +412,8 @@ class DeepseekV2MoE(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
enable_
mla
_cp
=
get_forward_context
().
enable_
mla_cp
#envs.VLLM_MLA_CP #and not get_forward_context().draft_model
enable_
lightly
_cp
=
get_forward_context
().
enable_
lightly_cp
if
enable_
mla
_cp
:
if
enable_
lightly
_cp
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
0
hidden_states
.
contiguous
(),
0
)
)
...
@@ -553,7 +488,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -553,7 +488,7 @@ class DeepseekV2MoE(nn.Module):
assert
shared_output
is
not
None
assert
shared_output
is
not
None
final_hidden_states
+=
shared_output
final_hidden_states
+=
shared_output
if
enable_
mla
_cp
:
if
enable_
lightly
_cp
:
final_hidden_states
=
tensor_model_parallel_reduce_scatter
(
final_hidden_states
=
tensor_model_parallel_reduce_scatter
(
final_hidden_states
.
contiguous
(),
0
final_hidden_states
.
contiguous
(),
0
)
)
...
@@ -889,13 +824,14 @@ class Indexer(nn.Module):
...
@@ -889,13 +824,14 @@ class Indexer(nn.Module):
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k
=
torch
.
cat
([
k_pe
.
squeeze
(
-
2
),
k_nope
],
dim
=-
1
)
k
=
torch
.
cat
([
k_pe
.
squeeze
(
-
2
),
k_nope
],
dim
=-
1
)
enable_
mla
_cp
=
get_forward_context
().
enable_
mla_cp
#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
enable_
lightly
_cp
=
get_forward_context
().
enable_
lightly_cp
if
enable_
mla
_cp
:
if
enable_
lightly
_cp
:
k
=
tensor_model_parallel_all_gather
(
k
=
tensor_model_parallel_all_gather
(
k
.
contiguous
(),
0
k
.
contiguous
(),
0
)
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
envs
.
VLLM_MLA_CPLB
and
gather_indexes_tensor
is
not
None
:
enable_lightly_cplb
=
get_forward_context
().
enable_lightly_cplb
if
enable_lightly_cplb
and
gather_indexes_tensor
is
not
None
:
k
=
torch
.
index_select
(
k
,
0
,
gather_indexes_tensor
)
k
=
torch
.
index_select
(
k
,
0
,
gather_indexes_tensor
)
# we only quant q here since k quant is fused with cache insertion
# we only quant q here since k quant is fused with cache insertion
...
@@ -964,8 +900,8 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -964,8 +900,8 @@ class DeepseekV2MLAAttention(nn.Module):
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
assert
num_heads
%
tp_size
==
0
assert
num_heads
%
tp_size
==
0
#
self.num_local_heads = num_heads // tp_size
self
.
num_local_heads
=
num_heads
//
tp_size
if
not
\
self
.
num_local_heads
=
num_heads
//
tp_size
if
not
envs
.
VLLM_MLA_CP
else
self
.
num_heads
vllm_config
.
parallel_config
.
enable_lightly_cp
else
self
.
num_heads
self
.
scaling
=
self
.
qk_head_dim
**-
0.5
self
.
scaling
=
self
.
qk_head_dim
**-
0.5
self
.
max_position_embeddings
=
max_position_embeddings
self
.
max_position_embeddings
=
max_position_embeddings
...
@@ -999,7 +935,7 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -999,7 +935,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.q_b_proj"
,
prefix
=
f
"
{
prefix
}
.q_b_proj"
,
disable_tp
=
envs
.
VLLM_MLA_CP
,
disable_tp
=
vllm_config
.
parallel_config
.
enable_lightly_cp
)
)
else
:
else
:
self
.
q_proj
=
ColumnParallelLinear
(
self
.
q_proj
=
ColumnParallelLinear
(
...
@@ -1008,7 +944,7 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -1008,7 +944,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.q_proj"
,
prefix
=
f
"
{
prefix
}
.q_proj"
,
disable_tp
=
envs
.
VLLM_MLA_CP
,
disable_tp
=
vllm_config
.
parallel_config
.
enable_lightly_cp
,
)
)
self
.
kv_a_layernorm
=
RMSNorm
(
self
.
kv_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
kv_a_layernorm
=
RMSNorm
(
self
.
kv_lora_rank
,
eps
=
config
.
rms_norm_eps
)
self
.
kv_b_proj
=
ColumnParallelLinear
(
self
.
kv_b_proj
=
ColumnParallelLinear
(
...
@@ -1017,7 +953,7 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -1017,7 +953,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.kv_b_proj"
,
prefix
=
f
"
{
prefix
}
.kv_b_proj"
,
disable_tp
=
envs
.
VLLM_MLA_CP
,
disable_tp
=
vllm_config
.
parallel_config
.
enable_lightly_cp
,
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
self
.
num_heads
*
self
.
v_head_dim
,
self
.
num_heads
*
self
.
v_head_dim
,
...
@@ -1025,7 +961,7 @@ class DeepseekV2MLAAttention(nn.Module):
...
@@ -1025,7 +961,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias
=
False
,
bias
=
False
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
disable_tp
=
envs
.
VLLM_MLA_CP
,
disable_tp
=
vllm_config
.
parallel_config
.
enable_lightly_cp
,
)
)
if
config
.
rope_parameters
[
"rope_type"
]
!=
"default"
:
if
config
.
rope_parameters
[
"rope_type"
]
!=
"default"
:
...
@@ -1262,8 +1198,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1262,8 +1198,7 @@ class DeepseekV2DecoderLayer(nn.Module):
residual
*=
1.0
/
self
.
routed_scaling_factor
residual
*=
1.0
/
self
.
routed_scaling_factor
# Fully Connected
# Fully Connected
enable_mla_cp
=
get_forward_context
().
enable_mla_cp
enable_lightly_cp
=
get_forward_context
().
enable_lightly_cp
skip_moe_large_batch_size
=
enable_mla_cp
update_hs
=
True
if
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
else
False
update_hs
=
True
if
isinstance
(
self
.
mlp
,
DeepseekV2MoE
)
else
False
assert
self
.
post_attention_layernorm
.
has_weight
is
True
assert
self
.
post_attention_layernorm
.
has_weight
is
True
_i_q
,
_i_s
,
residual
=
self
.
post_attention_layernorm
(
x
=
hidden_states
,
_i_q
,
_i_s
,
residual
=
self
.
post_attention_layernorm
(
x
=
hidden_states
,
...
@@ -1272,7 +1207,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1272,7 +1207,7 @@ class DeepseekV2DecoderLayer(nn.Module):
update_input
=
update_hs
update_input
=
update_hs
)
)
new_resi
=
residual
new_resi
=
residual
if
skip_moe_large_batch_size
and
isinstance
(
self
.
mlp
,
DeepseekV2MoE
):
if
enable_lightly_cp
and
isinstance
(
self
.
mlp
,
DeepseekV2MoE
):
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
else
:
else
:
hidden_states
=
self
.
mlp
(
hidden_states
,
iqis
=
(
_i_q
,
_i_s
))
hidden_states
=
self
.
mlp
(
hidden_states
,
iqis
=
(
_i_q
,
_i_s
))
...
@@ -1437,8 +1372,8 @@ class DeepseekV2Model(nn.Module):
...
@@ -1437,8 +1372,8 @@ class DeepseekV2Model(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
enable_
mla
_cp
=
get_forward_context
().
enable_
mla_cp
#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
enable_
lightly
_cp
=
get_forward_context
().
enable_
lightly_cp
if
enable_
mla
_cp
:
if
enable_
lightly
_cp
:
scatter_indexes_tensor
=
get_forward_context
().
scatter_indexes_tensor
scatter_indexes_tensor
=
get_forward_context
().
scatter_indexes_tensor
if
scatter_indexes_tensor
is
None
:
if
scatter_indexes_tensor
is
None
:
hidden_states_per_rank
=
torch
.
chunk
(
hidden_states
,
chunks
=
self
.
tp_size
,
dim
=
0
)
hidden_states_per_rank
=
torch
.
chunk
(
hidden_states
,
chunks
=
self
.
tp_size
,
dim
=
0
)
...
@@ -1481,7 +1416,7 @@ class DeepseekV2Model(nn.Module):
...
@@ -1481,7 +1416,7 @@ class DeepseekV2Model(nn.Module):
)
)
if
not
get_pp_group
().
is_last_rank
:
if
not
get_pp_group
().
is_last_rank
:
if
enable_
mla
_cp
:
if
enable_
lightly
_cp
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
dim
=
0
)
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
dim
=
0
)
residual
=
tensor_model_parallel_all_gather
(
residual
.
contiguous
(),
dim
=
0
)
residual
=
tensor_model_parallel_all_gather
(
residual
.
contiguous
(),
dim
=
0
)
return
IntermediateTensors
(
return
IntermediateTensors
(
...
@@ -1490,7 +1425,7 @@ class DeepseekV2Model(nn.Module):
...
@@ -1490,7 +1425,7 @@ class DeepseekV2Model(nn.Module):
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
enable_
mla
_cp
:
if
enable_
lightly
_cp
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
dim
=
0
)
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
dim
=
0
)
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
gather_indexes_tensor
=
get_forward_context
().
gather_indexes_tensor
if
gather_indexes_tensor
is
not
None
:
if
gather_indexes_tensor
is
not
None
:
...
...
vllm/v1/attention/backend.py
View file @
b58514dd
...
@@ -332,7 +332,6 @@ class CommonAttentionMetadata:
...
@@ -332,7 +332,6 @@ class CommonAttentionMetadata:
"""Number of requests"""
"""Number of requests"""
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading
num_actual_tokens
:
int
num_actual_tokens
:
int
"""Total number of tokens in batch"""
"""Total number of tokens in batch"""
max_query_len
:
int
max_query_len
:
int
"""Longest query in batch"""
"""Longest query in batch"""
...
@@ -348,7 +347,7 @@ class CommonAttentionMetadata:
...
@@ -348,7 +347,7 @@ class CommonAttentionMetadata:
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
cp_common_metadata
:
CpCommonAttentionMetadata
|
None
=
None
cp_common_metadata
:
CpCommonAttentionMetadata
|
None
=
None
enable_
mla
_cp
:
bool
=
False
enable_
lightly
_cp
:
bool
=
False
causal
:
bool
=
True
causal
:
bool
=
True
...
...
vllm/v1/spec_decode/eagle.py
View file @
b58514dd
...
@@ -78,7 +78,8 @@ class SpecDecodeBaseProposer:
...
@@ -78,7 +78,8 @@ class SpecDecodeBaseProposer:
self
.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
self
.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
self
.
num_speculative_tokens
=
self
.
speculative_config
.
num_speculative_tokens
self
.
num_speculative_tokens
=
self
.
speculative_config
.
num_speculative_tokens
# The drafter can get longer sequences than the target model.
# The drafter can get longer sequences than the target model.
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
if
not
envs
.
VLLM_MLA_CPLB
\
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
if
not
\
vllm_config
.
parallel_config
.
enable_lightly_cplb
\
else
vllm_config
.
scheduler_config
.
max_num_seqs
*
2
else
vllm_config
.
scheduler_config
.
max_num_seqs
*
2
self
.
max_num_tokens
=
(
self
.
max_num_tokens
=
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
+
max_batch_size
vllm_config
.
scheduler_config
.
max_num_batched_tokens
+
max_batch_size
...
@@ -224,7 +225,10 @@ class SpecDecodeBaseProposer:
...
@@ -224,7 +225,10 @@ class SpecDecodeBaseProposer:
self
.
scatter_indexes_tensor
=
None
self
.
scatter_indexes_tensor
=
None
self
.
gather_indexes_tensor
=
None
self
.
gather_indexes_tensor
=
None
if
envs
.
VLLM_MLA_CP
:
self
.
enable_lightly_cp
=
vllm_config
.
parallel_config
.
enable_lightly_cp
self
.
enable_lightly_cplb
=
self
.
enable_lightly_cp
and
vllm_config
.
parallel_config
.
enable_lightly_cplb
if
self
.
enable_lightly_cp
:
self
.
query_start_loc
=
CpuGpuBuffer
(
self
.
query_start_loc
=
CpuGpuBuffer
(
max_batch_size
+
1
,
max_batch_size
+
1
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
...
@@ -339,8 +343,8 @@ class SpecDecodeBaseProposer:
...
@@ -339,8 +343,8 @@ class SpecDecodeBaseProposer:
num_tokens_unpadded
=
num_tokens
,
num_tokens_padded
=
num_tokens
num_tokens_unpadded
=
num_tokens
,
num_tokens_padded
=
num_tokens
)
)
enable_
mla_cp
=
envs
.
VLLM_MLA_CP
and
num_tokens
>
self
.
runner
.
mla
_cp_threshould
enable_
lightly_cp
=
self
.
enable_lightly_cp
and
num_tokens
>
self
.
runner
.
lightly
_cp_threshould
if
enable_
mla
_cp
:
if
enable_
lightly
_cp
:
num_tokens_dp_padded
=
self
.
_pad_for_mla_cp
(
num_tokens_dp_padded
)
num_tokens_dp_padded
=
self
.
_pad_for_mla_cp
(
num_tokens_dp_padded
)
common_attn_metadata
=
self
.
_prepare_cp_metadata
(
common_attn_metadata
=
self
.
_prepare_cp_metadata
(
...
@@ -436,7 +440,8 @@ class SpecDecodeBaseProposer:
...
@@ -436,7 +440,8 @@ class SpecDecodeBaseProposer:
),
),
scatter_indexes_tensor
=
self
.
scatter_indexes_tensor
,
scatter_indexes_tensor
=
self
.
scatter_indexes_tensor
,
gather_indexes_tensor
=
self
.
gather_indexes_tensor
,
gather_indexes_tensor
=
self
.
gather_indexes_tensor
,
enable_mla_cp
=
envs
.
VLLM_MLA_CP
and
num_tokens
>
self
.
runner
.
mla_cp_threshould
,
enable_lightly_cp
=
self
.
enable_lightly_cp
and
num_tokens
>
self
.
runner
.
lightly_cp_threshould
,
enable_lightly_cplb
=
self
.
enable_lightly_cplb
):
):
ret_hidden_states
=
self
.
model
(
**
model_kwargs
)
ret_hidden_states
=
self
.
model
(
**
model_kwargs
)
if
not
self
.
model_returns_tuple
():
if
not
self
.
model_returns_tuple
():
...
@@ -513,7 +518,7 @@ class SpecDecodeBaseProposer:
...
@@ -513,7 +518,7 @@ class SpecDecodeBaseProposer:
if
batch_size_across_dp
is
not
None
:
if
batch_size_across_dp
is
not
None
:
batch_size_across_dp
[
self
.
dp_rank
]
=
input_batch_size
batch_size_across_dp
[
self
.
dp_rank
]
=
input_batch_size
if
enable_
mla
_cp
:
if
enable_
lightly
_cp
:
common_attn_metadata
=
common_attn_metadata
.
cp_common_metadata
common_attn_metadata
=
common_attn_metadata
.
cp_common_metadata
common_attn_metadata
.
num_actual_tokens
=
batch_size
common_attn_metadata
.
num_actual_tokens
=
batch_size
...
...
vllm/v1/worker/dp_utils.py
View file @
b58514dd
...
@@ -6,6 +6,7 @@ import numpy as np
...
@@ -6,6 +6,7 @@ import numpy as np
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
vllm.envs
as
envs
from
vllm.config
import
ParallelConfig
from
vllm.config
import
ParallelConfig
from
vllm.distributed.parallel_state
import
get_dp_group
from
vllm.distributed.parallel_state
import
get_dp_group
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -208,7 +209,7 @@ def coordinate_batch_across_dp(
...
@@ -208,7 +209,7 @@ def coordinate_batch_across_dp(
]
]
"""
"""
if
parallel_config
.
data_parallel_size
==
1
:
if
parallel_config
.
data_parallel_size
==
1
or
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
:
# Early exit.
# Early exit.
return
False
,
None
,
cudagraph_mode
return
False
,
None
,
cudagraph_mode
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
b58514dd
...
@@ -189,6 +189,7 @@ from .utils import (
...
@@ -189,6 +189,7 @@ from .utils import (
sanity_check_mm_encoder_outputs
,
sanity_check_mm_encoder_outputs
,
)
)
from
vllm.v1.spec_decode.utils
import
DraftProbs
from
vllm.v1.spec_decode.utils
import
DraftProbs
from
vllm.utils.torch_utils
import
async_tensor_h2d
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
...
@@ -382,12 +383,15 @@ class GPUModelRunner(
...
@@ -382,12 +383,15 @@ class GPUModelRunner(
self
.
dcp_rank
=
0
if
self
.
dcp_world_size
<=
1
else
get_dcp_group
().
rank_in_group
self
.
dcp_rank
=
0
if
self
.
dcp_world_size
<=
1
else
get_dcp_group
().
rank_in_group
self
.
max_num_tokens
=
scheduler_config
.
max_num_batched_tokens
self
.
max_num_tokens
=
scheduler_config
.
max_num_batched_tokens
#self.max_num_reqs = scheduler_config.max_num_seqs
#self.max_num_reqs = scheduler_config.max_num_seqs
self
.
enable_lightly_cp
=
self
.
parallel_config
.
enable_lightly_cp
self
.
enable_lightly_cplb
=
self
.
enable_lightly_cp
and
self
.
parallel_config
.
enable_lightly_cplb
self
.
max_num_reqs
=
(
self
.
max_num_reqs
=
(
scheduler_config
.
max_num_seqs
scheduler_config
.
max_num_seqs
if
not
envs
.
VLLM_MLA_CPLB
if
not
self
.
enable_lightly_cplb
else
scheduler_config
.
max_num_seqs
*
2
else
scheduler_config
.
max_num_seqs
*
2
)
)
self
.
mla
_cp_threshould
=
512
self
.
lightly
_cp_threshould
=
envs
.
VLLM_LIGHTLY_CP_THRESHOULD
# Broadcast PP output for external_launcher (torchrun)
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# to make sure we are synced across pp ranks
...
@@ -1525,7 +1529,7 @@ class GPUModelRunner(
...
@@ -1525,7 +1529,7 @@ class GPUModelRunner(
local_scatter_indexes_tensor
=
None
local_scatter_indexes_tensor
=
None
gather_indexes_tensor
=
None
gather_indexes_tensor
=
None
if
envs
.
VLLM_MLA_CPLB
:
if
self
.
enable_lightly_cp
:
rank_tokens
=
0
rank_tokens
=
0
rank_pad_tokens
=
0
rank_pad_tokens
=
0
accu_q_start
=
0
accu_q_start
=
0
...
@@ -1736,7 +1740,7 @@ class GPUModelRunner(
...
@@ -1736,7 +1740,7 @@ class GPUModelRunner(
cp_common_metadata
=
cp_common_metadata
,
cp_common_metadata
=
cp_common_metadata
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
enable_
mla
_cp
=
True
enable_
lightly
_cp
=
True
)
)
return
cm_base
return
cm_base
...
@@ -2040,8 +2044,8 @@ class GPUModelRunner(
...
@@ -2040,8 +2044,8 @@ class GPUModelRunner(
if
self
.
model_config
.
enable_return_routed_experts
:
if
self
.
model_config
.
enable_return_routed_experts
:
self
.
slot_mapping
=
slot_mapping_gid_0
[:
num_tokens
].
cpu
().
numpy
()
self
.
slot_mapping
=
slot_mapping_gid_0
[:
num_tokens
].
cpu
().
numpy
()
mla_cp_enable
=
envs
.
VLLM_MLA_CP
and
num_tokens
>
self
.
mla
_cp_threshould
enable_lightly_cp
=
self
.
enable_lightly_cp
and
num_tokens
>
self
.
lightly
_cp_threshould
if
not
mla_cp_enable
:
if
not
enable_lightly_cp
:
cm_base
=
CommonAttentionMetadata
(
cm_base
=
CommonAttentionMetadata
(
query_start_loc
=
self
.
query_start_loc
.
gpu
[:
num_reqs_padded
+
1
],
query_start_loc
=
self
.
query_start_loc
.
gpu
[:
num_reqs_padded
+
1
],
query_start_loc_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs_padded
+
1
],
query_start_loc_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs_padded
+
1
],
...
@@ -2183,19 +2187,19 @@ class GPUModelRunner(
...
@@ -2183,19 +2187,19 @@ class GPUModelRunner(
cm
.
block_table_tensor
=
_get_block_table
(
kv_cache_gid
)
cm
.
block_table_tensor
=
_get_block_table
(
kv_cache_gid
)
cm
.
slot_mapping
=
slot_mappings
[
kv_cache_gid
]
cm
.
slot_mapping
=
slot_mappings
[
kv_cache_gid
]
if
cm
.
seq_indexes_list
is
not
None
:
if
enable_lightly_cp
and
cm
.
seq_indexes_list
is
not
None
:
cm
.
block_table_tensor
=
cm
.
block_table_tensor
[
cm
.
seq_indexes_list
]
cm
.
block_table_tensor
=
cm
.
block_table_tensor
[
cm
.
seq_indexes_list
]
if
self
.
speculative_config
and
spec_decode_common_attn_metadata
is
None
and
hasattr
(
self
,
"drafter"
):
if
self
.
speculative_config
and
spec_decode_common_attn_metadata
is
None
and
hasattr
(
self
,
"drafter"
):
if
isinstance
(
self
.
drafter
,
EagleProposer
):
if
isinstance
(
self
.
drafter
,
EagleProposer
):
if
self
.
drafter
.
attn_layer_names
[
0
]
in
kv_cache_group
.
layer_names
:
if
self
.
drafter
.
attn_layer_names
[
0
]
in
kv_cache_group
.
layer_names
:
if
mla_cp_enable
:
if
enable_lightly_cp
:
spec_decode_common_attn_metadata
=
cm
.
cp_common_metadata
spec_decode_common_attn_metadata
=
cm
.
cp_common_metadata
else
:
else
:
spec_decode_common_attn_metadata
=
cm
spec_decode_common_attn_metadata
=
cm
#spec_decode_common_attn_metadata = cm
#spec_decode_common_attn_metadata = cm
else
:
else
:
if
mla_cp_enable
:
if
enable_lightly_cp
:
spec_decode_common_attn_metadata
=
cm
.
cp_common_metadata
spec_decode_common_attn_metadata
=
cm
.
cp_common_metadata
else
:
else
:
spec_decode_common_attn_metadata
=
cm
spec_decode_common_attn_metadata
=
cm
...
@@ -2230,7 +2234,7 @@ class GPUModelRunner(
...
@@ -2230,7 +2234,7 @@ class GPUModelRunner(
_metadata
.
mm_prefix_range
=
req_doc_ranges
# type: ignore[attr-defined]
_metadata
.
mm_prefix_range
=
req_doc_ranges
# type: ignore[attr-defined]
if
(
if
(
(
not
envs
.
VLLM_MLA_CP
)
(
not
self
.
enable_lightly_cp
)
and
spec_decode_common_attn_metadata
is
not
None
and
spec_decode_common_attn_metadata
is
not
None
and
(
num_reqs
!=
num_reqs_padded
or
num_tokens
!=
num_tokens_padded
)
and
(
num_reqs
!=
num_reqs_padded
or
num_tokens
!=
num_tokens_padded
)
):
):
...
@@ -3110,7 +3114,7 @@ class GPUModelRunner(
...
@@ -3110,7 +3114,7 @@ class GPUModelRunner(
# Pad tokens to multiple of tensor_parallel_size when
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
# enabled collective fusion for SP
if
envs
.
VLLM_MLA_CP
and
num_scheduled_tokens
>
self
.
mla
_cp_threshould
:
if
self
.
enable_lightly_cp
and
num_scheduled_tokens
>
self
.
lightly
_cp_threshould
:
return
self
.
_pad_for_mla_cp
(
num_scheduled_tokens
)
return
self
.
_pad_for_mla_cp
(
num_scheduled_tokens
)
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
if
self
.
compilation_config
.
pass_config
.
enable_sp
and
tp_size
>
1
:
if
self
.
compilation_config
.
pass_config
.
enable_sp
and
tp_size
>
1
:
...
@@ -3808,7 +3812,7 @@ class GPUModelRunner(
...
@@ -3808,7 +3812,7 @@ class GPUModelRunner(
)
)
num_tokens_padded
=
batch_desc
.
num_tokens
num_tokens_padded
=
batch_desc
.
num_tokens
if
envs
.
VLLM_MLA_CP
and
num_tokens_unpadded
>
self
.
mla
_cp_threshould
:
if
self
.
enable_lightly_cp
and
num_tokens_unpadded
>
self
.
lightly
_cp_threshould
:
num_tokens_padded
=
self
.
_pad_for_mla_cp
(
num_tokens_unpadded
)
num_tokens_padded
=
self
.
_pad_for_mla_cp
(
num_tokens_unpadded
)
num_reqs_padded
=
(
num_reqs_padded
=
(
batch_desc
.
num_reqs
if
batch_desc
.
num_reqs
is
not
None
else
num_reqs
batch_desc
.
num_reqs
if
batch_desc
.
num_reqs
is
not
None
else
num_reqs
...
@@ -3927,7 +3931,8 @@ class GPUModelRunner(
...
@@ -3927,7 +3931,8 @@ class GPUModelRunner(
skip_compiled
=
has_encoder_input
,
skip_compiled
=
has_encoder_input
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
enable_mla_cp
=
envs
.
VLLM_MLA_CP
and
num_tokens_unpadded
>
self
.
mla_cp_threshould
,
enable_lightly_cp
=
self
.
enable_lightly_cp
and
num_tokens_unpadded
>
self
.
lightly_cp_threshould
,
enable_lightly_cplb
=
self
.
enable_lightly_cplb
),
),
record_function_or_nullcontext
(
"gpu_model_runner: forward"
),
record_function_or_nullcontext
(
"gpu_model_runner: forward"
),
self
.
maybe_get_kv_connector_output
(
self
.
maybe_get_kv_connector_output
(
...
@@ -4421,7 +4426,7 @@ class GPUModelRunner(
...
@@ -4421,7 +4426,7 @@ class GPUModelRunner(
)
)
#total_num_tokens = common_attn_metadata.num_actual_tokens
#total_num_tokens = common_attn_metadata.num_actual_tokens
if
(
if
(
envs
.
VLLM_MLA_CP
self
.
enable_lightly_cp
and
common_attn_metadata
.
cp_common_metadata
is
not
None
and
common_attn_metadata
.
cp_common_metadata
is
not
None
):
):
total_num_tokens
=
(
total_num_tokens
=
(
...
@@ -4952,9 +4957,6 @@ class GPUModelRunner(
...
@@ -4952,9 +4957,6 @@ class GPUModelRunner(
or
cudagraph_runtime_mode
.
valid_runtime_modes
()
or
cudagraph_runtime_mode
.
valid_runtime_modes
()
)
)
# if envs.VLLM_MLA_CP:
# num_tokens = max(self.tp_size, num_tokens)
# If cudagraph_mode.decode_mode() == FULL and
# If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.separate_routine(). This means that we are using
# cudagraph_mode.separate_routine(). This means that we are using
# different graphs and/or modes for mixed prefill-decode batches vs.
# different graphs and/or modes for mixed prefill-decode batches vs.
...
@@ -5116,9 +5118,6 @@ class GPUModelRunner(
...
@@ -5116,9 +5118,6 @@ class GPUModelRunner(
inputs_embeds
=
self
.
inputs_embeds
.
gpu
[:
num_tokens_padded
]
inputs_embeds
=
self
.
inputs_embeds
.
gpu
[:
num_tokens_padded
]
model_kwargs
=
self
.
_init_model_kwargs
()
model_kwargs
=
self
.
_init_model_kwargs
()
else
:
else
:
self
.
input_ids
.
gpu
[:
num_tokens_padded
]
=
torch
.
randint
(
0
,
self
.
model_config
.
get_vocab_size
(),
(
num_tokens_padded
,),
dtype
=
torch
.
int32
)
input_ids
=
self
.
input_ids
.
gpu
[:
num_tokens_padded
]
input_ids
=
self
.
input_ids
.
gpu
[:
num_tokens_padded
]
inputs_embeds
=
None
inputs_embeds
=
None
...
@@ -5159,7 +5158,8 @@ class GPUModelRunner(
...
@@ -5159,7 +5158,8 @@ class GPUModelRunner(
batch_descriptor
=
batch_desc
,
batch_descriptor
=
batch_desc
,
ubatch_slices
=
ubatch_slices_padded
,
ubatch_slices
=
ubatch_slices_padded
,
slot_mapping
=
slot_mappings
,
slot_mapping
=
slot_mappings
,
enable_mla_cp
=
envs
.
VLLM_MLA_CP
and
num_tokens_unpadded
>
self
.
mla_cp_threshould
,
enable_lightly_cp
=
self
.
enable_lightly_cp
and
num_tokens_unpadded
>
self
.
lightly_cp_threshould
,
enable_lightly_cplb
=
self
.
enable_lightly_cplb
),
),
):
):
outputs
=
self
.
model
(
outputs
=
self
.
model
(
...
@@ -5232,9 +5232,15 @@ class GPUModelRunner(
...
@@ -5232,9 +5232,15 @@ class GPUModelRunner(
self
.
eplb_step
(
is_dummy
=
True
,
is_profile
=
is_profile
)
self
.
eplb_step
(
is_dummy
=
True
,
is_profile
=
is_profile
)
logit_indices
=
np
.
cumsum
(
num_scheduled_tokens
)
-
1
logit_indices
=
np
.
cumsum
(
num_scheduled_tokens
)
-
1
logit_indices_device
=
torch
.
from_numpy
(
logit_indices
).
to
(
# logit_indices_device = torch.from_numpy(logit_indices).to(
self
.
device
,
non_blocking
=
True
# self.device, non_blocking=True
)
# )
logit_indices
=
logit_indices
.
tolist
()
logit_indices_device
=
async_tensor_h2d
(
logit_indices
,
dtype
=
torch
.
int32
,
target_device
=
self
.
device
,
pin_memory
=
True
)
return
hidden_states
,
hidden_states
[
logit_indices_device
]
return
hidden_states
,
hidden_states
[
logit_indices_device
]
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment