Unverified Commit 9ca768c7 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Minor cleanup for Sampler (#34563)


Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
parent d5fe3f70
...@@ -7,12 +7,10 @@ import torch ...@@ -7,12 +7,10 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.config.model import LogprobsMode from vllm.config.model import LogprobsMode
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.metrics.logits import get_num_nans from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.gumbel import apply_temperature, gumbel_sample from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.penalties import PenaltiesState from vllm.v1.worker.gpu.sample.penalties import PenaltiesState
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates
...@@ -127,20 +125,15 @@ class Sampler: ...@@ -127,20 +125,15 @@ class Sampler:
) )
# Apply temperature in place. # Apply temperature in place.
apply_temperature(logits, idx_mapping, self.sampling_states.temperature.gpu) self.sampling_states.apply_temperature(logits, idx_mapping, idx_mapping_np)
# Apply min_p in place if any request has a non-zero min_p. # Apply min_p in place.
do_min_p = self.sampling_states.do_min_p(idx_mapping_np) self.sampling_states.apply_min_p(logits, idx_mapping, idx_mapping_np)
if do_min_p:
apply_min_p(logits, idx_mapping, self.sampling_states.min_p.gpu) # Apply top_k and/or top_p. This might or might not return a new tensor.
logits = self.sampling_states.apply_top_k_top_p(
# Apply top_k and/or top_p. This might return a new tensor. logits, idx_mapping, idx_mapping_np
do_top_k = self.sampling_states.do_top_k(idx_mapping_np) )
top_k = self.sampling_states.top_k.gpu[idx_mapping] if do_top_k else None
do_top_p = self.sampling_states.do_top_p(idx_mapping_np)
top_p = self.sampling_states.top_p.gpu[idx_mapping] if do_top_p else None
if do_top_k or do_top_p:
logits = apply_top_k_top_p(logits, top_k, top_p)
# Sample the next token. # Sample the next token.
sampled = gumbel_sample( sampled = gumbel_sample(
......
...@@ -4,7 +4,10 @@ import numpy as np ...@@ -4,7 +4,10 @@ import numpy as np
import torch import torch
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor
from vllm.v1.worker.gpu.sample.gumbel import apply_temperature
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
NO_LOGPROBS = -1 NO_LOGPROBS = -1
_NP_INT64_MIN = np.iinfo(np.int64).min _NP_INT64_MIN = np.iinfo(np.int64).min
...@@ -58,14 +61,44 @@ class SamplingStates: ...@@ -58,14 +61,44 @@ class SamplingStates:
self.min_p.copy_to_uva() self.min_p.copy_to_uva()
self.seeds.copy_to_uva() self.seeds.copy_to_uva()
def do_min_p(self, idx_mapping_np: np.ndarray) -> bool: def apply_temperature(
return np.any(self.min_p.np[idx_mapping_np] != 0.0) self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> None:
temp_np = self.temperature.np[idx_mapping_np]
if np.all((temp_np == 0.0) | (temp_np == 1.0)):
# No request requires temperature. Skip the kernel launch.
return
def do_top_k(self, idx_mapping_np: np.ndarray) -> bool: apply_temperature(logits, idx_mapping, self.temperature.gpu)
return np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
def do_top_p(self, idx_mapping_np: np.ndarray) -> bool: def apply_min_p(
return np.any(self.top_p.np[idx_mapping_np] != 1.0) self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> None:
if np.all(self.min_p.np[idx_mapping_np] == 0.0):
# No request uses min_p. Skip the kernel launch.
return
apply_min_p(logits, idx_mapping, self.min_p.gpu)
def apply_top_k_top_p(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> torch.Tensor:
do_top_k = np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
do_top_p = np.any(self.top_p.np[idx_mapping_np] != 1.0)
if not (do_top_k or do_top_p):
return logits
top_k = self.top_k.gpu[idx_mapping] if do_top_k else None
top_p = self.top_p.gpu[idx_mapping] if do_top_p else None
return apply_top_k_top_p(logits, top_k, top_p)
def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int: def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int:
return int(np.max(self.num_logprobs[idx_mapping_np])) return int(np.max(self.num_logprobs[idx_mapping_np]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment