Commit 19f117d8 authored by 王敏's avatar 王敏
Browse files

优化pp+mtp代码

parent dad7d083
...@@ -266,7 +266,7 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts, SupportsPP): ...@@ -266,7 +266,7 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts, SupportsPP):
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def set_moe_parameters(self): def set_moe_parameters(self):
self.expert_weights = [] self.expert_weights = []
...@@ -341,16 +341,6 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts, SupportsPP): ...@@ -341,16 +341,6 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts, SupportsPP):
model_has_indexer = any("indexer" in param_name for param_name in params_dict.keys()) model_has_indexer = any("indexer" in param_name for param_name in params_dict.keys())
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "embed_tokens" in name:
for local_name in params_dict.keys():
if "embed_tokens" in local_name:
param = params_dict[local_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
break
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
...@@ -358,12 +348,25 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts, SupportsPP): ...@@ -358,12 +348,25 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts, SupportsPP):
if "indexer" in name and not model_has_indexer: if "indexer" in name and not model_has_indexer:
logger.info(f"Skipping indexer weight (DSA disabled): {name}") logger.info(f"Skipping indexer weight (DSA disabled): {name}")
continue continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None: if spec_layer is None:
# load embed_tokens weight from target model if mtp weights missing embed_tokens
if "embed_tokens" in name:
for local_name in params_dict.keys():
if "embed_tokens" in local_name:
param = params_dict[local_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
break
continue continue
is_fusion_moe_shared_experts_layer = ( is_fusion_moe_shared_experts_layer = (
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name) rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
) )
name = self._rewrite_spec_layer_name(spec_layer, name) name = self._rewrite_spec_layer_name(spec_layer, name)
for param_name, weight_name, shard_id in stacked_params_mapping: for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below). # Skip non-stacked layers and experts (experts handled below).
......
...@@ -100,8 +100,6 @@ class Scheduler(SchedulerInterface): ...@@ -100,8 +100,6 @@ class Scheduler(SchedulerInterface):
# Scheduling constraints. # Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs self.max_num_running_reqs = self.scheduler_config.max_num_seqs
self.max_num_running_reqs = self.scheduler_config.max_num_seqs * self.vllm_config.parallel_config.pipeline_parallel_size
self.max_num_per_batch = self.scheduler_config.max_num_seqs
self.max_num_scheduled_tokens = self.scheduler_config.max_num_batched_tokens self.max_num_scheduled_tokens = self.scheduler_config.max_num_batched_tokens
self.max_model_len = vllm_config.model_config.max_model_len self.max_model_len = vllm_config.model_config.max_model_len
self.enable_kv_cache_events = ( self.enable_kv_cache_events = (
...@@ -358,10 +356,6 @@ class Scheduler(SchedulerInterface): ...@@ -358,10 +356,6 @@ class Scheduler(SchedulerInterface):
while req_index < len(self.running) and token_budget > 0: while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index] request = self.running[req_index]
current_batch_size = len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs)
if current_batch_size == self.max_num_per_batch:
break
# do not schedule another step for the same request while it still has # do not schedule another step for the same request while it still has
# output placeholders for PP. # output placeholders for PP.
# TODO: support PP + async scheduling without this limit # TODO: support PP + async scheduling without this limit
...@@ -370,7 +364,7 @@ class Scheduler(SchedulerInterface): ...@@ -370,7 +364,7 @@ class Scheduler(SchedulerInterface):
len(scheduled_new_reqs) + len(scheduled_resumed_reqs) len(scheduled_new_reqs) + len(scheduled_resumed_reqs)
+ len(scheduled_running_reqs) >= max_batch_running): + len(scheduled_running_reqs) >= max_batch_running):
break break
if request.num_output_placeholders > 0: if request.num_output_placeholders > 0 and self.scheduler_config.async_scheduling:
req_index += 1 req_index += 1
continue continue
...@@ -559,9 +553,7 @@ class Scheduler(SchedulerInterface): ...@@ -559,9 +553,7 @@ class Scheduler(SchedulerInterface):
# Next, schedule the WAITING requests. # Next, schedule the WAITING requests.
if not preempted_reqs: if not preempted_reqs:
while self.waiting and token_budget > 0: while self.waiting and token_budget > 0:
#if len(self.running) == self.max_num_running_reqs: if len(self.running) == self.max_num_running_reqs:
current_batch_size = len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs)
if len(self.running) == self.max_num_running_reqs or current_batch_size == self.max_num_per_batch:
break break
if (self.use_pp and envs.VLLM_USE_PP_BALANCE and if (self.use_pp and envs.VLLM_USE_PP_BALANCE and
len(scheduled_new_reqs) + len(scheduled_resumed_reqs) len(scheduled_new_reqs) + len(scheduled_resumed_reqs)
...@@ -679,7 +671,6 @@ class Scheduler(SchedulerInterface): ...@@ -679,7 +671,6 @@ class Scheduler(SchedulerInterface):
# We use `request.num_tokens` instead of # We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed # `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens. # requests, which have output tokens.
#num_new_tokens = request.num_tokens - num_computed_tokens
if self.is_mtp_kv_consumer: if self.is_mtp_kv_consumer:
num_new_tokens = (request.num_tokens_with_spec - num_new_tokens = (request.num_tokens_with_spec -
num_computed_tokens) num_computed_tokens)
...@@ -1086,7 +1077,11 @@ class Scheduler(SchedulerInterface): ...@@ -1086,7 +1077,11 @@ class Scheduler(SchedulerInterface):
# We use `request.num_tokens` instead of # We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed # `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens. # requests, which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens if self.is_mtp_kv_consumer:
num_new_tokens = (request.num_tokens_with_spec -
num_computed_tokens)
else:
num_new_tokens = request.num_tokens - num_computed_tokens
threshold = self.scheduler_config.long_prefill_token_threshold threshold = self.scheduler_config.long_prefill_token_threshold
if 0 < threshold < num_new_tokens: if 0 < threshold < num_new_tokens:
num_new_tokens = threshold num_new_tokens = threshold
...@@ -1190,6 +1185,20 @@ class Scheduler(SchedulerInterface): ...@@ -1190,6 +1185,20 @@ class Scheduler(SchedulerInterface):
self._update_connector_prefix_cache_stats(request) self._update_connector_prefix_cache_stats(request)
# Speculative decode related.
if (self.is_mtp_kv_consumer or not self.vllm_config.kv_transfer_config) and request.spec_token_ids:
num_scheduled_spec_tokens = (num_new_tokens +
num_computed_tokens -
request.num_tokens)
if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids)
else:
# Prefill request: spec tokens not applicable yet.
request.spec_token_ids = []
self.running.append(request) self.running.append(request)
if self.log_stats: if self.log_stats:
request.record_event( request.record_event(
...@@ -1242,7 +1251,7 @@ class Scheduler(SchedulerInterface): ...@@ -1242,7 +1251,7 @@ class Scheduler(SchedulerInterface):
# do not schedule another step for the same request while it still has # do not schedule another step for the same request while it still has
# output placeholders for PP. # output placeholders for PP.
# TODO: support PP + async scheduling without this limit # TODO: support PP + async scheduling without this limit
if self.use_pp and request.num_output_placeholders > 0: if self.use_pp and request.num_output_placeholders > 0 and self.scheduler_config.async_scheduling:
req_index += 1 req_index += 1
continue continue
...@@ -1300,7 +1309,7 @@ class Scheduler(SchedulerInterface): ...@@ -1300,7 +1309,7 @@ class Scheduler(SchedulerInterface):
request, num_new_tokens request, num_new_tokens
) )
if num_new_tokens == 0: if num_new_tokens <= 0:
# The request cannot be scheduled because one of the following # The request cannot be scheduled because one of the following
# reasons: # reasons:
# 1. No new tokens to schedule. This may happen when # 1. No new tokens to schedule. This may happen when
......
...@@ -4204,7 +4204,7 @@ class GPUModelRunner( ...@@ -4204,7 +4204,7 @@ class GPUModelRunner(
with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
# Get draft token ids if available # Get draft token ids if available
output_spec_token_ids = None output_spec_token_ids = None
if self._draft_token_ids is not None: if not self.use_async_scheduling and self._draft_token_ids is not None:
# Use synchronous copy to avoid NPU async stream/event # Use synchronous copy to avoid NPU async stream/event
# synchronization issues. _get_draft_token_ids_cpu relies on # synchronization issues. _get_draft_token_ids_cpu relies on
# event.synchronize() which may not properly wait for the # event.synchronize() which may not properly wait for the
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment