Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f28328d2
Commit
f28328d2
authored
Apr 07, 2026
by
王敏
Browse files
[feat]pcp支持去掉torch compile后的精度验证
parent
6c6c9c0d
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
40 additions
and
23 deletions
+40
-23
vllm/config/vllm.py
vllm/config/vllm.py
+2
-2
vllm/distributed/kv_transfer/kv_connector/v1/du/du_swift_connector.py
...uted/kv_transfer/kv_connector/v1/du/du_swift_connector.py
+3
-3
vllm/forward_context.py
vllm/forward_context.py
+5
-0
vllm/model_executor/layers/mla.py
vllm/model_executor/layers/mla.py
+2
-1
vllm/model_executor/layers/sparse_attn_indexer.py
vllm/model_executor/layers/sparse_attn_indexer.py
+9
-4
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+6
-4
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+13
-9
No files found.
vllm/config/vllm.py
View file @
f28328d2
...
@@ -1186,8 +1186,8 @@ class VllmConfig:
...
@@ -1186,8 +1186,8 @@ class VllmConfig:
if
(
if
(
self
.
parallel_config
.
tensor_parallel_size
>
1
self
.
parallel_config
.
tensor_parallel_size
>
1
and
(
self
.
compilation_config
.
pass_config
.
enable_sp
and
(
self
.
compilation_config
.
pass_config
.
enable_sp
)
or
envs
.
VLLM_MLA_CP
)
#
or envs.VLLM_MLA_CP)
):
):
cudagraph_capture_sizes
=
self
.
update_sizes_for_sequence_parallelism
(
cudagraph_capture_sizes
=
self
.
update_sizes_for_sequence_parallelism
(
cudagraph_capture_sizes
cudagraph_capture_sizes
...
...
vllm/distributed/kv_transfer/kv_connector/v1/du/du_swift_connector.py
View file @
f28328d2
...
@@ -209,7 +209,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
...
@@ -209,7 +209,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
request_id (str): request id for log
request_id (str): request id for log
"""
"""
dst_kv_cache_layer_shape
=
dst_kv_cache_layer
.
shape
dst_kv_cache_layer_shape
=
dst_kv_cache_layer
.
shape
if
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
all
(
isinstance
(
value
,
MLACommonMetadata
)
for
value
in
attn_metadata
.
values
()):
if
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
all
(
isinstance
(
value
,
MLACommonMetadata
)
for
value
in
attn_metadata
.
values
())
or
dst_kv_cache_layer
.
ndim
==
3
:
num_pages
=
dst_kv_cache_layer_shape
[
0
]
num_pages
=
dst_kv_cache_layer_shape
[
0
]
page_size
=
dst_kv_cache_layer_shape
[
1
]
page_size
=
dst_kv_cache_layer_shape
[
1
]
dst_kv_cache_layer
=
dst_kv_cache_layer
.
reshape
(
dst_kv_cache_layer
=
dst_kv_cache_layer
.
reshape
(
...
@@ -379,7 +379,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
...
@@ -379,7 +379,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
assert
self
.
du_swift_engine
is
not
None
assert
self
.
du_swift_engine
is
not
None
is_mla
=
isinstance
(
attn_metadata
,
MLACommonMetadata
)
is_mla
=
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
kv_layer
.
ndim
==
3
def
extract_kv_from_layer
(
def
extract_kv_from_layer
(
layer
:
torch
.
Tensor
,
layer
:
torch
.
Tensor
,
...
@@ -390,7 +390,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
...
@@ -390,7 +390,7 @@ class DuSwiftConnector(KVConnectorBase_V1):
Assume the shape of the layer is (2, num_pages, page_size, xxx)
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
"""
if
isinstance
(
attn_metadata
,
MLACommonMetadata
):
if
isinstance
(
attn_metadata
,
MLACommonMetadata
)
or
layer
.
ndim
==
3
:
num_pages
,
page_size
=
layer
.
shape
[
0
],
layer
.
shape
[
1
]
num_pages
,
page_size
=
layer
.
shape
[
0
],
layer
.
shape
[
1
]
return
layer
.
reshape
(
num_pages
*
page_size
,
-
1
)[
slot_mapping
,
return
layer
.
reshape
(
num_pages
*
page_size
,
-
1
)[
slot_mapping
,
...]
...]
...
...
vllm/forward_context.py
View file @
f28328d2
...
@@ -242,6 +242,7 @@ class ForwardContext:
...
@@ -242,6 +242,7 @@ class ForwardContext:
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
enable_mla_cp
:
bool
=
False
def
__post_init__
(
self
):
def
__post_init__
(
self
):
assert
self
.
cudagraph_runtime_mode
.
valid_runtime_modes
(),
(
assert
self
.
cudagraph_runtime_mode
.
valid_runtime_modes
(),
(
...
@@ -278,6 +279,7 @@ def create_forward_context(
...
@@ -278,6 +279,7 @@ def create_forward_context(
skip_compiled
:
bool
=
False
,
skip_compiled
:
bool
=
False
,
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
enable_mla_cp
:
bool
=
False
):
):
if
vllm_config
.
compilation_config
.
fast_moe_cold_start
:
if
vllm_config
.
compilation_config
.
fast_moe_cold_start
:
if
vllm_config
.
speculative_config
is
None
:
if
vllm_config
.
speculative_config
is
None
:
...
@@ -305,6 +307,7 @@ def create_forward_context(
...
@@ -305,6 +307,7 @@ def create_forward_context(
skip_compiled
=
skip_compiled
,
skip_compiled
=
skip_compiled
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
enable_mla_cp
=
enable_mla_cp
,
additional_kwargs
=
additional_kwargs
or
{},
additional_kwargs
=
additional_kwargs
or
{},
)
)
...
@@ -338,6 +341,7 @@ def set_forward_context(
...
@@ -338,6 +341,7 @@ def set_forward_context(
skip_compiled
:
bool
=
False
,
skip_compiled
:
bool
=
False
,
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
scatter_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
gather_indexes_tensor
:
torch
.
Tensor
|
None
=
None
,
enable_mla_cp
:
bool
=
False
,
):
):
"""A context manager that stores the current forward context,
"""A context manager that stores the current forward context,
can be attention metadata, etc.
can be attention metadata, etc.
...
@@ -400,6 +404,7 @@ def set_forward_context(
...
@@ -400,6 +404,7 @@ def set_forward_context(
skip_compiled
,
skip_compiled
,
scatter_indexes_tensor
,
scatter_indexes_tensor
,
gather_indexes_tensor
,
gather_indexes_tensor
,
enable_mla_cp
)
)
try
:
try
:
...
...
vllm/model_executor/layers/mla.py
View file @
f28328d2
...
@@ -7,6 +7,7 @@ import torch
...
@@ -7,6 +7,7 @@ import torch
from
vllm.attention.layer
import
MLAAttention
from
vllm.attention.layer
import
MLAAttention
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.forward_context
import
get_forward_context
from
vllm.model_executor.custom_op
import
PluggableLayer
from
vllm.model_executor.custom_op
import
PluggableLayer
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.distributed
import
(
from
vllm.distributed
import
(
...
@@ -187,7 +188,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
...
@@ -187,7 +188,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
if
llama_4_scaling
is
not
None
:
if
llama_4_scaling
is
not
None
:
q
*=
llama_4_scaling
q
*=
llama_4_scaling
enable_mla_cp
=
envs
.
VLLM_MLA_CP
# and not get_forward_context().draft_model
enable_mla_cp
=
get_forward_context
().
enable_mla_cp
#
envs.VLLM_MLA_CP # and not get_forward_context().draft_model
# if not use_fused_rms_rope_concat:
# if not use_fused_rms_rope_concat:
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
if
not
envs
.
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT
:
...
...
vllm/model_executor/layers/sparse_attn_indexer.py
View file @
f28328d2
...
@@ -20,6 +20,10 @@ from vllm.v1.attention.ops.rocm_aiter_mla_sparse import indexer_k_bf16_cache_tri
...
@@ -20,6 +20,10 @@ from vllm.v1.attention.ops.rocm_aiter_mla_sparse import indexer_k_bf16_cache_tri
from
vllm.v1.worker.workspace
import
current_workspace_manager
from
vllm.v1.worker.workspace
import
current_workspace_manager
from
lightop
import
op
,
gemmopt
from
lightop
import
op
,
gemmopt
from
vllm.attention.utils.kv_transfer_utils
import
(
maybe_transfer_kv_layer
,
)
if
current_platform
.
is_cuda_alike
():
if
current_platform
.
is_cuda_alike
():
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
elif
current_platform
.
is_xpu
():
elif
current_platform
.
is_xpu
():
...
@@ -28,9 +32,10 @@ elif current_platform.is_xpu():
...
@@ -28,9 +32,10 @@ elif current_platform.is_xpu():
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
@
maybe_transfer_kv_layer
def
sparse_attn_indexer
(
def
sparse_attn_indexer
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
k_cache_prefix
:
str
,
layer_name
:
str
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
...
@@ -56,7 +61,7 @@ def sparse_attn_indexer(
...
@@ -56,7 +61,7 @@ def sparse_attn_indexer(
)
)
return
sparse_attn_indexer_fake
(
return
sparse_attn_indexer_fake
(
hidden_states
,
hidden_states
,
k_cache_prefix
,
layer_name
,
kv_cache
,
kv_cache
,
q_fp8
,
q_fp8
,
k
,
k
,
...
@@ -69,7 +74,7 @@ def sparse_attn_indexer(
...
@@ -69,7 +74,7 @@ def sparse_attn_indexer(
total_seq_lens
,
total_seq_lens
,
topk_indices_buffer
,
topk_indices_buffer
,
)
)
attn_metadata
=
attn_metadata
[
k_cache_prefix
]
attn_metadata
=
attn_metadata
[
layer_name
]
assert
isinstance
(
attn_metadata
,
DeepseekV32IndexerMetadata
)
assert
isinstance
(
attn_metadata
,
DeepseekV32IndexerMetadata
)
slot_mapping
=
attn_metadata
.
slot_mapping
[:
attn_metadata
.
num_kv_actual_tokens
]
slot_mapping
=
attn_metadata
.
slot_mapping
[:
attn_metadata
.
num_kv_actual_tokens
]
has_decode
=
attn_metadata
.
num_decodes
>
0
has_decode
=
attn_metadata
.
num_decodes
>
0
...
@@ -282,7 +287,7 @@ def sparse_attn_indexer(
...
@@ -282,7 +287,7 @@ def sparse_attn_indexer(
def
sparse_attn_indexer_fake
(
def
sparse_attn_indexer_fake
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
k_cache_prefix
:
str
,
layer_name
:
str
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
q_fp8
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
f28328d2
...
@@ -46,6 +46,7 @@ from vllm.distributed import (
...
@@ -46,6 +46,7 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
)
)
from
vllm.forward_context
import
get_forward_context
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_reduce
,
tensor_model_parallel_reduce_scatter
from
vllm.distributed.communication_op
import
tensor_model_parallel_all_reduce
,
tensor_model_parallel_reduce_scatter
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
...
@@ -228,7 +229,7 @@ class DeepseekV2MLP(nn.Module):
...
@@ -228,7 +229,7 @@ class DeepseekV2MLP(nn.Module):
x
,
x
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
):
):
enable_mla_cp
=
envs
.
VLLM_MLA_CP
# and not get_forward_context().draft_model
enable_mla_cp
=
get_forward_context
().
enable_mla_cp
#
envs.VLLM_MLA_CP# and not get_forward_context().draft_model
if
enable_mla_cp
:
if
enable_mla_cp
:
x
=
tensor_model_parallel_all_gather
(
x
=
tensor_model_parallel_all_gather
(
x
.
contiguous
(),
0
x
.
contiguous
(),
0
...
@@ -249,6 +250,7 @@ class DeepseekV2MLP(nn.Module):
...
@@ -249,6 +250,7 @@ class DeepseekV2MLP(nn.Module):
if
enable_mla_cp
:
if
enable_mla_cp
:
x
=
tensor_model_parallel_reduce_scatter
(
x
.
contiguous
(),
dim
=
0
)
x
=
tensor_model_parallel_reduce_scatter
(
x
.
contiguous
(),
dim
=
0
)
return
x
elif
self
.
tp_size
>
1
:
elif
self
.
tp_size
>
1
:
x
=
tensor_model_parallel_all_reduce
(
x
)
x
=
tensor_model_parallel_all_reduce
(
x
)
return
x
return
x
...
@@ -430,7 +432,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -430,7 +432,7 @@ class DeepseekV2MoE(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
*
,
iqis
:
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
|
None
=
None
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
enable_mla_cp
=
envs
.
VLLM_MLA_CP
#and not get_forward_context().draft_model
enable_mla_cp
=
get_forward_context
().
enable_mla_cp
#
envs.VLLM_MLA_CP #and not get_forward_context().draft_model
if
enable_mla_cp
:
if
enable_mla_cp
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
0
hidden_states
.
contiguous
(),
0
...
@@ -839,7 +841,7 @@ class Indexer(nn.Module):
...
@@ -839,7 +841,7 @@ class Indexer(nn.Module):
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k
=
torch
.
cat
([
k_pe
.
squeeze
(
-
2
),
k_nope
],
dim
=-
1
)
k
=
torch
.
cat
([
k_pe
.
squeeze
(
-
2
),
k_nope
],
dim
=-
1
)
enable_mla_cp
=
envs
.
VLLM_MLA_CP
# and not get_forward_context().draft_model
enable_mla_cp
=
get_forward_context
().
enable_mla_cp
#
envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if
enable_mla_cp
:
if
enable_mla_cp
:
k
=
tensor_model_parallel_all_gather
(
k
=
tensor_model_parallel_all_gather
(
k
.
contiguous
(),
0
k
.
contiguous
(),
0
...
@@ -1376,7 +1378,7 @@ class DeepseekV2Model(nn.Module):
...
@@ -1376,7 +1378,7 @@ class DeepseekV2Model(nn.Module):
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
residual
=
intermediate_tensors
[
"residual"
]
residual
=
intermediate_tensors
[
"residual"
]
enable_mla_cp
=
envs
.
VLLM_MLA_CP
# and not get_forward_context().draft_model
enable_mla_cp
=
get_forward_context
().
enable_mla_cp
#
envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if
enable_mla_cp
:
if
enable_mla_cp
:
hidden_states_per_rank
=
torch
.
chunk
(
hidden_states
,
chunks
=
self
.
tp_size
,
dim
=
0
)
hidden_states_per_rank
=
torch
.
chunk
(
hidden_states
,
chunks
=
self
.
tp_size
,
dim
=
0
)
hidden_states
=
hidden_states_per_rank
[
self
.
tp_rank
].
contiguous
()
hidden_states
=
hidden_states_per_rank
[
self
.
tp_rank
].
contiguous
()
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
f28328d2
...
@@ -387,6 +387,7 @@ class GPUModelRunner(
...
@@ -387,6 +387,7 @@ class GPUModelRunner(
if
not
envs
.
VLLM_MLA_CPLB
if
not
envs
.
VLLM_MLA_CPLB
else
scheduler_config
.
max_num_seqs
*
2
else
scheduler_config
.
max_num_seqs
*
2
)
)
self
.
mla_cp_threshould
=
512
# Broadcast PP output for external_launcher (torchrun)
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# to make sure we are synced across pp ranks
...
@@ -2027,7 +2028,7 @@ class GPUModelRunner(
...
@@ -2027,7 +2028,7 @@ class GPUModelRunner(
if
self
.
model_config
.
enable_return_routed_experts
:
if
self
.
model_config
.
enable_return_routed_experts
:
self
.
slot_mapping
=
slot_mapping_gid_0
[:
num_tokens
].
cpu
().
numpy
()
self
.
slot_mapping
=
slot_mapping_gid_0
[:
num_tokens
].
cpu
().
numpy
()
if
not
envs
.
VLLM_MLA_CP
or
num_tokens
<=
tp_size
*
tp_size
:
if
not
envs
.
VLLM_MLA_CP
or
num_tokens
<=
self
.
mla_cp_threshould
:
cm_base
=
CommonAttentionMetadata
(
cm_base
=
CommonAttentionMetadata
(
query_start_loc
=
self
.
query_start_loc
.
gpu
[:
num_reqs_padded
+
1
],
query_start_loc
=
self
.
query_start_loc
.
gpu
[:
num_reqs_padded
+
1
],
query_start_loc_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs_padded
+
1
],
query_start_loc_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs_padded
+
1
],
...
@@ -3074,16 +3075,17 @@ class GPUModelRunner(
...
@@ -3074,16 +3075,17 @@ class GPUModelRunner(
def
_pad_for_mla_cp
(
self
,
num_scheduled_tokens
:
int
)
->
int
:
def
_pad_for_mla_cp
(
self
,
num_scheduled_tokens
:
int
)
->
int
:
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
if
num_scheduled_tokens
<=
tp_size
*
tp_size
:
# if num_scheduled_tokens <= tp_size * tp_size:
return
num_scheduled_tokens
*
tp_size
# return num_scheduled_tokens
else
:
# else:
# return round_up(num_scheduled_tokens, tp_size)
return
round_up
(
num_scheduled_tokens
,
tp_size
)
return
round_up
(
num_scheduled_tokens
,
tp_size
)
def
_pad_for_sequence_parallelism
(
self
,
num_scheduled_tokens
:
int
)
->
int
:
def
_pad_for_sequence_parallelism
(
self
,
num_scheduled_tokens
:
int
)
->
int
:
# Pad tokens to multiple of tensor_parallel_size when
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
# enabled collective fusion for SP
if
envs
.
VLLM_MLA_CP
:
if
envs
.
VLLM_MLA_CP
and
num_scheduled_tokens
>
self
.
mla_cp_threshould
:
return
self
.
_pad_for_mla_cp
(
num_scheduled_tokens
)
return
self
.
_pad_for_mla_cp
(
num_scheduled_tokens
)
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
if
self
.
compilation_config
.
pass_config
.
enable_sp
and
tp_size
>
1
:
if
self
.
compilation_config
.
pass_config
.
enable_sp
and
tp_size
>
1
:
...
@@ -3781,7 +3783,7 @@ class GPUModelRunner(
...
@@ -3781,7 +3783,7 @@ class GPUModelRunner(
)
)
num_tokens_padded
=
batch_desc
.
num_tokens
num_tokens_padded
=
batch_desc
.
num_tokens
if
envs
.
VLLM_MLA_CP
:
if
envs
.
VLLM_MLA_CP
and
num_tokens_unpadded
>
self
.
mla_cp_threshould
:
num_tokens_padded
=
self
.
_pad_for_mla_cp
(
num_tokens_unpadded
)
num_tokens_padded
=
self
.
_pad_for_mla_cp
(
num_tokens_unpadded
)
num_reqs_padded
=
(
num_reqs_padded
=
(
batch_desc
.
num_reqs
if
batch_desc
.
num_reqs
is
not
None
else
num_reqs
batch_desc
.
num_reqs
if
batch_desc
.
num_reqs
is
not
None
else
num_reqs
...
@@ -3899,6 +3901,7 @@ class GPUModelRunner(
...
@@ -3899,6 +3901,7 @@ class GPUModelRunner(
skip_compiled
=
has_encoder_input
,
skip_compiled
=
has_encoder_input
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
enable_mla_cp
=
envs
.
VLLM_MLA_CP
and
num_tokens_unpadded
>
self
.
mla_cp_threshould
,
),
),
record_function_or_nullcontext
(
"gpu_model_runner: forward"
),
record_function_or_nullcontext
(
"gpu_model_runner: forward"
),
self
.
maybe_get_kv_connector_output
(
scheduler_output
)
as
kv_connector_output
,
self
.
maybe_get_kv_connector_output
(
scheduler_output
)
as
kv_connector_output
,
...
@@ -4918,8 +4921,8 @@ class GPUModelRunner(
...
@@ -4918,8 +4921,8 @@ class GPUModelRunner(
or
cudagraph_runtime_mode
.
valid_runtime_modes
()
or
cudagraph_runtime_mode
.
valid_runtime_modes
()
)
)
if
envs
.
VLLM_MLA_CP
:
#
if envs.VLLM_MLA_CP:
num_tokens
=
max
(
self
.
tp_size
,
num_tokens
)
#
num_tokens = max(self.tp_size, num_tokens)
# If cudagraph_mode.decode_mode() == FULL and
# If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.separate_routine(). This means that we are using
# cudagraph_mode.separate_routine(). This means that we are using
...
@@ -5125,6 +5128,7 @@ class GPUModelRunner(
...
@@ -5125,6 +5128,7 @@ class GPUModelRunner(
batch_descriptor
=
batch_desc
,
batch_descriptor
=
batch_desc
,
ubatch_slices
=
ubatch_slices_padded
,
ubatch_slices
=
ubatch_slices_padded
,
slot_mapping
=
slot_mappings
,
slot_mapping
=
slot_mappings
,
enable_mla_cp
=
envs
.
VLLM_MLA_CP
and
num_tokens_unpadded
>
self
.
mla_cp_threshould
),
),
):
):
outputs
=
self
.
model
(
outputs
=
self
.
model
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment