Unverified Commit 6c01ffb8 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Decouple temperature from penalties (#32629)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 7b7cdce9
...@@ -5,6 +5,50 @@ import torch ...@@ -5,6 +5,50 @@ import torch
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
@triton.jit
def _temperature_kernel(
logits_ptr,
logits_stride,
idx_mapping_ptr,
temperature_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
temperature = tl.load(temperature_ptr + req_state_idx).to(tl.float32)
if temperature == 0.0 or temperature == 1.0:
# Early return to avoid loading logits.
return
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
logits = logits.to(tl.float32)
logits = logits / temperature
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
def apply_temperature(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
temperature: torch.Tensor,
) -> None:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 8192
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
_temperature_kernel[(num_reqs, num_blocks)](
logits,
logits.stride(0),
idx_mapping,
temperature,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
@triton.jit @triton.jit
def _gumbel_sample_kernel( def _gumbel_sample_kernel(
local_argmax_ptr, local_argmax_ptr,
...@@ -48,7 +92,7 @@ def _gumbel_sample_kernel( ...@@ -48,7 +92,7 @@ def _gumbel_sample_kernel(
# Apply temperature. # Apply temperature.
if APPLY_TEMPERATURE: if APPLY_TEMPERATURE:
# NOTE(woosuk): Match the behavior of _penalties_and_temperature_kernel. # NOTE(woosuk): Match the behavior of _temperature_kernel.
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too. # E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
logits = logits / temp logits = logits / temp
......
...@@ -66,16 +66,10 @@ class PenaltiesState: ...@@ -66,16 +66,10 @@ class PenaltiesState:
self.frequency_penalty.copy_to_uva() self.frequency_penalty.copy_to_uva()
self.presence_penalty.copy_to_uva() self.presence_penalty.copy_to_uva()
def apply_penalties_and_temperature( def apply_penalties(self, logits: torch.Tensor, idx_mapping: torch.Tensor) -> None:
self, apply_penalties(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
temperature: torch.Tensor,
) -> None:
apply_penalties_and_temperature(
logits, logits,
idx_mapping, idx_mapping,
temperature,
self.repetition_penalty.gpu, self.repetition_penalty.gpu,
self.frequency_penalty.gpu, self.frequency_penalty.gpu,
self.presence_penalty.gpu, self.presence_penalty.gpu,
...@@ -85,14 +79,13 @@ class PenaltiesState: ...@@ -85,14 +79,13 @@ class PenaltiesState:
@triton.jit @triton.jit
def _penalties_and_temperature_kernel( def _penalties_kernel(
logits_ptr, logits_ptr,
logits_stride, logits_stride,
idx_mapping_ptr, idx_mapping_ptr,
repetition_penalty_ptr, repetition_penalty_ptr,
frequency_penalty_ptr, frequency_penalty_ptr,
presence_penalty_ptr, presence_penalty_ptr,
temperature_ptr,
prompt_bin_mask_ptr, prompt_bin_mask_ptr,
prompt_bin_mask_stride, prompt_bin_mask_stride,
output_bin_counts_ptr, output_bin_counts_ptr,
...@@ -105,15 +98,12 @@ def _penalties_and_temperature_kernel( ...@@ -105,15 +98,12 @@ def _penalties_and_temperature_kernel(
rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx) rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx)
freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx) freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx)
pres_penalty = tl.load(presence_penalty_ptr + req_state_idx) pres_penalty = tl.load(presence_penalty_ptr + req_state_idx)
temperature = tl.load(temperature_ptr + req_state_idx)
temperature = tl.where(temperature == 0.0, 1.0, temperature)
use_rep_penalty = rep_penalty != 1.0 use_rep_penalty = rep_penalty != 1.0
use_freq_penalty = freq_penalty != 0.0 use_freq_penalty = freq_penalty != 0.0
use_pres_penalty = pres_penalty != 0.0 use_pres_penalty = pres_penalty != 0.0
use_penalty = use_rep_penalty or use_freq_penalty or use_pres_penalty use_penalty = use_rep_penalty or use_freq_penalty or use_pres_penalty
use_temperature = temperature != 1.0 if not use_penalty:
if not (use_penalty or use_temperature):
# Early return to avoid loading logits. # Early return to avoid loading logits.
return return
...@@ -123,7 +113,6 @@ def _penalties_and_temperature_kernel( ...@@ -123,7 +113,6 @@ def _penalties_and_temperature_kernel(
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask) logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask)
logits = logits.to(tl.float32) logits = logits.to(tl.float32)
if use_penalty:
output_bin_counts = tl.load( output_bin_counts = tl.load(
output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block, output_bin_counts_ptr + req_state_idx * output_bin_counts_stride + block,
mask=mask, mask=mask,
...@@ -134,9 +123,7 @@ def _penalties_and_temperature_kernel( ...@@ -134,9 +123,7 @@ def _penalties_and_temperature_kernel(
if use_rep_penalty: if use_rep_penalty:
packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(0, BLOCK_SIZE // 32) packed_block = block_idx * BLOCK_SIZE // 32 + tl.arange(0, BLOCK_SIZE // 32)
packed_mask = tl.load( packed_mask = tl.load(
prompt_bin_mask_ptr prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + packed_block,
+ req_state_idx * prompt_bin_mask_stride
+ packed_block,
mask=packed_block < tl.cdiv(vocab_size, 32), mask=packed_block < tl.cdiv(vocab_size, 32),
) )
prompt_bin_mask = (packed_mask[:, None] >> (tl.arange(0, 32)[None, :])) & 1 prompt_bin_mask = (packed_mask[:, None] >> (tl.arange(0, 32)[None, :])) & 1
...@@ -152,18 +139,13 @@ def _penalties_and_temperature_kernel( ...@@ -152,18 +139,13 @@ def _penalties_and_temperature_kernel(
logits -= freq_penalty * output_bin_counts logits -= freq_penalty * output_bin_counts
# Apply presence penalties. # Apply presence penalties.
logits -= pres_penalty * output_bin_mask logits -= pres_penalty * output_bin_mask
# Apply temperature.
logits = logits / temperature
# Store back to logits. # Store back to logits.
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask) tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask)
def apply_penalties_and_temperature( def apply_penalties(
logits: torch.Tensor, logits: torch.Tensor,
idx_mapping: torch.Tensor, idx_mapping: torch.Tensor,
temperature: torch.Tensor,
repetition_penalty: torch.Tensor, repetition_penalty: torch.Tensor,
frequency_penalty: torch.Tensor, frequency_penalty: torch.Tensor,
presence_penalty: torch.Tensor, presence_penalty: torch.Tensor,
...@@ -173,14 +155,13 @@ def apply_penalties_and_temperature( ...@@ -173,14 +155,13 @@ def apply_penalties_and_temperature(
num_reqs, vocab_size = logits.shape num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 8192 BLOCK_SIZE = 8192
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
_penalties_and_temperature_kernel[(num_reqs, num_blocks)]( _penalties_kernel[(num_reqs, num_blocks)](
logits, logits,
logits.stride(0), logits.stride(0),
idx_mapping, idx_mapping,
repetition_penalty, repetition_penalty,
frequency_penalty, frequency_penalty,
presence_penalty, presence_penalty,
temperature,
prompt_bin_mask, prompt_bin_mask,
prompt_bin_mask.stride(0), prompt_bin_mask.stride(0),
output_bin_counts, output_bin_counts,
......
...@@ -9,7 +9,7 @@ from vllm.config.model import LogprobsMode ...@@ -9,7 +9,7 @@ from vllm.config.model import LogprobsMode
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.metrics.logits import get_num_nans from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.sample.gumbel import apply_temperature, gumbel_sample
from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
from vllm.v1.worker.gpu.sample.min_p import apply_min_p from vllm.v1.worker.gpu.sample.min_p import apply_min_p
...@@ -106,10 +106,11 @@ class Sampler: ...@@ -106,10 +106,11 @@ class Sampler:
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place. # Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, pos) self.logit_bias_state.apply_logit_bias(logits, idx_mapping, pos)
# Apply penalties and temperature in place. # Apply penalties in place.
self.penalties_state.apply_penalties_and_temperature( self.penalties_state.apply_penalties(logits, idx_mapping)
logits, idx_mapping, self.sampling_states.temperature.gpu
) # Apply temperature in place.
apply_temperature(logits, idx_mapping, self.sampling_states.temperature.gpu)
# Apply min_p in place if any request has a non-zero min_p. # Apply min_p in place if any request has a non-zero min_p.
do_min_p = self.sampling_states.do_min_p(idx_mapping_np) do_min_p = self.sampling_states.do_min_p(idx_mapping_np)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment