Unverified Commit 03b5f940 authored by dongbo910220's avatar dongbo910220 Committed by GitHub
Browse files

[V1][Spec Decode] Optimize Medusa proposer to avoid GPU-CPU sync (#29723)


Signed-off-by: default avatardongbo910220 <1275604947@qq.com>
parent 2e7054da
...@@ -38,16 +38,16 @@ class MedusaProposer: ...@@ -38,16 +38,16 @@ class MedusaProposer:
self, self,
target_hidden_states: torch.Tensor, target_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> list[list[int]]: ) -> torch.Tensor:
# Generate blocks and compute logits # Generate blocks and compute logits
blocks = self.model(target_hidden_states) blocks = self.model(target_hidden_states)
logits = self.model.compute_logits(blocks) logits = self.model.compute_logits(blocks)
# Get draft tokens and transpose the result # Compute argmax for each Medusa head and stack into a single tensor
# TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU # Shape: [batch_size, num_heads]
# synchronization. draft_tokens = torch.stack([logit.argmax(dim=-1) for logit in logits], dim=1)
draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits]
return [list(row) for row in zip(*draft_tokens)] return draft_tokens
def load_model(self, target_model: nn.Module) -> None: def load_model(self, target_model: nn.Module) -> None:
from vllm.compilation.backends import set_model_tag from vllm.compilation.backends import set_model_tag
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment