from typing import Dict, List, Tuple import torch from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData class SamplingMetadata: """Metadata for input sequences. Used in sampler. Args: seq_groups: List of (seq_ids, sampling_params). seq_data: Seq_id -> SequenceData. prompt_lens: Lengths of prompts. selected_token_indices: Token indices selected for sampling. categorized_sample_indices: SamplingType -> token indicies to sample. """ def __init__( self, seq_groups: List[Tuple[List[int], SamplingParams]], seq_data: Dict[int, SequenceData], prompt_lens: List[int], selected_token_indices: torch.Tensor, categorized_sample_indices: Dict[SamplingType, torch.Tensor], ) -> None: self.seq_groups = seq_groups self.seq_data = seq_data self.prompt_lens = prompt_lens self.selected_token_indices = selected_token_indices self.categorized_sample_indices = categorized_sample_indices self.num_prompts = len(prompt_lens) def __repr__(self) -> str: return ( "SamplingMetadata(" f"seq_groups={self.seq_groups}, " f"seq_data={self.seq_data}, " f"prompt_lens={self.prompt_lens}, " f"selected_token_indices={self.selected_token_indices}, " f"categorized_sample_indices={self.categorized_sample_indices})")