Unverified Commit 16786da7 authored by zhrrr's avatar zhrrr Committed by GitHub
Browse files

[Model Runner V2] support apply penalty for spec decode (#33251)


Signed-off-by: default avatarzhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
parent aaa2efbe
...@@ -40,6 +40,8 @@ class InputBatch: ...@@ -40,6 +40,8 @@ class InputBatch:
idx_mapping_np: np.ndarray idx_mapping_np: np.ndarray
# Identical to idx_mapping except for spec decoding. # Identical to idx_mapping except for spec decoding.
expanded_idx_mapping: torch.Tensor expanded_idx_mapping: torch.Tensor
# [total_num_logits] position within request for each logit
expanded_local_pos: torch.Tensor
# [num_reqs] # [num_reqs]
# batch_idx -> num_scheduled_tokens # batch_idx -> num_scheduled_tokens
...@@ -91,6 +93,7 @@ class InputBatch: ...@@ -91,6 +93,7 @@ class InputBatch:
idx_mapping_np = np.arange(num_reqs, dtype=np.int32) idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device) idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
expanded_idx_mapping = idx_mapping expanded_idx_mapping = idx_mapping
expanded_local_pos = torch.zeros(num_reqs, dtype=torch.int32, device=device)
num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32) num_scheduled_tokens = np.full(num_reqs, num_tokens // num_reqs, dtype=np.int32)
num_scheduled_tokens[-1] += num_tokens % num_reqs num_scheduled_tokens[-1] += num_tokens % num_reqs
assert int(num_scheduled_tokens.sum()) == num_tokens assert int(num_scheduled_tokens.sum()) == num_tokens
...@@ -126,6 +129,7 @@ class InputBatch: ...@@ -126,6 +129,7 @@ class InputBatch:
idx_mapping=idx_mapping, idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np, idx_mapping_np=idx_mapping_np,
expanded_idx_mapping=expanded_idx_mapping, expanded_idx_mapping=expanded_idx_mapping,
expanded_local_pos=expanded_local_pos,
num_scheduled_tokens=num_scheduled_tokens, num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens, num_tokens=num_tokens,
num_tokens_after_padding=num_tokens, num_tokens_after_padding=num_tokens,
...@@ -487,6 +491,7 @@ def post_update( ...@@ -487,6 +491,7 @@ def post_update(
def _expand_idx_mapping_kernel( def _expand_idx_mapping_kernel(
idx_mapping_ptr, idx_mapping_ptr,
expanded_idx_mapping_ptr, expanded_idx_mapping_ptr,
expanded_local_pos_ptr,
cu_num_logits_ptr, cu_num_logits_ptr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
...@@ -499,6 +504,7 @@ def _expand_idx_mapping_kernel( ...@@ -499,6 +504,7 @@ def _expand_idx_mapping_kernel(
mask = block < num_tokens mask = block < num_tokens
req_state_idx = tl.load(idx_mapping_ptr + req_idx) req_state_idx = tl.load(idx_mapping_ptr + req_idx)
tl.store(expanded_idx_mapping_ptr + start_idx + block, req_state_idx, mask=mask) tl.store(expanded_idx_mapping_ptr + start_idx + block, req_state_idx, mask=mask)
tl.store(expanded_local_pos_ptr + start_idx + block, block, mask=mask)
def expand_idx_mapping( def expand_idx_mapping(
...@@ -506,13 +512,17 @@ def expand_idx_mapping( ...@@ -506,13 +512,17 @@ def expand_idx_mapping(
total_num_logits: int, total_num_logits: int,
cu_num_logits: torch.Tensor, cu_num_logits: torch.Tensor,
max_expand_len: int, max_expand_len: int,
) -> torch.Tensor: ) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = idx_mapping.shape[0] num_reqs = idx_mapping.shape[0]
expanded_idx_mapping = idx_mapping.new_empty(total_num_logits) expanded_idx_mapping = idx_mapping.new_empty(total_num_logits)
expanded_local_pos = torch.empty(
total_num_logits, dtype=torch.int32, device=idx_mapping.device
)
_expand_idx_mapping_kernel[(num_reqs,)]( _expand_idx_mapping_kernel[(num_reqs,)](
idx_mapping, idx_mapping,
expanded_idx_mapping, expanded_idx_mapping,
expanded_local_pos,
cu_num_logits, cu_num_logits,
BLOCK_SIZE=triton.next_power_of_2(max_expand_len), BLOCK_SIZE=triton.next_power_of_2(max_expand_len),
) )
return expanded_idx_mapping return expanded_idx_mapping, expanded_local_pos
...@@ -152,6 +152,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -152,6 +152,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
device=self.device, device=self.device,
logprobs_mode=self.model_config.logprobs_mode, logprobs_mode=self.model_config.logprobs_mode,
num_speculative_tokens=self.num_speculative_steps + 1,
) )
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs) self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
...@@ -318,10 +319,22 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -318,10 +319,22 @@ class GPUModelRunner(LoRAModelRunnerMixin):
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=self.device) idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=self.device)
idx_mapping_np = np.arange(num_reqs, dtype=np.int32) idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
pos = torch.zeros(num_reqs, dtype=torch.int64, device=self.device) pos = torch.zeros(num_reqs, dtype=torch.int64, device=self.device)
dummy_input_ids = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
expanded_local_pos = torch.zeros(
num_reqs, dtype=torch.int32, device=self.device
)
# NOTE(woosuk): During the initial memory profiling, the sampler may skip # NOTE(woosuk): During the initial memory profiling, the sampler may skip
# top_k, top_p, and logprobs, using less GPU memory than what is possible # top_k, top_p, and logprobs, using less GPU memory than what is possible
# during actual execution. # during actual execution.
self.sampler(logits, idx_mapping, idx_mapping_np, idx_mapping_np, pos) self.sampler(
logits,
idx_mapping,
idx_mapping_np,
idx_mapping_np,
pos,
dummy_input_ids,
expanded_local_pos,
)
@torch.inference_mode() @torch.inference_mode()
def profile_run(self) -> None: def profile_run(self) -> None:
...@@ -511,6 +524,9 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -511,6 +524,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_reqs + 1, device=self.device, dtype=torch.int32 num_reqs + 1, device=self.device, dtype=torch.int32
) )
expanded_idx_mapping = idx_mapping expanded_idx_mapping = idx_mapping
expanded_local_pos = torch.zeros(
num_reqs, dtype=torch.int32, device=self.device
)
else: else:
num_draft_tokens = np.array( num_draft_tokens = np.array(
[len(draft_tokens.get(req_id, ())) for req_id in req_ids], [len(draft_tokens.get(req_id, ())) for req_id in req_ids],
...@@ -526,7 +542,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -526,7 +542,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
cu_num_logits = async_copy_to_gpu(cu_num_logits_np, device=self.device) cu_num_logits = async_copy_to_gpu(cu_num_logits_np, device=self.device)
max_expand_len = self.num_speculative_steps + 1 max_expand_len = self.num_speculative_steps + 1
expanded_idx_mapping = expand_idx_mapping( expanded_idx_mapping, expanded_local_pos = expand_idx_mapping(
idx_mapping, total_num_logits, cu_num_logits, max_expand_len idx_mapping, total_num_logits, cu_num_logits, max_expand_len
) )
...@@ -627,6 +643,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -627,6 +643,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
idx_mapping=idx_mapping, idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np, idx_mapping_np=idx_mapping_np,
expanded_idx_mapping=expanded_idx_mapping, expanded_idx_mapping=expanded_idx_mapping,
expanded_local_pos=expanded_local_pos,
num_scheduled_tokens=num_scheduled_tokens, num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens, num_tokens=num_tokens,
num_tokens_after_padding=num_tokens_after_padding, num_tokens_after_padding=num_tokens_after_padding,
...@@ -674,6 +691,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -674,6 +691,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]: ) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
sample_hidden_states = hidden_states[input_batch.logits_indices] sample_hidden_states = hidden_states[input_batch.logits_indices]
sample_pos = input_batch.positions[input_batch.logits_indices] sample_pos = input_batch.positions[input_batch.logits_indices]
input_ids = input_batch.input_ids[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
if grammar_output is not None: if grammar_output is not None:
# Apply grammar bitmask to the logits in-place. # Apply grammar bitmask to the logits in-place.
...@@ -691,6 +709,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -691,6 +709,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
input_batch.idx_mapping_np, input_batch.idx_mapping_np,
input_batch.cu_num_logits_np, input_batch.cu_num_logits_np,
sample_pos, sample_pos,
input_ids,
input_batch.expanded_local_pos,
) )
if input_batch.num_draft_tokens == 0: if input_batch.num_draft_tokens == 0:
...@@ -700,7 +720,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -700,7 +720,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) )
else: else:
# Rejection sampling for spec decoding. # Rejection sampling for spec decoding.
input_ids = input_batch.input_ids[input_batch.logits_indices]
sampled_tokens, num_sampled = rejection_sample( sampled_tokens, num_sampled = rejection_sample(
sampler_output.sampled_token_ids, sampler_output.sampled_token_ids,
input_ids, input_ids,
......
...@@ -75,6 +75,9 @@ class PenaltiesState: ...@@ -75,6 +75,9 @@ class PenaltiesState:
logits: torch.Tensor, logits: torch.Tensor,
idx_mapping: torch.Tensor, idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray, idx_mapping_np: np.ndarray,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
num_speculative_tokens: int,
) -> None: ) -> None:
if not np.any(self.use_penalty[idx_mapping_np]): if not np.any(self.use_penalty[idx_mapping_np]):
# No request uses penalties. Skip the kernel launch. # No request uses penalties. Skip the kernel launch.
...@@ -83,11 +86,14 @@ class PenaltiesState: ...@@ -83,11 +86,14 @@ class PenaltiesState:
apply_penalties( apply_penalties(
logits, logits,
idx_mapping, idx_mapping,
input_ids,
expanded_local_pos,
self.repetition_penalty.gpu, self.repetition_penalty.gpu,
self.frequency_penalty.gpu, self.frequency_penalty.gpu,
self.presence_penalty.gpu, self.presence_penalty.gpu,
self.prompt_bin_mask, self.prompt_bin_mask,
self.output_bin_counts, self.output_bin_counts,
num_speculative_tokens,
) )
...@@ -96,6 +102,8 @@ def _penalties_kernel( ...@@ -96,6 +102,8 @@ def _penalties_kernel(
logits_ptr, logits_ptr,
logits_stride, logits_stride,
idx_mapping_ptr, idx_mapping_ptr,
token_ids_ptr,
expanded_local_pos_ptr,
repetition_penalty_ptr, repetition_penalty_ptr,
frequency_penalty_ptr, frequency_penalty_ptr,
presence_penalty_ptr, presence_penalty_ptr,
...@@ -105,9 +113,10 @@ def _penalties_kernel( ...@@ -105,9 +113,10 @@ def _penalties_kernel(
output_bin_counts_stride, output_bin_counts_stride,
vocab_size, vocab_size,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
MAX_SPEC_LEN: tl.constexpr,
): ):
batch_idx = tl.program_id(0) token_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx) req_state_idx = tl.load(idx_mapping_ptr + token_idx)
rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx) rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx)
freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx) freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx)
pres_penalty = tl.load(presence_penalty_ptr + req_state_idx) pres_penalty = tl.load(presence_penalty_ptr + req_state_idx)
...@@ -123,13 +132,27 @@ def _penalties_kernel( ...@@ -123,13 +132,27 @@ def _penalties_kernel(
block_idx = tl.program_id(1) block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size mask = block < vocab_size
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask) logits = tl.load(logits_ptr + token_idx * logits_stride + block, mask=mask)
logits = logits.to(tl.float32) logits = logits.to(tl.float32)
output_bin_counts = tl.load( base_output_counts = tl.load(
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block, output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
mask=mask, mask=mask,
other=0,
) )
# Compute cumulative draft_counts from previous positions in this request
pos = tl.load(expanded_local_pos_ptr + token_idx)
start_idx = token_idx - pos
draft_counts = tl.zeros((BLOCK_SIZE,), dtype=tl.int32)
for prev_pos in tl.static_range(MAX_SPEC_LEN):
if prev_pos < pos:
prev_token = tl.load(token_ids_ptr + start_idx + prev_pos + 1)
token_match = block == prev_token
draft_counts = draft_counts + token_match.to(tl.int32)
# Total counts = base output counts + cumulative draft counts
output_bin_counts = base_output_counts + draft_counts
output_bin_mask = output_bin_counts > 0 output_bin_mask = output_bin_counts > 0
# Apply repetition penalties. # Apply repetition penalties.
...@@ -138,6 +161,7 @@ def _penalties_kernel( ...@@ -138,6 +161,7 @@ def _penalties_kernel(
packed_mask = tl.load( packed_mask = tl.load(
prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + packed_block, prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + packed_block,
mask=packed_block < tl.cdiv(vocab_size, 32), mask=packed_block < tl.cdiv(vocab_size, 32),
other=0,
) )
prompt_bin_mask = (packed_mask[:, None] >> (tl.arange(0, 32)[None, :])) & 1 prompt_bin_mask = (packed_mask[:, None] >> (tl.arange(0, 32)[None, :])) & 1
prompt_bin_mask = prompt_bin_mask.to(tl.int1) prompt_bin_mask = prompt_bin_mask.to(tl.int1)
...@@ -153,25 +177,30 @@ def _penalties_kernel( ...@@ -153,25 +177,30 @@ def _penalties_kernel(
# Apply presence penalties. # Apply presence penalties.
logits -= pres_penalty * output_bin_mask logits -= pres_penalty * output_bin_mask
# Store back to logits. # Store back to logits.
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask) tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
def apply_penalties( def apply_penalties(
logits: torch.Tensor, logits: torch.Tensor,
idx_mapping: torch.Tensor, idx_mapping: torch.Tensor,
token_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
repetition_penalty: torch.Tensor, repetition_penalty: torch.Tensor,
frequency_penalty: torch.Tensor, frequency_penalty: torch.Tensor,
presence_penalty: torch.Tensor, presence_penalty: torch.Tensor,
prompt_bin_mask: torch.Tensor, prompt_bin_mask: torch.Tensor,
output_bin_counts: torch.Tensor, output_bin_counts: torch.Tensor,
num_speculative_tokens: int,
) -> None: ) -> None:
num_reqs, vocab_size = logits.shape num_tokens, vocab_size = logits.shape
BLOCK_SIZE = 8192 BLOCK_SIZE = 8192
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
_penalties_kernel[(num_reqs, num_blocks)]( _penalties_kernel[(num_tokens, num_blocks)](
logits, logits,
logits.stride(0), logits.stride(0),
idx_mapping, idx_mapping,
token_ids,
expanded_local_pos,
repetition_penalty, repetition_penalty,
frequency_penalty, frequency_penalty,
presence_penalty, presence_penalty,
...@@ -181,6 +210,7 @@ def apply_penalties( ...@@ -181,6 +210,7 @@ def apply_penalties(
output_bin_counts.stride(0), output_bin_counts.stride(0),
vocab_size, vocab_size,
BLOCK_SIZE=BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE,
MAX_SPEC_LEN=num_speculative_tokens,
) )
......
...@@ -25,6 +25,7 @@ class Sampler: ...@@ -25,6 +25,7 @@ class Sampler:
vocab_size: int, vocab_size: int,
device: torch.device, device: torch.device,
logprobs_mode: LogprobsMode = "raw_logprobs", logprobs_mode: LogprobsMode = "raw_logprobs",
num_speculative_tokens: int = 1,
): ):
if logprobs_mode not in ("processed_logprobs", "raw_logprobs"): if logprobs_mode not in ("processed_logprobs", "raw_logprobs"):
raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}") raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}")
...@@ -34,6 +35,7 @@ class Sampler: ...@@ -34,6 +35,7 @@ class Sampler:
self.sampling_states = SamplingStates(max_num_reqs, vocab_size) self.sampling_states = SamplingStates(max_num_reqs, vocab_size)
self.penalties_state = PenaltiesState(max_num_reqs, vocab_size, device) self.penalties_state = PenaltiesState(max_num_reqs, vocab_size, device)
self.logit_bias_state = LogitBiasState(max_num_reqs, device) self.logit_bias_state = LogitBiasState(max_num_reqs, device)
self.num_speculative_tokens = num_speculative_tokens
def add_request( def add_request(
self, req_idx: int, prompt_len: int, sampling_params: SamplingParams self, req_idx: int, prompt_len: int, sampling_params: SamplingParams
...@@ -61,12 +63,19 @@ class Sampler: ...@@ -61,12 +63,19 @@ class Sampler:
idx_mapping_np: np.ndarray, idx_mapping_np: np.ndarray,
cu_num_logits_np: np.ndarray, cu_num_logits_np: np.ndarray,
pos: torch.Tensor, pos: torch.Tensor,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
) -> SamplerOutput: ) -> SamplerOutput:
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear # NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
# that num_nans is computed before applying penalties and temperature. # that num_nans is computed before applying penalties and temperature.
num_nans = get_num_nans(logits) if self.compute_nans else None num_nans = get_num_nans(logits) if self.compute_nans else None
sampled, processed_logits = self.sample( sampled, processed_logits = self.sample(
logits, idx_mapping, idx_mapping_np, pos logits,
idx_mapping,
idx_mapping_np,
pos,
input_ids,
expanded_local_pos,
) )
max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np) max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
...@@ -98,6 +107,8 @@ class Sampler: ...@@ -98,6 +107,8 @@ class Sampler:
idx_mapping: torch.Tensor, idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray, idx_mapping_np: np.ndarray,
pos: torch.Tensor, pos: torch.Tensor,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Copy logits to a new FP32 tensor. # Copy logits to a new FP32 tensor.
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits) logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
...@@ -106,7 +117,14 @@ class Sampler: ...@@ -106,7 +117,14 @@ class Sampler:
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, idx_mapping_np, pos) self.logit_bias_state.apply_logit_bias(logits, idx_mapping, idx_mapping_np, pos)
# Apply penalties in place. # Apply penalties in place.
self.penalties_state.apply_penalties(logits, idx_mapping, idx_mapping_np) self.penalties_state.apply_penalties(
logits,
idx_mapping,
idx_mapping_np,
input_ids,
expanded_local_pos,
self.num_speculative_tokens,
)
# Apply temperature in place. # Apply temperature in place.
apply_temperature(logits, idx_mapping, self.sampling_states.temperature.gpu) apply_temperature(logits, idx_mapping, self.sampling_states.temperature.gpu)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment