"vscode:/vscode.git/clone" did not exist on "f4b76056ee5c3a3f917527da5be3786e1b8530c6"
Unverified Commit 7a5adad4 authored by Xin Yang's avatar Xin Yang Committed by GitHub
Browse files

[Kernel] Optimize sample_recovered_tokens_kernel (#34974)


Signed-off-by: default avatarXin Yang <xyangx@amazon.com>
parent 59c62332
...@@ -11,7 +11,11 @@ from tests.v1.sample.utils import create_allowed_token_ids ...@@ -11,7 +11,11 @@ from tests.v1.sample.utils import create_allowed_token_ids
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID, RejectionSampler from vllm.v1.sample.rejection_sampler import (
PLACEHOLDER_TOKEN_ID,
RejectionSampler,
sample_recovered_tokens,
)
from vllm.v1.sample.sampler import Sampler, SamplerOutput from vllm.v1.sample.sampler import Sampler, SamplerOutput
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
...@@ -518,6 +522,70 @@ def estimate_rejection_sampling_pdf( ...@@ -518,6 +522,70 @@ def estimate_rejection_sampling_pdf(
return hist.hist return hist.hist
def native_sample_recovered_tokens(
max_spec_len: int,
num_draft_tokens: list[int],
cu_num_draft_tokens: torch.Tensor, # [batch_size]
draft_token_ids: torch.Tensor, # [num_tokens]
draft_probs: torch.Tensor | None, # [num_tokens, vocab_size]
target_probs: torch.Tensor, # [num_tokens, vocab_size]
sampling_metadata: SamplingMetadata,
device: torch.device,
) -> torch.Tensor:
batch_size = len(num_draft_tokens)
vocab_size = target_probs.shape[-1]
q = torch.empty(
(batch_size, vocab_size),
dtype=torch.float32,
device=device,
)
q.exponential_()
states = {
i: generator.get_state()
for i, generator in sampling_metadata.generators.items()
}
for i, generator in sampling_metadata.generators.items():
# Do not generate random numbers for requests with no draft tokens.
# This can be important for reproducibility.
if num_draft_tokens[i] > 0:
q[i].exponential_(generator=generator)
# In order to generate the same exponential later, reset the CUDA RNG
# state because RNG state advances after each call.
generator.set_state(states[i])
inv_q = q.reciprocal()
out = torch.empty_like(draft_token_ids)
for req_idx in range(batch_size):
start_idx = 0 if req_idx == 0 else int(cu_num_draft_tokens[req_idx - 1].item())
end_idx = int(cu_num_draft_tokens[req_idx].item())
num_tokens = end_idx - start_idx
for pos in range(max_spec_len):
if pos >= num_tokens:
continue
token_idx = start_idx + pos
if draft_probs is None:
# prob is target_probs[token_idx] except draft_token_id is zeroed
prob = target_probs[token_idx].clone()
draft_token_id = draft_token_ids[token_idx]
prob[draft_token_id] = 0.0
else:
prob = (target_probs[token_idx] - draft_probs[token_idx]).clamp_min_(
0.0
)
score = prob * inv_q[req_idx]
recovered_id = torch.argmax(score, dim=-1)
out[token_idx] = recovered_id
return out
def _test_masked_logits( def _test_masked_logits(
rejection_sampler, rejection_sampler,
batch_size: int, batch_size: int,
...@@ -778,3 +846,60 @@ def test_allowed_token_ids(rejection_sampler): ...@@ -778,3 +846,60 @@ def test_allowed_token_ids(rejection_sampler):
device=logits.device, device=logits.device,
) )
assert torch.equal(output.sampled_token_ids, expected) assert torch.equal(output.sampled_token_ids, expected)
@pytest.mark.parametrize("batch_size", [1, 100])
@pytest.mark.parametrize("vocab_size", [100, 8192, 10000])
@pytest.mark.parametrize("max_spec_len", [1, 3])
@pytest.mark.parametrize("no_draft_probs", [True, False])
def test_sample_recovered_tokens(
batch_size: int, vocab_size: int, max_spec_len: int, no_draft_probs: bool
):
num_tokens = batch_size * max_spec_len
# Create random draft probabilities.
draft_probs = torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE)
draft_probs = F.softmax(draft_probs, dim=-1)
# Create random target probabilities.
target_logits = torch.rand(
num_tokens, vocab_size, dtype=torch.float32, device=DEVICE
)
target_probs = F.softmax(target_logits, dim=-1)
# Randomly sample draft token ids from draft probs
draft_token_ids = torch.multinomial(draft_probs, num_samples=1).to(torch.int32)
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
generators = {
i: torch.Generator(device=DEVICE).manual_seed(i) for i in range(batch_size)
}
sampling_metadata = create_sampling_metadata(
all_greedy=False, temperature=temperature, generators=generators
)
spec_decode_metadata = create_spec_decode_metadata(
draft_token_ids.reshape(batch_size, max_spec_len).tolist(), target_logits
)
ref_recovered_token_ids = native_sample_recovered_tokens(
max_spec_len,
spec_decode_metadata.num_draft_tokens,
spec_decode_metadata.cu_num_draft_tokens,
draft_token_ids,
None if no_draft_probs else draft_probs,
target_probs,
sampling_metadata,
device=DEVICE,
)
recovered_token_ids = sample_recovered_tokens(
max_spec_len,
spec_decode_metadata.num_draft_tokens,
spec_decode_metadata.cu_num_draft_tokens,
draft_token_ids,
None if no_draft_probs else draft_probs,
target_probs,
sampling_metadata,
device=DEVICE,
)
assert torch.equal(recovered_token_ids, ref_recovered_token_ids)
...@@ -623,16 +623,19 @@ def sample_recovered_tokens( ...@@ -623,16 +623,19 @@ def sample_recovered_tokens(
if num_draft_tokens[i] > 0: if num_draft_tokens[i] > 0:
q[i].exponential_(generator=generator) q[i].exponential_(generator=generator)
inv_q = q.reciprocal()
recovered_token_ids = torch.empty_like(draft_token_ids) recovered_token_ids = torch.empty_like(draft_token_ids)
BLOCK_SIZE = 8192
sample_recovered_tokens_kernel[(batch_size, max_spec_len)]( sample_recovered_tokens_kernel[(batch_size, max_spec_len)](
recovered_token_ids, recovered_token_ids,
cu_num_draft_tokens, cu_num_draft_tokens,
draft_token_ids, draft_token_ids,
draft_probs, draft_probs,
target_probs, target_probs,
q, inv_q,
vocab_size, vocab_size,
triton.next_power_of_2(vocab_size), BLOCK_SIZE,
NO_DRAFT_PROBS=draft_probs is None, NO_DRAFT_PROBS=draft_probs is None,
) )
return recovered_token_ids return recovered_token_ids
...@@ -776,9 +779,9 @@ def sample_recovered_tokens_kernel( ...@@ -776,9 +779,9 @@ def sample_recovered_tokens_kernel(
draft_token_ids_ptr, # [num_tokens] draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size] target_probs_ptr, # [num_tokens, vocab_size]
q_ptr, # [batch_size, vocab_size] inv_q_ptr, # [batch_size, vocab_size]
vocab_size, vocab_size,
PADDED_VOCAB_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr, NO_DRAFT_PROBS: tl.constexpr,
): ):
req_idx = tl.program_id(0) req_idx = tl.program_id(0)
...@@ -791,33 +794,50 @@ def sample_recovered_tokens_kernel( ...@@ -791,33 +794,50 @@ def sample_recovered_tokens_kernel(
if pos >= num_draft_tokens: if pos >= num_draft_tokens:
return return
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) token_idx = start_idx + pos
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + token_idx)
max_val = float("-inf")
recovered_id = 0
for v in range(0, vocab_size, BLOCK_SIZE):
vocab_offset = v + tl.arange(0, BLOCK_SIZE)
vocab_mask = vocab_offset < vocab_size
if NO_DRAFT_PROBS: if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
prob = tl.load( prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, target_probs_ptr + token_idx * vocab_size + vocab_offset,
mask=((vocab_offset < vocab_size) & (vocab_offset != draft_token_id)), mask=(vocab_mask & (vocab_offset != draft_token_id)),
other=0, other=0.0,
) )
else: else:
draft_prob = tl.load( draft_prob = tl.load(
draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, draft_probs_ptr + token_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size, mask=vocab_mask,
other=0, other=0.0,
) )
target_prob = tl.load( target_prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, target_probs_ptr + token_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size, mask=vocab_mask,
other=0, other=0.0,
) )
prob = tl.maximum(target_prob - draft_prob, 0) prob = tl.maximum(target_prob - draft_prob, 0.0)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value. # `tl.argmax` will select the maximum value.
q = tl.load( inv_q = tl.load(
q_ptr + req_idx * vocab_size + vocab_offset, inv_q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size, mask=vocab_mask,
other=float("-inf"), other=0.0,
) )
recovered_id = tl.argmax(prob / q, axis=-1)
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) # Local tile reduction
score = prob * inv_q
local_max, local_id = tl.max(score, axis=0, return_indices=True)
if local_max > max_val:
max_val = local_max
recovered_id = v + local_id
tl.store(output_token_ids_ptr + token_idx, recovered_id)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment