"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "7920e9b1c5e168fe6218d2d147bdb9acf6bc993d"
Unverified Commit 8980001c authored by caozuoba's avatar caozuoba Committed by GitHub
Browse files

[perf] v1/spec_decode: skip softmax for all-greedy rejection sampling (#32852)


Signed-off-by: default avatarhdj <1293066020@qq.com>
parent 527bcd14
...@@ -136,8 +136,6 @@ class RejectionSampler(nn.Module): ...@@ -136,8 +136,6 @@ class RejectionSampler(nn.Module):
metadata.cu_num_draft_tokens, metadata.cu_num_draft_tokens,
sampling_metadata, sampling_metadata,
) )
# Compute probability distribution from target logits.
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
output_token_ids = rejection_sample( output_token_ids = rejection_sample(
metadata.draft_token_ids, metadata.draft_token_ids,
...@@ -145,7 +143,7 @@ class RejectionSampler(nn.Module): ...@@ -145,7 +143,7 @@ class RejectionSampler(nn.Module):
metadata.max_spec_len, metadata.max_spec_len,
metadata.cu_num_draft_tokens, metadata.cu_num_draft_tokens,
draft_probs, draft_probs,
target_probs, target_logits,
bonus_token_ids, bonus_token_ids,
sampling_metadata, sampling_metadata,
) )
...@@ -353,7 +351,7 @@ def rejection_sample( ...@@ -353,7 +351,7 @@ def rejection_sample(
# [num_tokens, vocab_size] # [num_tokens, vocab_size]
draft_probs: torch.Tensor | None, draft_probs: torch.Tensor | None,
# [num_tokens, vocab_size] # [num_tokens, vocab_size]
target_probs: torch.Tensor, target_logits: torch.Tensor,
# [batch_size, 1] # [batch_size, 1]
bonus_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
...@@ -361,17 +359,16 @@ def rejection_sample( ...@@ -361,17 +359,16 @@ def rejection_sample(
assert draft_token_ids.ndim == 1 assert draft_token_ids.ndim == 1
assert draft_probs is None or draft_probs.ndim == 2 assert draft_probs is None or draft_probs.ndim == 2
assert cu_num_draft_tokens.ndim == 1 assert cu_num_draft_tokens.ndim == 1
assert target_probs.ndim == 2 assert target_logits.ndim == 2
batch_size = len(num_draft_tokens) batch_size = len(num_draft_tokens)
num_tokens = draft_token_ids.shape[0] num_tokens = draft_token_ids.shape[0]
vocab_size = target_probs.shape[-1] vocab_size = target_logits.shape[-1]
device = target_probs.device device = target_logits.device
assert draft_token_ids.is_contiguous() assert draft_token_ids.is_contiguous()
assert draft_probs is None or draft_probs.is_contiguous() assert draft_probs is None or draft_probs.is_contiguous()
assert target_probs.is_contiguous()
assert bonus_token_ids.is_contiguous() assert bonus_token_ids.is_contiguous()
assert target_probs.shape == (num_tokens, vocab_size) assert target_logits.shape == (num_tokens, vocab_size)
# Create output buffer. # Create output buffer.
output_token_ids = torch.full( output_token_ids = torch.full(
...@@ -387,7 +384,7 @@ def rejection_sample( ...@@ -387,7 +384,7 @@ def rejection_sample(
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
if not sampling_metadata.all_random: if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests. # Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1) target_argmax = target_logits.argmax(dim=-1)
rejection_greedy_sample_kernel[(batch_size,)]( rejection_greedy_sample_kernel[(batch_size,)](
output_token_ids, output_token_ids,
cu_num_draft_tokens, cu_num_draft_tokens,
...@@ -400,6 +397,10 @@ def rejection_sample( ...@@ -400,6 +397,10 @@ def rejection_sample(
if sampling_metadata.all_greedy: if sampling_metadata.all_greedy:
return output_token_ids return output_token_ids
# Compute probability distribution from target logits.
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
assert target_probs.is_contiguous()
# Generate uniform probabilities for rejection sampling. # Generate uniform probabilities for rejection sampling.
# [num_tokens] # [num_tokens]
uniform_probs = generate_uniform_probs( uniform_probs = generate_uniform_probs(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment