Unverified Commit 90c08369 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Refactor Sampler (#32245)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 8ef50d9a
......@@ -49,7 +49,6 @@ from vllm.v1.worker.gpu.input_batch import (
)
from vllm.v1.worker.gpu.mm.mrope_utils import MRopeState
from vllm.v1.worker.gpu.sample.logprob import compute_prompt_logprobs
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.sampler import Sampler
from vllm.v1.worker.gpu.spec_decode import init_speculator
......@@ -139,7 +138,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
dtype=self.dtype,
device=self.device,
)
self.sampler = Sampler(logprobs_mode=self.model_config.logprobs_mode)
self.sampler = Sampler(
max_num_reqs=self.max_num_reqs,
vocab_size=self.vocab_size,
device=self.device,
logprobs_mode=self.model_config.logprobs_mode,
)
# CUDA graphs.
self.cudagraph_manager = CudaGraphManager(
......@@ -310,12 +314,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
hidden_states: torch.Tensor,
) -> None:
num_reqs = hidden_states.shape[0]
sampling_metadata = SamplingMetadata.make_dummy(
num_reqs=num_reqs,
device=self.device,
)
logits = self.model.compute_logits(hidden_states)
self.sampler(logits, sampling_metadata)
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=self.device)
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
pos = torch.zeros(num_reqs, dtype=torch.int64, device=self.device)
# NOTE(woosuk): During the initial memory profiling, the sampler may skip
# top_k, top_p, and logprobs, using less GPU memory than what is possible
# during actual execution.
self.sampler(logits, idx_mapping, idx_mapping_np, pos)
@torch.inference_mode()
def profile_run(self) -> None:
......@@ -401,9 +407,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert new_req_data.prefill_token_ids is not None
assert new_req_data.sampling_params is not None
req_id = new_req_data.req_id
prompt_len = len(new_req_data.prompt_token_ids)
self.req_states.add_request(
req_id=req_id,
prompt_len=len(new_req_data.prompt_token_ids),
prompt_len=prompt_len,
prefill_token_ids=new_req_data.prefill_token_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
sampling_params=new_req_data.sampling_params,
......@@ -423,6 +430,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.block_tables.append_block_ids(
req_index, new_req_data.block_ids, overwrite=True
)
self.sampler.add_request(
req_index, prompt_len, new_req_data.sampling_params
)
# Add new blocks for the existing requests.
cached_reqs = scheduler_output.scheduled_cached_reqs
......@@ -436,6 +446,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.req_states.apply_staged_writes()
self.block_tables.apply_staged_writes()
self.sampler.apply_staged_writes(
self.req_states.prefill_token_ids.gpu,
self.req_states.prefill_len.np,
self.req_states.prompt_len,
)
if self.uses_mrope:
self.mrope_states.apply_staged_writes()
......@@ -612,10 +627,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self,
hidden_states: torch.Tensor,
input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
grammar_output: GrammarOutput | None,
) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
sample_hidden_states = hidden_states[input_batch.logits_indices]
sample_pos = input_batch.positions[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states)
if grammar_output is not None:
# Apply grammar bitmask to the logits in-place.
......@@ -627,7 +642,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
# Sample tokens and compute logprobs (if needed).
sampler_output = self.sampler(logits, sampling_metadata)
sampler_output = self.sampler(
logits,
input_batch.expanded_idx_mapping,
input_batch.idx_mapping_np,
sample_pos,
)
if input_batch.num_draft_tokens == 0:
# No draft tokens (common case).
......@@ -766,7 +786,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
input_batch.idx_mapping,
self.req_states.num_computed_tokens.gpu,
self.req_states.last_sampled_tokens,
self.req_states.output_bin_counts,
self.sampler.penalties_state.output_bin_counts,
sampled_tokens,
num_sampled,
num_rejected,
......@@ -786,7 +806,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def propose_draft(
self,
input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
last_hidden_states: torch.Tensor,
aux_hidden_states: list[torch.Tensor] | None,
num_sampled: torch.Tensor,
......@@ -801,13 +820,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
]
draft_tokens = self.speculator.propose(
input_batch,
sampling_metadata,
last_hidden_states,
aux_hidden_states,
num_sampled,
num_rejected,
last_sampled_tokens,
next_prefill_tokens,
self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu,
)
return draft_tokens
......@@ -893,12 +913,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
scheduler_output,
num_tokens_after_padding,
)
pos = input_batch.positions[input_batch.logits_indices]
sampling_metadata = self.req_states.make_sampling_metadata(
input_batch.expanded_idx_mapping, input_batch.idx_mapping_np, pos
)
if self.lora_config:
# Activate LoRA adapters.
lora_inputs = self.req_states.make_lora_inputs(
......@@ -917,7 +931,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
device=self.device,
)
self.prepare_dummy_attn_metadata(input_batch)
sampling_metadata = None
# Run model.
if cudagraph_mode == CUDAGraphMode.FULL:
......@@ -946,7 +959,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
positions=positions,
)
self.execute_model_state = hidden_states, input_batch, sampling_metadata
self.execute_model_state = hidden_states, input_batch
return None
@torch.inference_mode()
......@@ -955,12 +968,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
grammar_output: GrammarOutput | None,
) -> AsyncOutput | ModelRunnerOutput:
assert self.execute_model_state is not None
hidden_states, input_batch, sampling_metadata = self.execute_model_state
hidden_states, input_batch = self.execute_model_state
self.execute_model_state = None # type: ignore
assert sampling_metadata is not None
sampler_output, num_sampled, num_rejected = self.sample(
hidden_states, input_batch, sampling_metadata, grammar_output
hidden_states, input_batch, grammar_output
)
prompt_logprobs_dict = self.compute_prompt_logprobs(hidden_states, input_batch)
......@@ -992,7 +1004,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
if self.do_spec_decode:
draft_tokens = self.propose_draft(
input_batch,
sampling_metadata,
hidden_states,
None, # aux_hidden_states
num_sampled,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
@dataclass
class SamplingMetadata:
idx_mapping: torch.Tensor
temperature: torch.Tensor
top_p: torch.Tensor | None
top_k: torch.Tensor | None
min_p: torch.Tensor | None
# For penalties
repetition_penalty: torch.Tensor
frequency_penalty: torch.Tensor
presence_penalty: torch.Tensor
prompt_bin_mask: torch.Tensor
output_bin_counts: torch.Tensor
seeds: torch.Tensor
pos: torch.Tensor
# None means no logprobs, 0 means sampled token logprobs only
max_num_logprobs: int | None
@classmethod
def make_dummy(
cls,
num_reqs: int,
device: torch.device,
) -> "SamplingMetadata":
assert num_reqs > 0
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=device)
temperature = torch.zeros(num_reqs, dtype=torch.float32, device=device)
temperature[0] = 0.5
# TODO(woosuk): Use top-p and top-k for dummy sampler.
# Currently, they are disabled because of memory usage.
# top_p = torch.full((num_reqs,), 0.95, dtype=torch.float32, device=device)
# top_k = torch.full((num_reqs,), 20, dtype=torch.int32, device=device)
top_p = None
top_k = None
min_p = torch.zeros(num_reqs, dtype=torch.float32, device=device)
# NOTE(woosuk): We must set penalties to their default values to make sure
# the penalties kernel does not touch the placeholder bin_counts tensors.
repetition_penalty = torch.ones(num_reqs, dtype=torch.float32, device=device)
frequency_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
presence_penalty = torch.zeros(num_reqs, dtype=torch.float32, device=device)
# NOTE(woosuk): These are placeholder tensors to avoid None checks in the
# penalties kernel. We use 2 instead of 1 as vocab_size to avoid Triton
# specialization and re-compilation at runtime.
prompt_bin_mask = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
output_bin_counts = torch.zeros(num_reqs, 2, dtype=torch.int32, device=device)
seeds = torch.zeros(num_reqs, dtype=torch.int64, device=device)
pos = torch.zeros(num_reqs, dtype=torch.int64, device=device)
max_num_logprobs = 20
return cls(
idx_mapping=idx_mapping,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
repetition_penalty=repetition_penalty,
frequency_penalty=frequency_penalty,
presence_penalty=presence_penalty,
prompt_bin_mask=prompt_bin_mask,
output_bin_counts=output_bin_counts,
seeds=seeds,
pos=pos,
max_num_logprobs=max_num_logprobs,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.utils.math_utils import cdiv
from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor
class PenaltiesState:
def __init__(self, max_num_reqs: int, vocab_size: int, device: torch.device):
self.max_num_reqs = max_num_reqs
self.vocab_size = vocab_size
self.device = device
self.repetition_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.frequency_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.presence_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
# Initialize repetition penalty manually because 0 is an invalid value for it.
self.repetition_penalty.np.fill(1.0)
self.repetition_penalty.copy_to_uva()
# Statistics for penalties.
self.prompt_bin_mask = torch.zeros(
self.max_num_reqs,
cdiv(self.vocab_size, 32),
dtype=torch.int32,
device=self.device,
)
# TODO(woosuk): This tensor is rarely used but can be very large, taking up
# GBs of GPU memory. Optimize the memory usage.
self.output_bin_counts = torch.zeros(
self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
)
self._penalties_reqs: list[int] = []
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty
self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty
self.presence_penalty.np[req_idx] = sampling_params.presence_penalty
if use_penalty(sampling_params):
self._penalties_reqs.append(req_idx)
def apply_staged_writes(
self,
prefill_token_ids: torch.Tensor,
prefill_lens: np.ndarray,
prompt_lens: np.ndarray,
) -> None:
# TODO(woosuk): Optimize this.
for req_idx in self._penalties_reqs:
bincount(
prefill_token_ids[req_idx],
int(prefill_lens[req_idx]),
int(prompt_lens[req_idx]),
self.prompt_bin_mask[req_idx],
self.output_bin_counts[req_idx],
)
self._penalties_reqs.clear()
self.repetition_penalty.copy_to_uva()
self.frequency_penalty.copy_to_uva()
self.presence_penalty.copy_to_uva()
def apply_penalties_and_temperature(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
temperature: torch.Tensor,
) -> None:
apply_penalties_and_temperature(
logits,
idx_mapping,
temperature,
self.repetition_penalty.gpu,
self.frequency_penalty.gpu,
self.presence_penalty.gpu,
self.prompt_bin_mask,
self.output_bin_counts,
)
@triton.jit
......@@ -84,7 +162,13 @@ def _penalties_and_temperature_kernel(
def apply_penalties_and_temperature(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
idx_mapping: torch.Tensor,
temperature: torch.Tensor,
repetition_penalty: torch.Tensor,
frequency_penalty: torch.Tensor,
presence_penalty: torch.Tensor,
prompt_bin_mask: torch.Tensor,
output_bin_counts: torch.Tensor,
) -> None:
num_reqs, vocab_size = logits.shape
BLOCK_SIZE = 8192
......@@ -92,15 +176,15 @@ def apply_penalties_and_temperature(
_penalties_and_temperature_kernel[(num_reqs, num_blocks)](
logits,
logits.stride(0),
sampling_metadata.idx_mapping,
sampling_metadata.repetition_penalty,
sampling_metadata.frequency_penalty,
sampling_metadata.presence_penalty,
sampling_metadata.temperature,
sampling_metadata.prompt_bin_mask,
sampling_metadata.prompt_bin_mask.stride(0),
sampling_metadata.output_bin_counts,
sampling_metadata.output_bin_counts.stride(0),
idx_mapping,
repetition_penalty,
frequency_penalty,
presence_penalty,
temperature,
prompt_bin_mask,
prompt_bin_mask.stride(0),
output_bin_counts,
output_bin_counts.stride(0),
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
......@@ -153,3 +237,11 @@ def bincount(
output_bin_counts,
BLOCK_SIZE=BLOCK_SIZE,
)
def use_penalty(sampling_params: SamplingParams) -> bool:
return (
sampling_params.repetition_penalty != 1.0
or sampling_params.frequency_penalty != 0.0
or sampling_params.presence_penalty != 0.0
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
import vllm.envs as envs
from vllm.config.model import LogprobsMode
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.penalties import apply_penalties_and_temperature
from vllm.v1.worker.gpu.sample.penalties import PenaltiesState
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates
class Sampler:
def __init__(
self,
max_num_reqs: int,
vocab_size: int,
device: torch.device,
logprobs_mode: LogprobsMode = "raw_logprobs",
):
if logprobs_mode not in ["processed_logprobs", "raw_logprobs"]:
......@@ -25,26 +31,54 @@ class Sampler:
self.logprobs_mode = logprobs_mode
self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default.
self.sampling_states = SamplingStates(max_num_reqs, vocab_size)
self.penalties_state = PenaltiesState(max_num_reqs, vocab_size, device)
self.logit_bias_state = LogitBiasState(max_num_reqs, device)
def add_request(
self,
req_idx: int,
prompt_len: int,
sampling_params: SamplingParams,
) -> None:
self.sampling_states.add_request(req_idx, sampling_params)
self.penalties_state.add_request(req_idx, sampling_params)
self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params)
def apply_staged_writes(
self,
prefill_token_ids: torch.Tensor,
prefill_lens: np.ndarray,
prompt_lens: np.ndarray,
) -> None:
self.sampling_states.apply_staged_writes()
self.penalties_state.apply_staged_writes(
prefill_token_ids, prefill_lens, prompt_lens
)
self.logit_bias_state.apply_staged_writes()
def __call__(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
) -> SamplerOutput:
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
# that num_nans is computed before applying penalties and temperature.
num_nans = get_num_nans(logits) if self.compute_nans else None
sampled, processed_logits = self.sample(logits, sampling_metadata)
if sampling_metadata.max_num_logprobs is not None:
sampled, processed_logits = self.sample(
logits, idx_mapping, idx_mapping_np, pos
)
max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np)
if max_num_logprobs != NO_LOGPROBS:
logits = (
processed_logits
if self.logprobs_mode == "processed_logprobs"
else logits
)
logprobs_tensors = compute_topk_logprobs(
logits,
sampling_metadata.max_num_logprobs,
sampled,
)
logprobs_tensors = compute_topk_logprobs(logits, max_num_logprobs, sampled)
else:
logprobs_tensors = None
......@@ -62,27 +96,41 @@ class Sampler:
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
# Copy logits to a new FP32 tensor.
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, pos)
# Apply penalties and temperature in place.
apply_penalties_and_temperature(logits, sampling_metadata)
# Apply min_p in place.
if sampling_metadata.min_p is not None:
apply_min_p(logits, sampling_metadata.idx_mapping, sampling_metadata.min_p)
# Apply top_k and/or top_p. This might return a new tensor.
logits = apply_top_k_top_p(
logits, sampling_metadata.top_k, sampling_metadata.top_p
self.penalties_state.apply_penalties_and_temperature(
logits, idx_mapping, self.sampling_states.temperature.gpu
)
# Apply min_p in place if any request has a non-zero min_p.
do_min_p = self.sampling_states.do_min_p(idx_mapping_np)
if do_min_p:
apply_min_p(logits, idx_mapping, self.sampling_states.min_p.gpu)
# Apply top_k and/or top_p. This might return a new tensor.
do_top_k = self.sampling_states.do_top_k(idx_mapping_np)
top_k = self.sampling_states.top_k.gpu[idx_mapping] if do_top_k else None
do_top_p = self.sampling_states.do_top_p(idx_mapping_np)
top_p = self.sampling_states.top_p.gpu[idx_mapping] if do_top_p else None
if do_top_k or do_top_p:
logits = apply_top_k_top_p(logits, top_k, top_p)
# Sample the next token.
sampled = gumbel_sample(
logits,
sampling_metadata.idx_mapping,
sampling_metadata.temperature,
sampling_metadata.seeds,
sampling_metadata.pos,
idx_mapping,
self.sampling_states.temperature.gpu,
self.sampling_states.seeds.gpu,
pos,
apply_temperature=False,
)
return sampled, logits
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.sampling_params import SamplingParams
from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor
NO_LOGPROBS = -1
_NP_INT64_MIN = np.iinfo(np.int64).min
_NP_INT64_MAX = np.iinfo(np.int64).max
class SamplingStates:
def __init__(self, max_num_reqs: int, vocab_size: int):
self.max_num_reqs = max_num_reqs
self.vocab_size = vocab_size
self.temperature = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.top_k = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
self.top_p = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.min_p = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.seeds = UvaBackedTensor(max_num_reqs, dtype=torch.int64)
# Initialize top_k and top_p manually because 0 is an invalid value for them.
self.top_k.np.fill(self.vocab_size)
self.top_k.copy_to_uva()
self.top_p.np.fill(1.0)
self.top_p.copy_to_uva()
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
# -1 means no logprobs are requested.
self.num_logprobs.fill(NO_LOGPROBS)
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
self.temperature.np[req_idx] = sampling_params.temperature
self.top_p.np[req_idx] = sampling_params.top_p
if 0 < sampling_params.top_k < self.vocab_size:
top_k = sampling_params.top_k
else:
top_k = self.vocab_size
self.top_k.np[req_idx] = top_k
self.min_p.np[req_idx] = sampling_params.min_p
if sampling_params.seed is not None:
seed = sampling_params.seed
else:
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
self.seeds.np[req_idx] = seed
if sampling_params.logprobs is not None:
num_logprobs = sampling_params.logprobs
else:
num_logprobs = NO_LOGPROBS
self.num_logprobs[req_idx] = num_logprobs
def apply_staged_writes(self) -> None:
self.temperature.copy_to_uva()
self.top_p.copy_to_uva()
self.top_k.copy_to_uva()
self.min_p.copy_to_uva()
self.seeds.copy_to_uva()
def do_min_p(self, idx_mapping_np: np.ndarray) -> bool:
return np.any(self.min_p.np[idx_mapping_np] != 0.0)
def do_top_k(self, idx_mapping_np: np.ndarray) -> bool:
return np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
def do_top_p(self, idx_mapping_np: np.ndarray) -> bool:
return np.any(self.top_p.np[idx_mapping_np] != 1.0)
def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int:
return int(np.max(self.num_logprobs[idx_mapping_np]))
......@@ -17,7 +17,6 @@ from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
from vllm.v1.worker.gpu.block_table import BlockTables
from vllm.v1.worker.gpu.input_batch import InputBatch, InputBuffers
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.spec_decode.eagle_cudagraph import EagleCudaGraphManager
logger = init_logger(__name__)
......@@ -188,7 +187,6 @@ class EagleSpeculator:
def propose(
self,
input_batch: InputBatch,
sampling_metadata: SamplingMetadata,
# [num_tokens, hidden_size]
last_hidden_states: torch.Tensor,
# num_layers x [num_tokens, hidden_size]
......@@ -201,6 +199,10 @@ class EagleSpeculator:
last_sampled: torch.Tensor,
# [num_reqs]
next_prefill_tokens: torch.Tensor,
# [max_num_reqs]
temperature: torch.Tensor,
# [max_num_reqs]
seeds: torch.Tensor,
) -> torch.Tensor:
# NOTE(woosuk): To avoid CPU-GPU synchronization without CPU knowing the
# number of rejected tokens, we maintain the size of eagle's input_ids and
......@@ -246,8 +248,8 @@ class EagleSpeculator:
# affect the output distribution after rejection sampling.
idx_mapping = self.idx_mapping[:num_reqs]
idx_mapping.copy_(input_batch.idx_mapping)
self.temperature.copy_(sampling_metadata.temperature)
self.seeds.copy_(sampling_metadata.seeds)
self.temperature.copy_(temperature)
self.seeds.copy_(seeds)
# Gather the values and copy them to the pre-allocated buffers.
pos = self.input_buffers.positions[:num_reqs]
torch.gather(input_batch.positions, 0, last_token_indices, out=pos)
......
......@@ -7,14 +7,9 @@ import torch
from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.utils.math_utils import cdiv
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
from vllm.v1.worker.gpu.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu.sample.penalties import bincount
_NP_INT64_MIN = np.iinfo(np.int64).min
_NP_INT64_MAX = np.iinfo(np.int64).max
NO_LORA_ID = 0
......@@ -81,38 +76,8 @@ class RequestState:
self.lora_ids = np.zeros(self.max_num_reqs, dtype=np.int32)
self.lora_ids.fill(NO_LORA_ID)
# Sampling parameters.
self.temperature = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
self.top_p = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
self.top_k = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
self.min_p = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
self.repetition_penalty = UvaBackedTensor(
self.max_num_reqs, dtype=torch.float32
)
self.frequency_penalty = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
self.presence_penalty = UvaBackedTensor(self.max_num_reqs, dtype=torch.float32)
self.seeds = UvaBackedTensor(self.max_num_reqs, dtype=torch.int64)
self.num_logprobs = np.empty(self.max_num_reqs, dtype=np.int32)
# -1 means no logprobs are requested.
self.num_logprobs.fill(-1)
self.needs_prompt_logprobs = np.zeros(self.max_num_reqs, dtype=bool)
# Statistics for penalties.
self.prompt_bin_mask = torch.zeros(
self.max_num_reqs,
cdiv(self.vocab_size, 32),
dtype=torch.int32,
device=self.device,
)
# TODO(woosuk): This tensor is rarely used but can be extremely large.
# Optimize the memory usage.
self.output_bin_counts = torch.zeros(
self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
)
self._penalties_reqs: list[int] = []
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
......@@ -147,33 +112,6 @@ class RequestState:
else:
self.lora_ids[req_idx] = NO_LORA_ID
self.temperature.np[req_idx] = sampling_params.temperature
self.top_p.np[req_idx] = sampling_params.top_p
if 0 < sampling_params.top_k < self.vocab_size:
top_k = sampling_params.top_k
else:
top_k = self.vocab_size
self.top_k.np[req_idx] = top_k
self.min_p.np[req_idx] = sampling_params.min_p
self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty
self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty
self.presence_penalty.np[req_idx] = sampling_params.presence_penalty
if use_penalty(sampling_params):
self._penalties_reqs.append(req_idx)
if sampling_params.seed is not None:
seed = sampling_params.seed
else:
seed = np.random.randint(_NP_INT64_MIN, _NP_INT64_MAX)
self.seeds.np[req_idx] = seed
if sampling_params.logprobs is not None:
num_logprobs = sampling_params.logprobs
else:
num_logprobs = -1
self.num_logprobs[req_idx] = num_logprobs
# For now, only support prompt logprobs for the prompt tokens.
needs_prompt_logprobs = sampling_params.prompt_logprobs is not None
self.needs_prompt_logprobs[req_idx] = needs_prompt_logprobs
......@@ -183,17 +121,6 @@ class RequestState:
self.prefill_token_ids.apply_write()
self.num_computed_tokens.apply_write()
# TODO(woosuk): Optimize this.
for req_idx in self._penalties_reqs:
bincount(
self.prefill_token_ids.gpu[req_idx],
int(self.prefill_len.np[req_idx]),
int(self.prompt_len[req_idx]),
self.prompt_bin_mask[req_idx],
self.output_bin_counts[req_idx],
)
self._penalties_reqs.clear()
def remove_request(self, req_id: str) -> None:
self.extra_data.pop(req_id, None)
req_idx = self.req_id_to_index.pop(req_id, None)
......@@ -203,53 +130,6 @@ class RequestState:
self.index_to_req_id.pop(req_idx, None)
self.free_indices.append(req_idx)
def make_sampling_metadata(
self,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
) -> SamplingMetadata:
temperature = self.temperature.copy_to_uva()
top_p = self.top_p.np[idx_mapping_np]
no_top_p = np.all(top_p == 1.0)
top_p = self.top_p.copy_to_uva()[idx_mapping] if not no_top_p else None
top_k = self.top_k.np[idx_mapping_np]
no_top_k = np.all(top_k == self.vocab_size)
top_k = self.top_k.copy_to_uva()[idx_mapping] if not no_top_k else None
min_p = self.min_p.np[idx_mapping_np]
no_min_p = np.all(min_p == 0.0)
min_p = self.min_p.copy_to_uva() if not no_min_p else None
rep_penalty = self.repetition_penalty.copy_to_uva()
freq_penalty = self.frequency_penalty.copy_to_uva()
pres_penalty = self.presence_penalty.copy_to_uva()
seeds = self.seeds.copy_to_uva()
num_logprobs = self.num_logprobs[idx_mapping_np]
max_num_logprobs: int | None = int(np.max(num_logprobs))
if max_num_logprobs == -1:
max_num_logprobs = None
return SamplingMetadata(
idx_mapping=idx_mapping,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
repetition_penalty=rep_penalty,
frequency_penalty=freq_penalty,
presence_penalty=pres_penalty,
prompt_bin_mask=self.prompt_bin_mask,
output_bin_counts=self.output_bin_counts,
seeds=seeds,
pos=pos,
max_num_logprobs=max_num_logprobs,
)
def make_lora_inputs(
self,
req_ids: list[str],
......@@ -272,11 +152,3 @@ class RequestState:
class ExtraData:
lora_request: LoRARequest | None
in_progress_prompt_logprobs: list[LogprobsTensors] = field(default_factory=list)
def use_penalty(sampling_params: SamplingParams) -> bool:
return (
sampling_params.repetition_penalty != 1.0
or sampling_params.frequency_penalty != 0.0
or sampling_params.presence_penalty != 0.0
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment