Unverified Commit e9c83cdc authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Skip kernel launch for penalties & logit_bias (#32634)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent b75e85de
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch import torch
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
...@@ -49,12 +50,18 @@ class LogitBiasState: ...@@ -49,12 +50,18 @@ class LogitBiasState:
device=device, device=device,
) )
# Using any of the above.
self.use_logit_bias = np.zeros(max_num_reqs, dtype=bool)
def add_request( def add_request(
self, self,
req_idx: int, req_idx: int,
prompt_len: int, prompt_len: int,
sampling_params: SamplingParams, sampling_params: SamplingParams,
) -> None: ) -> None:
# Using any logit bias.
use_logit_bias = False
# Allowed token IDs. # Allowed token IDs.
allowed_token_ids = sampling_params.allowed_token_ids allowed_token_ids = sampling_params.allowed_token_ids
if allowed_token_ids: if allowed_token_ids:
...@@ -66,6 +73,7 @@ class LogitBiasState: ...@@ -66,6 +73,7 @@ class LogitBiasState:
) )
self.num_allowed_token_ids.np[req_idx] = num_allowed_token_ids self.num_allowed_token_ids.np[req_idx] = num_allowed_token_ids
self.allowed_token_ids.stage_write(req_idx, 0, allowed_token_ids) self.allowed_token_ids.stage_write(req_idx, 0, allowed_token_ids)
use_logit_bias = True
else: else:
self.num_allowed_token_ids.np[req_idx] = 0 self.num_allowed_token_ids.np[req_idx] = 0
...@@ -81,6 +89,7 @@ class LogitBiasState: ...@@ -81,6 +89,7 @@ class LogitBiasState:
self.num_logit_bias.np[req_idx] = num_logit_bias self.num_logit_bias.np[req_idx] = num_logit_bias
self.logit_bias_token_ids.stage_write(req_idx, 0, logit_bias.keys()) self.logit_bias_token_ids.stage_write(req_idx, 0, logit_bias.keys())
self.logit_bias.stage_write(req_idx, 0, logit_bias.values()) self.logit_bias.stage_write(req_idx, 0, logit_bias.values())
use_logit_bias = True
else: else:
self.num_logit_bias.np[req_idx] = 0 self.num_logit_bias.np[req_idx] = 0
...@@ -89,7 +98,7 @@ class LogitBiasState: ...@@ -89,7 +98,7 @@ class LogitBiasState:
min_len = prompt_len + min_tokens min_len = prompt_len + min_tokens
self.min_lens.np[req_idx] = min_len self.min_lens.np[req_idx] = min_len
stop_token_ids = sampling_params.all_stop_token_ids stop_token_ids = sampling_params.all_stop_token_ids
if stop_token_ids: if min_tokens > 0 and stop_token_ids:
num_stop_token_ids = len(stop_token_ids) num_stop_token_ids = len(stop_token_ids)
if num_stop_token_ids > MAX_NUM_STOP_TOKEN_IDS: if num_stop_token_ids > MAX_NUM_STOP_TOKEN_IDS:
raise ValueError( raise ValueError(
...@@ -98,9 +107,12 @@ class LogitBiasState: ...@@ -98,9 +107,12 @@ class LogitBiasState:
) )
self.num_stop_token_ids.np[req_idx] = num_stop_token_ids self.num_stop_token_ids.np[req_idx] = num_stop_token_ids
self.stop_token_ids.stage_write(req_idx, 0, stop_token_ids) self.stop_token_ids.stage_write(req_idx, 0, stop_token_ids)
use_logit_bias = True
else: else:
self.num_stop_token_ids.np[req_idx] = 0 self.num_stop_token_ids.np[req_idx] = 0
self.use_logit_bias[req_idx] = use_logit_bias
def apply_staged_writes(self) -> None: def apply_staged_writes(self) -> None:
self.num_allowed_token_ids.copy_to_uva() self.num_allowed_token_ids.copy_to_uva()
self.allowed_token_ids.apply_write() self.allowed_token_ids.apply_write()
...@@ -117,8 +129,13 @@ class LogitBiasState: ...@@ -117,8 +129,13 @@ class LogitBiasState:
self, self,
logits: torch.Tensor, logits: torch.Tensor,
idx_mapping: torch.Tensor, idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor, pos: torch.Tensor,
) -> None: ) -> None:
if not np.any(self.use_logit_bias[idx_mapping_np]):
# No request uses logit bias. Skip the kernel launch.
return
apply_logit_bias( apply_logit_bias(
logits, logits,
idx_mapping, idx_mapping,
......
...@@ -18,6 +18,7 @@ class PenaltiesState: ...@@ -18,6 +18,7 @@ class PenaltiesState:
self.repetition_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32) self.repetition_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.frequency_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32) self.frequency_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.presence_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32) self.presence_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.use_penalty = np.zeros(max_num_reqs, dtype=bool)
# Initialize repetition penalty manually because 0 is an invalid value for it. # Initialize repetition penalty manually because 0 is an invalid value for it.
self.repetition_penalty.np.fill(1.0) self.repetition_penalty.np.fill(1.0)
...@@ -42,7 +43,10 @@ class PenaltiesState: ...@@ -42,7 +43,10 @@ class PenaltiesState:
self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty
self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty
self.presence_penalty.np[req_idx] = sampling_params.presence_penalty self.presence_penalty.np[req_idx] = sampling_params.presence_penalty
if use_penalty(sampling_params):
do_penalty = use_penalty(sampling_params)
self.use_penalty[req_idx] = do_penalty
if do_penalty:
self._penalties_reqs.append(req_idx) self._penalties_reqs.append(req_idx)
def apply_staged_writes( def apply_staged_writes(
...@@ -66,7 +70,16 @@ class PenaltiesState: ...@@ -66,7 +70,16 @@ class PenaltiesState:
self.frequency_penalty.copy_to_uva() self.frequency_penalty.copy_to_uva()
self.presence_penalty.copy_to_uva() self.presence_penalty.copy_to_uva()
def apply_penalties(self, logits: torch.Tensor, idx_mapping: torch.Tensor) -> None: def apply_penalties(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> None:
if not np.any(self.use_penalty[idx_mapping_np]):
# No request uses penalties. Skip the kernel launch.
return
apply_penalties( apply_penalties(
logits, logits,
idx_mapping, idx_mapping,
......
...@@ -104,10 +104,10 @@ class Sampler: ...@@ -104,10 +104,10 @@ class Sampler:
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits) logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place. # Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, pos) self.logit_bias_state.apply_logit_bias(logits, idx_mapping, idx_mapping_np, pos)
# Apply penalties in place. # Apply penalties in place.
self.penalties_state.apply_penalties(logits, idx_mapping) self.penalties_state.apply_penalties(logits, idx_mapping, idx_mapping_np)
# Apply temperature in place. # Apply temperature in place.
apply_temperature(logits, idx_mapping, self.sampling_states.temperature.gpu) apply_temperature(logits, idx_mapping, self.sampling_states.temperature.gpu)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment