Unverified Commit 90c08369 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Refactor Sampler (#32245)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 8ef50d9a
...@@ -49,7 +49,6 @@ from vllm.v1.worker.gpu.input_batch import ( ...@@ -49,7 +49,6 @@ from vllm.v1.worker.gpu.input_batch import (
) )
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
from vllm.v1.worker.gpu.sample.logprob import compute_prompt_logprobs from vllm.v1.worker.gpu.sample.logprob import compute_prompt_logprobs
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.sampler import Sampler from vllm.v1.worker.gpu.sample.sampler import Sampler
from vllm.v1.worker.gpu.spec_decode import init_speculator from vllm.v1.worker.gpu.spec_decode import init_speculator
...@@ -139,7 +138,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -139,7 +138,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=self.dtype, dtype=self.dtype,
device=self.device, device=self.device,
) )
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode) self.sampler = Sampler(
max_num_reqs=self.max_num_reqs,
vocab_size=self.vocab_size,
device=self.device,
logprobs_mode=self.model_config.logprobs_mode,
)
# CUDA graphs. # CUDA graphs.
self.cudagraph_manager = CudaGraphManager( self.cudagraph_manager = CudaGraphManager(
...@@ -310,12 +314,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -310,12 +314,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> None: ) -> None:
num_reqs = hidden_states.shape[0] num_reqs = hidden_states.shape[0]
sampling_metadata = SamplingMetadata.make_dummy(
num_reqs=num_reqs,
device=self.device,
)
logits = self.model.compute_logits(hidden_states) logits = self.model.compute_logits(hidden_states)
self.sampler(logits, sampling_metadata) idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=self.device)
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
pos = torch.zeros(num_reqs, dtype=torch.int64, device=self.device)
# NOTE(woosuk): During the initial memory profiling, the sampler may skip
# top_k, top_p, and logprobs, using less GPU memory than what is possible
# during actual execution.
self.sampler(logits, idx_mapping, idx_mapping_np, pos)
@torch.inference_mode() @torch.inference_mode()
def profile_run(self) -> None: def profile_run(self) -> None:
...@@ -401,9 +407,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -401,9 +407,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert new_req_data.prefill_token_ids is not None assert new_req_data.prefill_token_ids is not None
assert new_req_data.sampling_params is not None assert new_req_data.sampling_params is not None
req_id = new_req_data.req_id req_id = new_req_data.req_id
prompt_len = len(new_req_data.prompt_token_ids)
self.req_states.add_request( self.req_states.add_request(
req_id=req_id, req_id=req_id,
prompt_len=len(new_req_data.prompt_token_ids), prompt_len=prompt_len,
prefill_token_ids=new_req_data.prefill_token_ids, prefill_token_ids=new_req_data.prefill_token_ids,
num_computed_tokens=new_req_data.num_computed_tokens, num_computed_tokens=new_req_data.num_computed_tokens,
sampling_params=new_req_data.sampling_params, sampling_params=new_req_data.sampling_params,
...@@ -423,6 +430,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -423,6 +430,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.block_tables.append_block_ids( self.block_tables.append_block_ids(
req_index, new_req_data.block_ids, overwrite=True req_index, new_req_data.block_ids, overwrite=True
) )
self.sampler.add_request(
req_index, prompt_len, new_req_data.sampling_params
)
# Add new blocks for the existing requests. # Add new blocks for the existing requests.
cached_reqs = scheduler_output.scheduled_cached_reqs cached_reqs = scheduler_output.scheduled_cached_reqs
...@@ -436,6 +446,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -436,6 +446,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.req_states.apply_staged_writes() self.req_states.apply_staged_writes()
self.block_tables.apply_staged_writes() self.block_tables.apply_staged_writes()
self.sampler.apply_staged_writes(
self.req_states.prefill_token_ids.gpu,
self.req_states.prefill_len.np,
self.req_states.prompt_len,
)
if self.uses_mrope: if self.uses_mrope:
self.mrope_states.apply_staged_writes() self.mrope_states.apply_staged_writes()
...@@ -612,10 +627,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -612,10 +627,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_batch: InputBatch, input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
grammar_output: GrammarOutput | None, grammar_output: GrammarOutput | None,
) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]: ) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
sample_hidden_states = hidden_states[input_batch.logits_indices] sample_hidden_states = hidden_states[input_batch.logits_indices]
sample_pos = input_batch.positions[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
if grammar_output is not None: if grammar_output is not None:
# Apply grammar bitmask to the logits in-place. # Apply grammar bitmask to the logits in-place.
...@@ -627,7 +642,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -627,7 +642,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) )
# Sample tokens and compute logprobs (if needed). # Sample tokens and compute logprobs (if needed).
sampler_output = self.sampler(logits, sampling_metadata) sampler_output = self.sampler(
logits,
input_batch.expanded_idx_mapping,
input_batch.idx_mapping_np,
sample_pos,
)
if input_batch.num_draft_tokens == 0: if input_batch.num_draft_tokens == 0:
# No draft tokens (common case). # No draft tokens (common case).
...@@ -766,7 +786,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -766,7 +786,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch.idx_mapping, input_batch.idx_mapping,
self.req_states.num_computed_tokens.gpu, self.req_states.num_computed_tokens.gpu,
self.req_states.last_sampled_tokens, self.req_states.last_sampled_tokens,
self.req_states.output_bin_counts, self.sampler.penalties_state.output_bin_counts,
sampled_tokens, sampled_tokens,
num_sampled, num_sampled,
num_rejected, num_rejected,
...@@ -786,7 +806,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -786,7 +806,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def propose_draft( def propose_draft(
self, self,
input_batch: InputBatch, input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
last_hidden_states: torch.Tensor, last_hidden_states: torch.Tensor,
aux_hidden_states: list[torch.Tensor] | None, aux_hidden_states: list[torch.Tensor] | None,
num_sampled: torch.Tensor, num_sampled: torch.Tensor,
...@@ -801,13 +820,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -801,13 +820,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
] ]
draft_tokens = self.speculator.propose( draft_tokens = self.speculator.propose(
input_batch, input_batch,
sampling_metadata,
last_hidden_states, last_hidden_states,
aux_hidden_states, aux_hidden_states,
num_sampled, num_sampled,
num_rejected, num_rejected,
last_sampled_tokens, last_sampled_tokens,
next_prefill_tokens, next_prefill_tokens,
self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu,
) )
return draft_tokens return draft_tokens
...@@ -893,12 +913,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -893,12 +913,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
scheduler_output, scheduler_output,
num_tokens_after_padding, num_tokens_after_padding,
) )
pos = input_batch.positions[input_batch.logits_indices]
sampling_metadata = self.req_states.make_sampling_metadata(
input_batch.expanded_idx_mapping, input_batch.idx_mapping_np, pos
)
if self.lora_config: if self.lora_config:
# Activate LoRA adapters. # Activate LoRA adapters.
lora_inputs = self.req_states.make_lora_inputs( lora_inputs = self.req_states.make_lora_inputs(
...@@ -917,7 +931,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -917,7 +931,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
device=self.device, device=self.device,
) )
self.prepare_dummy_attn_metadata(input_batch) self.prepare_dummy_attn_metadata(input_batch)
sampling_metadata = None
# Run model. # Run model.
if cudagraph_mode == CUDAGraphMode.FULL: if cudagraph_mode == CUDAGraphMode.FULL:
...@@ -946,7 +959,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -946,7 +959,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
positions=positions, positions=positions,
) )
self.execute_model_state = hidden_states, input_batch, sampling_metadata self.execute_model_state = hidden_states, input_batch
return None return None
@torch.inference_mode() @torch.inference_mode()
...@@ -955,12 +968,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -955,12 +968,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
grammar_output: GrammarOutput | None, grammar_output: GrammarOutput | None,
) -> AsyncOutput | ModelRunnerOutput: ) -> AsyncOutput | ModelRunnerOutput:
assert self.execute_model_state is not None assert self.execute_model_state is not None
hidden_states, input_batch, sampling_metadata = self.execute_model_state hidden_states, input_batch = self.execute_model_state
self.execute_model_state = None # type: ignore self.execute_model_state = None # type: ignore
assert sampling_metadata is not None
sampler_output, num_sampled, num_rejected = self.sample( sampler_output, num_sampled, num_rejected = self.sample(
hidden_states, input_batch, sampling_metadata, grammar_output hidden_states, input_batch, grammar_output
) )
prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch) prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch)
...@@ -992,7 +1004,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -992,7 +1004,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.do_spec_decode: if self.do_spec_decode:
draft_tokens = self.propose_draft( draft_tokens = self.propose_draft(
input_batch, input_batch,
sampling_metadata,
hidden_states, hidden_states,
None, # aux_hidden_states None, # aux_hidden_states
num_sampled, num_sampled,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
@dataclass
class SamplingMetadata:
idx_mapping: torch.Tensor
temperature: torch.Tensor
top_p: torch.Tensor | None
top_k: torch.Tensor | None
min_p: torch.Tensor | None
# For penalties
repetition_penalty: torch.Tensor
frequency_penalty: torch.Tensor
presence_penalty: torch.Tensor
prompt_bin_mask: torch.Tensor
output_bin_counts: torch.Tensor
seeds: torch.Tensor
pos: torch.Tensor
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: int | None
@classmethod
def make_dummy(
cls,
num_reqs: int,
device: torch.device,
) -> "SamplingMetadata":
assert num_reqs > 0
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
temperature[0] = 0.5
# TODO(woosuk): Use top-p and top-k for dummy sampler.
# Currently, they are disabled because of memory usage.
# top_p = torch.full((num_reqs,), 0.95, dtype=torch.float32, device=device)
# top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device)
top_p = None
top_k = None
min_p = torch.zeros(num_reqs, dtype=torch.float32, device=device)
# NOTE(woosuk): We must set penalties to their default values to make sure
# the penalties kernel does not touch the placeholder bin_counts tensors.
repetition_penalty = torch.ones(num_reqs, dtype=torch.float32, device=device)
frequency_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
presence_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
# specialization and re-compilation at runtime.
prompt_bin_mask = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
max_num_logprobs = 20
return cls(
idx_mapping=idx_mapping,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
prompt_bin_mask=prompt_bin_mask,
output_bin_counts=output_bin_counts,
seeds=seeds,
pos=pos,
max_num_logprobs=max_num_logprobs,
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch import torch
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata from vllm.utils.math_utils import cdiv
from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor
class PenaltiesState:
def __init__(self, max_num_reqs: int, vocab_size: int, device: torch.device):
self.max_num_reqs = max_num_reqs
self.vocab_size = vocab_size
self.device = device
self.repetition_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.frequency_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.presence_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
# Initialize repetition penalty manually because 0 is an invalid value for it.
self.repetition_penalty.np.fill(1.0)
self.repetition_penalty.copy_to_uva()
# Statistics for penalties.
self.prompt_bin_mask = torch.zeros(
self.max_num_reqs,
cdiv(self.vocab_size, 32),
dtype=torch.int32,
device=self.device,
)
# TODO(woosuk): This tensor is rarely used but can be very large, taking up
# GBs of GPU memory. Optimize the memory usage.
self.output_bin_counts = torch.zeros(
self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
)
self._penalties_reqs: list[int] = []
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty
self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty
self.presence_penalty.np[req_idx] = sampling_params.presence_penalty
if use_penalty(sampling_params):
self._penalties_reqs.append(req_idx)
def apply_staged_writes(
self,
prefill_token_ids: torch.Tensor,
prefill_lens: np.ndarray,
prompt_lens: np.ndarray,
) -> None:
# TODO(woosuk): Optimize this.
for req_idx in self._penalties_reqs:
bincount(
prefill_token_ids[req_idx],
int(prefill_lens[req_idx]),
int(prompt_lens[req_idx]),
self.prompt_bin_mask[req_idx],
self.output_bin_counts[req_idx],
)
self._penalties_reqs.clear()
self.repetition_penalty.copy_to_uva()
self.frequency_penalty.copy_to_uva()
self.presence_penalty.copy_to_uva()
def apply_penalties_and_temperature(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
temperature: torch.Tensor,
) -> None:
apply_penalties_and_temperature(
logits,
idx_mapping,
temperature,
self.repetition_penalty.gpu,
self.frequency_penalty.gpu,
self.presence_penalty.gpu,
self.prompt_bin_mask,
self.output_bin_counts,
)
@triton.jit @triton.jit
...@@ -84,7 +162,13 @@ def _penalties_and_temperature_kernel( ...@@ -84,7 +162,13 @@ def _penalties_and_temperature_kernel(
def apply_penalties_and_temperature( def apply_penalties_and_temperature(
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, idx_mapping: torch.Tensor,
temperature: torch.Tensor,
repetition_penalty: torch.Tensor,
frequency_penalty: torch.Tensor,
presence_penalty: torch.Tensor,
prompt_bin_mask: torch.Tensor,
output_bin_counts: torch.Tensor,
) -> None: ) -> None:
num_reqs, vocab_size = logits.shape num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 8192 BLOCK_SIZE = 8192
...@@ -92,15 +176,15 @@ def apply_penalties_and_temperature( ...@@ -92,15 +176,15 @@ def apply_penalties_and_temperature(
_penalties_and_temperature_kernel[(num_reqs, num_blocks)]( _penalties_and_temperature_kernel[(num_reqs, num_blocks)](
logits, logits,
logits.stride(0), logits.stride(0),
sampling_metadata.idx_mapping, idx_mapping,
sampling_metadata.repetition_penalty, repetition_penalty,
sampling_metadata.frequency_penalty, frequency_penalty,
sampling_metadata.presence_penalty, presence_penalty,
sampling_metadata.temperature, temperature,
sampling_metadata.prompt_bin_mask, prompt_bin_mask,
sampling_metadata.prompt_bin_mask.stride(0), prompt_bin_mask.stride(0),
sampling_metadata.output_bin_counts, output_bin_counts,
sampling_metadata.output_bin_counts.stride(0), output_bin_counts.stride(0),
vocab_size, vocab_size,
BLOCK_SIZE=BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE,
) )
...@@ -153,3 +237,11 @@ def bincount( ...@@ -153,3 +237,11 @@ def bincount(
output_bin_counts, output_bin_counts,
BLOCK_SIZE=BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE,
) )
def use_penalty(sampling_params: SamplingParams) -> bool:
return (
sampling_params.repetition_penalty != 1.0
or sampling_params.frequency_penalty != 0.0
or sampling_params.presence_penalty != 0.0
)
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.config.model import LogprobsMode from vllm.config.model import LogprobsMode
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.metrics.logits import get_num_nans from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.sample.min_p import apply_min_p from vllm.v1.worker.gpu.sample.min_p import apply_min_p
from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.penalties import apply_penalties_and_temperature from vllm.v1.worker.gpu.sample.penalties import PenaltiesState
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates
class Sampler: class Sampler:
def __init__( def __init__(
self, self,
max_num_reqs: int,
vocab_size: int,
device: torch.device,
logprobs_mode: LogprobsMode = "raw_logprobs", logprobs_mode: LogprobsMode = "raw_logprobs",
): ):
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]: if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
...@@ -25,26 +31,54 @@ class Sampler: ...@@ -25,26 +31,54 @@ class Sampler:
self.logprobs_mode = logprobs_mode self.logprobs_mode = logprobs_mode
self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default. self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default.
self.sampling_states = SamplingStates(max_num_reqs, vocab_size)
self.penalties_state = PenaltiesState(max_num_reqs, vocab_size, device)
self.logit_bias_state = LogitBiasState(max_num_reqs, device)
def add_request(
self,
req_idx: int,
prompt_len: int,
sampling_params: SamplingParams,
) -> None:
self.sampling_states.add_request(req_idx, sampling_params)
self.penalties_state.add_request(req_idx, sampling_params)
self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params)
def apply_staged_writes(
self,
prefill_token_ids: torch.Tensor,
prefill_lens: np.ndarray,
prompt_lens: np.ndarray,
) -> None:
self.sampling_states.apply_staged_writes()
self.penalties_state.apply_staged_writes(
prefill_token_ids, prefill_lens, prompt_lens
)
self.logit_bias_state.apply_staged_writes()
def __call__( def __call__(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
) -> SamplerOutput: ) -> SamplerOutput:
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear # NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
# that num_nans is computed before applying penalties and temperature. # that num_nans is computed before applying penalties and temperature.
num_nans = get_num_nans(logits) if self.compute_nans else None num_nans = get_num_nans(logits) if self.compute_nans else None
sampled, processed_logits = self.sample(logits, sampling_metadata) sampled, processed_logits = self.sample(
if sampling_metadata.max_num_logprobs is not None: logits, idx_mapping, idx_mapping_np, pos
)
max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
if max_num_logprobs != NO_LOGPROBS:
logits = ( logits = (
processed_logits processed_logits
if self.logprobs_mode == "processed_logprobs" if self.logprobs_mode == "processed_logprobs"
else logits else logits
) )
logprobs_tensors = compute_topk_logprobs( logprobs_tensors = compute_topk_logprobs(logits, max_num_logprobs, sampled)
logits,
sampling_metadata.max_num_logprobs,
sampled,
)
else: else:
logprobs_tensors = None logprobs_tensors = None
...@@ -62,27 +96,41 @@ class Sampler: ...@@ -62,27 +96,41 @@ class Sampler:
def sample( def sample(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Copy logits to a new FP32 tensor. # Copy logits to a new FP32 tensor.
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits) logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, pos)
# Apply penalties and temperature in place. # Apply penalties and temperature in place.
apply_penalties_and_temperature(logits, sampling_metadata) self.penalties_state.apply_penalties_and_temperature(
# Apply min_p in place. logits, idx_mapping, self.sampling_states.temperature.gpu
if sampling_metadata.min_p is not None:
apply_min_p(logits, sampling_metadata.idx_mapping, sampling_metadata.min_p)
# Apply top_k and/or top_p. This might return a new tensor.
logits = apply_top_k_top_p(
logits, sampling_metadata.top_k, sampling_metadata.top_p
) )
# Apply min_p in place if any request has a non-zero min_p.
do_min_p = self.sampling_states.do_min_p(idx_mapping_np)
if do_min_p:
apply_min_p(logits, idx_mapping, self.sampling_states.min_p.gpu)
# Apply top_k and/or top_p. This might return a new tensor.
do_top_k = self.sampling_states.do_top_k(idx_mapping_np)
top_k = self.sampling_states.top_k.gpu[idx_mapping] if do_top_k else None
do_top_p = self.sampling_states.do_top_p(idx_mapping_np)
top_p = self.sampling_states.top_p.gpu[idx_mapping] if do_top_p else None
if do_top_k or do_top_p:
logits = apply_top_k_top_p(logits, top_k, top_p)
# Sample the next token.
sampled = gumbel_sample( sampled = gumbel_sample(
logits, logits,
sampling_metadata.idx_mapping, idx_mapping,
sampling_metadata.temperature, self.sampling_states.temperature.gpu,
sampling_metadata.seeds, self.sampling_states.seeds.gpu,
sampling_metadata.pos, pos,
apply_temperature=False, apply_temperature=False,
) )
return sampled, logits return sampled, logits
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.sampling_params import SamplingParams
from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor
NO_LOGPROBS = -1
_NP_INT64_MIN = np.iinfo(np.int64).min
_NP_INT64_MAX = np.iinfo(np.int64).max
class SamplingStates:
def __init__(self, max_num_reqs: int, vocab_size: int):
self.max_num_reqs = max_num_reqs
self.vocab_size = vocab_size
self.temperature = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.top_k = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
self.top_p = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.min_p = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.seeds = UvaBackedTensor(max_num_reqs, dtype=torch.int64)
# Initialize top_k and top_p manually because 0 is an invalid value for them.
self.top_k.np.fill(self.vocab_size)
self.top_k.copy_to_uva()
self.top_p.np.fill(1.0)
self.top_p.copy_to_uva()
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
# -1 means no logprobs are requested.
self.num_logprobs.fill(NO_LOGPROBS)
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
self.temperature.np[req_idx] = sampling_params.temperature
self.top_p.np[req_idx] = sampling_params.top_p
if 0 < sampling_params.top_k < self.vocab_size:
top_k = sampling_params.top_k
else:
top_k = self.vocab_size
self.top_k.np[req_idx] = top_k
self.min_p.np[req_idx] = sampling_params.min_p
if sampling_params.seed is not None:
seed = sampling_params.seed
else:
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
self.seeds.np[req_idx] = seed
if sampling_params.logprobs is not None:
num_logprobs = sampling_params.logprobs
else:
num_logprobs = NO_LOGPROBS
self.num_logprobs[req_idx] = num_logprobs
def apply_staged_writes(self) -> None:
self.temperature.copy_to_uva()
self.top_p.copy_to_uva()
self.top_k.copy_to_uva()
self.min_p.copy_to_uva()
self.seeds.copy_to_uva()
def do_min_p(self, idx_mapping_np: np.ndarray) -> bool:
return np.any(self.min_p.np[idx_mapping_np] != 0.0)
def do_top_k(self, idx_mapping_np: np.ndarray) -> bool:
return np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
def do_top_p(self, idx_mapping_np: np.ndarray) -> bool:
return np.any(self.top_p.np[idx_mapping_np] != 1.0)
def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int:
return int(np.max(self.num_logprobs[idx_mapping_np]))
...@@ -17,7 +17,6 @@ from vllm.v1.worker.gpu.attn_utils import build_attn_metadata ...@@ -17,7 +17,6 @@ from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.spec_decode.eagle_cudagraph import EagleCudaGraphManager from vllm.v1.worker.gpu.spec_decode.eagle_cudagraph import EagleCudaGraphManager
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -188,7 +187,6 @@ class EagleSpeculator: ...@@ -188,7 +187,6 @@ class EagleSpeculator:
def propose( def propose(
self, self,
input_batch: InputBatch, input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
# [num_tokens, hidden_size] # [num_tokens, hidden_size]
last_hidden_states: torch.Tensor, last_hidden_states: torch.Tensor,
# num_layers x [num_tokens, hidden_size] # num_layers x [num_tokens, hidden_size]
...@@ -201,6 +199,10 @@ class EagleSpeculator: ...@@ -201,6 +199,10 @@ class EagleSpeculator:
last_sampled: torch.Tensor, last_sampled: torch.Tensor,
# [num_reqs] # [num_reqs]
next_prefill_tokens: torch.Tensor, next_prefill_tokens: torch.Tensor,
# [max_num_reqs]
temperature: torch.Tensor,
# [max_num_reqs]
seeds: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the # NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
# number of rejected tokens, we maintain the size of eagle's input_ids and # number of rejected tokens, we maintain the size of eagle's input_ids and
...@@ -246,8 +248,8 @@ class EagleSpeculator: ...@@ -246,8 +248,8 @@ class EagleSpeculator:
# affect the output distribution after rejection sampling. # affect the output distribution after rejection sampling.
idx_mapping = self.idx_mapping[:num_reqs] idx_mapping = self.idx_mapping[:num_reqs]
idx_mapping.copy_(input_batch.idx_mapping) idx_mapping.copy_(input_batch.idx_mapping)
self.temperature.copy_(sampling_metadata.temperature) self.temperature.copy_(temperature)
self.seeds.copy_(sampling_metadata.seeds) self.seeds.copy_(seeds)
# Gather the values and copy them to the pre-allocated buffers. # Gather the values and copy them to the pre-allocated buffers.
pos = self.input_buffers.positions[:num_reqs] pos = self.input_buffers.positions[:num_reqs]
torch.gather(input_batch.positions, 0, last_token_indices, out=pos) torch.gather(input_batch.positions, 0, last_token_indices, out=pos)
......
...@@ -7,14 +7,9 @@ import torch ...@@ -7,14 +7,9 @@ import torch
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils.math_utils import cdiv
from vllm.v1.outputs import LogprobsTensors from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.sample.penalties import bincount
_NP_INT64_MIN = np.iinfo(np.int64).min
_NP_INT64_MAX = np.iinfo(np.int64).max
NO_LORA_ID = 0 NO_LORA_ID = 0
...@@ -81,38 +76,8 @@ class RequestState: ...@@ -81,38 +76,8 @@ class RequestState:
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32) self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
self.lora_ids.fill(NO_LORA_ID) self.lora_ids.fill(NO_LORA_ID)
# Sampling parameters.
self.temperature = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
self.top_p = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
self.top_k = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
self.min_p = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
self.repetition_penalty = UvaBackedTensor(
self.max_num_reqs, dtype=torch.float32
)
self.frequency_penalty = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
self.presence_penalty = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
self.seeds = UvaBackedTensor(self.max_num_reqs, dtype=torch.int64)
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
# -1 means no logprobs are requested.
self.num_logprobs.fill(-1)
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool) self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
# Statistics for penalties.
self.prompt_bin_mask = torch.zeros(
self.max_num_reqs,
cdiv(self.vocab_size, 32),
dtype=torch.int32,
device=self.device,
)
# TODO(woosuk): This tensor is rarely used but can be extremely large.
# Optimize the memory usage.
self.output_bin_counts = torch.zeros(
self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
)
self._penalties_reqs: list[int] = []
@property @property
def num_reqs(self) -> int: def num_reqs(self) -> int:
return len(self.req_id_to_index) return len(self.req_id_to_index)
...@@ -147,33 +112,6 @@ class RequestState: ...@@ -147,33 +112,6 @@ class RequestState:
else: else:
self.lora_ids[req_idx] = NO_LORA_ID self.lora_ids[req_idx] = NO_LORA_ID
self.temperature.np[req_idx] = sampling_params.temperature
self.top_p.np[req_idx] = sampling_params.top_p
if 0 < sampling_params.top_k < self.vocab_size:
top_k = sampling_params.top_k
else:
top_k = self.vocab_size
self.top_k.np[req_idx] = top_k
self.min_p.np[req_idx] = sampling_params.min_p
self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty
self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty
self.presence_penalty.np[req_idx] = sampling_params.presence_penalty
if use_penalty(sampling_params):
self._penalties_reqs.append(req_idx)
if sampling_params.seed is not None:
seed = sampling_params.seed
else:
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
self.seeds.np[req_idx] = seed
if sampling_params.logprobs is not None:
num_logprobs = sampling_params.logprobs
else:
num_logprobs = -1
self.num_logprobs[req_idx] = num_logprobs
# For now, only support prompt logprobs for the prompt tokens. # For now, only support prompt logprobs for the prompt tokens.
needs_prompt_logprobs = sampling_params.prompt_logprobs is not None needs_prompt_logprobs = sampling_params.prompt_logprobs is not None
self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs
...@@ -183,17 +121,6 @@ class RequestState: ...@@ -183,17 +121,6 @@ class RequestState:
self.prefill_token_ids.apply_write() self.prefill_token_ids.apply_write()
self.num_computed_tokens.apply_write() self.num_computed_tokens.apply_write()
# TODO(woosuk): Optimize this.
for req_idx in self._penalties_reqs:
bincount(
self.prefill_token_ids.gpu[req_idx],
int(self.prefill_len.np[req_idx]),
int(self.prompt_len[req_idx]),
self.prompt_bin_mask[req_idx],
self.output_bin_counts[req_idx],
)
self._penalties_reqs.clear()
def remove_request(self, req_id: str) -> None: def remove_request(self, req_id: str) -> None:
self.extra_data.pop(req_id, None) self.extra_data.pop(req_id, None)
req_idx = self.req_id_to_index.pop(req_id, None) req_idx = self.req_id_to_index.pop(req_id, None)
...@@ -203,53 +130,6 @@ class RequestState: ...@@ -203,53 +130,6 @@ class RequestState:
self.index_to_req_id.pop(req_idx, None) self.index_to_req_id.pop(req_idx, None)
self.free_indices.append(req_idx) self.free_indices.append(req_idx)
def make_sampling_metadata(
self,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
) -> SamplingMetadata:
temperature = self.temperature.copy_to_uva()
top_p = self.top_p.np[idx_mapping_np]
no_top_p = np.all(top_p == 1.0)
top_p = self.top_p.copy_to_uva()[idx_mapping] if not no_top_p else None
top_k = self.top_k.np[idx_mapping_np]
no_top_k = np.all(top_k == self.vocab_size)
top_k = self.top_k.copy_to_uva()[idx_mapping] if not no_top_k else None
min_p = self.min_p.np[idx_mapping_np]
no_min_p = np.all(min_p == 0.0)
min_p = self.min_p.copy_to_uva() if not no_min_p else None
rep_penalty = self.repetition_penalty.copy_to_uva()
freq_penalty = self.frequency_penalty.copy_to_uva()
pres_penalty = self.presence_penalty.copy_to_uva()
seeds = self.seeds.copy_to_uva()
num_logprobs = self.num_logprobs[idx_mapping_np]
max_num_logprobs: int | None = int(np.max(num_logprobs))
if max_num_logprobs == -1:
max_num_logprobs = None
return SamplingMetadata(
idx_mapping=idx_mapping,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
repetition_penalty=rep_penalty,
frequency_penalty=freq_penalty,
presence_penalty=pres_penalty,
prompt_bin_mask=self.prompt_bin_mask,
output_bin_counts=self.output_bin_counts,
seeds=seeds,
pos=pos,
max_num_logprobs=max_num_logprobs,
)
def make_lora_inputs( def make_lora_inputs(
self, self,
req_ids: list[str], req_ids: list[str],
...@@ -272,11 +152,3 @@ class RequestState: ...@@ -272,11 +152,3 @@ class RequestState:
class ExtraData: class ExtraData:
lora_request: LoRARequest | None lora_request: LoRARequest | None
in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list) in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)
def use_penalty(sampling_params: SamplingParams) -> bool:
return (
sampling_params.repetition_penalty != 1.0
or sampling_params.frequency_penalty != 0.0
or sampling_params.presence_penalty != 0.0
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment