Commit 19f117d8 authored by 王敏's avatar 王敏
Browse files

优化pp+mtp代码

parent dad7d083
......@@ -341,6 +341,17 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts, SupportsPP):
model_has_indexer = any("indexer" in param_name for param_name in params_dict.keys())
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# 跳过加载"indexer"权重
if "indexer" in name and not model_has_indexer:
logger.info(f"Skipping indexer weight (DSA disabled): {name}")
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
# load embed_tokens weight from target model if mtp weights missing embed_tokens
if "embed_tokens" in name:
for local_name in params_dict.keys():
if "embed_tokens" in local_name:
......@@ -350,20 +361,12 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts, SupportsPP):
)
weight_loader(param, loaded_weight)
break
if "rotary_emb.inv_freq" in name:
continue
# 跳过加载"indexer"权重
if "indexer" in name and not model_has_indexer:
logger.info(f"Skipping indexer weight (DSA disabled): {name}")
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
continue
is_fusion_moe_shared_experts_layer = (
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
)
name = self._rewrite_spec_layer_name(spec_layer, name)
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
......
......@@ -100,8 +100,6 @@ class Scheduler(SchedulerInterface):
# Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
self.max_num_running_reqs = self.scheduler_config.max_num_seqs * self.vllm_config.parallel_config.pipeline_parallel_size
self.max_num_per_batch = self.scheduler_config.max_num_seqs
self.max_num_scheduled_tokens = self.scheduler_config.max_num_batched_tokens
self.max_model_len = vllm_config.model_config.max_model_len
self.enable_kv_cache_events = (
......@@ -358,10 +356,6 @@ class Scheduler(SchedulerInterface):
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
current_batch_size = len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs)
if current_batch_size == self.max_num_per_batch:
break
# do not schedule another step for the same request while it still has
# output placeholders for PP.
# TODO: support PP + async scheduling without this limit
......@@ -370,7 +364,7 @@ class Scheduler(SchedulerInterface):
len(scheduled_new_reqs) + len(scheduled_resumed_reqs)
+ len(scheduled_running_reqs) >= max_batch_running):
break
if request.num_output_placeholders > 0:
if request.num_output_placeholders > 0 and self.scheduler_config.async_scheduling:
req_index += 1
continue
......@@ -559,9 +553,7 @@ class Scheduler(SchedulerInterface):
# Next, schedule the WAITING requests.
if not preempted_reqs:
while self.waiting and token_budget > 0:
#if len(self.running) == self.max_num_running_reqs:
current_batch_size = len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs)
if len(self.running) == self.max_num_running_reqs or current_batch_size == self.max_num_per_batch:
if len(self.running) == self.max_num_running_reqs:
break
if (self.use_pp and envs.VLLM_USE_PP_BALANCE and
len(scheduled_new_reqs) + len(scheduled_resumed_reqs)
......@@ -679,7 +671,6 @@ class Scheduler(SchedulerInterface):
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens.
#num_new_tokens = request.num_tokens - num_computed_tokens
if self.is_mtp_kv_consumer:
num_new_tokens = (request.num_tokens_with_spec -
num_computed_tokens)
......@@ -1086,6 +1077,10 @@ class Scheduler(SchedulerInterface):
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens.
if self.is_mtp_kv_consumer:
num_new_tokens = (request.num_tokens_with_spec -
num_computed_tokens)
else:
num_new_tokens = request.num_tokens - num_computed_tokens
threshold = self.scheduler_config.long_prefill_token_threshold
if 0 < threshold < num_new_tokens:
......@@ -1190,6 +1185,20 @@ class Scheduler(SchedulerInterface):
self._update_connector_prefix_cache_stats(request)
# Speculative decode related.
if (self.is_mtp_kv_consumer or not self.vllm_config.kv_transfer_config) and request.spec_token_ids:
num_scheduled_spec_tokens = (num_new_tokens +
num_computed_tokens -
request.num_tokens)
if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids)
else:
# Prefill request: spec tokens not applicable yet.
request.spec_token_ids = []
self.running.append(request)
if self.log_stats:
request.record_event(
......@@ -1242,7 +1251,7 @@ class Scheduler(SchedulerInterface):
# do not schedule another step for the same request while it still has
# output placeholders for PP.
# TODO: support PP + async scheduling without this limit
if self.use_pp and request.num_output_placeholders > 0:
if self.use_pp and request.num_output_placeholders > 0 and self.scheduler_config.async_scheduling:
req_index += 1
continue
......@@ -1300,7 +1309,7 @@ class Scheduler(SchedulerInterface):
request, num_new_tokens
)
if num_new_tokens == 0:
if num_new_tokens <= 0:
# The request cannot be scheduled because one of the following
# reasons:
# 1. No new tokens to schedule. This may happen when
......
......@@ -4204,7 +4204,7 @@ class GPUModelRunner(
with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
# Get draft token ids if available
output_spec_token_ids = None
if self._draft_token_ids is not None:
if not self.use_async_scheduling and self._draft_token_ids is not None:
# Use synchronous copy to avoid NPU async stream/event
# synchronization issues. _get_draft_token_ids_cpu relies on
# event.synchronize() which may not properly wait for the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment