Unverified Commit 2bcf71b9 authored by qizixi's avatar qizixi Committed by GitHub
Browse files

[Spec Decode] Reduce TP communication for speculative decoding draft token generation (#34049)


Signed-off-by: default avatarqizixi <qizixi@meta.com>
Co-authored-by: default avatarLu Fang <30275821+houseroad@users.noreply.github.com>
parent b7892a3b
...@@ -109,6 +109,11 @@ class SpeculativeConfig: ...@@ -109,6 +109,11 @@ class SpeculativeConfig:
speculative input batches can contain sequences of different lengths, speculative input batches can contain sequences of different lengths,
which may only be supported by certain attention backends. This currently which may only be supported by certain attention backends. This currently
only affects the EAGLE method of speculation.""" only affects the EAGLE method of speculation."""
use_local_argmax_reduction: bool = False
"""Use vocab-parallel local argmax instead of all-gathering full logits
for draft token generation. Reduces communication from O(vocab_size) to
O(2 * tp_size) per token. Only applies to greedy draft selection in
non-tree speculation."""
# Ngram proposer configuration # Ngram proposer configuration
prompt_lookup_max: int | None = Field(default=None, ge=1) prompt_lookup_max: int | None = Field(default=None, ge=1)
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import torch import torch
from vllm.distributed import ( from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_gather, tensor_model_parallel_gather,
) )
...@@ -102,6 +103,58 @@ class LogitsProcessor(CustomOp): ...@@ -102,6 +103,58 @@ class LogitsProcessor(CustomOp):
logits = logits[..., : self.org_vocab_size] logits = logits[..., : self.org_vocab_size]
return logits return logits
def get_top_tokens(
self,
lm_head: VocabParallelEmbedding,
hidden_states: torch.Tensor,
embedding_bias: torch.Tensor | None = None,
) -> torch.Tensor:
"""Vocab-parallel argmax without all-gathering full logits.
Each TP rank computes local argmax, then only the (value, index) pairs
are gathered and reduced. Communication: O(batch * 2 * tp_size) vs
O(batch * vocab_size).
"""
if self.scale <= 0.0 and self.scale != 1.0:
raise ValueError(
"The local argmax reduction optimization is not supported for "
"non-positive logit scaling factors."
)
tp_size = get_tensor_model_parallel_world_size()
logits = lm_head.quant_method.apply(lm_head, hidden_states, bias=embedding_bias)
if self.soft_cap is not None:
logits = torch.tanh(logits / self.soft_cap) * self.soft_cap
if self.scale != 1.0:
logits = logits * self.scale
# Mask out padding entries beyond org_vocab_size on this shard.
num_pad = lm_head.shard_indices.num_org_vocab_padding
if num_pad > 0:
logits[..., -num_pad:] = -float("inf")
local_max_vals, local_max_indices = logits.max(dim=-1)
# Convert shard-local indices to global vocab indices.
vocab_start = lm_head.shard_indices.org_vocab_start_index
global_indices = local_max_indices + vocab_start
if tp_size == 1:
return global_indices
# All-gather (value, index) pairs, then reduce to global argmax.
# Use float32 to avoid bf16 precision loss on large vocab indices.
local_pair = torch.stack(
[local_max_vals.float(), global_indices.float()], dim=-1
)
# [batch, 2] -> [batch, 2 * tp_size]
gathered = tensor_model_parallel_all_gather(local_pair, dim=-1)
# [batch, tp_size, 2] where [:, :, 0]=values, [:, :, 1]=indices
gathered = gathered.view(hidden_states.shape[0], tp_size, 2)
max_rank_idx = gathered[:, :, 0].argmax(dim=-1, keepdim=True)
top_tokens = gathered[:, :, 1].gather(dim=-1, index=max_rank_idx)
return top_tokens.squeeze(-1).to(torch.int64)
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"vocab_size={self.vocab_size}" s = f"vocab_size={self.vocab_size}"
s += f", org_vocab_size={self.org_vocab_size}" s += f", org_vocab_size={self.org_vocab_size}"
......
...@@ -208,6 +208,23 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM): ...@@ -208,6 +208,23 @@ class EagleLlama4ForCausalLM(Llama4ForCausalLM):
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
return self.model(input_ids, positions, hidden_states, inputs_embeds) return self.model(input_ids, positions, hidden_states, inputs_embeds)
def get_top_tokens(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
"""Vocab-parallel argmax without all-gathering full logits.
Falls back to full logits when draft_id_to_target_id remapping is
active, since the shared lm_head covers the full target vocab but
the draft model only predicts over a subset (draft_vocab_size).
"""
if (
hasattr(self, "draft_id_to_target_id")
and self.draft_id_to_target_id is not None
):
return self.compute_logits(hidden_states).argmax(dim=-1)
return self.logits_processor.get_top_tokens(self.lm_head, hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
def transform(inputs): def transform(inputs):
name, loaded_weight = inputs name, loaded_weight = inputs
......
...@@ -99,6 +99,9 @@ class SpecDecodeBaseProposer: ...@@ -99,6 +99,9 @@ class SpecDecodeBaseProposer:
self.parallel_drafting_hidden_state_tensor: torch.Tensor | None = None self.parallel_drafting_hidden_state_tensor: torch.Tensor | None = None
if self.parallel_drafting: if self.parallel_drafting:
self._init_parallel_drafting_params() self._init_parallel_drafting_params()
self.use_local_argmax_reduction: bool = (
self.speculative_config.use_local_argmax_reduction
)
max_batch_size = vllm_config.scheduler_config.max_num_seqs max_batch_size = vllm_config.scheduler_config.max_num_seqs
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
...@@ -369,6 +372,12 @@ class SpecDecodeBaseProposer: ...@@ -369,6 +372,12 @@ class SpecDecodeBaseProposer:
self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode) self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)
def _greedy_sample(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Greedy-sample draft tokens from hidden states."""
if self.use_local_argmax_reduction:
return self.model.get_top_tokens(hidden_states)
return self.model.compute_logits(hidden_states).argmax(dim=-1)
def propose( def propose(
self, self,
# [num_tokens] # [num_tokens]
...@@ -491,11 +500,10 @@ class SpecDecodeBaseProposer: ...@@ -491,11 +500,10 @@ class SpecDecodeBaseProposer:
last_hidden_states, hidden_states = ret_hidden_states last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[token_indices_to_sample] sample_hidden_states = last_hidden_states[token_indices_to_sample]
logits = self.model.compute_logits(sample_hidden_states)
# Early exit if there is only one draft token to be generated. # Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1 or self.parallel_drafting: if self.num_speculative_tokens == 1 or self.parallel_drafting:
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = self._greedy_sample(sample_hidden_states)
return draft_token_ids.view(-1, self.num_speculative_tokens) return draft_token_ids.view(-1, self.num_speculative_tokens)
if self.uses_mrope: if self.uses_mrope:
...@@ -513,7 +521,8 @@ class SpecDecodeBaseProposer: ...@@ -513,7 +521,8 @@ class SpecDecodeBaseProposer:
hidden_states = hidden_states[token_indices_to_sample] hidden_states = hidden_states[token_indices_to_sample]
if isinstance(attn_metadata, TreeAttentionMetadata): if isinstance(attn_metadata, TreeAttentionMetadata):
# Draft using tree attention. # Draft using tree attention - requires full logits for top-k
logits = self.model.compute_logits(sample_hidden_states)
draft_token_ids_list = self.propose_tree( draft_token_ids_list = self.propose_tree(
batch_size=batch_size, batch_size=batch_size,
logits=logits, logits=logits,
...@@ -525,7 +534,7 @@ class SpecDecodeBaseProposer: ...@@ -525,7 +534,7 @@ class SpecDecodeBaseProposer:
# [batch_size, num_tree_tokens] # [batch_size, num_tree_tokens]
return torch.cat(draft_token_ids_list, dim=1) return torch.cat(draft_token_ids_list, dim=1)
draft_token_ids = logits.argmax(dim=-1) draft_token_ids = self._greedy_sample(sample_hidden_states)
if self.allowed_attn_types is not None and not isinstance( if self.allowed_attn_types is not None and not isinstance(
attn_metadata, self.allowed_attn_types attn_metadata, self.allowed_attn_types
...@@ -690,8 +699,7 @@ class SpecDecodeBaseProposer: ...@@ -690,8 +699,7 @@ class SpecDecodeBaseProposer:
last_hidden_states, hidden_states = ret_hidden_states last_hidden_states, hidden_states = ret_hidden_states
hidden_states = hidden_states[:batch_size] hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size]) draft_token_ids = self._greedy_sample(last_hidden_states[:batch_size])
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids) draft_token_ids_list.append(draft_token_ids)
# [batch_size, num_speculative_tokens] # [batch_size, num_speculative_tokens]
...@@ -1521,6 +1529,31 @@ class SpecDecodeBaseProposer: ...@@ -1521,6 +1529,31 @@ class SpecDecodeBaseProposer:
"Shared target model lm_head with MTP shared_head.head." "Shared target model lm_head with MTP shared_head.head."
) )
if self.use_local_argmax_reduction:
if not hasattr(self.model, "get_top_tokens"):
raise ValueError(
"use_local_argmax_reduction is enabled but draft model "
f"{self.model.__class__.__name__} does not implement "
"get_top_tokens()."
)
# Warn if draft model has vocab remapping, which forces fallback
# to the full-logits path (negating the optimization).
if (
hasattr(self.model, "draft_id_to_target_id")
and self.model.draft_id_to_target_id is not None
):
logger.warning(
"use_local_argmax_reduction is enabled but draft model "
"uses draft_id_to_target_id vocab remapping. The "
"optimization will be bypassed (falling back to full "
"logits gather + argmax)."
)
else:
logger.info(
"Using local argmax reduction for draft token generation "
"(communication: O(2*tp_size) vs O(vocab_size))."
)
@torch.inference_mode() @torch.inference_mode()
def dummy_run( def dummy_run(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment