Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0beafe40
Commit
0beafe40
authored
Apr 10, 2026
by
王敏
Browse files
[Feat]支持pcp+mtp
parent
09f318c1
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
224 additions
and
11 deletions
+224
-11
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+25
-1
vllm/v1/attention/backend.py
vllm/v1/attention/backend.py
+17
-0
vllm/v1/spec_decode/eagle.py
vllm/v1/spec_decode/eagle.py
+158
-7
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+24
-3
No files found.
vllm/model_executor/models/deepseek_mtp.py
View file @
0beafe40
...
@@ -11,6 +11,7 @@ import torch
...
@@ -11,6 +11,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
from
vllm.forward_context
import
get_forward_context
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
...
@@ -36,6 +37,9 @@ from .deepseek_v2 import (
...
@@ -36,6 +37,9 @@ from .deepseek_v2 import (
DeepseekV2MoE
,
DeepseekV2MoE
,
get_spec_layer_idx_from_weight_name
,
get_spec_layer_idx_from_weight_name
,
)
)
from
vllm.distributed
import
(
tensor_model_parallel_all_gather
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
.utils
import
maybe_prefix
from
.utils
import
maybe_prefix
from
.interfaces
import
SupportsPP
from
.interfaces
import
SupportsPP
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
...
@@ -177,6 +181,9 @@ class DeepSeekMultiTokenPredictor(nn.Module):
...
@@ -177,6 +181,9 @@ class DeepSeekMultiTokenPredictor(nn.Module):
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
embed_input_ids
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
return
self
.
embed_tokens
(
input_ids
)
...
@@ -191,7 +198,19 @@ class DeepSeekMultiTokenPredictor(nn.Module):
...
@@ -191,7 +198,19 @@ class DeepSeekMultiTokenPredictor(nn.Module):
if
inputs_embeds
is
None
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
current_step_idx
=
spec_step_idx
%
self
.
num_mtp_layers
current_step_idx
=
spec_step_idx
%
self
.
num_mtp_layers
return
self
.
layers
[
str
(
self
.
mtp_start_layer_idx
+
current_step_idx
)](
enable_mla_cp
=
get_forward_context
().
enable_mla_cp
#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if
enable_mla_cp
:
inputs_embeds_per_rank
=
torch
.
chunk
(
inputs_embeds
,
chunks
=
self
.
tp_size
,
dim
=
0
)
inputs_embeds
=
inputs_embeds_per_rank
[
self
.
tp_rank
].
contiguous
()
previous_hidden_states_per_rank
=
torch
.
chunk
(
previous_hidden_states
,
chunks
=
self
.
tp_size
,
dim
=
0
)
previous_hidden_states
=
previous_hidden_states_per_rank
[
self
.
tp_rank
].
contiguous
()
if
positions
is
not
None
:
positions_per_rank
=
torch
.
chunk
(
positions
,
chunks
=
self
.
tp_size
,
dim
=
0
)
positions
=
positions_per_rank
[
self
.
tp_rank
].
contiguous
()
hidden_states
=
self
.
layers
[
str
(
self
.
mtp_start_layer_idx
+
current_step_idx
)](
input_ids
,
input_ids
,
positions
,
positions
,
previous_hidden_states
,
previous_hidden_states
,
...
@@ -199,6 +218,11 @@ class DeepSeekMultiTokenPredictor(nn.Module):
...
@@ -199,6 +218,11 @@ class DeepSeekMultiTokenPredictor(nn.Module):
current_step_idx
,
current_step_idx
,
)
)
if
enable_mla_cp
:
hidden_states
=
tensor_model_parallel_all_gather
(
hidden_states
.
contiguous
(),
dim
=
0
)
return
hidden_states
def
compute_logits
(
def
compute_logits
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
...
vllm/v1/attention/backend.py
View file @
0beafe40
...
@@ -293,9 +293,26 @@ class CpCommonAttentionMetadata:
...
@@ -293,9 +293,26 @@ class CpCommonAttentionMetadata:
seq_lens
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
_seq_lens_cpu
:
torch
.
Tensor
_seq_lens_cpu
:
torch
.
Tensor
num_actual_tokens
:
int
num_actual_tokens
:
int
num_kv_actual_tokens
:
int
max_query_len
:
int
max_query_len
:
int
max_seq_len
:
int
num_reqs
:
int
num_reqs
:
int
req_ids
:
list
[
str
]
req_ids
:
list
[
str
]
block_table_tensor
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
_num_computed_tokens_cpu
:
torch
.
Tensor
dcp_local_seq_lens
:
torch
.
Tensor
|
None
=
None
dcp_local_seq_lens_cpu
:
torch
.
Tensor
|
None
=
None
def
batch_size
(
self
)
->
int
:
return
self
.
seq_lens
.
shape
[
0
]
@
property
def
seq_lens_cpu
(
self
)
->
torch
.
Tensor
:
if
self
.
_seq_lens_cpu
is
None
:
self
.
_seq_lens_cpu
=
self
.
seq_lens
.
to
(
"cpu"
)
return
self
.
_seq_lens_cpu
@
dataclass
@
dataclass
...
...
vllm/v1/spec_decode/eagle.py
View file @
0beafe40
...
@@ -14,7 +14,7 @@ from vllm.config import (
...
@@ -14,7 +14,7 @@ from vllm.config import (
VllmConfig
,
VllmConfig
,
get_layers_from_vllm_config
,
get_layers_from_vllm_config
,
)
)
from
vllm.distributed.parallel_state
import
get_pp_group
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tensor_model_parallel_rank
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
from
vllm.model_executor.layers.attention_layer_base
import
AttentionLayerBase
...
@@ -29,6 +29,7 @@ from vllm.utils.platform_utils import is_pin_memory_available
...
@@ -29,6 +29,7 @@ from vllm.utils.platform_utils import is_pin_memory_available
from
vllm.v1.attention.backend
import
(
from
vllm.v1.attention.backend
import
(
AttentionMetadataBuilder
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
CommonAttentionMetadata
,
CpCommonAttentionMetadata
,
)
)
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.backends.registry
import
AttentionBackendEnum
from
vllm.v1.attention.backends.tree_attn
import
(
from
vllm.v1.attention.backends.tree_attn
import
(
...
@@ -48,6 +49,7 @@ from vllm.v1.spec_decode.utils import (
...
@@ -48,6 +49,7 @@ from vllm.v1.spec_decode.utils import (
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.utils
import
CpuGpuBuffer
from
vllm.v1.worker.dp_utils
import
coordinate_batch_across_dp
from
vllm.v1.worker.dp_utils
import
coordinate_batch_across_dp
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
CachedRequestState
,
InputBatch
from
vllm.utils.math_utils
import
cdiv
,
round_up
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -76,7 +78,8 @@ class SpecDecodeBaseProposer:
...
@@ -76,7 +78,8 @@ class SpecDecodeBaseProposer:
self
.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
self
.
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
self
.
num_speculative_tokens
=
self
.
speculative_config
.
num_speculative_tokens
self
.
num_speculative_tokens
=
self
.
speculative_config
.
num_speculative_tokens
# The drafter can get longer sequences than the target model.
# The drafter can get longer sequences than the target model.
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
max_batch_size
=
vllm_config
.
scheduler_config
.
max_num_seqs
if
not
envs
.
VLLM_MLA_CPLB
\
else
vllm_config
.
scheduler_config
.
max_num_seqs
*
2
self
.
max_num_tokens
=
(
self
.
max_num_tokens
=
(
vllm_config
.
scheduler_config
.
max_num_batched_tokens
+
max_batch_size
vllm_config
.
scheduler_config
.
max_num_batched_tokens
+
max_batch_size
)
)
...
@@ -219,6 +222,25 @@ class SpecDecodeBaseProposer:
...
@@ -219,6 +222,25 @@ class SpecDecodeBaseProposer:
1
,
len
(
self
.
tree_choices
)
+
1
,
device
=
device
,
dtype
=
torch
.
int32
1
,
len
(
self
.
tree_choices
)
+
1
,
device
=
device
,
dtype
=
torch
.
int32
).
repeat
(
max_batch_size
,
1
)
).
repeat
(
max_batch_size
,
1
)
if
envs
.
VLLM_MLA_CP
:
self
.
scatter_indexes_tensor
=
None
self
.
gather_indexes_tensor
=
None
self
.
query_start_loc
=
CpuGpuBuffer
(
max_batch_size
+
1
,
dtype
=
torch
.
int32
,
pin_memory
=
is_pin_memory_available
(),
device
=
device
,
with_numpy
=
True
,
)
self
.
seq_lens
=
CpuGpuBuffer
(
max_batch_size
,
dtype
=
torch
.
int32
,
pin_memory
=
is_pin_memory_available
(),
device
=
device
,
with_numpy
=
True
,
)
def
_get_positions
(
self
,
num_tokens
:
int
):
def
_get_positions
(
self
,
num_tokens
:
int
):
if
self
.
uses_mrope
:
if
self
.
uses_mrope
:
return
self
.
mrope_positions
[:,
:
num_tokens
]
return
self
.
mrope_positions
[:,
:
num_tokens
]
...
@@ -270,6 +292,10 @@ class SpecDecodeBaseProposer:
...
@@ -270,6 +292,10 @@ class SpecDecodeBaseProposer:
self
.
cudagraph_dispatcher
.
initialize_cudagraph_keys
(
eagle_cudagraph_mode
)
self
.
cudagraph_dispatcher
.
initialize_cudagraph_keys
(
eagle_cudagraph_mode
)
def
_pad_for_mla_cp
(
self
,
num_scheduled_tokens
:
int
)
->
int
:
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
return
round_up
(
num_scheduled_tokens
,
tp_size
)
def
propose
(
def
propose
(
self
,
self
,
# [num_tokens]
# [num_tokens]
...
@@ -309,6 +335,31 @@ class SpecDecodeBaseProposer:
...
@@ -309,6 +335,31 @@ class SpecDecodeBaseProposer:
)
)
)
)
num_tokens_dp_padded
,
num_tokens_across_dp
=
self
.
_pad_batch_across_dp
(
num_tokens_unpadded
=
num_tokens
,
num_tokens_padded
=
num_tokens
)
enable_mla_cp
=
envs
.
VLLM_MLA_CP
and
num_tokens
>
self
.
runner
.
mla_cp_threshould
if
enable_mla_cp
:
num_tokens_dp_padded
=
self
.
_pad_for_mla_cp
(
num_tokens_dp_padded
)
common_attn_metadata
=
self
.
_prepare_cp_metadata
(
num_reqs_padded
=
common_attn_metadata
.
num_reqs
,
max_query_len
=
common_attn_metadata
.
max_query_len
,
max_seq_len
=
common_attn_metadata
.
seq_lens_cpu
.
max
().
item
(),
num_tokens
=
num_tokens
,
block_table_gid_0
=
common_attn_metadata
.
block_table_tensor
,
slot_mapping_gid_0
=
common_attn_metadata
.
slot_mapping
,
query_start_loc
=
common_attn_metadata
.
query_start_loc
,
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
,
seq_lens
=
common_attn_metadata
.
seq_lens
,
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
,
num_computed_tokens_cpu
=
common_attn_metadata
.
_num_computed_tokens_cpu
,
)
self
.
scatter_indexes_tensor
=
common_attn_metadata
.
scatter_indexes_tensor
self
.
gather_indexes_tensor
=
common_attn_metadata
.
gather_indexes_tensor
assert
self
.
runner
is
not
None
assert
self
.
runner
is
not
None
if
self
.
attn_metadata_builder
is
None
:
if
self
.
attn_metadata_builder
is
None
:
...
@@ -339,10 +390,6 @@ class SpecDecodeBaseProposer:
...
@@ -339,10 +390,6 @@ class SpecDecodeBaseProposer:
assert
draft_indexer_metadata
is
not
None
assert
draft_indexer_metadata
is
not
None
per_layer_attn_metadata
[
layer_name
]
=
draft_indexer_metadata
per_layer_attn_metadata
[
layer_name
]
=
draft_indexer_metadata
num_tokens_dp_padded
,
num_tokens_across_dp
=
self
.
_pad_batch_across_dp
(
num_tokens_unpadded
=
num_tokens
,
num_tokens_padded
=
num_tokens
)
cudagraph_runtime_mode
,
batch_desc
=
self
.
cudagraph_dispatcher
.
dispatch
(
cudagraph_runtime_mode
,
batch_desc
=
self
.
cudagraph_dispatcher
.
dispatch
(
num_tokens_dp_padded
num_tokens_dp_padded
)
)
...
@@ -387,6 +434,9 @@ class SpecDecodeBaseProposer:
...
@@ -387,6 +434,9 @@ class SpecDecodeBaseProposer:
slot_mapping
=
self
.
_get_slot_mapping
(
slot_mapping
=
self
.
_get_slot_mapping
(
num_input_tokens
,
common_attn_metadata
.
slot_mapping
num_input_tokens
,
common_attn_metadata
.
slot_mapping
),
),
scatter_indexes_tensor
=
self
.
scatter_indexes_tensor
,
gather_indexes_tensor
=
self
.
gather_indexes_tensor
,
enable_mla_cp
=
envs
.
VLLM_MLA_CP
and
num_tokens
>
self
.
runner
.
mla_cp_threshould
,
):
):
ret_hidden_states
=
self
.
model
(
**
model_kwargs
)
ret_hidden_states
=
self
.
model
(
**
model_kwargs
)
if
not
self
.
model_returns_tuple
():
if
not
self
.
model_returns_tuple
():
...
@@ -463,6 +513,9 @@ class SpecDecodeBaseProposer:
...
@@ -463,6 +513,9 @@ class SpecDecodeBaseProposer:
if
batch_size_across_dp
is
not
None
:
if
batch_size_across_dp
is
not
None
:
batch_size_across_dp
[
self
.
dp_rank
]
=
input_batch_size
batch_size_across_dp
[
self
.
dp_rank
]
=
input_batch_size
if
enable_mla_cp
:
common_attn_metadata
=
common_attn_metadata
.
cp_common_metadata
common_attn_metadata
.
num_actual_tokens
=
batch_size
common_attn_metadata
.
num_actual_tokens
=
batch_size
common_attn_metadata
.
max_query_len
=
1
common_attn_metadata
.
max_query_len
=
1
common_attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
common_attn_metadata
.
query_start_loc
=
self
.
arange
[:
batch_size
+
1
]
...
@@ -990,6 +1043,104 @@ class SpecDecodeBaseProposer:
...
@@ -990,6 +1043,104 @@ class SpecDecodeBaseProposer:
total_num_drafts
=
self
.
cu_drafts_per_level
[
level
+
1
]
total_num_drafts
=
self
.
cu_drafts_per_level
[
level
+
1
]
return
draft_token_ids_list
return
draft_token_ids_list
def
_prepare_cp_metadata
(
self
,
num_reqs_padded
,
max_query_len
,
max_seq_len
,
num_tokens
,
block_table_gid_0
,
slot_mapping_gid_0
,
query_start_loc
,
query_start_loc_cpu
,
seq_lens
,
seq_lens_cpu
,
num_computed_tokens_cpu
,
):
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
tp_rank
=
get_tensor_model_parallel_rank
()
cp_common_metadata
=
CpCommonAttentionMetadata
(
query_start_loc
=
query_start_loc
.
clone
(),
query_start_loc_cpu
=
query_start_loc_cpu
.
clone
(),
seq_lens
=
seq_lens
.
clone
(),
_seq_lens_cpu
=
seq_lens_cpu
.
clone
(),
max_query_len
=
max_query_len
,
max_seq_len
=
max_seq_len
,
num_reqs
=
num_reqs_padded
,
req_ids
=
self
.
runner
.
input_batch
.
req_ids
,
num_actual_tokens
=
num_tokens
,
num_kv_actual_tokens
=
num_tokens
,
block_table_tensor
=
block_table_gid_0
,
slot_mapping
=
slot_mapping_gid_0
,
_num_computed_tokens_cpu
=
num_computed_tokens_cpu
)
q_lens_cpu
=
query_start_loc_cpu
[
1
:]
-
query_start_loc_cpu
[:
-
1
]
kv_lens_cpu
=
seq_lens_cpu
total_q_len
=
num_tokens
total_kv_len
=
num_tokens
(
total_q_len
,
q_lens_cpu
,
seq_count
,
kv_lens_cpu
,
local_req_ids
,
scatter_indexes_tensor
,
gather_indexes_tensor
,
seq_indexes_list
,
)
=
self
.
runner
.
_distribute_tokens_to_cp_ranks
(
total_q_len
,
q_lens_cpu
,
kv_lens_cpu
,
tp_rank
,
tp_size
,
self
.
runner
.
input_batch
.
req_ids
,
)
num_reqs
=
seq_count
cu_num_tokens
=
np
.
cumsum
(
q_lens_cpu
)
self
.
query_start_loc
.
np
[
0
]
=
0
self
.
query_start_loc
.
np
[
1
:
num_reqs
+
1
]
=
cu_num_tokens
self
.
query_start_loc
.
np
[
num_reqs
+
1
:].
fill
(
cu_num_tokens
[
-
1
])
self
.
query_start_loc
.
copy_to_gpu
()
q_acc_lens
=
self
.
query_start_loc
.
gpu
[:
num_reqs
+
1
]
q_acc_lens_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs
+
1
]
max_q_len
=
max
(
q_acc_lens_cpu
)
self
.
seq_lens
.
np
[:
num_reqs
]
=
kv_lens_cpu
self
.
seq_lens
.
np
[
num_reqs
:].
fill
(
0
)
self
.
seq_lens
.
copy_to_gpu
()
kv_lens
=
self
.
seq_lens
.
gpu
[:
num_reqs
]
kv_lens_cpu
=
self
.
seq_lens
.
cpu
[:
num_reqs
]
max_kv_len
=
max
(
kv_lens_cpu
)
num_computed_tokens_cpu
=
kv_lens_cpu
-
q_acc_lens_cpu
[
1
:]
blk_table_tensor
=
block_table_gid_0
[
seq_indexes_list
]
cm_base
=
CommonAttentionMetadata
(
query_start_loc
=
q_acc_lens
,
query_start_loc_cpu
=
q_acc_lens_cpu
,
seq_lens
=
kv_lens
,
_seq_lens_cpu
=
kv_lens_cpu
,
_num_computed_tokens_cpu
=
num_computed_tokens_cpu
,
num_reqs
=
num_reqs
,
num_actual_tokens
=
total_q_len
,
max_query_len
=
max_q_len
,
max_seq_len
=
max_kv_len
,
block_table_tensor
=
blk_table_tensor
,
slot_mapping
=
slot_mapping_gid_0
,
causal
=
True
,
num_kv_actual_tokens
=
total_kv_len
,
seq_indexes_list
=
seq_indexes_list
,
cp_common_metadata
=
cp_common_metadata
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
)
return
cm_base
def
prepare_inputs
(
def
prepare_inputs
(
self
,
self
,
common_attn_metadata
:
CommonAttentionMetadata
,
common_attn_metadata
:
CommonAttentionMetadata
,
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
0beafe40
...
@@ -1644,9 +1644,11 @@ class GPUModelRunner(
...
@@ -1644,9 +1644,11 @@ class GPUModelRunner(
self
,
self
,
num_reqs_padded
,
num_reqs_padded
,
max_query_len
,
max_query_len
,
max_seq_len
,
num_tokens
,
num_tokens
,
block_table_gid_0
,
block_table_gid_0
,
slot_mapping_gid_0
,
slot_mapping_gid_0
,
num_computed_tokens_cpu
):
):
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
tp_size
=
self
.
vllm_config
.
parallel_config
.
tensor_parallel_size
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
...
@@ -1657,9 +1659,14 @@ class GPUModelRunner(
...
@@ -1657,9 +1659,14 @@ class GPUModelRunner(
seq_lens
=
self
.
seq_lens
.
gpu
[:
num_reqs_padded
].
clone
(),
seq_lens
=
self
.
seq_lens
.
gpu
[:
num_reqs_padded
].
clone
(),
_seq_lens_cpu
=
self
.
seq_lens
.
cpu
[:
num_reqs_padded
].
clone
(),
_seq_lens_cpu
=
self
.
seq_lens
.
cpu
[:
num_reqs_padded
].
clone
(),
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
max_seq_len
=
max_seq_len
,
num_reqs
=
num_reqs_padded
,
num_reqs
=
num_reqs_padded
,
req_ids
=
self
.
input_batch
.
req_ids
,
req_ids
=
self
.
input_batch
.
req_ids
,
num_actual_tokens
=
num_tokens
,
num_actual_tokens
=
num_tokens
,
num_kv_actual_tokens
=
num_tokens
,
block_table_tensor
=
block_table_gid_0
,
slot_mapping
=
slot_mapping_gid_0
,
_num_computed_tokens_cpu
=
num_computed_tokens_cpu
)
)
query_start_loc_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs_padded
+
1
]
query_start_loc_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs_padded
+
1
]
...
@@ -1725,6 +1732,7 @@ class GPUModelRunner(
...
@@ -1725,6 +1732,7 @@ class GPUModelRunner(
cp_common_metadata
=
cp_common_metadata
,
cp_common_metadata
=
cp_common_metadata
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
scatter_indexes_tensor
=
scatter_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
gather_indexes_tensor
=
gather_indexes_tensor
,
enable_mla_cp
=
True
)
)
return
cm_base
return
cm_base
...
@@ -2028,7 +2036,8 @@ class GPUModelRunner(
...
@@ -2028,7 +2036,8 @@ class GPUModelRunner(
if
self
.
model_config
.
enable_return_routed_experts
:
if
self
.
model_config
.
enable_return_routed_experts
:
self
.
slot_mapping
=
slot_mapping_gid_0
[:
num_tokens
].
cpu
().
numpy
()
self
.
slot_mapping
=
slot_mapping_gid_0
[:
num_tokens
].
cpu
().
numpy
()
if
not
envs
.
VLLM_MLA_CP
or
num_tokens
<=
self
.
mla_cp_threshould
:
mla_cp_enable
=
envs
.
VLLM_MLA_CP
and
num_tokens
>
self
.
mla_cp_threshould
if
not
mla_cp_enable
:
cm_base
=
CommonAttentionMetadata
(
cm_base
=
CommonAttentionMetadata
(
query_start_loc
=
self
.
query_start_loc
.
gpu
[:
num_reqs_padded
+
1
],
query_start_loc
=
self
.
query_start_loc
.
gpu
[:
num_reqs_padded
+
1
],
query_start_loc_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs_padded
+
1
],
query_start_loc_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs_padded
+
1
],
...
@@ -2050,9 +2059,13 @@ class GPUModelRunner(
...
@@ -2050,9 +2059,13 @@ class GPUModelRunner(
cm_base
=
self
.
_prepare_cp_metadata
(
cm_base
=
self
.
_prepare_cp_metadata
(
num_reqs_padded
,
num_reqs_padded
,
max_query_len
,
max_query_len
,
max_seq_len
,
num_tokens
,
num_tokens
,
block_table_gid_0
,
block_table_gid_0
,
slot_mapping_gid_0
,
slot_mapping_gid_0
,
self
.
input_batch
.
num_computed_tokens_cpu_tensor
[
:
num_reqs_padded
],
)
)
scatter_indexes_tensor
=
cm_base
.
scatter_indexes_tensor
scatter_indexes_tensor
=
cm_base
.
scatter_indexes_tensor
gather_indexes_tensor
=
cm_base
.
gather_indexes_tensor
gather_indexes_tensor
=
cm_base
.
gather_indexes_tensor
...
@@ -2172,9 +2185,17 @@ class GPUModelRunner(
...
@@ -2172,9 +2185,17 @@ class GPUModelRunner(
if
self
.
speculative_config
and
spec_decode_common_attn_metadata
is
None
and
hasattr
(
self
,
"drafter"
):
if
self
.
speculative_config
and
spec_decode_common_attn_metadata
is
None
and
hasattr
(
self
,
"drafter"
):
if
isinstance
(
self
.
drafter
,
EagleProposer
):
if
isinstance
(
self
.
drafter
,
EagleProposer
):
if
self
.
drafter
.
attn_layer_names
[
0
]
in
kv_cache_group
.
layer_names
:
if
self
.
drafter
.
attn_layer_names
[
0
]
in
kv_cache_group
.
layer_names
:
if
mla_cp_enable
:
spec_decode_common_attn_metadata
=
cm
.
cp_common_metadata
else
:
spec_decode_common_attn_metadata
=
cm
spec_decode_common_attn_metadata
=
cm
#spec_decode_common_attn_metadata = cm
else
:
if
mla_cp_enable
:
spec_decode_common_attn_metadata
=
cm
.
cp_common_metadata
else
:
else
:
spec_decode_common_attn_metadata
=
cm
spec_decode_common_attn_metadata
=
cm
#spec_decode_common_attn_metadata = cm
for
attn_gid
in
range
(
len
(
self
.
attn_groups
[
kv_cache_gid
])):
for
attn_gid
in
range
(
len
(
self
.
attn_groups
[
kv_cache_gid
])):
if
ubatch_slices
is
not
None
:
if
ubatch_slices
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment