Unverified Commit 0a7165fd authored by Andy Lo's avatar Andy Lo Committed by GitHub
Browse files

[ModelRunnerV2] Rename sampler functions and variables for clarity (#35459)


Signed-off-by: default avatarAndy Lo <andy@mistral.ai>
parent 6521ccf2
...@@ -72,7 +72,7 @@ class BadWordsState: ...@@ -72,7 +72,7 @@ class BadWordsState:
def apply_bad_words( def apply_bad_words(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
idx_mapping: torch.Tensor, expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray, idx_mapping_np: np.ndarray,
input_ids: torch.Tensor, input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor, expanded_local_pos: torch.Tensor,
...@@ -84,7 +84,7 @@ class BadWordsState: ...@@ -84,7 +84,7 @@ class BadWordsState:
apply_bad_words( apply_bad_words(
logits, logits,
idx_mapping, expanded_idx_mapping,
self.bad_word_token_ids.gpu, self.bad_word_token_ids.gpu,
self.bad_word_offsets.gpu, self.bad_word_offsets.gpu,
self.num_bad_words.gpu, self.num_bad_words.gpu,
...@@ -114,17 +114,17 @@ def _bad_words_kernel( ...@@ -114,17 +114,17 @@ def _bad_words_kernel(
input_ids_ptr, input_ids_ptr,
expanded_local_pos_ptr, expanded_local_pos_ptr,
): ):
logit_idx = tl.program_id(0) token_idx = tl.program_id(0)
bw_idx = tl.program_id(1) bw_idx = tl.program_id(1)
req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx) req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
num_bad_words = tl.load(num_bad_words_ptr + req_state_idx) num_bad_words = tl.load(num_bad_words_ptr + req_state_idx)
if bw_idx >= num_bad_words: if bw_idx >= num_bad_words:
return return
pos = tl.load(expanded_local_pos_ptr + logit_idx) pos = tl.load(expanded_local_pos_ptr + token_idx)
cur_req_first_pos = logit_idx - pos cur_req_first_pos = token_idx - pos
prompt_len = tl.load(prompt_len_ptr + req_state_idx) prompt_len = tl.load(prompt_len_ptr + req_state_idx)
total_len = tl.load(total_len_ptr + req_state_idx) total_len = tl.load(total_len_ptr + req_state_idx)
...@@ -159,7 +159,7 @@ def _bad_words_kernel( ...@@ -159,7 +159,7 @@ def _bad_words_kernel(
match = match & (expected == actual) match = match & (expected == actual)
if match: if match:
tl.store(logits_ptr + logit_idx * logits_stride + last_token, -float("inf")) tl.store(logits_ptr + token_idx * logits_stride + last_token, -float("inf"))
def apply_bad_words( def apply_bad_words(
...@@ -175,8 +175,8 @@ def apply_bad_words( ...@@ -175,8 +175,8 @@ def apply_bad_words(
expanded_local_pos: torch.Tensor, expanded_local_pos: torch.Tensor,
max_num_bad_words: int, max_num_bad_words: int,
) -> None: ) -> None:
total_num_tokens = logits.shape[0] num_tokens = logits.shape[0]
_bad_words_kernel[(total_num_tokens, max_num_bad_words)]( _bad_words_kernel[(num_tokens, max_num_bad_words)](
logits, logits,
logits.stride(0), logits.stride(0),
expanded_idx_mapping, expanded_idx_mapping,
......
...@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton ...@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton
def _temperature_kernel( def _temperature_kernel(
logits_ptr, logits_ptr,
logits_stride, logits_stride,
idx_mapping_ptr, expanded_idx_mapping_ptr,
temperature_ptr, temperature_ptr,
vocab_size, vocab_size,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
batch_idx = tl.program_id(0) token_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx) req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
temperature = tl.load(temperature_ptr + req_state_idx).to(tl.float32) temperature = tl.load(temperature_ptr + req_state_idx).to(tl.float32)
if temperature == 0.0 or temperature == 1.0: if temperature == 0.0 or temperature == 1.0:
# Early return to avoid loading logits. # Early return to avoid loading logits.
...@@ -25,24 +25,24 @@ def _temperature_kernel( ...@@ -25,24 +25,24 @@ def _temperature_kernel(
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size mask = block < vocab_size
logits = tl.load(logits_ptr + batch_idx * logits_stride + block, mask=mask) logits = tl.load(logits_ptr + token_idx * logits_stride + block, mask=mask)
logits = logits.to(tl.float32) logits = logits.to(tl.float32)
logits = logits / temperature logits = logits / temperature
tl.store(logits_ptr + batch_idx * logits_stride + block, logits, mask=mask) tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
def apply_temperature( def apply_temperature(
logits: torch.Tensor, logits: torch.Tensor,
idx_mapping: torch.Tensor, expanded_idx_mapping: torch.Tensor,
temperature: torch.Tensor, temperature: torch.Tensor,
) -> None: ) -> None:
num_reqs, vocab_size = logits.shape num_tokens, vocab_size = logits.shape
BLOCK_SIZE = 8192 BLOCK_SIZE = 8192
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
_temperature_kernel[(num_reqs, num_blocks)]( _temperature_kernel[(num_tokens, num_blocks)](
logits, logits,
logits.stride(0), logits.stride(0),
idx_mapping, expanded_idx_mapping,
temperature, temperature,
vocab_size, vocab_size,
BLOCK_SIZE=BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE,
...@@ -57,7 +57,7 @@ def _gumbel_sample_kernel( ...@@ -57,7 +57,7 @@ def _gumbel_sample_kernel(
local_max_stride, local_max_stride,
logits_ptr, logits_ptr,
logits_stride, logits_stride,
idx_mapping_ptr, expanded_idx_mapping_ptr,
seeds_ptr, seeds_ptr,
pos_ptr, pos_ptr,
temp_ptr, temp_ptr,
...@@ -65,14 +65,14 @@ def _gumbel_sample_kernel( ...@@ -65,14 +65,14 @@ def _gumbel_sample_kernel(
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
APPLY_TEMPERATURE: tl.constexpr, APPLY_TEMPERATURE: tl.constexpr,
): ):
batch_idx = tl.program_id(0) token_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx) req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
block_idx = tl.program_id(1) block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size mask = block < vocab_size
logits = tl.load( logits = tl.load(
logits_ptr + batch_idx * logits_stride + block, logits_ptr + token_idx * logits_stride + block,
mask=mask, mask=mask,
other=float("-inf"), other=float("-inf"),
) )
...@@ -82,7 +82,7 @@ def _gumbel_sample_kernel( ...@@ -82,7 +82,7 @@ def _gumbel_sample_kernel(
if temp != 0.0: if temp != 0.0:
# Calculate the seed for gumbel noise. # Calculate the seed for gumbel noise.
seed = tl.load(seeds_ptr + req_state_idx) seed = tl.load(seeds_ptr + req_state_idx)
pos = tl.load(pos_ptr + batch_idx) pos = tl.load(pos_ptr + token_idx)
gumbel_seed = tl.randint(seed, pos) gumbel_seed = tl.randint(seed, pos)
# Generate gumbel noise in FP32. # Generate gumbel noise in FP32.
...@@ -101,41 +101,41 @@ def _gumbel_sample_kernel( ...@@ -101,41 +101,41 @@ def _gumbel_sample_kernel(
value, idx = tl.max(logits, axis=0, return_indices=True) value, idx = tl.max(logits, axis=0, return_indices=True)
token_id = block_idx * BLOCK_SIZE + idx token_id = block_idx * BLOCK_SIZE + idx
tl.store(local_argmax_ptr + batch_idx * local_argmax_stride + block_idx, token_id) tl.store(local_argmax_ptr + token_idx * local_argmax_stride + block_idx, token_id)
tl.store(local_max_ptr + batch_idx * local_max_stride + block_idx, value) tl.store(local_max_ptr + token_idx * local_max_stride + block_idx, value)
def gumbel_sample( def gumbel_sample(
logits: torch.Tensor, # [num_reqs, vocab_size] logits: torch.Tensor, # [num_tokens, vocab_size]
idx_mapping: torch.Tensor, # [max_num_reqs] expanded_idx_mapping: torch.Tensor, # [num_tokens]
temperature: torch.Tensor, # [max_num_reqs] temperature: torch.Tensor, # [max_num_reqs]
seed: torch.Tensor, # [max_num_reqs] seed: torch.Tensor, # [max_num_reqs]
pos: torch.Tensor, # [num_reqs] pos: torch.Tensor, # [num_tokens]
apply_temperature: bool, apply_temperature: bool,
) -> torch.Tensor: ) -> torch.Tensor:
num_reqs, vocab_size = logits.shape num_tokens, vocab_size = logits.shape
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
local_argmax = torch.empty( local_argmax = torch.empty(
num_reqs, num_tokens,
num_blocks, num_blocks,
dtype=torch.int64, dtype=torch.int64,
device=logits.device, device=logits.device,
) )
local_max = torch.empty( local_max = torch.empty(
num_reqs, num_tokens,
num_blocks, num_blocks,
dtype=torch.float32, dtype=torch.float32,
device=logits.device, device=logits.device,
) )
_gumbel_sample_kernel[(num_reqs, num_blocks)]( _gumbel_sample_kernel[(num_tokens, num_blocks)](
local_argmax, local_argmax,
local_argmax.stride(0), local_argmax.stride(0),
local_max, local_max,
local_max.stride(0), local_max.stride(0),
logits, logits,
logits.stride(0), logits.stride(0),
idx_mapping, expanded_idx_mapping,
seed, seed,
pos, pos,
temperature, temperature,
......
...@@ -121,7 +121,7 @@ class LogitBiasState: ...@@ -121,7 +121,7 @@ class LogitBiasState:
def apply_logit_bias( def apply_logit_bias(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
idx_mapping: torch.Tensor, expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray, idx_mapping_np: np.ndarray,
pos: torch.Tensor, pos: torch.Tensor,
) -> None: ) -> None:
...@@ -131,7 +131,7 @@ class LogitBiasState: ...@@ -131,7 +131,7 @@ class LogitBiasState:
apply_logit_bias( apply_logit_bias(
logits, logits,
idx_mapping, expanded_idx_mapping,
pos, pos,
self.num_allowed_token_ids.gpu, self.num_allowed_token_ids.gpu,
self.allowed_token_ids.gpu, self.allowed_token_ids.gpu,
...@@ -149,7 +149,7 @@ def _bias_kernel( ...@@ -149,7 +149,7 @@ def _bias_kernel(
logits_ptr, logits_ptr,
logits_stride, logits_stride,
vocab_size, vocab_size,
idx_mapping_ptr, expanded_idx_mapping_ptr,
# Allowed token IDs. # Allowed token IDs.
num_allowed_token_ids_ptr, num_allowed_token_ids_ptr,
allowed_token_ids_ptr, allowed_token_ids_ptr,
...@@ -169,8 +169,8 @@ def _bias_kernel( ...@@ -169,8 +169,8 @@ def _bias_kernel(
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
LOGITS_BLOCK_SIZE: tl.constexpr, LOGITS_BLOCK_SIZE: tl.constexpr,
): ):
batch_idx = tl.program_id(0) token_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx) req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
block = tl.arange(0, BLOCK_SIZE) block = tl.arange(0, BLOCK_SIZE)
...@@ -186,21 +186,21 @@ def _bias_kernel( ...@@ -186,21 +186,21 @@ def _bias_kernel(
mask=mask, mask=mask,
) )
logits = tl.load( logits = tl.load(
logits_ptr + batch_idx * logits_stride + allowed_token_ids, mask=mask logits_ptr + token_idx * logits_stride + allowed_token_ids, mask=mask
) )
# Set logits to -inf for all tokens. # Set logits to -inf for all tokens.
for i in range(0, vocab_size, LOGITS_BLOCK_SIZE): for i in range(0, vocab_size, LOGITS_BLOCK_SIZE):
offset = i + tl.arange(0, LOGITS_BLOCK_SIZE) offset = i + tl.arange(0, LOGITS_BLOCK_SIZE)
tl.store( tl.store(
logits_ptr + batch_idx * logits_stride + offset, logits_ptr + token_idx * logits_stride + offset,
-float("inf"), -float("inf"),
mask=offset < vocab_size, mask=offset < vocab_size,
) )
# Restore logits for allowed token IDs. # Restore logits for allowed token IDs.
tl.store( tl.store(
logits_ptr + batch_idx * logits_stride + allowed_token_ids, logits_ptr + token_idx * logits_stride + allowed_token_ids,
logits, logits,
mask=mask, mask=mask,
) )
...@@ -214,13 +214,13 @@ def _bias_kernel( ...@@ -214,13 +214,13 @@ def _bias_kernel(
mask=mask, mask=mask,
) )
bias = tl.load(bias_ptr + req_state_idx * bias_stride + block, mask=mask) bias = tl.load(bias_ptr + req_state_idx * bias_stride + block, mask=mask)
logits = tl.load(logits_ptr + batch_idx * logits_stride + token_ids, mask=mask) logits = tl.load(logits_ptr + token_idx * logits_stride + token_ids, mask=mask)
logits += bias logits += bias
tl.store(logits_ptr + batch_idx * logits_stride + token_ids, logits, mask=mask) tl.store(logits_ptr + token_idx * logits_stride + token_ids, logits, mask=mask)
# Apply min tokens. # Apply min tokens.
num_stop_token_ids = tl.load(num_stop_token_ids_ptr + req_state_idx) num_stop_token_ids = tl.load(num_stop_token_ids_ptr + req_state_idx)
pos = tl.load(pos_ptr + batch_idx) pos = tl.load(pos_ptr + token_idx)
min_len = tl.load(min_lens_ptr + req_state_idx) min_len = tl.load(min_lens_ptr + req_state_idx)
if num_stop_token_ids > 0 and pos < min_len: if num_stop_token_ids > 0 and pos < min_len:
mask = block < num_stop_token_ids mask = block < num_stop_token_ids
...@@ -229,7 +229,7 @@ def _bias_kernel( ...@@ -229,7 +229,7 @@ def _bias_kernel(
mask=mask, mask=mask,
) )
tl.store( tl.store(
logits_ptr + batch_idx * logits_stride + stop_token_ids, logits_ptr + token_idx * logits_stride + stop_token_ids,
-float("inf"), -float("inf"),
mask=mask, mask=mask,
) )
...@@ -237,7 +237,7 @@ def _bias_kernel( ...@@ -237,7 +237,7 @@ def _bias_kernel(
def apply_logit_bias( def apply_logit_bias(
logits: torch.Tensor, logits: torch.Tensor,
idx_mapping: torch.Tensor, expanded_idx_mapping: torch.Tensor,
pos: torch.Tensor, pos: torch.Tensor,
num_allowed_token_ids: torch.Tensor, num_allowed_token_ids: torch.Tensor,
allowed_token_ids: torch.Tensor, allowed_token_ids: torch.Tensor,
...@@ -248,7 +248,7 @@ def apply_logit_bias( ...@@ -248,7 +248,7 @@ def apply_logit_bias(
num_stop_token_ids: torch.Tensor, num_stop_token_ids: torch.Tensor,
stop_token_ids: torch.Tensor, stop_token_ids: torch.Tensor,
) -> None: ) -> None:
num_reqs, vocab_size = logits.shape num_tokens, vocab_size = logits.shape
BLOCK_SIZE = triton.next_power_of_2( BLOCK_SIZE = triton.next_power_of_2(
max( max(
allowed_token_ids.shape[-1], allowed_token_ids.shape[-1],
...@@ -257,11 +257,11 @@ def apply_logit_bias( ...@@ -257,11 +257,11 @@ def apply_logit_bias(
) )
) )
LOGITS_BLOCK_SIZE = 8192 LOGITS_BLOCK_SIZE = 8192
_bias_kernel[(num_reqs,)]( _bias_kernel[(num_tokens,)](
logits, logits,
logits.stride(0), logits.stride(0),
vocab_size, vocab_size,
idx_mapping, expanded_idx_mapping,
num_allowed_token_ids, num_allowed_token_ids,
allowed_token_ids, allowed_token_ids,
allowed_token_ids.stride(0), allowed_token_ids.stride(0),
......
...@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton ...@@ -9,13 +9,13 @@ from vllm.triton_utils import tl, triton
def _min_p_kernel( def _min_p_kernel(
logits_ptr, logits_ptr,
logits_stride, logits_stride,
idx_mapping_ptr, expanded_idx_mapping_ptr,
min_p_ptr, min_p_ptr,
vocab_size, vocab_size,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
req_idx = tl.program_id(0) token_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + req_idx) req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
min_p = tl.load(min_p_ptr + req_state_idx).to(tl.float32) min_p = tl.load(min_p_ptr + req_state_idx).to(tl.float32)
if min_p == 0.0: if min_p == 0.0:
return return
...@@ -25,7 +25,9 @@ def _min_p_kernel( ...@@ -25,7 +25,9 @@ def _min_p_kernel(
block = i + tl.arange(0, BLOCK_SIZE) block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size mask = block < vocab_size
logits = tl.load( logits = tl.load(
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf") logits_ptr + token_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
) )
max_val = tl.max(tl.maximum(logits, max_val)) max_val = tl.max(tl.maximum(logits, max_val))
max_val = max_val.to(tl.float32) # type: ignore max_val = max_val.to(tl.float32) # type: ignore
...@@ -35,21 +37,23 @@ def _min_p_kernel( ...@@ -35,21 +37,23 @@ def _min_p_kernel(
block = i + tl.arange(0, BLOCK_SIZE) block = i + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size mask = block < vocab_size
logits = tl.load( logits = tl.load(
logits_ptr + req_idx * logits_stride + block, mask=mask, other=float("-inf") logits_ptr + token_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
) )
logits = tl.where(logits < threshold, float("-inf"), logits) logits = tl.where(logits < threshold, float("-inf"), logits)
tl.store(logits_ptr + req_idx * logits_stride + block, logits, mask=mask) tl.store(logits_ptr + token_idx * logits_stride + block, logits, mask=mask)
def apply_min_p( def apply_min_p(
logits: torch.Tensor, idx_mapping: torch.Tensor, min_p: torch.Tensor logits: torch.Tensor, expanded_idx_mapping: torch.Tensor, min_p: torch.Tensor
) -> None: ) -> None:
num_reqs, vocab_size = logits.shape num_tokens, vocab_size = logits.shape
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
_min_p_kernel[(num_reqs,)]( _min_p_kernel[(num_tokens,)](
logits, logits,
logits.stride(0), logits.stride(0),
idx_mapping, expanded_idx_mapping,
min_p, min_p,
vocab_size, vocab_size,
BLOCK_SIZE=BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE,
......
...@@ -82,7 +82,7 @@ class PenaltiesState: ...@@ -82,7 +82,7 @@ class PenaltiesState:
def apply_penalties( def apply_penalties(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
idx_mapping: torch.Tensor, expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray, idx_mapping_np: np.ndarray,
input_ids: torch.Tensor, input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor, expanded_local_pos: torch.Tensor,
...@@ -94,7 +94,7 @@ class PenaltiesState: ...@@ -94,7 +94,7 @@ class PenaltiesState:
apply_penalties( apply_penalties(
logits, logits,
idx_mapping, expanded_idx_mapping,
input_ids, input_ids,
expanded_local_pos, expanded_local_pos,
self.repetition_penalty.gpu, self.repetition_penalty.gpu,
...@@ -110,7 +110,7 @@ class PenaltiesState: ...@@ -110,7 +110,7 @@ class PenaltiesState:
def _penalties_kernel( def _penalties_kernel(
logits_ptr, logits_ptr,
logits_stride, logits_stride,
idx_mapping_ptr, expanded_idx_mapping_ptr,
token_ids_ptr, token_ids_ptr,
expanded_local_pos_ptr, expanded_local_pos_ptr,
repetition_penalty_ptr, repetition_penalty_ptr,
...@@ -125,7 +125,7 @@ def _penalties_kernel( ...@@ -125,7 +125,7 @@ def _penalties_kernel(
MAX_SPEC_LEN: tl.constexpr, MAX_SPEC_LEN: tl.constexpr,
): ):
token_idx = tl.program_id(0) token_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + token_idx) req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx) rep_penalty = tl.load(repetition_penalty_ptr + req_state_idx)
freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx) freq_penalty = tl.load(frequency_penalty_ptr + req_state_idx)
pres_penalty = tl.load(presence_penalty_ptr + req_state_idx) pres_penalty = tl.load(presence_penalty_ptr + req_state_idx)
...@@ -191,7 +191,7 @@ def _penalties_kernel( ...@@ -191,7 +191,7 @@ def _penalties_kernel(
def apply_penalties( def apply_penalties(
logits: torch.Tensor, logits: torch.Tensor,
idx_mapping: torch.Tensor, expanded_idx_mapping: torch.Tensor,
token_ids: torch.Tensor, token_ids: torch.Tensor,
expanded_local_pos: torch.Tensor, expanded_local_pos: torch.Tensor,
repetition_penalty: torch.Tensor, repetition_penalty: torch.Tensor,
...@@ -207,7 +207,7 @@ def apply_penalties( ...@@ -207,7 +207,7 @@ def apply_penalties(
_penalties_kernel[(num_tokens, num_blocks)]( _penalties_kernel[(num_tokens, num_blocks)](
logits, logits,
logits.stride(0), logits.stride(0),
idx_mapping, expanded_idx_mapping,
token_ids, token_ids,
expanded_local_pos, expanded_local_pos,
repetition_penalty, repetition_penalty,
...@@ -225,7 +225,7 @@ def apply_penalties( ...@@ -225,7 +225,7 @@ def apply_penalties(
@triton.jit @triton.jit
def _bincount_kernel( def _bincount_kernel(
idx_mapping_ptr, expanded_idx_mapping_ptr,
all_token_ids_ptr, all_token_ids_ptr,
all_token_ids_stride, all_token_ids_stride,
prompt_len_ptr, prompt_len_ptr,
...@@ -236,9 +236,9 @@ def _bincount_kernel( ...@@ -236,9 +236,9 @@ def _bincount_kernel(
output_bin_counts_stride, output_bin_counts_stride,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
batch_idx = tl.program_id(0) token_idx = tl.program_id(0)
block_idx = tl.program_id(1) block_idx = tl.program_id(1)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx) req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
prefill_len = tl.load(prefill_len_ptr + req_state_idx) prefill_len = tl.load(prefill_len_ptr + req_state_idx)
if block_idx * BLOCK_SIZE >= prefill_len: if block_idx * BLOCK_SIZE >= prefill_len:
...@@ -276,7 +276,7 @@ def _bincount_kernel( ...@@ -276,7 +276,7 @@ def _bincount_kernel(
def bincount( def bincount(
idx_mapping: torch.Tensor, expanded_idx_mapping: torch.Tensor,
all_token_ids: torch.Tensor, all_token_ids: torch.Tensor,
prompt_len: torch.Tensor, prompt_len: torch.Tensor,
prefill_len: torch.Tensor, prefill_len: torch.Tensor,
...@@ -284,13 +284,13 @@ def bincount( ...@@ -284,13 +284,13 @@ def bincount(
output_bin_counts: torch.Tensor, output_bin_counts: torch.Tensor,
max_prefill_len: int, max_prefill_len: int,
) -> None: ) -> None:
prompt_bin_mask[idx_mapping] = 0 prompt_bin_mask[expanded_idx_mapping] = 0
output_bin_counts[idx_mapping] = 0 output_bin_counts[expanded_idx_mapping] = 0
num_reqs = idx_mapping.shape[0] num_tokens = expanded_idx_mapping.shape[0]
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
num_blocks = triton.cdiv(max_prefill_len, BLOCK_SIZE) num_blocks = triton.cdiv(max_prefill_len, BLOCK_SIZE)
_bincount_kernel[(num_reqs, num_blocks)]( _bincount_kernel[(num_tokens, num_blocks)](
idx_mapping, expanded_idx_mapping,
all_token_ids, all_token_ids,
all_token_ids.stride(0), all_token_ids.stride(0),
prompt_len, prompt_len,
......
...@@ -56,7 +56,7 @@ class Sampler: ...@@ -56,7 +56,7 @@ class Sampler:
def __call__( def __call__(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
idx_mapping: torch.Tensor, expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray, idx_mapping_np: np.ndarray,
cu_num_logits_np: np.ndarray, cu_num_logits_np: np.ndarray,
pos: torch.Tensor, pos: torch.Tensor,
...@@ -68,7 +68,7 @@ class Sampler: ...@@ -68,7 +68,7 @@ class Sampler:
num_nans = get_num_nans(logits) if self.compute_nans else None num_nans = get_num_nans(logits) if self.compute_nans else None
sampled, processed_logits = self.sample( sampled, processed_logits = self.sample(
logits, logits,
idx_mapping, expanded_idx_mapping,
idx_mapping_np, idx_mapping_np,
pos, pos,
input_ids, input_ids,
...@@ -101,7 +101,7 @@ class Sampler: ...@@ -101,7 +101,7 @@ class Sampler:
def sample( def sample(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
idx_mapping: torch.Tensor, expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray, idx_mapping_np: np.ndarray,
pos: torch.Tensor, pos: torch.Tensor,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -111,12 +111,14 @@ class Sampler: ...@@ -111,12 +111,14 @@ class Sampler:
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits) logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place. # Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, idx_mapping_np, pos) self.logit_bias_state.apply_logit_bias(
logits, expanded_idx_mapping, idx_mapping_np, pos
)
# Apply penalties in place. # Apply penalties in place.
self.penalties_state.apply_penalties( self.penalties_state.apply_penalties(
logits, logits,
idx_mapping, expanded_idx_mapping,
idx_mapping_np, idx_mapping_np,
input_ids, input_ids,
expanded_local_pos, expanded_local_pos,
...@@ -126,27 +128,29 @@ class Sampler: ...@@ -126,27 +128,29 @@ class Sampler:
# Apply bad words masking in place. # Apply bad words masking in place.
self.bad_words_state.apply_bad_words( self.bad_words_state.apply_bad_words(
logits, logits,
idx_mapping, expanded_idx_mapping,
idx_mapping_np, idx_mapping_np,
input_ids, input_ids,
expanded_local_pos, expanded_local_pos,
) )
# Apply temperature in place. # Apply temperature in place.
self.sampling_states.apply_temperature(logits, idx_mapping, idx_mapping_np) self.sampling_states.apply_temperature(
logits, expanded_idx_mapping, idx_mapping_np
)
# Apply min_p in place. # Apply min_p in place.
self.sampling_states.apply_min_p(logits, idx_mapping, idx_mapping_np) self.sampling_states.apply_min_p(logits, expanded_idx_mapping, idx_mapping_np)
# Apply top_k and/or top_p. This might or might not return a new tensor. # Apply top_k and/or top_p. This might or might not return a new tensor.
logits = self.sampling_states.apply_top_k_top_p( logits = self.sampling_states.apply_top_k_top_p(
logits, idx_mapping, idx_mapping_np logits, expanded_idx_mapping, idx_mapping_np
) )
# Sample the next token. # Sample the next token.
sampled = gumbel_sample( sampled = gumbel_sample(
logits, logits,
idx_mapping, expanded_idx_mapping,
self.sampling_states.temperature.gpu, self.sampling_states.temperature.gpu,
self.sampling_states.seeds.gpu, self.sampling_states.seeds.gpu,
pos, pos,
......
...@@ -64,7 +64,7 @@ class SamplingStates: ...@@ -64,7 +64,7 @@ class SamplingStates:
def apply_temperature( def apply_temperature(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
idx_mapping: torch.Tensor, expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray, idx_mapping_np: np.ndarray,
) -> None: ) -> None:
temp_np = self.temperature.np[idx_mapping_np] temp_np = self.temperature.np[idx_mapping_np]
...@@ -72,23 +72,23 @@ class SamplingStates: ...@@ -72,23 +72,23 @@ class SamplingStates:
# No request requires temperature. Skip the kernel launch. # No request requires temperature. Skip the kernel launch.
return return
apply_temperature(logits, idx_mapping, self.temperature.gpu) apply_temperature(logits, expanded_idx_mapping, self.temperature.gpu)
def apply_min_p( def apply_min_p(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
idx_mapping: torch.Tensor, expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray, idx_mapping_np: np.ndarray,
) -> None: ) -> None:
if np.all(self.min_p.np[idx_mapping_np] == 0.0): if np.all(self.min_p.np[idx_mapping_np] == 0.0):
# No request uses min_p. Skip the kernel launch. # No request uses min_p. Skip the kernel launch.
return return
apply_min_p(logits, idx_mapping, self.min_p.gpu) apply_min_p(logits, expanded_idx_mapping, self.min_p.gpu)
def apply_top_k_top_p( def apply_top_k_top_p(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
idx_mapping: torch.Tensor, expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray, idx_mapping_np: np.ndarray,
) -> torch.Tensor: ) -> torch.Tensor:
do_top_k = np.any(self.top_k.np[idx_mapping_np] != self.vocab_size) do_top_k = np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
...@@ -96,8 +96,8 @@ class SamplingStates: ...@@ -96,8 +96,8 @@ class SamplingStates:
if not (do_top_k or do_top_p): if not (do_top_k or do_top_p):
return logits return logits
top_k = self.top_k.gpu[idx_mapping] if do_top_k else None top_k = self.top_k.gpu[expanded_idx_mapping] if do_top_k else None
top_p = self.top_p.gpu[idx_mapping] if do_top_p else None top_p = self.top_p.gpu[expanded_idx_mapping] if do_top_p else None
return apply_top_k_top_p(logits, top_k, top_p) return apply_top_k_top_p(logits, top_k, top_p)
def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int: def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment