Unverified Commit 27feead2 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Refactor Worker & InputMetadata (#1843)

parent c7821956
...@@ -161,6 +161,12 @@ class ModelConfig: ...@@ -161,6 +161,12 @@ class ModelConfig:
"must be divisible by pipeline parallel size " "must be divisible by pipeline parallel size "
f"({pipeline_parallel_size}).") f"({pipeline_parallel_size}).")
def get_sliding_window(self) -> Optional[int]:
return getattr(self.hf_config, "sliding_window", None)
def get_vocab_size(self) -> int:
return self.hf_config.vocab_size
def get_hidden_size(self) -> int: def get_hidden_size(self) -> int:
return self.hf_config.hidden_size return self.hf_config.hidden_size
......
...@@ -201,9 +201,10 @@ class EngineArgs: ...@@ -201,9 +201,10 @@ class EngineArgs:
self.dtype, self.seed, self.revision, self.dtype, self.seed, self.revision,
self.tokenizer_revision, self.max_model_len, self.tokenizer_revision, self.max_model_len,
self.quantization) self.quantization)
cache_config = CacheConfig( cache_config = CacheConfig(self.block_size,
self.block_size, self.gpu_memory_utilization, self.swap_space, self.gpu_memory_utilization,
getattr(model_config.hf_config, 'sliding_window', None)) self.swap_space,
model_config.get_sliding_window())
parallel_config = ParallelConfig(self.pipeline_parallel_size, parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size, self.tensor_parallel_size,
self.worker_use_ray, self.worker_use_ray,
......
...@@ -88,8 +88,6 @@ class LLMEngine: ...@@ -88,8 +88,6 @@ class LLMEngine:
self.model_config = model_config self.model_config = model_config
self.cache_config = cache_config self.cache_config = cache_config
assert self.cache_config.sliding_window == getattr(
self.model_config.hf_config, "sliding_window", None)
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.log_stats = log_stats self.log_stats = log_stats
......
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
__all__ = [ __all__ = [
"InputMetadata", "InputMetadata",
"get_model", "get_model",
"SamplingMetadata",
"set_random_seed", "set_random_seed",
] ]
from typing import Dict, List, Optional, Tuple from typing import List, Optional
import torch import torch
from xformers.ops import AttentionBias
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData
class InputMetadata: class InputMetadata:
"""Metadata for input sequences. Used for PagedAttention. """Metadata for input sequences. Used in PagedAttention.
Args: Args:
seq_groups: List of (seq_ids, sampling_params).
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts. prompt_lens: Lengths of prompts.
slot_mapping: The address to write the new KV to of each token. slot_mapping: The address to write the new KV to of each token.
context_lens: the length of attention context for each generation token.
max_context_len: The maximum context length. max_context_len: The maximum context length.
context_lens: the length of attention context for each sequence.
block_tables: The block tables. (Seq id -> list of physical block) block_tables: The block tables. (Seq id -> list of physical block)
""" """
def __init__( def __init__(
self, self,
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_data: Dict[int, SequenceData],
prompt_lens: List[int], prompt_lens: List[int],
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
context_lens: torch.Tensor, max_context_len: Optional[int],
max_context_len: int, context_lens: Optional[torch.Tensor],
block_tables: torch.Tensor, block_tables: Optional[torch.Tensor],
selected_token_indices: torch.Tensor,
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
sliding_window: Optional[int] = None,
) -> None: ) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
self.prompt_lens = prompt_lens self.prompt_lens = prompt_lens
self.max_context_len = max_context_len
self.slot_mapping = slot_mapping self.slot_mapping = slot_mapping
self.context_lens = context_lens self.context_lens = context_lens
self.max_context_len = max_context_len
self.block_tables = block_tables self.block_tables = block_tables
self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
self.to_cache = None
if sliding_window is not None:
# We need to keep the positions of sliding windows within
# the key / value tables, this is helpful to know which
# elements we need to cache.
to_cache, start_idx = [], 0
for prompt_len in self.prompt_lens:
to_cache.extend(
range(
start_idx + max(0, prompt_len - sliding_window),
start_idx + prompt_len,
))
start_idx += self.max_prompt_len
to_cache.extend(range(start_idx, slot_mapping.shape[0]))
self.to_cache = torch.tensor(to_cache,
dtype=torch.int32,
device=self.slot_mapping.device)
self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = self.num_prompts * self.max_prompt_len
self.num_generation_tokens = context_lens.shape[0]
if block_tables.numel() > 0:
self.max_num_blocks_per_seq = block_tables.shape[1]
else:
self.max_num_blocks_per_seq = 0
assert block_tables.shape[0] == self.num_generation_tokens
self.is_prompt = len(prompt_lens) > 0
# Set during the execution of the first attention op. # Set during the execution of the first attention op.
self.attn_bias: Optional[AttentionBias] = None # FIXME(woosuk): This is a hack.
self.attn_bias = None
def __repr__(self) -> str: def __repr__(self) -> str:
# Print only useful metadata. return ("InputMetadata("
return ( f"prompt_lens={self.prompt_lens}, "
f'InputMetadata(' f"max_context_len={self.max_context_len}, "
f'num_prompt_tokens={self.num_prompt_tokens}, ' f"slot_mapping={self.slot_mapping}, "
f'num_prompts={self.num_prompts}, ' f"context_lens={self.context_lens}, "
f'prompt_lens={self.prompt_lens}, ' f"block_tables={self.block_tables})")
f'num_generation_tokens={self.num_generation_tokens}, '
f'context_lens={self.context_lens}, '
f'max_context_len={self.max_context_len}), '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'block_tables={self.block_tables}, '
f'selected_token_indices={self.selected_token_indices}, '
f'categorized_sample_indices={self.categorized_sample_indices}, '
f'slot_mapping={self.slot_mapping})')
...@@ -101,23 +101,15 @@ class PagedAttention(nn.Module): ...@@ -101,23 +101,15 @@ class PagedAttention(nn.Module):
# vectors will not be cached. This happens during the initial memory # vectors will not be cached. This happens during the initial memory
# profiling run. # profiling run.
if key_cache is not None and value_cache is not None: if key_cache is not None and value_cache is not None:
key_to_cache = key
value_to_cache = value
if input_metadata.to_cache is not None:
key_to_cache = key_to_cache[input_metadata.to_cache]
value_to_cache = value_to_cache[input_metadata.to_cache]
slot_mapping = slot_mapping[input_metadata.to_cache]
cache_ops.reshape_and_cache( cache_ops.reshape_and_cache(
key_to_cache, key,
value_to_cache, value,
key_cache, key_cache,
value_cache, value_cache,
slot_mapping, slot_mapping,
) )
is_prompt = len(input_metadata.prompt_lens) > 0 if input_metadata.is_prompt:
if is_prompt:
# Prompt run. # Prompt run.
if self.num_kv_heads != self.num_heads: if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA, # As of Nov 2023, xformers only supports MHA. For MQA/GQA,
......
...@@ -4,9 +4,9 @@ from typing import Dict, List, Optional, Tuple ...@@ -4,9 +4,9 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather) tensor_model_parallel_all_gather)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput, from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
SequenceData, SequenceGroupOutput, SequenceOutput) SequenceData, SequenceGroupOutput, SequenceOutput)
...@@ -37,29 +37,30 @@ class Sampler(nn.Module): ...@@ -37,29 +37,30 @@ class Sampler(nn.Module):
self, self,
embedding: torch.Tensor, embedding: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None, embedding_bias: Optional[torch.Tensor] = None,
) -> SamplerOutput: ) -> SamplerOutput:
# Get the hidden states that we use for sampling. # Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, input_metadata) hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
# Get the logits for the next tokens. # Get the logits for the next tokens.
logits = _get_logits(hidden_states, embedding, embedding_bias, logits = _get_logits(hidden_states, embedding, embedding_bias,
self.vocab_size) self.vocab_size)
# Apply logits processors (if any). # Apply logits processors (if any).
logits = _apply_logits_processors(logits, input_metadata) logits = _apply_logits_processors(logits, sampling_metadata)
# Apply presence and frequency penalties. # Apply presence and frequency penalties.
presence_penalties, frequency_penalties, repetition_penalties = ( presence_penalties, frequency_penalties, repetition_penalties = (
_get_penalties(input_metadata)) _get_penalties(sampling_metadata))
assert len(presence_penalties) == logits.shape[0] assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0] assert len(frequency_penalties) == logits.shape[0]
assert len(repetition_penalties) == logits.shape[0] assert len(repetition_penalties) == logits.shape[0]
logits = _apply_penalties(logits, input_metadata, presence_penalties, logits = _apply_penalties(logits, sampling_metadata,
frequency_penalties, repetition_penalties) presence_penalties, frequency_penalties,
repetition_penalties)
# Apply temperature scaling. # Apply temperature scaling.
temperatures = _get_temperatures(input_metadata) temperatures = _get_temperatures(sampling_metadata)
assert len(temperatures) == logits.shape[0] assert len(temperatures) == logits.shape[0]
if any(t != 1.0 for t in temperatures): if any(t != 1.0 for t in temperatures):
t = torch.tensor(temperatures, t = torch.tensor(temperatures,
...@@ -70,7 +71,7 @@ class Sampler(nn.Module): ...@@ -70,7 +71,7 @@ class Sampler(nn.Module):
# Apply top-p and top-k truncation. # Apply top-p and top-k truncation.
top_ps, top_ks, min_ps = _get_top_p_top_k_min_p( top_ps, top_ks, min_ps = _get_top_p_top_k_min_p(
input_metadata, self.vocab_size) sampling_metadata, self.vocab_size)
assert len(top_ps) == len(top_ks) == logits.shape[0] assert len(top_ps) == len(top_ks) == logits.shape[0]
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps) do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
do_top_k = any(k != self.vocab_size for k in top_ks) do_top_k = any(k != self.vocab_size for k in top_ks)
...@@ -89,11 +90,11 @@ class Sampler(nn.Module): ...@@ -89,11 +90,11 @@ class Sampler(nn.Module):
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens. # Sample the next tokens.
sample_results = _sample(probs, logprobs, input_metadata) sample_results = _sample(probs, logprobs, sampling_metadata)
# Get the logprobs query results. # Get the logprobs query results.
prompt_logprobs, sample_logprobs = _get_logprobs( prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, input_metadata, sample_results) logprobs, sampling_metadata, sample_results)
return _build_sampler_output(sample_results, input_metadata, return _build_sampler_output(sample_results, sampling_metadata,
prompt_logprobs, sample_logprobs) prompt_logprobs, sample_logprobs)
...@@ -112,29 +113,30 @@ def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor, ...@@ -112,29 +113,30 @@ def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
def _prune_hidden_states( def _prune_hidden_states(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, sampling_metadata: SamplingMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
return hidden_states.index_select(0, input_metadata.selected_token_indices) return hidden_states.index_select(0,
sampling_metadata.selected_token_indices)
def _get_penalties( def _get_penalties(
input_metadata: InputMetadata sampling_metadata: SamplingMetadata
) -> Tuple[List[float], List[float], List[float]]: ) -> Tuple[List[float], List[float], List[float]]:
# Collect the presence and frequency penalties. # Collect the presence and frequency penalties.
presence_penalties: List[float] = [] presence_penalties: List[float] = []
frequency_penalties: List[float] = [] frequency_penalties: List[float] = []
repetition_penalties: List[float] = [] repetition_penalties: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
p = sampling_params.presence_penalty p = sampling_params.presence_penalty
f = sampling_params.frequency_penalty f = sampling_params.frequency_penalty
r = sampling_params.repetition_penalty r = sampling_params.repetition_penalty
if (i < input_metadata.num_prompts if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None): and sampling_params.prompt_logprobs is not None):
# NOTE: We do not apply presence and frequency penalties for the # NOTE: We do not apply presence and frequency penalties for the
# prompt token positions where we don't sample new tokens. # prompt token positions where we don't sample new tokens.
prompt_len = input_metadata.prompt_lens[i] prompt_len = sampling_metadata.prompt_lens[i]
presence_penalties += [0] * (prompt_len - 1) presence_penalties += [0] * (prompt_len - 1)
frequency_penalties += [0] * (prompt_len - 1) frequency_penalties += [0] * (prompt_len - 1)
repetition_penalties += [1] * (prompt_len - 1) repetition_penalties += [1] * (prompt_len - 1)
...@@ -145,21 +147,21 @@ def _get_penalties( ...@@ -145,21 +147,21 @@ def _get_penalties(
def _get_prompt_and_output_tokens( def _get_prompt_and_output_tokens(
input_metadata: InputMetadata sampling_metadata: SamplingMetadata,
) -> Tuple[List[List[int]], List[List[int]]]: ) -> Tuple[List[List[int]], List[List[int]]]:
prompt_tokens: List[List[int]] = [] prompt_tokens: List[List[int]] = []
output_tokens: List[List[int]] = [] output_tokens: List[List[int]] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
if (i < input_metadata.num_prompts if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None): and sampling_params.prompt_logprobs is not None):
# NOTE: prompt token positions do not need output tokens to # NOTE: prompt token positions do not need output tokens to
# compute penalties. # compute penalties.
prompt_len = input_metadata.prompt_lens[i] prompt_len = sampling_metadata.prompt_lens[i]
prompt_tokens.extend([] for _ in range(prompt_len - 1)) prompt_tokens.extend([] for _ in range(prompt_len - 1))
output_tokens.extend([] for _ in range(prompt_len - 1)) output_tokens.extend([] for _ in range(prompt_len - 1))
for seq_id in seq_ids: for seq_id in seq_ids:
seq_data = input_metadata.seq_data[seq_id] seq_data = sampling_metadata.seq_data[seq_id]
prompt_tokens.append(seq_data.prompt_token_ids) prompt_tokens.append(seq_data.prompt_token_ids)
output_tokens.append(seq_data.output_token_ids) output_tokens.append(seq_data.output_token_ids)
return prompt_tokens, output_tokens return prompt_tokens, output_tokens
...@@ -191,17 +193,19 @@ def _get_bin_counts_and_mask( ...@@ -191,17 +193,19 @@ def _get_bin_counts_and_mask(
return bin_counts, mask return bin_counts, mask
def _apply_logits_processors(logits: torch.Tensor, def _apply_logits_processors(
input_metadata: InputMetadata) -> torch.Tensor: logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
logits_row_idx = 0 logits_row_idx = 0
found_logits_processors = False found_logits_processors = False
for seq_ids, sampling_params in input_metadata.seq_groups: for seq_ids, sampling_params in sampling_metadata.seq_groups:
logits_processors = sampling_params.logits_processors logits_processors = sampling_params.logits_processors
if logits_processors: if logits_processors:
found_logits_processors = True found_logits_processors = True
for seq_id in seq_ids: for seq_id in seq_ids:
logits_row = logits[logits_row_idx] logits_row = logits[logits_row_idx]
token_ids = input_metadata.seq_data[seq_id].output_token_ids token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors: for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row) logits_row = logits_processor(token_ids, logits_row)
logits[logits_row_idx] = logits_row logits[logits_row_idx] = logits_row
...@@ -215,7 +219,7 @@ def _apply_logits_processors(logits: torch.Tensor, ...@@ -215,7 +219,7 @@ def _apply_logits_processors(logits: torch.Tensor,
def _apply_penalties( def _apply_penalties(
logits: torch.Tensor, logits: torch.Tensor,
input_metadata: InputMetadata, sampling_metadata: SamplingMetadata,
presence_penalties: List[float], presence_penalties: List[float],
frequency_penalties: List[float], frequency_penalties: List[float],
repetition_penalties: List[float], repetition_penalties: List[float],
...@@ -234,7 +238,7 @@ def _apply_penalties( ...@@ -234,7 +238,7 @@ def _apply_penalties(
return logits return logits
prompt_tokens, output_tokens = ( prompt_tokens, output_tokens = (
_get_prompt_and_output_tokens(input_metadata)) _get_prompt_and_output_tokens(sampling_metadata))
assert len(prompt_tokens) == logits.shape[0] assert len(prompt_tokens) == logits.shape[0]
assert len(output_tokens) == logits.shape[0] assert len(output_tokens) == logits.shape[0]
...@@ -265,10 +269,10 @@ def _apply_penalties( ...@@ -265,10 +269,10 @@ def _apply_penalties(
return logits return logits
def _get_temperatures(input_metadata: InputMetadata) -> List[float]: def _get_temperatures(sampling_metadata: SamplingMetadata) -> List[float]:
# Collect the temperatures for the logits. # Collect the temperatures for the logits.
temperatures: List[float] = [] temperatures: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature temperature = sampling_params.temperature
if temperature < _SAMPLING_EPS: if temperature < _SAMPLING_EPS:
...@@ -276,22 +280,22 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]: ...@@ -276,22 +280,22 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
# (i.e., greedy sampling or beam search). # (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero. # Set the temperature to 1 to avoid division by zero.
temperature = 1.0 temperature = 1.0
if (i < input_metadata.num_prompts if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None): and sampling_params.prompt_logprobs is not None):
prompt_len = input_metadata.prompt_lens[i] prompt_len = sampling_metadata.prompt_lens[i]
temperatures += [temperature] * (prompt_len - 1) temperatures += [temperature] * (prompt_len - 1)
temperatures += [temperature] * len(seq_ids) temperatures += [temperature] * len(seq_ids)
return temperatures return temperatures
def _get_top_p_top_k_min_p( def _get_top_p_top_k_min_p(
input_metadata: InputMetadata, sampling_metadata: SamplingMetadata,
vocab_size: int, vocab_size: int,
) -> Tuple[List[float], List[int], List[float]]: ) -> Tuple[List[float], List[int], List[float]]:
top_ps: List[float] = [] top_ps: List[float] = []
top_ks: List[int] = [] top_ks: List[int] = []
min_ps: List[float] = [] min_ps: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
top_p = sampling_params.top_p top_p = sampling_params.top_p
min_p = sampling_params.min_p min_p = sampling_params.min_p
...@@ -299,9 +303,9 @@ def _get_top_p_top_k_min_p( ...@@ -299,9 +303,9 @@ def _get_top_p_top_k_min_p(
top_k = min(sampling_params.top_k, vocab_size) top_k = min(sampling_params.top_k, vocab_size)
# k=-1 means no truncation. # k=-1 means no truncation.
top_k = vocab_size if top_k == -1 else top_k top_k = vocab_size if top_k == -1 else top_k
if (i < input_metadata.num_prompts if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None): and sampling_params.prompt_logprobs is not None):
prompt_len = input_metadata.prompt_lens[i] prompt_len = sampling_metadata.prompt_lens[i]
top_ps += [top_p] * (prompt_len - 1) top_ps += [top_p] * (prompt_len - 1)
top_ks += [top_k] * (prompt_len - 1) top_ks += [top_k] * (prompt_len - 1)
min_ps += [min_p] * (prompt_len - 1) min_ps += [min_p] * (prompt_len - 1)
...@@ -471,11 +475,11 @@ def _beam_search_sample( ...@@ -471,11 +475,11 @@ def _beam_search_sample(
def _sample( def _sample(
probs: torch.Tensor, probs: torch.Tensor,
logprobs: torch.Tensor, logprobs: torch.Tensor,
input_metadata: InputMetadata, sampling_metadata: SamplingMetadata,
) -> List[Tuple[List[int], List[int]]]: ) -> List[Tuple[List[int], List[int]]]:
categorized_seq_group_ids = {t: [] for t in SamplingType} categorized_seq_group_ids = {t: [] for t in SamplingType}
categorized_sample_indices = input_metadata.categorized_sample_indices categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(input_metadata.seq_groups): for i, seq_group in enumerate(sampling_metadata.seq_groups):
_, sampling_params = seq_group _, sampling_params = seq_group
sampling_type = sampling_params.sampling_type sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i) categorized_seq_group_ids[sampling_type].append(i)
...@@ -483,8 +487,8 @@ def _sample( ...@@ -483,8 +487,8 @@ def _sample(
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
for sampling_type in SamplingType: for sampling_type in SamplingType:
seq_group_ids = categorized_seq_group_ids[sampling_type] seq_group_ids = categorized_seq_group_ids[sampling_type]
seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids] seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids] is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
sample_indices = categorized_sample_indices[sampling_type] sample_indices = categorized_sample_indices[sampling_type]
num_tokens = len(sample_indices) num_tokens = len(sample_indices)
if num_tokens == 0: if num_tokens == 0:
...@@ -499,21 +503,22 @@ def _sample( ...@@ -499,21 +503,22 @@ def _sample(
elif sampling_type == SamplingType.BEAM: elif sampling_type == SamplingType.BEAM:
category_logprobs = logprobs[sample_indices] category_logprobs = logprobs[sample_indices]
sample_results = _beam_search_sample(seq_groups, is_prompts, sample_results = _beam_search_sample(seq_groups, is_prompts,
input_metadata.seq_data, sampling_metadata.seq_data,
category_logprobs) category_logprobs)
else: else:
raise ValueError(f"Unsupported sampling type: {sampling_type}") raise ValueError(f"Unsupported sampling type: {sampling_type}")
sample_results_dict.update(zip(seq_group_ids, sample_results)) sample_results_dict.update(zip(seq_group_ids, sample_results))
sample_results = [ sample_results = [
sample_results_dict[i] for i in range(len(input_metadata.seq_groups)) sample_results_dict[i]
for i in range(len(sampling_metadata.seq_groups))
] ]
return sample_results return sample_results
def _get_logprobs( def _get_logprobs(
logprobs: torch.Tensor, logprobs: torch.Tensor,
input_metadata: InputMetadata, sampling_metadata: SamplingMetadata,
sample_results: List[Tuple[List[int], List[int]]], sample_results: List[Tuple[List[int], List[int]]],
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[ ) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
int, float]]]]: int, float]]]]:
...@@ -523,16 +528,16 @@ def _get_logprobs( ...@@ -523,16 +528,16 @@ def _get_logprobs(
largest_num_logprobs = 0 largest_num_logprobs = 0
sample_idx = 0 sample_idx = 0
for i, (seq_group, sample_result) in enumerate( for i, (seq_group, sample_result) in enumerate(
zip(input_metadata.seq_groups, sample_results)): zip(sampling_metadata.seq_groups, sample_results)):
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result next_token_ids, parent_ids = sample_result
num_parent_seqs = len(seq_ids) num_parent_seqs = len(seq_ids)
if (i < input_metadata.num_prompts if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None): and sampling_params.prompt_logprobs is not None):
largest_num_logprobs = max(largest_num_logprobs, largest_num_logprobs = max(largest_num_logprobs,
sampling_params.prompt_logprobs) sampling_params.prompt_logprobs)
prompt_len = input_metadata.prompt_lens[i] prompt_len = sampling_metadata.prompt_lens[i]
prompt_tokens = input_metadata.seq_data[ prompt_tokens = sampling_metadata.seq_data[
seq_ids[0]].prompt_token_ids seq_ids[0]].prompt_token_ids
batched_logprobs_query_seq_indices.extend( batched_logprobs_query_seq_indices.extend(
sample_idx + j for j in range(prompt_len - 1)) sample_idx + j for j in range(prompt_len - 1))
...@@ -570,16 +575,16 @@ def _get_logprobs( ...@@ -570,16 +575,16 @@ def _get_logprobs(
sample_idx = 0 sample_idx = 0
query_result_idx = 0 query_result_idx = 0
for i, (seq_group, sample_result) in enumerate( for i, (seq_group, sample_result) in enumerate(
zip(input_metadata.seq_groups, sample_results)): zip(sampling_metadata.seq_groups, sample_results)):
seq_ids, sampling_params = seq_group seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result next_token_ids, parent_ids = sample_result
# Prompt logprobs # Prompt logprobs
if (i < input_metadata.num_prompts if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None): and sampling_params.prompt_logprobs is not None):
num_logprobs = sampling_params.prompt_logprobs num_logprobs = sampling_params.prompt_logprobs
prompt_len = input_metadata.prompt_lens[i] prompt_len = sampling_metadata.prompt_lens[i]
prompt_tokens = input_metadata.seq_data[ prompt_tokens = sampling_metadata.seq_data[
seq_ids[0]].prompt_token_ids seq_ids[0]].prompt_token_ids
group_prompt_logprobs: PromptLogprobs = [None] group_prompt_logprobs: PromptLogprobs = [None]
for token_id in prompt_tokens[1:]: for token_id in prompt_tokens[1:]:
...@@ -625,13 +630,13 @@ def _get_logprobs( ...@@ -625,13 +630,13 @@ def _get_logprobs(
def _build_sampler_output( def _build_sampler_output(
sample_results: List[Tuple[List[int], List[int]]], sample_results: List[Tuple[List[int], List[int]]],
input_metadata: InputMetadata, sampling_metadata: SamplingMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]], prompt_logprobs: List[Optional[PromptLogprobs]],
sample_logprobs: List[SampleLogprobs], sample_logprobs: List[SampleLogprobs],
) -> SamplerOutput: ) -> SamplerOutput:
sampler_output = [] sampler_output = []
for (seq_group, sample_result, group_prompt_logprobs, for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(input_metadata.seq_groups, group_sample_logprobs) in zip(sampling_metadata.seq_groups,
sample_results, prompt_logprobs, sample_results, prompt_logprobs,
sample_logprobs): sample_logprobs):
seq_ids, _ = seq_group seq_ids, _ = seq_group
......
...@@ -39,6 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -39,6 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -296,11 +297,18 @@ class AquilaForCausalLM(nn.Module): ...@@ -296,11 +297,18 @@ class AquilaForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata) sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -38,6 +38,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -38,6 +38,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -311,11 +312,18 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -311,11 +312,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata) sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -288,11 +289,18 @@ class BloomForCausalLM(nn.Module): ...@@ -288,11 +289,18 @@ class BloomForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata) sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -22,6 +22,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -22,6 +22,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -350,11 +351,18 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -350,11 +351,18 @@ class ChatGLMForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata) sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -41,6 +41,7 @@ from vllm.model_executor.parallel_utils.communication_op import ( ...@@ -41,6 +41,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -389,7 +390,7 @@ class FalconForCausalLM(nn.Module): ...@@ -389,7 +390,7 @@ class FalconForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput: ) -> torch.Tensor:
hidden_states = self.transformer( hidden_states = self.transformer(
input_ids, input_ids,
positions, positions,
...@@ -397,9 +398,15 @@ class FalconForCausalLM(nn.Module): ...@@ -397,9 +398,15 @@ class FalconForCausalLM(nn.Module):
input_metadata, input_metadata,
cache_events, cache_events,
) )
next_tokens = self.sampler(self.lm_head.weight, hidden_states, return hidden_states
input_metadata)
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -232,11 +233,18 @@ class GPT2LMHeadModel(nn.Module): ...@@ -232,11 +233,18 @@ class GPT2LMHeadModel(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata) sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -36,6 +36,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -36,6 +36,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -251,11 +252,18 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -251,11 +252,18 @@ class GPTBigCodeForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata) sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -238,11 +239,18 @@ class GPTJForCausalLM(nn.Module): ...@@ -238,11 +239,18 @@ class GPTJForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata, self.lm_head.bias) sampling_metadata, self.lm_head.bias)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -251,11 +252,18 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -251,11 +252,18 @@ class GPTNeoXForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput: ) -> torch.Tensor:
hidden_states = self.gpt_neox(input_ids, positions, kv_caches, hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.embed_out.weight, hidden_states, next_tokens = self.sampler(self.embed_out.weight, hidden_states,
input_metadata) sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -19,6 +19,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -19,6 +19,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -250,11 +251,18 @@ class InternLMForCausalLM(nn.Module): ...@@ -250,11 +251,18 @@ class InternLMForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata) sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -289,11 +290,18 @@ class LlamaForCausalLM(nn.Module): ...@@ -289,11 +290,18 @@ class LlamaForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata) sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -285,11 +286,18 @@ class MistralForCausalLM(nn.Module): ...@@ -285,11 +286,18 @@ class MistralForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata) sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -18,6 +18,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -18,6 +18,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import ( from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator) hf_model_weights_iterator)
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -256,11 +257,18 @@ class MPTForCausalLM(nn.Module): ...@@ -256,11 +257,18 @@ class MPTForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events) input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata) sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment