Commit ca35113a authored by 王敏's avatar 王敏
Browse files

[feat]初步实现PP+MTP功能,精度还有问题

parent 0dc059af
...@@ -341,6 +341,16 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts, SupportsPP): ...@@ -341,6 +341,16 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts, SupportsPP):
model_has_indexer = any("indexer" in param_name for param_name in params_dict.keys()) model_has_indexer = any("indexer" in param_name for param_name in params_dict.keys())
for name, loaded_weight in weights: for name, loaded_weight in weights:
if "embed_tokens" in name:
for local_name in params_dict.keys():
if "embed_tokens" in local_name:
param = params_dict[local_name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
break
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
......
...@@ -100,6 +100,8 @@ class Scheduler(SchedulerInterface): ...@@ -100,6 +100,8 @@ class Scheduler(SchedulerInterface):
# Scheduling constraints. # Scheduling constraints.
self.max_num_running_reqs = self.scheduler_config.max_num_seqs self.max_num_running_reqs = self.scheduler_config.max_num_seqs
self.max_num_running_reqs = self.scheduler_config.max_num_seqs * self.vllm_config.parallel_config.pipeline_parallel_size
self.max_num_per_batch = self.scheduler_config.max_num_seqs
self.max_num_scheduled_tokens = self.scheduler_config.max_num_batched_tokens self.max_num_scheduled_tokens = self.scheduler_config.max_num_batched_tokens
self.max_model_len = vllm_config.model_config.max_model_len self.max_model_len = vllm_config.model_config.max_model_len
self.enable_kv_cache_events = ( self.enable_kv_cache_events = (
...@@ -234,6 +236,10 @@ class Scheduler(SchedulerInterface): ...@@ -234,6 +236,10 @@ class Scheduler(SchedulerInterface):
self.use_pp = self.parallel_config.pipeline_parallel_size > 1 self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER
self.is_mtp_kv_consumer = self.vllm_config.speculative_config is not None and \
self.vllm_config.kv_transfer_config is not None \
and self.vllm_config.kv_transfer_config.is_kv_consumer
def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool: def has_mamba_layers(kv_cache_config: KVCacheConfig) -> bool:
return any( return any(
isinstance(group_spec.kv_cache_spec, MambaSpec) isinstance(group_spec.kv_cache_spec, MambaSpec)
...@@ -352,17 +358,21 @@ class Scheduler(SchedulerInterface): ...@@ -352,17 +358,21 @@ class Scheduler(SchedulerInterface):
while req_index < len(self.running) and token_budget > 0: while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index] request = self.running[req_index]
current_batch_size = len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs)
if current_batch_size == self.max_num_per_batch:
break
# do not schedule another step for the same request while it still has # do not schedule another step for the same request while it still has
# output placeholders for PP. # output placeholders for PP.
# TODO: support PP + async scheduling without this limit # TODO: support PP + async scheduling without this limit
if self.use_pp: # if self.use_pp:
if (envs.VLLM_USE_PP_BALANCE and # if (envs.VLLM_USE_PP_BALANCE and
len(scheduled_new_reqs) + len(scheduled_resumed_reqs) # len(scheduled_new_reqs) + len(scheduled_resumed_reqs)
+ len(scheduled_running_reqs) >= max_batch_running): # + len(scheduled_running_reqs) >= max_batch_running):
break # break
if request.num_output_placeholders > 0: # if request.num_output_placeholders > 0:
req_index += 1 # req_index += 1
continue # continue
if ( if (
request.num_output_placeholders > 0 request.num_output_placeholders > 0
...@@ -418,7 +428,7 @@ class Scheduler(SchedulerInterface): ...@@ -418,7 +428,7 @@ class Scheduler(SchedulerInterface):
request, num_new_tokens request, num_new_tokens
) )
if num_new_tokens == 0: if num_new_tokens <= 0:
# The request cannot be scheduled because one of the following # The request cannot be scheduled because one of the following
# reasons: # reasons:
# 1. No new tokens to schedule. This may happen when # 1. No new tokens to schedule. This may happen when
...@@ -549,7 +559,9 @@ class Scheduler(SchedulerInterface): ...@@ -549,7 +559,9 @@ class Scheduler(SchedulerInterface):
# Next, schedule the WAITING requests. # Next, schedule the WAITING requests.
if not preempted_reqs: if not preempted_reqs:
while self.waiting and token_budget > 0: while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_running_reqs: #if len(self.running) == self.max_num_running_reqs:
current_batch_size = len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(scheduled_running_reqs)
if len(self.running) == self.max_num_running_reqs or current_batch_size == self.max_num_per_batch:
break break
if (self.use_pp and envs.VLLM_USE_PP_BALANCE and if (self.use_pp and envs.VLLM_USE_PP_BALANCE and
len(scheduled_new_reqs) + len(scheduled_resumed_reqs) len(scheduled_new_reqs) + len(scheduled_resumed_reqs)
...@@ -667,6 +679,11 @@ class Scheduler(SchedulerInterface): ...@@ -667,6 +679,11 @@ class Scheduler(SchedulerInterface):
# We use `request.num_tokens` instead of # We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed # `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens. # requests, which have output tokens.
#num_new_tokens = request.num_tokens - num_computed_tokens
if self.is_mtp_kv_consumer:
num_new_tokens = (request.num_tokens_with_spec -
num_computed_tokens)
else:
num_new_tokens = request.num_tokens - num_computed_tokens num_new_tokens = request.num_tokens - num_computed_tokens
threshold = self.scheduler_config.long_prefill_token_threshold threshold = self.scheduler_config.long_prefill_token_threshold
if 0 < threshold < num_new_tokens: if 0 < threshold < num_new_tokens:
...@@ -771,6 +788,20 @@ class Scheduler(SchedulerInterface): ...@@ -771,6 +788,20 @@ class Scheduler(SchedulerInterface):
self._update_connector_prefix_cache_stats(request) self._update_connector_prefix_cache_stats(request)
# Speculative decode related.
if (self.is_mtp_kv_consumer or not self.vllm_config.kv_transfer_config) and request.spec_token_ids:
num_scheduled_spec_tokens = (num_new_tokens +
num_computed_tokens -
request.num_tokens)
if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids)
else:
# Prefill request: spec tokens not applicable yet.
request.spec_token_ids = []
self.running.append(request) self.running.append(request)
if self.log_stats: if self.log_stats:
request.record_event( request.record_event(
...@@ -1617,7 +1648,11 @@ class Scheduler(SchedulerInterface): ...@@ -1617,7 +1648,11 @@ class Scheduler(SchedulerInterface):
for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)): for idx, req in enumerate(itertools.chain(running_reqs, resumed_reqs)):
req_id = req.request_id req_id = req.request_id
req_ids.append(req_id) req_ids.append(req_id)
if self.use_pp: #if self.use_pp:
# NOTE: In PP+async scheduling, we consume token ids via a direct GPU
# broadcast path (`input_batch.prev_sampled_token_ids`), so we can
# omit this payload.
if self.use_pp and not self.scheduler_config.async_scheduling:
# When using PP, the scheduler sends the sampled tokens back, # When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first- # because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't # stage worker and the last-stage worker. Otherwise, we don't
...@@ -1842,6 +1877,7 @@ class Scheduler(SchedulerInterface): ...@@ -1842,6 +1877,7 @@ class Scheduler(SchedulerInterface):
model_runner_output: ModelRunnerOutput, model_runner_output: ModelRunnerOutput,
) -> dict[int, EngineCoreOutputs]: ) -> dict[int, EngineCoreOutputs]:
sampled_token_ids = model_runner_output.sampled_token_ids sampled_token_ids = model_runner_output.sampled_token_ids
spec_token_ids = model_runner_output.spec_token_ids
logprobs = model_runner_output.logprobs logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens num_scheduled_tokens = scheduler_output.num_scheduled_tokens
...@@ -1978,6 +2014,26 @@ class Scheduler(SchedulerInterface): ...@@ -1978,6 +2014,26 @@ class Scheduler(SchedulerInterface):
if num_nans_in_logits is not None and req_id in num_nans_in_logits: if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id] request.num_nans_in_logits = num_nans_in_logits[req_id]
# NOTE: Use `new_token_ids` (from this output) instead of
# `request.is_prefill_chunk` (from current schedule step) to
# decide whether this was a decode step. In batch_queue mode
# (PP>1), update_from_output processes output from step T-N,
# but is_prefill_chunk reflects step T's state — using it
# causes stale spec_token_ids to be set on prefill chunks.
if spec_token_ids:
if not new_token_ids:
# Non-final prefill chunk: no tokens generated,
# clear any stale spec_token_ids.
if request.spec_token_ids:
request.spec_token_ids = []
else:
if self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
request.spec_token_ids = metadata.grammar.validate_tokens(
spec_token_ids[req_index])
else:
request.spec_token_ids = spec_token_ids[req_index]
# Get prompt logprobs for this request. # Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if ( if (
......
...@@ -1002,7 +1002,7 @@ class EngineCoreProc(EngineCore): ...@@ -1002,7 +1002,7 @@ class EngineCoreProc(EngineCore):
for output in outputs.items() if outputs else (): for output in outputs.items() if outputs else ():
self.output_queue.put_nowait(output) self.output_queue.put_nowait(output)
# Post-step hook. # Post-step hook.
self.post_step(model_executed) #self.post_step(model_executed)
# If no model execution happened but there are waiting requests # If no model execution happened but there are waiting requests
# (e.g., WAITING_FOR_REMOTE_KVS), yield the GIL briefly to allow # (e.g., WAITING_FOR_REMOTE_KVS), yield the GIL briefly to allow
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, NamedTuple, TypeAlias from typing import TYPE_CHECKING, NamedTuple, TypeAlias, Optional
import numpy as np import numpy as np
import torch import torch
...@@ -161,6 +161,9 @@ class ModelRunnerOutput: ...@@ -161,6 +161,9 @@ class ModelRunnerOutput:
# each request due to speculative/jump decoding. # each request due to speculative/jump decoding.
sampled_token_ids: list[list[int]] = field(default_factory=list) sampled_token_ids: list[list[int]] = field(default_factory=list)
# num_reqs x num_spec_tokens
spec_token_ids: Optional[list[list[int]]] = None
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
# [num_reqs] # [num_reqs]
...@@ -244,8 +247,9 @@ def make_empty_encoder_model_runner_output( ...@@ -244,8 +247,9 @@ def make_empty_encoder_model_runner_output(
req_ids=req_ids, req_ids=req_ids,
req_id_to_index=req_id_to_index, req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
spec_token_ids=None,
pooler_output=pooler_output, pooler_output=pooler_output,
) )
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], req_id_to_index={}) EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], req_id_to_index={}, spec_token_ids=None)
...@@ -931,6 +931,8 @@ class GPUModelRunner( ...@@ -931,6 +931,8 @@ class GPUModelRunner(
The SamplingMetadata is updated and copied to the GPU if there is a The SamplingMetadata is updated and copied to the GPU if there is a
new/resumed/paused/finished request in the batch. new/resumed/paused/finished request in the batch.
""" """
if scheduler_output.total_num_scheduled_tokens == 0:
return
# Remove finished requests from the cached states. # Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids: for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None) self.requests.pop(req_id, None)
...@@ -975,6 +977,14 @@ class GPUModelRunner( ...@@ -975,6 +977,14 @@ class GPUModelRunner(
self.input_batch.remove_request(req_id) self.input_batch.remove_request(req_id)
reqs_to_add: list[CachedRequestState] = [] reqs_to_add: list[CachedRequestState] = []
# Track re-added requests on non-last ranks that need token_ids_cpu
# fix-up after add_request. On non-last ranks, output_token_ids
# does NOT include accepted draft tokens, so add_request() places
# tokens at wrong positions. We save (new_token_ids, num_computed)
# here and fix up token_ids_cpu right after add_request.
fix_tokens_map: dict[str, tuple[list[int], int]] = {}
# Add new requests to the cached states. # Add new requests to the cached states.
for new_req_data in scheduler_output.scheduled_new_reqs: for new_req_data in scheduler_output.scheduled_new_reqs:
req_id = new_req_data.req_id req_id = new_req_data.req_id
...@@ -1082,9 +1092,13 @@ class GPUModelRunner( ...@@ -1082,9 +1092,13 @@ class GPUModelRunner(
req_state.num_computed_tokens = num_computed_tokens req_state.num_computed_tokens = num_computed_tokens
if not is_last_rank: if not is_last_rank:
# When using PP, the scheduler sends the sampled tokens back, if not req_data.new_token_ids:
# because there's no direct communication between the first- # Async scheduled PP: Sampled tokens propagated via GPU broadcast.
# stage worker and the last-stage worker. new_token_ids: list[int] = []
else:
# Non-async scheduling with PP: The scheduler sends
# sampled token ids back because there's no direct communication
# between the first-stage worker and the last-stage worker.
new_token_ids = req_data.new_token_ids[i] new_token_ids = req_data.new_token_ids[i]
# Add the sampled token(s) from the previous step (if any). # Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec tokens. # This doesn't include "unverified" tokens like spec tokens.
...@@ -1095,7 +1109,9 @@ class GPUModelRunner( ...@@ -1095,7 +1109,9 @@ class GPUModelRunner(
# Avoid slicing list in most common case. # Avoid slicing list in most common case.
req_state.output_token_ids.append(new_token_ids[-1]) req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0: elif num_new_tokens > 0:
req_state.output_token_ids.extend(new_token_ids[-num_new_tokens:]) req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:]
)
elif num_output_tokens < len(req_state.output_token_ids): elif num_output_tokens < len(req_state.output_token_ids):
# Some output tokens were discarded due to a sync-KV-load # Some output tokens were discarded due to a sync-KV-load
# failure. Align the cached state. # failure. Align the cached state.
...@@ -1131,6 +1147,13 @@ class GPUModelRunner( ...@@ -1131,6 +1147,13 @@ class GPUModelRunner(
resumed_token_ids = req_data.all_token_ids[req_id] resumed_token_ids = req_data.all_token_ids[req_id]
req_state.output_token_ids = resumed_token_ids[-num_output_tokens:] req_state.output_token_ids = resumed_token_ids[-num_output_tokens:]
# On non-last ranks with PP + spec decode, output_token_ids
# doesn't include accepted draft tokens. Save the fix-up
# data so we can correct token_ids_cpu after add_request.
if not is_last_rank and new_token_ids:
fix_tokens_map[req_id] = (
list(new_token_ids), num_computed_tokens)
reqs_to_add.append(req_state) reqs_to_add.append(req_state)
continue continue
...@@ -1157,7 +1180,26 @@ class GPUModelRunner( ...@@ -1157,7 +1180,26 @@ class GPUModelRunner(
# The smaller empty indices are filled first. # The smaller empty indices are filled first.
for request in reqs_to_add: for request in reqs_to_add:
self.input_batch.add_request(request) self.input_batch.add_request(request)
self.input_batch.update_req_spec_token_ids(request, scheduled_spec_tokens) req_id = request.req_id
req_index = self.input_batch.req_id_to_index[req_id]
# Fix token_ids_cpu for re-added requests on non-last PP ranks.
# add_request() copies output_token_ids to token_ids_cpu, but on
# non-last ranks output_token_ids does NOT include accepted draft
# tokens, causing tokens to land at wrong positions. Overwrite
# the new tokens at the correct position (num_computed_tokens)
# and adjust num_tokens_no_spec before placing spec tokens.
fix_data = fix_tokens_map.get(req_id)
if fix_data is not None:
new_toks, n_computed = fix_data
start = n_computed
end = start + len(new_toks)
self.input_batch.token_ids_cpu[req_index, start:end] = new_toks
self.input_batch.num_tokens_no_spec[req_index] = end
# Place spec tokens at the (now-correct) num_tokens_no_spec offset.
self.input_batch.update_req_spec_token_ids(
request, scheduled_spec_tokens)
# Condense the batched states if there are gaps left by removed requests # Condense the batched states if there are gaps left by removed requests
self.input_batch.condense() self.input_batch.condense()
...@@ -4028,7 +4070,9 @@ class GPUModelRunner( ...@@ -4028,7 +4070,9 @@ class GPUModelRunner(
self.kv_connector_output = None self.kv_connector_output = None
if self.execute_model_state is None: if self.execute_model_state is None:
# Nothing to do (PP non-final rank case), output isn't used. # receive sampled token ids from the last PP rank.
if self.use_async_scheduling and get_pp_group().world_size > 1:
self._pp_receive_prev_sampled_token_ids_to_input_batch()
if not kv_connector_output: if not kv_connector_output:
return None # type: ignore[return-value] return None # type: ignore[return-value]
...@@ -4070,6 +4114,13 @@ class GPUModelRunner( ...@@ -4070,6 +4114,13 @@ class GPUModelRunner(
sampler_output.sampled_token_ids, scheduler_output sampler_output.sampled_token_ids, scheduler_output
) )
if self.use_async_scheduling:
pp = get_pp_group()
if pp.world_size > 1 and pp.is_last_rank:
self._pp_broadcast_prev_sampled_token_ids(
sampler_output.sampled_token_ids
)
self._draft_token_ids = None self._draft_token_ids = None
self._draft_token_req_ids = None self._draft_token_req_ids = None
self.input_batch.prev_sampled_token_ids = None self.input_batch.prev_sampled_token_ids = None
...@@ -4160,6 +4211,28 @@ class GPUModelRunner( ...@@ -4160,6 +4211,28 @@ class GPUModelRunner(
self.eplb_step() self.eplb_step()
with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"): with record_function_or_nullcontext("gpu_model_runner: ModelRunnerOutput"):
# Get draft token ids if available
output_spec_token_ids = None
if self._draft_token_ids is not None:
# Use synchronous copy to avoid NPU async stream/event
# synchronization issues. _get_draft_token_ids_cpu relies on
# event.synchronize() which may not properly wait for the
# async copy on NPU, resulting in stale data.
if torch.is_tensor(self._draft_token_ids):
num_reqs = self._draft_token_ids.shape[0]
draft_ids_list = self._draft_token_ids[:num_reqs].cpu().tolist()
draft_req_ids = self._draft_token_req_ids
else:
draft_ids_list = self._draft_token_ids
draft_req_ids = self.input_batch.req_ids
if draft_ids_list and draft_req_ids:
draft_by_req_id = dict(
zip(draft_req_ids, draft_ids_list))
output_spec_token_ids = [
draft_by_req_id.get(req_id, [])
for req_id in req_ids_output_copy
]
if self.model_config.enable_return_routed_experts: if self.model_config.enable_return_routed_experts:
capturer = RoutedExpertsCapturer.get_instance() capturer = RoutedExpertsCapturer.get_instance()
if capturer is not None: if capturer is not None:
...@@ -4171,6 +4244,7 @@ class GPUModelRunner( ...@@ -4171,6 +4244,7 @@ class GPUModelRunner(
req_ids=req_ids_output_copy, req_ids=req_ids_output_copy,
req_id_to_index=req_id_to_index_output_copy, req_id_to_index=req_id_to_index_output_copy,
sampled_token_ids=valid_sampled_token_ids, sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=output_spec_token_ids,
logprobs=logprobs_lists, logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict, prompt_logprobs_dict=prompt_logprobs_dict,
kv_connector_output=kv_connector_output, kv_connector_output=kv_connector_output,
...@@ -4207,6 +4281,45 @@ class GPUModelRunner( ...@@ -4207,6 +4281,45 @@ class GPUModelRunner(
return async_output return async_output
def _pp_broadcast_prev_sampled_token_ids(
self, sampled_token_ids: torch.Tensor
) -> None:
"""Broadcast sampled token ids (GPU) from last PP stage"""
pp = get_pp_group()
assert pp.is_last_rank
# `prev_sampled_token_ids` is expected to have shape [num_reqs, 1].
assert sampled_token_ids.dim() == 2 and sampled_token_ids.shape[-1] == 1, (
"PP+async expects sampled_token_ids to have shape [num_reqs, 1]"
)
torch.distributed.broadcast(
sampled_token_ids, src=pp.rank, group=pp.device_group
)
def _pp_receive_prev_sampled_token_ids_to_input_batch(self) -> None:
"""Receive sampled token ids broadcast from last PP stage"""
pp = get_pp_group()
assert not pp.is_last_rank
num_reqs = self.input_batch.num_reqs
# `prev_sampled_token_ids` is expected to have shape [num_reqs, 1].
recv = torch.empty((num_reqs, 1), dtype=torch.int32, device=self.device)
torch.distributed.broadcast(recv, src=pp.last_rank, group=pp.device_group)
self.input_batch.prev_sampled_token_ids = recv
# construct `prev_req_id_to_index` here so `_prepare_input_ids`
# can map req_id -> previous batch row
discard_req_indices = np.nonzero(self.discard_request_mask.np[:num_reqs])[0]
discard_req_indices_set = set(discard_req_indices)
prev_req_id_to_index: dict[str, int] = {}
for i, req_id in enumerate(self.input_batch.req_ids):
if i in discard_req_indices_set:
continue
prev_req_id_to_index[req_id] = i
# PP+async scheduling: advance per-request local cached output length by
# appending a placeholder (-1) token id.
if (req_state := self.requests.get(req_id)) is not None:
req_state.output_token_ids.append(-1)
self.input_batch.prev_req_id_to_index = prev_req_id_to_index
def take_draft_token_ids(self) -> DraftTokenIds | None: def take_draft_token_ids(self) -> DraftTokenIds | None:
if not self.num_spec_tokens or not self._draft_token_req_ids: if not self.num_spec_tokens or not self._draft_token_req_ids:
return None return None
...@@ -5889,7 +6002,8 @@ class GPUModelRunner( ...@@ -5889,7 +6002,8 @@ class GPUModelRunner(
) )
# Initialize eagle's cudagraph dispatcher if using eagle spec decode. # Initialize eagle's cudagraph dispatcher if using eagle spec decode.
if self.speculative_config and self.speculative_config.use_eagle() and hasattr(self, "drafter"): if self.speculative_config and self.speculative_config.use_eagle() and hasattr(self, "drafter") \
and get_pp_group().is_last_rank:
assert isinstance(self.drafter, EagleProposer) assert isinstance(self.drafter, EagleProposer)
self.drafter.initialize_cudagraph_keys(cudagraph_mode) self.drafter.initialize_cudagraph_keys(cudagraph_mode)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment