Unverified Commit 27feead2 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Refactor Worker & InputMetadata (#1843)

parent c7821956
......@@ -161,6 +161,12 @@ class ModelConfig:
"must be divisible by pipeline parallel size "
f"({pipeline_parallel_size}).")
def get_sliding_window(self) -> Optional[int]:
return getattr(self.hf_config, "sliding_window", None)
def get_vocab_size(self) -> int:
return self.hf_config.vocab_size
def get_hidden_size(self) -> int:
return self.hf_config.hidden_size
......
......@@ -201,9 +201,10 @@ class EngineArgs:
self.dtype, self.seed, self.revision,
self.tokenizer_revision, self.max_model_len,
self.quantization)
cache_config = CacheConfig(
self.block_size, self.gpu_memory_utilization, self.swap_space,
getattr(model_config.hf_config, 'sliding_window', None))
cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization,
self.swap_space,
model_config.get_sliding_window())
parallel_config = ParallelConfig(self.pipeline_parallel_size,
self.tensor_parallel_size,
self.worker_use_ray,
......
......@@ -88,8 +88,6 @@ class LLMEngine:
self.model_config = model_config
self.cache_config = cache_config
assert self.cache_config.sliding_window == getattr(
self.model_config.hf_config, "sliding_window", None)
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.log_stats = log_stats
......
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
__all__ = [
"InputMetadata",
"get_model",
"SamplingMetadata",
"set_random_seed",
]
from typing import Dict, List, Optional, Tuple
from typing import List, Optional
import torch
from xformers.ops import AttentionBias
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData
class InputMetadata:
"""Metadata for input sequences. Used for PagedAttention.
"""Metadata for input sequences. Used in PagedAttention.
Args:
seq_groups: List of (seq_ids, sampling_params).
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts.
slot_mapping: The address to write the new KV to of each token.
context_lens: the length of attention context for each generation token.
max_context_len: The maximum context length.
context_lens: the length of attention context for each sequence.
block_tables: The block tables. (Seq id -> list of physical block)
"""
def __init__(
self,
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_data: Dict[int, SequenceData],
prompt_lens: List[int],
slot_mapping: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
block_tables: torch.Tensor,
selected_token_indices: torch.Tensor,
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
sliding_window: Optional[int] = None,
max_context_len: Optional[int],
context_lens: Optional[torch.Tensor],
block_tables: Optional[torch.Tensor],
) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
self.prompt_lens = prompt_lens
self.max_context_len = max_context_len
self.slot_mapping = slot_mapping
self.context_lens = context_lens
self.max_context_len = max_context_len
self.block_tables = block_tables
self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices
self.max_prompt_len = max(prompt_lens) if prompt_lens else 0
self.to_cache = None
if sliding_window is not None:
# We need to keep the positions of sliding windows within
# the key / value tables, this is helpful to know which
# elements we need to cache.
to_cache, start_idx = [], 0
for prompt_len in self.prompt_lens:
to_cache.extend(
range(
start_idx + max(0, prompt_len - sliding_window),
start_idx + prompt_len,
))
start_idx += self.max_prompt_len
to_cache.extend(range(start_idx, slot_mapping.shape[0]))
self.to_cache = torch.tensor(to_cache,
dtype=torch.int32,
device=self.slot_mapping.device)
self.num_prompts = len(prompt_lens)
self.num_prompt_tokens = self.num_prompts * self.max_prompt_len
self.num_generation_tokens = context_lens.shape[0]
if block_tables.numel() > 0:
self.max_num_blocks_per_seq = block_tables.shape[1]
else:
self.max_num_blocks_per_seq = 0
assert block_tables.shape[0] == self.num_generation_tokens
self.is_prompt = len(prompt_lens) > 0
# Set during the execution of the first attention op.
self.attn_bias: Optional[AttentionBias] = None
# FIXME(woosuk): This is a hack.
self.attn_bias = None
def __repr__(self) -> str:
# Print only useful metadata.
return (
f'InputMetadata('
f'num_prompt_tokens={self.num_prompt_tokens}, '
f'num_prompts={self.num_prompts}, '
f'prompt_lens={self.prompt_lens}, '
f'num_generation_tokens={self.num_generation_tokens}, '
f'context_lens={self.context_lens}, '
f'max_context_len={self.max_context_len}), '
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
f'block_tables={self.block_tables}, '
f'selected_token_indices={self.selected_token_indices}, '
f'categorized_sample_indices={self.categorized_sample_indices}, '
f'slot_mapping={self.slot_mapping})')
return ("InputMetadata("
f"prompt_lens={self.prompt_lens}, "
f"max_context_len={self.max_context_len}, "
f"slot_mapping={self.slot_mapping}, "
f"context_lens={self.context_lens}, "
f"block_tables={self.block_tables})")
......@@ -101,23 +101,15 @@ class PagedAttention(nn.Module):
# vectors will not be cached. This happens during the initial memory
# profiling run.
if key_cache is not None and value_cache is not None:
key_to_cache = key
value_to_cache = value
if input_metadata.to_cache is not None:
key_to_cache = key_to_cache[input_metadata.to_cache]
value_to_cache = value_to_cache[input_metadata.to_cache]
slot_mapping = slot_mapping[input_metadata.to_cache]
cache_ops.reshape_and_cache(
key_to_cache,
value_to_cache,
key,
value,
key_cache,
value_cache,
slot_mapping,
)
is_prompt = len(input_metadata.prompt_lens) > 0
if is_prompt:
if input_metadata.is_prompt:
# Prompt run.
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
......
......@@ -4,9 +4,9 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
SequenceData, SequenceGroupOutput, SequenceOutput)
......@@ -37,29 +37,30 @@ class Sampler(nn.Module):
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> SamplerOutput:
# Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, input_metadata)
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
# Get the logits for the next tokens.
logits = _get_logits(hidden_states, embedding, embedding_bias,
self.vocab_size)
# Apply logits processors (if any).
logits = _apply_logits_processors(logits, input_metadata)
logits = _apply_logits_processors(logits, sampling_metadata)
# Apply presence and frequency penalties.
presence_penalties, frequency_penalties, repetition_penalties = (
_get_penalties(input_metadata))
_get_penalties(sampling_metadata))
assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0]
assert len(repetition_penalties) == logits.shape[0]
logits = _apply_penalties(logits, input_metadata, presence_penalties,
frequency_penalties, repetition_penalties)
logits = _apply_penalties(logits, sampling_metadata,
presence_penalties, frequency_penalties,
repetition_penalties)
# Apply temperature scaling.
temperatures = _get_temperatures(input_metadata)
temperatures = _get_temperatures(sampling_metadata)
assert len(temperatures) == logits.shape[0]
if any(t != 1.0 for t in temperatures):
t = torch.tensor(temperatures,
......@@ -70,7 +71,7 @@ class Sampler(nn.Module):
# Apply top-p and top-k truncation.
top_ps, top_ks, min_ps = _get_top_p_top_k_min_p(
input_metadata, self.vocab_size)
sampling_metadata, self.vocab_size)
assert len(top_ps) == len(top_ks) == logits.shape[0]
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
do_top_k = any(k != self.vocab_size for k in top_ks)
......@@ -89,11 +90,11 @@ class Sampler(nn.Module):
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens.
sample_results = _sample(probs, logprobs, input_metadata)
sample_results = _sample(probs, logprobs, sampling_metadata)
# Get the logprobs query results.
prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, input_metadata, sample_results)
return _build_sampler_output(sample_results, input_metadata,
logprobs, sampling_metadata, sample_results)
return _build_sampler_output(sample_results, sampling_metadata,
prompt_logprobs, sample_logprobs)
......@@ -112,29 +113,30 @@ def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
def _prune_hidden_states(
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
return hidden_states.index_select(0, input_metadata.selected_token_indices)
return hidden_states.index_select(0,
sampling_metadata.selected_token_indices)
def _get_penalties(
input_metadata: InputMetadata
sampling_metadata: SamplingMetadata
) -> Tuple[List[float], List[float], List[float]]:
# Collect the presence and frequency penalties.
presence_penalties: List[float] = []
frequency_penalties: List[float] = []
repetition_penalties: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
p = sampling_params.presence_penalty
f = sampling_params.frequency_penalty
r = sampling_params.repetition_penalty
if (i < input_metadata.num_prompts
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
# NOTE: We do not apply presence and frequency penalties for the
# prompt token positions where we don't sample new tokens.
prompt_len = input_metadata.prompt_lens[i]
prompt_len = sampling_metadata.prompt_lens[i]
presence_penalties += [0] * (prompt_len - 1)
frequency_penalties += [0] * (prompt_len - 1)
repetition_penalties += [1] * (prompt_len - 1)
......@@ -145,21 +147,21 @@ def _get_penalties(
def _get_prompt_and_output_tokens(
input_metadata: InputMetadata
sampling_metadata: SamplingMetadata,
) -> Tuple[List[List[int]], List[List[int]]]:
prompt_tokens: List[List[int]] = []
output_tokens: List[List[int]] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
if (i < input_metadata.num_prompts
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
# NOTE: prompt token positions do not need output tokens to
# compute penalties.
prompt_len = input_metadata.prompt_lens[i]
prompt_len = sampling_metadata.prompt_lens[i]
prompt_tokens.extend([] for _ in range(prompt_len - 1))
output_tokens.extend([] for _ in range(prompt_len - 1))
for seq_id in seq_ids:
seq_data = input_metadata.seq_data[seq_id]
seq_data = sampling_metadata.seq_data[seq_id]
prompt_tokens.append(seq_data.prompt_token_ids)
output_tokens.append(seq_data.output_token_ids)
return prompt_tokens, output_tokens
......@@ -191,17 +193,19 @@ def _get_bin_counts_and_mask(
return bin_counts, mask
def _apply_logits_processors(logits: torch.Tensor,
input_metadata: InputMetadata) -> torch.Tensor:
def _apply_logits_processors(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
logits_row_idx = 0
found_logits_processors = False
for seq_ids, sampling_params in input_metadata.seq_groups:
for seq_ids, sampling_params in sampling_metadata.seq_groups:
logits_processors = sampling_params.logits_processors
if logits_processors:
found_logits_processors = True
for seq_id in seq_ids:
logits_row = logits[logits_row_idx]
token_ids = input_metadata.seq_data[seq_id].output_token_ids
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
logits[logits_row_idx] = logits_row
......@@ -215,7 +219,7 @@ def _apply_logits_processors(logits: torch.Tensor,
def _apply_penalties(
logits: torch.Tensor,
input_metadata: InputMetadata,
sampling_metadata: SamplingMetadata,
presence_penalties: List[float],
frequency_penalties: List[float],
repetition_penalties: List[float],
......@@ -234,7 +238,7 @@ def _apply_penalties(
return logits
prompt_tokens, output_tokens = (
_get_prompt_and_output_tokens(input_metadata))
_get_prompt_and_output_tokens(sampling_metadata))
assert len(prompt_tokens) == logits.shape[0]
assert len(output_tokens) == logits.shape[0]
......@@ -265,10 +269,10 @@ def _apply_penalties(
return logits
def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
def _get_temperatures(sampling_metadata: SamplingMetadata) -> List[float]:
# Collect the temperatures for the logits.
temperatures: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature
if temperature < _SAMPLING_EPS:
......@@ -276,22 +280,22 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
temperature = 1.0
if (i < input_metadata.num_prompts
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
prompt_len = input_metadata.prompt_lens[i]
prompt_len = sampling_metadata.prompt_lens[i]
temperatures += [temperature] * (prompt_len - 1)
temperatures += [temperature] * len(seq_ids)
return temperatures
def _get_top_p_top_k_min_p(
input_metadata: InputMetadata,
sampling_metadata: SamplingMetadata,
vocab_size: int,
) -> Tuple[List[float], List[int], List[float]]:
top_ps: List[float] = []
top_ks: List[int] = []
min_ps: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
top_p = sampling_params.top_p
min_p = sampling_params.min_p
......@@ -299,9 +303,9 @@ def _get_top_p_top_k_min_p(
top_k = min(sampling_params.top_k, vocab_size)
# k=-1 means no truncation.
top_k = vocab_size if top_k == -1 else top_k
if (i < input_metadata.num_prompts
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
prompt_len = input_metadata.prompt_lens[i]
prompt_len = sampling_metadata.prompt_lens[i]
top_ps += [top_p] * (prompt_len - 1)
top_ks += [top_k] * (prompt_len - 1)
min_ps += [min_p] * (prompt_len - 1)
......@@ -471,11 +475,11 @@ def _beam_search_sample(
def _sample(
probs: torch.Tensor,
logprobs: torch.Tensor,
input_metadata: InputMetadata,
sampling_metadata: SamplingMetadata,
) -> List[Tuple[List[int], List[int]]]:
categorized_seq_group_ids = {t: [] for t in SamplingType}
categorized_sample_indices = input_metadata.categorized_sample_indices
for i, seq_group in enumerate(input_metadata.seq_groups):
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
_, sampling_params = seq_group
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
......@@ -483,8 +487,8 @@ def _sample(
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
for sampling_type in SamplingType:
seq_group_ids = categorized_seq_group_ids[sampling_type]
seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids]
is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
sample_indices = categorized_sample_indices[sampling_type]
num_tokens = len(sample_indices)
if num_tokens == 0:
......@@ -499,21 +503,22 @@ def _sample(
elif sampling_type == SamplingType.BEAM:
category_logprobs = logprobs[sample_indices]
sample_results = _beam_search_sample(seq_groups, is_prompts,
input_metadata.seq_data,
sampling_metadata.seq_data,
category_logprobs)
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")
sample_results_dict.update(zip(seq_group_ids, sample_results))
sample_results = [
sample_results_dict[i] for i in range(len(input_metadata.seq_groups))
sample_results_dict[i]
for i in range(len(sampling_metadata.seq_groups))
]
return sample_results
def _get_logprobs(
logprobs: torch.Tensor,
input_metadata: InputMetadata,
sampling_metadata: SamplingMetadata,
sample_results: List[Tuple[List[int], List[int]]],
) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[
int, float]]]]:
......@@ -523,16 +528,16 @@ def _get_logprobs(
largest_num_logprobs = 0
sample_idx = 0
for i, (seq_group, sample_result) in enumerate(
zip(input_metadata.seq_groups, sample_results)):
zip(sampling_metadata.seq_groups, sample_results)):
seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result
num_parent_seqs = len(seq_ids)
if (i < input_metadata.num_prompts
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
largest_num_logprobs = max(largest_num_logprobs,
sampling_params.prompt_logprobs)
prompt_len = input_metadata.prompt_lens[i]
prompt_tokens = input_metadata.seq_data[
prompt_len = sampling_metadata.prompt_lens[i]
prompt_tokens = sampling_metadata.seq_data[
seq_ids[0]].prompt_token_ids
batched_logprobs_query_seq_indices.extend(
sample_idx + j for j in range(prompt_len - 1))
......@@ -570,16 +575,16 @@ def _get_logprobs(
sample_idx = 0
query_result_idx = 0
for i, (seq_group, sample_result) in enumerate(
zip(input_metadata.seq_groups, sample_results)):
zip(sampling_metadata.seq_groups, sample_results)):
seq_ids, sampling_params = seq_group
next_token_ids, parent_ids = sample_result
# Prompt logprobs
if (i < input_metadata.num_prompts
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
num_logprobs = sampling_params.prompt_logprobs
prompt_len = input_metadata.prompt_lens[i]
prompt_tokens = input_metadata.seq_data[
prompt_len = sampling_metadata.prompt_lens[i]
prompt_tokens = sampling_metadata.seq_data[
seq_ids[0]].prompt_token_ids
group_prompt_logprobs: PromptLogprobs = [None]
for token_id in prompt_tokens[1:]:
......@@ -625,13 +630,13 @@ def _get_logprobs(
def _build_sampler_output(
sample_results: List[Tuple[List[int], List[int]]],
input_metadata: InputMetadata,
sampling_metadata: SamplingMetadata,
prompt_logprobs: List[Optional[PromptLogprobs]],
sample_logprobs: List[SampleLogprobs],
) -> SamplerOutput:
sampler_output = []
for (seq_group, sample_result, group_prompt_logprobs,
group_sample_logprobs) in zip(input_metadata.seq_groups,
group_sample_logprobs) in zip(sampling_metadata.seq_groups,
sample_results, prompt_logprobs,
sample_logprobs):
seq_ids, _ = seq_group
......
......@@ -39,6 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -296,11 +297,18 @@ class AquilaForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -38,6 +38,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -311,11 +312,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -288,11 +289,18 @@ class BloomForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -22,6 +22,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -350,11 +351,18 @@ class ChatGLMForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -41,6 +41,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -389,7 +390,7 @@ class FalconForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids,
positions,
......@@ -397,9 +398,15 @@ class FalconForCausalLM(nn.Module):
input_metadata,
cache_events,
)
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -232,11 +233,18 @@ class GPT2LMHeadModel(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -36,6 +36,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -251,11 +252,18 @@ class GPTBigCodeForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -238,11 +239,18 @@ class GPTJForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata, self.lm_head.bias)
sampling_metadata, self.lm_head.bias)
return next_tokens
def load_weights(self,
......
......@@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -251,11 +252,18 @@ class GPTNeoXForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.embed_out.weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -19,6 +19,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -250,11 +251,18 @@ class InternLMForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -289,11 +290,18 @@ class LlamaForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -285,11 +286,18 @@ class MistralForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -18,6 +18,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -256,11 +257,18 @@ class MPTForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment