Unverified Commit 04244fd0 authored by Giancarlo Delfin's avatar Giancarlo Delfin Committed by GitHub
Browse files

[Model Runner V2] Spec decode rejection sampler greedy support (#37238)


Signed-off-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
parent 9482b0b0
...@@ -821,9 +821,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -821,9 +821,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logits, logits,
input_batch, input_batch,
# Draft logits are needed for probabilistic rejection sampling. # Draft logits are needed for probabilistic rejection sampling.
self.req_states.draft_logits[input_batch.idx_mapping] self.req_states.draft_logits,
if self.req_states.draft_logits is not None
else None,
) )
# Get the number of sampled and rejected tokens. # Get the number of sampled and rejected tokens.
......
...@@ -68,55 +68,158 @@ def strict_rejection_sample( ...@@ -68,55 +68,158 @@ def strict_rejection_sample(
@triton.jit @triton.jit
def _probabilistic_rejection_sample_kernel( def _gather_draft_logits_and_target_argmax_kernel(
local_target_argmax_ptr,
local_target_argmax_stride,
local_target_max_ptr,
local_target_max_stride,
# [num_logits, V]
out_draft_logits_ptr,
out_draft_logits_stride,
# [num_logits, V]
target_logits_ptr,
target_logits_stride,
# [max_num_reqs, num_speculative_steps, V]
draft_logits_ptr,
draft_logits_stride_0,
draft_logits_stride_1,
# [num_logits]
expanded_idx_mapping_ptr,
# [num_logits]
expanded_local_pos_ptr,
# [max_num_reqs]
temp_ptr,
vocab_size,
num_speculative_steps,
BLOCK_SIZE: tl.constexpr,
):
logit_idx = tl.program_id(0)
req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx)
draft_step_idx = tl.load(expanded_local_pos_ptr + logit_idx)
block_idx = tl.program_id(1)
block_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block_offsets < vocab_size
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
if temp == 0.0:
# Greedy sampling. Get the target logits argmax.
target_logits = tl.load(
target_logits_ptr + logit_idx * target_logits_stride + block_offsets,
mask=mask,
other=float("-inf"),
).to(tl.float32)
value, idx = tl.max(target_logits, axis=0, return_indices=True)
token_id = block_idx * BLOCK_SIZE + idx
tl.store(
local_target_argmax_ptr
+ logit_idx * local_target_argmax_stride
+ block_idx,
token_id,
)
tl.store(
local_target_max_ptr + logit_idx * local_target_max_stride + block_idx,
value,
)
elif draft_step_idx < num_speculative_steps:
draft_logits = tl.load(
draft_logits_ptr
+ req_state_idx * draft_logits_stride_0
+ draft_step_idx * draft_logits_stride_1
+ block_offsets,
mask=mask,
other=float("-inf"),
).to(tl.float32)
tl.store(
out_draft_logits_ptr + logit_idx * out_draft_logits_stride + block_offsets,
draft_logits,
mask=mask,
)
@triton.jit
def _probabilistic_rejection_kernel(
# [num_reqs, num_speculative_steps + 1] # [num_reqs, num_speculative_steps + 1]
sampled_ptr, sampled_ptr,
sampled_stride, sampled_stride,
# [num_reqs] # [num_reqs]
rejected_steps_ptr, rejected_steps_ptr,
# [num_reqs]
rejected_pos_ptr,
# [num_logits] # [num_logits]
draft_sampled_ptr, draft_sampled_ptr,
# [num_logits, V] # [num_logits, V]
target_probs_ptr, target_probs_ptr,
target_probs_stride, target_probs_stride,
# [num_reqs, num_speculative_steps, V] # [num_logits, V]
draft_probs_ptr, draft_probs_ptr,
draft_probs_stride_0, draft_probs_stride,
draft_probs_stride_1, # [num_logits, num_blocks]
local_target_argmax_ptr,
local_target_argmax_stride,
# [num_logits, num_blocks]
local_target_max_ptr,
local_target_max_stride,
# [num_reqs + 1] # [num_reqs + 1]
cu_num_logits_ptr, cu_num_logits_ptr,
# [num_logits] # [num_logits]
pos_ptr, pos_ptr,
# [num_reqs] # [num_reqs]
idx_mapping_ptr, idx_mapping_ptr,
# [num_reqs] # [max_num_reqs]
temp_ptr,
# [max_num_reqs]
seeds_ptr, seeds_ptr,
NUM_BLOCKS: tl.constexpr,
PADDED_NUM_BLOCKS: tl.constexpr,
): ):
req_idx = tl.program_id(0) req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx) start_idx = tl.load(cu_num_logits_ptr + req_idx)
num_tokens = tl.load(cu_num_logits_ptr + req_idx + 1) - start_idx num_tokens = tl.load(cu_num_logits_ptr + req_idx + 1) - start_idx
seed = tl.load(seeds_ptr + tl.load(idx_mapping_ptr + req_idx)) req_state_idx = tl.load(idx_mapping_ptr + req_idx)
seed = tl.load(seeds_ptr + req_state_idx)
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
rejected_step = 0 rejected_step = 0
accepted = True accepted = True
for i in range(num_tokens - 1): for i in range(num_tokens - 1):
if accepted: if accepted:
draft_sampled = tl.load(draft_sampled_ptr + start_idx + i + 1) logit_idx = start_idx + i
target_prob = tl.load( draft_sampled = tl.load(draft_sampled_ptr + logit_idx + 1)
target_probs_ptr + (start_idx + i) * target_probs_stride + draft_sampled if temp == 0.0:
) # Greedy sampling. Only accept the sampled draft token if
draft_prob = tl.load( # it exactly matches the target argmax.
draft_probs_ptr block_offsets = tl.arange(0, PADDED_NUM_BLOCKS)
+ req_idx * draft_probs_stride_0 block_mask = block_offsets < NUM_BLOCKS
+ i * draft_probs_stride_1 local_max = tl.load(
+ draft_sampled local_target_max_ptr
) + logit_idx * local_target_max_stride
pos = tl.load(pos_ptr + start_idx + i) + block_offsets,
u = tl.sum(tl.rand(seed, pos + tl.arange(0, 1))) mask=block_mask,
accepted &= target_prob > u * draft_prob other=float("-inf"),
)
max_block = tl.argmax(local_max, axis=0)
target_argmax = tl.load(
local_target_argmax_ptr
+ logit_idx * local_target_argmax_stride
+ max_block
)
accepted &= target_argmax == draft_sampled
else:
target_prob = tl.load(
target_probs_ptr + logit_idx * target_probs_stride + draft_sampled
)
draft_prob = tl.load(
draft_probs_ptr + logit_idx * draft_probs_stride + draft_sampled
)
pos = tl.load(pos_ptr + logit_idx)
u = tl.sum(tl.rand(seed, pos + tl.arange(0, 1)))
accepted &= target_prob > u * draft_prob
tl.store(sampled_ptr + req_idx * sampled_stride + i, draft_sampled) tl.store(sampled_ptr + req_idx * sampled_stride + i, draft_sampled)
rejected_step += accepted rejected_step += accepted
tl.store(rejected_steps_ptr + req_idx, rejected_step) tl.store(rejected_steps_ptr + req_idx, rejected_step)
pos_val = tl.load(pos_ptr + start_idx + rejected_step)
tl.store(rejected_pos_ptr + req_idx, pos_val)
@triton.jit @triton.jit
...@@ -124,63 +227,60 @@ def _compute_residual_logits_kernel( ...@@ -124,63 +227,60 @@ def _compute_residual_logits_kernel(
# [num_reqs, V] # [num_reqs, V]
residual_logits_ptr, residual_logits_ptr,
residual_logits_stride, residual_logits_stride,
# [num_reqs]
residual_pos_ptr,
# [num_logits, V]
target_logits_ptr,
target_logits_stride,
# [num_logits, V] # [num_logits, V]
target_probs_ptr, target_probs_ptr,
target_probs_stride, target_probs_stride,
# [num_reqs, num_speculative_steps, V] # [num_logits, V]
draft_probs_ptr, draft_probs_ptr,
draft_probs_stride_0, draft_probs_stride,
draft_probs_stride_1, # [num_logits, V]
target_logits_ptr,
target_logits_stride,
# [num_reqs] # [num_reqs]
rejected_step_ptr, rejected_step_ptr,
# [num_reqs + 1] # [num_reqs + 1]
cu_num_logits_ptr, cu_num_logits_ptr,
# [num_logits] # [num_reqs]
pos_ptr, idx_mapping_ptr,
# [max_num_reqs]
temp_ptr,
vocab_size, vocab_size,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
req_idx = tl.program_id(0) req_idx = tl.program_id(0)
block_idx = tl.program_id(1) block_idx = tl.program_id(1)
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
start_idx = tl.load(cu_num_logits_ptr + req_idx) start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1) end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
rejected_draft_step = tl.load(rejected_step_ptr + req_idx) rejected_logit_idx = start_idx + tl.load(rejected_step_ptr + req_idx)
rejected_logit_idx = start_idx + rejected_draft_step temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
block_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) block_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block_offsets < vocab_size mask = block_offsets < vocab_size
if rejected_logit_idx < end_idx - 1: if temp == 0.0 or (rejected_logit_idx == end_idx - 1):
# Greedy sampling / bonus token. In either case, use the
# target logits directly to reduce numerical error.
residual_logits = tl.load(
target_logits_ptr
+ rejected_logit_idx * target_logits_stride
+ block_offsets,
mask=mask,
other=float("-inf"),
)
else:
target_probs = tl.load( target_probs = tl.load(
target_probs_ptr + rejected_logit_idx * target_probs_stride + block_offsets, target_probs_ptr + rejected_logit_idx * target_probs_stride + block_offsets,
mask=mask, mask=mask,
other=0.0, other=0.0,
) )
draft_probs = tl.load( draft_probs = tl.load(
draft_probs_ptr draft_probs_ptr + rejected_logit_idx * draft_probs_stride + block_offsets,
+ req_idx * draft_probs_stride_0
+ rejected_draft_step * draft_probs_stride_1
+ block_offsets,
mask=mask, mask=mask,
other=0.0, other=0.0,
) )
residual_probs = tl.maximum(target_probs - draft_probs, 0.0) residual_probs = tl.maximum(target_probs - draft_probs, 0.0)
residual_logits = tl.log(residual_probs) residual_logits = tl.log(residual_probs)
else:
# This is a bonus token. Directly return the target logits.
residual_logits = tl.load(
target_logits_ptr
+ rejected_logit_idx * target_logits_stride
+ block_offsets,
mask=mask,
other=0.0,
)
tl.store( tl.store(
residual_logits_ptr + req_idx * residual_logits_stride + block_offsets, residual_logits_ptr + req_idx * residual_logits_stride + block_offsets,
...@@ -188,18 +288,13 @@ def _compute_residual_logits_kernel( ...@@ -188,18 +288,13 @@ def _compute_residual_logits_kernel(
mask=mask, mask=mask,
) )
# First block computes the residual logit positions.
if block_idx == 0:
pos_val = tl.load(pos_ptr + rejected_logit_idx)
tl.store(residual_pos_ptr + req_idx, pos_val)
def probabilistic_rejection_sample( def probabilistic_rejection_sample(
# [num_draft_tokens + num_reqs, V] # [num_logits, V]
target_logits: torch.Tensor, target_logits: torch.Tensor,
# [num_reqs, num_speculative_steps, V] # [max_num_reqs, num_speculative_steps, V]
draft_logits: torch.Tensor, draft_logits: torch.Tensor,
# [num_draft_tokens + num_reqs] # [num_logits]
draft_sampled: torch.Tensor, draft_sampled: torch.Tensor,
# [num_reqs + 1] # [num_reqs + 1]
cu_num_logits: torch.Tensor, cu_num_logits: torch.Tensor,
...@@ -207,16 +302,53 @@ def probabilistic_rejection_sample( ...@@ -207,16 +302,53 @@ def probabilistic_rejection_sample(
pos: torch.Tensor, pos: torch.Tensor,
# [num_reqs] # [num_reqs]
idx_mapping: torch.Tensor, idx_mapping: torch.Tensor,
# [num_logits]
expanded_idx_mapping: torch.Tensor,
# [num_logits]
expanded_local_pos: torch.Tensor,
# [max_num_reqs]
temperature: torch.Tensor, temperature: torch.Tensor,
# [max_num_reqs]
seed: torch.Tensor, seed: torch.Tensor,
num_speculative_steps: int, num_speculative_steps: int,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1 num_reqs = cu_num_logits.shape[0] - 1
vocab_size = target_logits.shape[-1] num_logits, vocab_size = target_logits.shape
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
# Gather draft logits and target argmax for greedy sampling.
gathered_draft_logits = target_logits.new_empty(target_logits.shape)
local_target_argmax = target_logits.new_empty(
num_logits, num_blocks, dtype=torch.int64
)
local_target_max = target_logits.new_empty(
num_logits, num_blocks, dtype=torch.float32
)
_gather_draft_logits_and_target_argmax_kernel[(num_logits, num_blocks)](
local_target_argmax,
local_target_argmax.stride(0),
local_target_max,
local_target_max.stride(0),
gathered_draft_logits,
gathered_draft_logits.stride(0),
target_logits,
target_logits.stride(0),
draft_logits,
draft_logits.stride(0),
draft_logits.stride(1),
expanded_idx_mapping,
expanded_local_pos,
temperature,
vocab_size,
num_speculative_steps,
BLOCK_SIZE=BLOCK_SIZE,
)
# Compute target and draft probs. # Compute target and draft probs.
target_probs = torch.softmax(target_logits, dim=-1) target_probs = torch.softmax(target_logits, dim=-1)
draft_probs = torch.softmax(draft_logits, dim=-1) draft_probs = torch.softmax(gathered_draft_logits, dim=-1)
# Rejection sample. # Rejection sample.
# [num_reqs, num_speculative_steps + 1] # [num_reqs, num_speculative_steps + 1]
...@@ -225,45 +357,49 @@ def probabilistic_rejection_sample( ...@@ -225,45 +357,49 @@ def probabilistic_rejection_sample(
) )
# [num_reqs] # [num_reqs]
rejected_steps = sampled.new_empty(num_reqs) rejected_steps = sampled.new_empty(num_reqs)
_probabilistic_rejection_sample_kernel[(num_reqs,)]( # [num_reqs]
rejected_pos = pos.new_empty(num_reqs)
_probabilistic_rejection_kernel[(num_reqs,)](
sampled, sampled,
sampled.stride(0), sampled.stride(0),
rejected_steps, rejected_steps,
rejected_pos,
draft_sampled, draft_sampled,
target_probs, target_probs,
target_probs.stride(0), target_probs.stride(0),
draft_probs, draft_probs,
draft_probs.stride(0), draft_probs.stride(0),
draft_probs.stride(1), local_target_argmax,
local_target_argmax.stride(0),
local_target_max,
local_target_max.stride(0),
cu_num_logits, cu_num_logits,
pos, pos,
idx_mapping, idx_mapping,
temperature,
seed, seed,
num_warps=1, num_warps=1,
NUM_BLOCKS=num_blocks,
PADDED_NUM_BLOCKS=triton.next_power_of_2(num_blocks),
) )
# Compute the logits and positions to resample the rejected/bonus # Compute the logits and positions to resample the rejected/bonus
# tokens from. # tokens from.
# [num_reqs, vocab_size] # [num_reqs, vocab_size]
residual_logits = target_logits.new_empty(num_reqs, vocab_size) residual_logits = target_logits.new_empty(num_reqs, vocab_size)
# [num_reqs]
residual_pos = pos.new_empty(num_reqs)
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
_compute_residual_logits_kernel[(num_reqs, num_blocks)]( _compute_residual_logits_kernel[(num_reqs, num_blocks)](
residual_logits, residual_logits,
residual_logits.stride(0), residual_logits.stride(0),
residual_pos,
target_logits,
target_logits.stride(0),
target_probs, target_probs,
target_probs.stride(0), target_probs.stride(0),
draft_probs, draft_probs,
draft_probs.stride(0), draft_probs.stride(0),
draft_probs.stride(1), target_logits,
target_logits.stride(0),
rejected_steps, rejected_steps,
cu_num_logits, cu_num_logits,
pos, idx_mapping,
temperature,
vocab_size, vocab_size,
BLOCK_SIZE=BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE,
) )
...@@ -274,7 +410,7 @@ def probabilistic_rejection_sample( ...@@ -274,7 +410,7 @@ def probabilistic_rejection_sample(
idx_mapping, idx_mapping,
temperature, temperature,
seed, seed,
residual_pos, rejected_pos,
apply_temperature=False, apply_temperature=False,
) )
sampled.scatter_(1, rejected_steps.unsqueeze(1), resampled.unsqueeze(1)) sampled.scatter_(1, rejected_steps.unsqueeze(1), resampled.unsqueeze(1))
...@@ -333,6 +469,8 @@ class RejectionSampler: ...@@ -333,6 +469,8 @@ class RejectionSampler:
input_batch.cu_num_logits, input_batch.cu_num_logits,
pos, pos,
input_batch.idx_mapping, input_batch.idx_mapping,
input_batch.expanded_idx_mapping,
input_batch.expanded_local_pos,
self.sampler.sampling_states.temperature.gpu, self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu, self.sampler.sampling_states.seeds.gpu,
self.num_speculative_steps, self.num_speculative_steps,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment