Unverified Commit 507df79a authored by Francesco Fusco's avatar Francesco Fusco Committed by GitHub
Browse files

[Hybrid] Simplify accepted token counting in spec decode for hybrid models (#38372)

parent 1696c864
...@@ -1428,26 +1428,11 @@ class GPUModelRunner( ...@@ -1428,26 +1428,11 @@ class GPUModelRunner(
# TODO: Remove .cpu() sync to enable fully async for hybrid model; # TODO: Remove .cpu() sync to enable fully async for hybrid model;
# Use num_computed_tokens.gpu instead of req.num_computed_tokens to # Use num_computed_tokens.gpu instead of req.num_computed_tokens to
# support aligned mamba cache mode. # support aligned mamba cache mode.
# Find the number of accepted tokens for each sequence. # Count the number of accepted tokens for each sequence.
# Valid tokens are contiguous from position 0, so counting non-(-1)
# tokens gives us the first -1 position (i.e., number of accepted).
num_reqs = output_token_ids.size(0) num_reqs = output_token_ids.size(0)
self.num_accepted_tokens.gpu[:num_reqs] = ( self.num_accepted_tokens.gpu[:num_reqs] = (output_token_ids != -1).sum(dim=1)
(
torch.cat(
[
output_token_ids,
torch.full(
(num_reqs, 1),
-1,
device=output_token_ids.device,
),
],
dim=1,
)
== -1
)
.int()
.argmax(-1)
)
if self.cache_config.mamba_cache_mode == "align": if self.cache_config.mamba_cache_mode == "align":
for i, num_tokens in enumerate( for i, num_tokens in enumerate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment