# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch from vllm.triton_utils import tl, triton @triton.jit def _temperature_kernel( logits_ptr, logits_stride, expanded_idx_mapping_ptr, temperature_ptr, vocab_size, BLOCK_SIZE: tl.constexpr, ): token_idx = tl.program_id(0) req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx) temperature = tl.load(temperature_ptr + req_state_idx).to(tl.float32) if temperature == 0.0 or temperature == 1.0: # Early return to avoid loading logits. return block_idx = tl.program_id(1) block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = block < vocab_size logits = tl.load(logits_ptr + token_idx * logits_stride + block, mask=mask) logits = logits.to(tl.float32) logits = logits / temperature tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask) def apply_temperature( logits: torch.Tensor, expanded_idx_mapping: torch.Tensor, temperature: torch.Tensor, ) -> None: num_tokens, vocab_size = logits.shape BLOCK_SIZE = 8192 num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) _temperature_kernel[(num_tokens, num_blocks)]( logits, logits.stride(0), expanded_idx_mapping, temperature, vocab_size, BLOCK_SIZE=BLOCK_SIZE, ) @triton.jit def tl_rand64(seed, offset, includes_zero: tl.constexpr): lo, hi, _, _ = tl.randint4x(seed, offset) lo = lo.to(tl.uint32, bitcast=True).to(tl.uint64) hi = hi.to(tl.uint32, bitcast=True).to(tl.uint64) r = (hi << 32) | lo # 1 / 2**64 scale = 5.421010862427522170037e-20 u = r.to(tl.float64) * scale if not includes_zero: u = tl.maximum(u, 2.2250738585072014e-308) # float64 tiny return u @triton.jit def _gumbel_sample_kernel( local_argmax_ptr, local_argmax_stride, local_max_ptr, local_max_stride, processed_logits_ptr, processed_logits_stride, logits_ptr, logits_stride, expanded_idx_mapping_ptr, seeds_ptr, pos_ptr, temp_ptr, vocab_size, BLOCK_SIZE: tl.constexpr, APPLY_TEMPERATURE: tl.constexpr, ): token_idx = tl.program_id(0) req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx) block_idx = tl.program_id(1) block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = block < vocab_size logits = tl.load( logits_ptr + token_idx * logits_stride + block, mask=mask, other=float("-inf"), ) logits = logits.to(tl.float32) temp = tl.load(temp_ptr + req_state_idx).to(tl.float32) if temp != 0.0 and APPLY_TEMPERATURE: # Apply temperature. # NOTE(woosuk): Match the behavior of _temperature_kernel. # E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too. logits = logits / temp # Store the temperature-applied logits. if processed_logits_ptr is not None: tl.store( processed_logits_ptr + req_state_idx * processed_logits_stride + block, logits, mask=mask, ) logits = logits.to(tl.float64) if temp != 0.0: # Calculate the seed for gumbel noise. seed = tl.load(seeds_ptr + req_state_idx) pos = tl.load(pos_ptr + token_idx) gumbel_seed = tl.randint(seed, pos) # tl.rand returns fp32, so build a true fp64 uniform from 64 random # bits before applying the double-log transform. u = tl_rand64(gumbel_seed, block, includes_zero=False) gumbel_noise = -tl.log(-tl.log(u)) # Apply gumbel noise. logits = tl.where(mask, logits + gumbel_noise, float("-inf")) value, idx = tl.max(logits, axis=0, return_indices=True) token_id = block_idx * BLOCK_SIZE + idx tl.store(local_argmax_ptr + token_idx * local_argmax_stride + block_idx, token_id) tl.store(local_max_ptr + token_idx * local_max_stride + block_idx, value) def gumbel_sample( logits: torch.Tensor, # [num_tokens, vocab_size] expanded_idx_mapping: torch.Tensor, # [num_tokens] temperature: torch.Tensor, # [max_num_reqs] seed: torch.Tensor, # [max_num_reqs] pos: torch.Tensor, # [num_tokens] apply_temperature: bool, processed_logits_out: torch.Tensor | None = None, # [num_reqs, vocab_size] ) -> torch.Tensor: num_tokens, vocab_size = logits.shape BLOCK_SIZE = 1024 num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) local_argmax = logits.new_empty(num_tokens, num_blocks, dtype=torch.int64) local_max = logits.new_empty(num_tokens, num_blocks, dtype=torch.float64) _gumbel_sample_kernel[(num_tokens, num_blocks)]( local_argmax, local_argmax.stride(0), local_max, local_max.stride(0), processed_logits_out, processed_logits_out.stride(0) if processed_logits_out is not None else 0, logits, logits.stride(0), expanded_idx_mapping, seed, pos, temperature, vocab_size, BLOCK_SIZE=BLOCK_SIZE, APPLY_TEMPERATURE=apply_temperature, ) # NOTE(woosuk): Use int64 for later indexing. max_block_idx = local_max.argmax(dim=-1, keepdim=True) sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1) return sampled