Unverified Commit 27feead2 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Refactor Worker & InputMetadata (#1843)

parent c7821956
......@@ -36,6 +36,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -308,11 +309,18 @@ class OPTForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -54,6 +54,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -210,28 +211,6 @@ class PhiLayer(nn.Module):
return hidden_states
class PhiCausalLMHead(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.ln = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
self.linear = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=True)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
):
hidden_states = self.ln(hidden_states)
next_tokens = self.sampler(self.linear.weight, hidden_states,
input_metadata, self.linear.bias)
return next_tokens
class PhiModel(nn.Module):
def __init__(self,
......@@ -253,7 +232,7 @@ class PhiModel(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.embd(input_ids)
for i in range(self.config.num_hidden_layers):
cache_event = None if cache_events is None else cache_events[i]
......@@ -268,6 +247,17 @@ class PhiModel(nn.Module):
return hidden_states
class PhiCausalLMHead(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.ln = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
self.linear = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=True)
class PhiForCausalLM(nn.Module):
def __init__(self,
......@@ -279,6 +269,7 @@ class PhiForCausalLM(nn.Module):
self.transformer = PhiModel(config, linear_method)
self.lm_head = PhiCausalLMHead(config)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
......@@ -287,11 +278,21 @@ class PhiForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
lm_logits = self.lm_head(hidden_states, input_metadata)
return lm_logits
hidden_states = self.lm_head.ln(hidden_states)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
head = self.lm_head.linear
next_tokens = self.sampler(head.weight, hidden_states,
sampling_metadata, head.bias)
return next_tokens
def load_weights(self,
model_name_or_path: str,
......
......@@ -23,6 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -246,11 +247,18 @@ class QWenLMHeadModel(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -284,11 +285,18 @@ class YiForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata, cache_events)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> SamplerOutput:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
input_metadata)
sampling_metadata)
return next_tokens
def load_weights(self,
......
from typing import Dict, List, Tuple
import torch
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData
class SamplingMetadata:
"""Metadata for input sequences. Used in sampler.
Args:
seq_groups: List of (seq_ids, sampling_params).
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts.
selected_token_indices: Token indices selected for sampling.
categorized_sample_indices: SamplingType -> token indicies to sample.
"""
def __init__(
self,
seq_groups: List[Tuple[List[int], SamplingParams]],
seq_data: Dict[int, SequenceData],
prompt_lens: List[int],
selected_token_indices: torch.Tensor,
categorized_sample_indices: Dict[SamplingType, torch.Tensor],
) -> None:
self.seq_groups = seq_groups
self.seq_data = seq_data
self.prompt_lens = prompt_lens
self.selected_token_indices = selected_token_indices
self.categorized_sample_indices = categorized_sample_indices
self.num_prompts = len(prompt_lens)
def __repr__(self) -> str:
return (
"SamplingMetadata("
f"seq_groups={self.seq_groups}, "
f"seq_data={self.seq_data}, "
f"prompt_lens={self.prompt_lens}, "
f"selected_token_indices={self.selected_token_indices}, "
f"categorized_sample_indices={self.categorized_sample_indices})")
from typing import Dict, List, Optional, Tuple
import torch
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
logger = init_logger(__name__)
_PAD_SLOT_ID = -1
class ModelRunner:
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
):
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.sliding_window = model_config.get_sliding_window()
self.model = None
self.block_size = None # Set after initial profiling.
def load_model(self) -> None:
self.model = get_model(self.model_config)
def set_block_size(self, block_size: int) -> None:
self.block_size = block_size
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
prompt_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
seq_id = seq_ids[0]
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
input_tokens.append(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.append(list(range(prompt_len)))
if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping.append([_PAD_SLOT_ID] * prompt_len)
continue
# Compute the slot mapping.
slot_mapping.append([])
block_table = seq_group_metadata.block_tables[seq_id]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, prompt_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx = 0
if self.sliding_window is not None:
start_idx = max(0, prompt_len - self.sliding_window)
for i in range(prompt_len):
if i < start_idx:
slot_mapping[-1].append(_PAD_SLOT_ID)
continue
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping[-1].append(slot)
max_prompt_len = max(prompt_lens)
input_tokens = _make_tensor_with_pad(input_tokens,
max_prompt_len,
pad=0,
dtype=torch.long)
input_positions = _make_tensor_with_pad(input_positions,
max_prompt_len,
pad=0,
dtype=torch.long)
slot_mapping = _make_tensor_with_pad(slot_mapping,
max_prompt_len,
pad=_PAD_SLOT_ID,
dtype=torch.long)
input_metadata = InputMetadata(
prompt_lens=prompt_lens,
slot_mapping=slot_mapping,
max_context_len=None,
context_lens=None,
block_tables=None,
)
return input_tokens, input_positions, input_metadata
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
context_lens: List[int] = []
block_tables: List[List[int]] = []
for seq_group_metadata in seq_group_metadata_list:
assert not seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
context_len = seq_data.get_len()
if self.sliding_window is not None:
context_len = min(context_len, self.sliding_window)
context_lens.append(context_len)
position = context_len - 1
input_positions.append([position])
block_table = seq_group_metadata.block_tables[seq_id]
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])
if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window //
self.block_size)
block_table = block_table[-sliding_window_blocks:]
block_tables.append(block_table)
input_tokens = _make_tensor_with_pad(input_tokens,
max_len=1,
pad=0,
dtype=torch.long)
input_positions = _make_tensor_with_pad(input_positions,
max_len=1,
pad=0,
dtype=torch.long)
slot_mapping = _make_tensor_with_pad(slot_mapping,
max_len=1,
pad=_PAD_SLOT_ID,
dtype=torch.long)
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device="cuda")
max_block_table_len = max([len(t) for t in block_tables])
block_tables = _make_tensor_with_pad(block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int)
input_metadata = InputMetadata(
prompt_lens=[],
slot_mapping=slot_mapping,
max_context_len=max_context_len,
context_lens=context_lens,
block_tables=block_tables,
)
return input_tokens, input_positions, input_metadata
def _prepare_sample(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> SamplingMetadata:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
selected_token_indices: List[int] = []
selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices_start_idx = 0
max_prompt_len = max(prompt_lens) if prompt_lens else 1
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
if seq_group_metadata.is_prompt:
assert len(seq_ids) == 1
prompt_len = prompt_lens[i]
if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += prompt_len - 1
categorized_sample_indices[
sampling_params.sampling_type].append(
categorized_sample_indices_start_idx)
categorized_sample_indices_start_idx += 1
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + prompt_len - 1))
selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += max_prompt_len
else:
num_seqs = len(seq_ids)
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + num_seqs))
selected_token_start_idx += num_seqs
categorized_sample_indices[
sampling_params.sampling_type].extend(
range(categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx + num_seqs))
categorized_sample_indices_start_idx += num_seqs
selected_token_indices = torch.tensor(selected_token_indices,
dtype=torch.long,
device="cuda")
categorized_sample_indices = {
t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
for t, seq_ids in categorized_sample_indices.items()
}
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
sampling_metadata = SamplingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
)
return sampling_metadata
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
cache_events: Optional[List[torch.cuda.Event]] = None,
) -> SamplerOutput:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
# Prepare input tensors.
is_prompt = seq_group_metadata_list[0].is_prompt
if is_prompt:
inputs = self._prepare_prompt(seq_group_metadata_list)
input_tokens, input_positions, input_metadata = inputs
else:
inputs = self._prepare_decode(seq_group_metadata_list)
input_tokens, input_positions, input_metadata = inputs
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
input_metadata.prompt_lens)
# Execute the model.
hidden_states = self.model(
input_ids=input_tokens,
positions=input_positions,
kv_caches=kv_caches,
input_metadata=input_metadata,
cache_events=cache_events,
)
# Sample the next token.
output = self.model.sample(
hidden_states=hidden_states,
sampling_metadata=sampling_metadata,
)
return output
@torch.inference_mode()
def profile_run(self) -> None:
# Enable top-k sampling to reflect the accurate memory usage.
vocab_size = self.model_config.get_vocab_size()
sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs: List[SequenceGroupMetadata] = []
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
seq_data = SequenceData([0] * seq_len)
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: seq_data},
sampling_params=sampling_params,
block_tables=None,
)
seqs.append(seq)
# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [(None, None)] * num_layers
self.execute_model(seqs, kv_caches)
return
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
assert len(x) <= max_len
return x + [pad] * (max_len - len(x))
def _make_tensor_with_pad(
x: List[List[int]],
max_len: int,
pad: int,
dtype: torch.dtype,
) -> torch.Tensor:
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
return torch.tensor(padded_x, dtype=dtype, device="cuda")
"""A GPU worker class."""
import os
from typing import Dict, List, Tuple, Optional
from typing import Dict, List, Optional, Tuple
import torch
import torch.distributed
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from vllm.model_executor import get_model, InputMetadata, set_random_seed
from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils.parallel_state import (
initialize_model_parallel)
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
from vllm.utils import get_gpu_memory
......@@ -38,11 +38,11 @@ class Worker:
self.rank = rank
self.distributed_init_method = distributed_init_method
self.model_runner = ModelRunner(model_config, parallel_config,
scheduler_config)
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self.cache_config = None
self.block_size = None
self.sliding_window = None
self.cache_engine = None
self.cache_events = None
self.gpu_cache = None
......@@ -69,7 +69,7 @@ class Worker:
set_random_seed(self.model_config.seed)
def load_model(self):
self.model = get_model(self.model_config)
self.model_runner.load_model()
@torch.inference_mode()
def profile_num_available_blocks(
......@@ -83,40 +83,9 @@ class Worker:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
# Enable top-k sampling to reflect the accurate memory usage.
vocab_size = self.model.config.vocab_size
sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
seqs = []
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
seq_data = SequenceData([0] * seq_len)
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: seq_data},
sampling_params=sampling_params,
block_tables=None,
)
seqs.append(seq)
input_tokens, input_positions, input_metadata = self._prepare_inputs(
seqs)
# Execute the model.
num_layers = self.model_config.get_num_layers(self.parallel_config)
self.model(
input_ids=input_tokens,
positions=input_positions,
kv_caches=[(None, None)] * num_layers,
input_metadata=input_metadata,
cache_events=None,
)
# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
......@@ -140,197 +109,11 @@ class Worker:
def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.cache_config = cache_config
self.block_size = cache_config.block_size
self.sliding_window = cache_config.sliding_window
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config)
self.cache_events = self.cache_engine.events
self.gpu_cache = self.cache_engine.gpu_cache
def _prepare_inputs(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
slot_mapping: List[List[int]] = []
selected_token_indices: List[int] = []
selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices_start_idx = 0
# Add prompt tokens.
prompt_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
if not seq_group_metadata.is_prompt:
continue
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
# Use any sequence in the group.
seq_id = seq_ids[0]
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
if sampling_params.prompt_logprobs is not None:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx += prompt_len - 1
categorized_sample_indices[sampling_params.sampling_type].append(
categorized_sample_indices_start_idx)
categorized_sample_indices_start_idx += 1
input_tokens.append(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.append(list(range(prompt_len)))
if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping.append([0] * prompt_len)
continue
# Compute the slot mapping.
slot_mapping.append([])
block_table = seq_group_metadata.block_tables[seq_id]
for i in range(prompt_len):
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping[-1].append(slot)
# Add generation tokens.
max_context_len = 0
max_num_blocks_per_seq = 0
context_lens: List[int] = []
generation_block_tables: List[List[int]] = []
max_seq_len = max(prompt_lens) if prompt_lens else 1
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
if seq_group_metadata.is_prompt:
# We need to do this in this loop as we need to know max_seq_len
assert len(
seq_ids) == 1, "Prompt input should have only one seq."
sampling_params = seq_group_metadata.sampling_params
assert len(prompt_lens) == len(seq_group_metadata_list)
prompt_len = prompt_lens[i]
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + prompt_len - 1))
selected_token_indices.append(selected_token_start_idx +
prompt_len - 1)
selected_token_start_idx += max_seq_len
continue
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
num_seqs = len(seq_ids)
selected_token_indices.extend(
range(selected_token_start_idx,
selected_token_start_idx + num_seqs))
selected_token_start_idx += num_seqs
categorized_sample_indices[sampling_params.sampling_type].extend(
range(categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx + num_seqs))
categorized_sample_indices_start_idx += num_seqs
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])
context_len = seq_data.get_len()
position = context_len - 1
if self.sliding_window is not None:
context_len = min(context_len, self.sliding_window)
input_positions.append([position])
block_table = seq_group_metadata.block_tables[seq_id]
max_context_len = max(max_context_len, context_len)
max_num_blocks_per_seq = max(max_num_blocks_per_seq,
len(block_table))
context_lens.append(context_len)
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append([slot])
if self.sliding_window is not None:
sliding_window_blocks = (self.sliding_window //
self.block_size)
block_table = block_table[-sliding_window_blocks:]
generation_block_tables.append(block_table)
padded_input_tokens = [
_pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens
]
padded_input_positions = [
_pad_to_max(positions, max_seq_len, pad=0)
for positions in input_positions
]
padded_slot_mapping = [
_pad_to_max(mapping, max_seq_len, pad=-1)
for mapping in slot_mapping
]
padded_block_tables = [
_pad_to_max(block_table, max_num_blocks_per_seq, pad=0)
for block_table in generation_block_tables
]
# Convert to tensors.
tokens_tensor = torch.tensor(padded_input_tokens,
dtype=torch.long,
device="cuda")
positions_tensor = torch.tensor(padded_input_positions,
dtype=torch.long,
device="cuda")
slot_mapping_tensor = torch.tensor(padded_slot_mapping,
dtype=torch.long,
device="cuda")
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device="cuda")
selected_token_indices = torch.tensor(selected_token_indices,
dtype=torch.long,
device="cuda")
categorized_sample_indices = {
t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
for t, seq_ids in categorized_sample_indices.items()
}
block_tables_tensor = torch.tensor(padded_block_tables,
dtype=torch.int,
device="cuda")
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
input_metadata = InputMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
slot_mapping=slot_mapping_tensor,
context_lens=context_lens_tensor,
max_context_len=max_context_len,
block_tables=block_tables_tensor,
selected_token_indices=selected_token_indices,
categorized_sample_indices=categorized_sample_indices,
sliding_window=self.sliding_window,
)
return tokens_tensor, positions_tensor, input_metadata
self.model_runner.set_block_size(self.cache_engine.block_size)
@torch.inference_mode()
def execute_model(
......@@ -361,18 +144,8 @@ class Worker:
event.wait()
return {}
# Prepare input tensors.
input_tokens, input_positions, input_metadata = self._prepare_inputs(
seq_group_metadata_list)
# Execute the model.
output = self.model(
input_ids=input_tokens,
positions=input_positions,
kv_caches=self.gpu_cache,
input_metadata=input_metadata,
cache_events=cache_events,
)
output = self.model_runner.execute_model(seq_group_metadata_list,
self.gpu_cache, cache_events)
return output
......@@ -407,14 +180,6 @@ def _init_distributed_environment(
parallel_config.pipeline_parallel_size)
def _pad_to_alignment(x: List[int], multiple_of: int, pad: int) -> List[int]:
return x + [pad] * ((-len(x)) % multiple_of)
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
return x + [pad] * (max_len - len(x))
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment