Unverified Commit dec28688 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Minor refactor for logit_bias (#32209)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 9f430c94
......@@ -119,35 +119,18 @@ class LogitBiasState:
idx_mapping: torch.Tensor,
pos: torch.Tensor,
) -> None:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = triton.next_power_of_2(
max(
MAX_NUM_ALLOWED_TOKEN_IDS,
MAX_NUM_LOGIT_BIAS_TOKENS,
MAX_NUM_STOP_TOKEN_IDS,
)
)
LOGITS_BLOCK_SIZE = 8192
_bias_kernel[(num_reqs,)](
apply_logit_bias(
logits,
logits.stride(0),
vocab_size,
idx_mapping,
self.num_allowed_token_ids,
self.allowed_token_ids,
self.allowed_token_ids.gpu.stride(0),
self.num_logit_bias,
self.logit_bias_token_ids,
self.logit_bias_token_ids.gpu.stride(0),
self.logit_bias,
self.logit_bias.gpu.stride(0),
pos,
self.min_lens,
self.num_stop_token_ids,
self.stop_token_ids,
self.stop_token_ids.gpu.stride(0),
BLOCK_SIZE=BLOCK_SIZE,
LOGITS_BLOCK_SIZE=LOGITS_BLOCK_SIZE,
self.num_allowed_token_ids.gpu,
self.allowed_token_ids.gpu,
self.num_logit_bias.gpu,
self.logit_bias_token_ids.gpu,
self.logit_bias.gpu,
self.min_lens.gpu,
self.num_stop_token_ids.gpu,
self.stop_token_ids.gpu,
)
......@@ -240,3 +223,48 @@ def _bias_kernel(
-float("inf"),
mask=mask,
)
def apply_logit_bias(
logits: torch.Tensor,
idx_mapping: torch.Tensor,
pos: torch.Tensor,
num_allowed_token_ids: torch.Tensor,
allowed_token_ids: torch.Tensor,
num_logit_bias: torch.Tensor,
logit_bias_token_ids: torch.Tensor,
logit_bias: torch.Tensor,
min_lens: torch.Tensor,
num_stop_token_ids: torch.Tensor,
stop_token_ids: torch.Tensor,
) -> None:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = triton.next_power_of_2(
max(
allowed_token_ids.shape[-1],
logit_bias_token_ids.shape[-1],
stop_token_ids.shape[-1],
)
)
LOGITS_BLOCK_SIZE = 8192
_bias_kernel[(num_reqs,)](
logits,
logits.stride(0),
vocab_size,
idx_mapping,
num_allowed_token_ids,
allowed_token_ids,
allowed_token_ids.stride(0),
num_logit_bias,
logit_bias_token_ids,
logit_bias_token_ids.stride(0),
logit_bias,
logit_bias.stride(0),
pos,
min_lens,
num_stop_token_ids,
stop_token_ids,
stop_token_ids.stride(0),
BLOCK_SIZE=BLOCK_SIZE,
LOGITS_BLOCK_SIZE=LOGITS_BLOCK_SIZE,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment