Unverified Commit 925f3332 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Core] Refactor Attention Take 2 (#3462)

parent b0dfa91d
......@@ -431,7 +431,7 @@ class SequenceGroup:
class SequenceGroupMetadata:
"""Metadata for a sequence group. Used to create `InputMetadata`.
"""Metadata for a sequence group. Used to create `AttentionMetadata`.
Args:
request_id: The ID of the request.
......
"""CacheEngine class for managing the KV cache."""
from typing import Dict, List, Tuple
from typing import Dict, List
import torch
from vllm.attention import get_attn_backend
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.utils import is_pin_memory_available, STR_DTYPE_TO_TORCH_DTYPE
logger = init_logger(__name__)
KVCache = Tuple[torch.Tensor, torch.Tensor]
class CacheEngine:
"""Manages the KV cache.
......@@ -43,95 +42,43 @@ class CacheEngine:
else:
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# Get attention backend.
self.attn_backend = get_attn_backend(model_config.dtype)
# Initialize the cache.
self.gpu_cache = self.allocate_gpu_cache()
self.cpu_cache = self.allocate_cpu_cache()
def get_key_block_shape(self) -> Tuple[int, int, int, int]:
element_size = torch.tensor([], dtype=self.dtype).element_size()
x = 16 // element_size
return (
self.num_heads,
self.head_size // x,
self.block_size,
x,
)
def get_value_block_shape(self) -> Tuple[int, int, int]:
return (
self.num_heads,
self.head_size,
self.block_size,
)
def allocate_gpu_cache(self) -> List[KVCache]:
gpu_cache: List[KVCache] = []
key_block_shape = self.get_key_block_shape()
value_block_shape = self.get_value_block_shape()
for _ in range(self.num_layers):
key_blocks = torch.empty(
size=(self.num_gpu_blocks, *key_block_shape),
dtype=self.dtype,
device="cuda",
)
value_blocks = torch.empty(
size=(self.num_gpu_blocks, *value_block_shape),
dtype=self.dtype,
device="cuda",
)
gpu_cache.append((key_blocks, value_blocks))
return gpu_cache
def allocate_cpu_cache(self) -> List[KVCache]:
cpu_cache: List[KVCache] = []
key_block_shape = self.get_key_block_shape()
value_block_shape = self.get_value_block_shape()
pin_memory = is_pin_memory_available()
for _ in range(self.num_layers):
key_blocks = torch.empty(
size=(self.num_cpu_blocks, *key_block_shape),
dtype=self.dtype,
pin_memory=pin_memory,
device="cpu",
)
value_blocks = torch.empty(
size=(self.num_cpu_blocks, *value_block_shape),
dtype=self.dtype,
pin_memory=pin_memory,
device="cpu",
)
cpu_cache.append((key_blocks, value_blocks))
return cpu_cache
def _swap(
self,
src: List[KVCache],
dst: List[KVCache],
src_to_dst: Dict[int, int],
) -> None:
from vllm._C import cache_ops
self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda")
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
for i in range(self.num_layers):
src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks.
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst)
# Copy the value blocks.
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst)
def _allocate_kv_cache(
self,
num_blocks: int,
device: str,
) -> List[torch.Tensor]:
"""Allocates KV cache on the specified device."""
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_heads, self.head_size)
pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[torch.Tensor] = []
for _ in range(self.num_layers):
kv_cache.append(
torch.empty(kv_cache_shape,
dtype=self.dtype,
pin_memory=pin_memory,
device=device))
return kv_cache
def swap_in(self, src_to_dst: Dict[int, int]) -> None:
self._swap(self.cpu_cache, self.gpu_cache, src_to_dst)
for i in range(self.num_layers):
self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
src_to_dst)
def swap_out(self, src_to_dst: Dict[int, int]) -> None:
self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)
for i in range(self.num_layers):
self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
src_to_dst)
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
from vllm._C import cache_ops
key_caches = [key_cache for key_cache, _ in self.gpu_cache]
value_caches = [value_cache for _, value_cache in self.gpu_cache]
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
@staticmethod
def get_cache_block_size(
......
......@@ -6,10 +6,11 @@ import numpy as np
import torch
import torch.nn as nn
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig,
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import InputMetadata, SamplingMetadata
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils.communication_op import (
......@@ -28,7 +29,6 @@ from vllm.utils import (async_tensor_h2d, CudaMemoryProfiler,
logger = init_logger(__name__)
KVCache = Tuple[torch.Tensor, torch.Tensor]
_PAD_SLOT_ID = -1
LORA_WARMUP_RANK = 8
_BATCH_SIZE_ALIGNMENT = 8
......@@ -85,6 +85,9 @@ class ModelRunner:
self.pin_memory = is_pin_memory_available()
self.kv_cache_dtype = kv_cache_dtype
self.attn_backend = get_attn_backend(
self.model_config.dtype if model_config is not None else None)
def load_model(self) -> None:
with CudaMemoryProfiler() as m:
self.model = get_model(self.model_config,
......@@ -127,8 +130,8 @@ class ModelRunner:
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
List[int], List[int], Set[LoRARequest]]:
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
List[int], List[int], List[int], Set[LoRARequest]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
......@@ -216,7 +219,7 @@ class ModelRunner:
slot_mapping.append(slot)
max_subquery_len = max(subquery_lens)
max_seq_len = max(prompt_lens)
max_prompt_len = max(prompt_lens)
num_prompt_tokens = len(input_tokens)
assert max_subquery_len > 0
......@@ -270,7 +273,7 @@ class ModelRunner:
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
input_metadata = InputMetadata(
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
slot_mapping=slot_mapping,
prompt_lens=prompt_lens,
......@@ -279,7 +282,7 @@ class ModelRunner:
num_generation_tokens=0,
max_subquery_len=max_subquery_len,
max_context_len=None,
max_seq_len=max_seq_len,
max_prompt_len=max_prompt_len,
subquery_start_loc=subquery_start_loc,
seq_start_loc=seq_start_loc,
context_lens=context_lens_tensor,
......@@ -287,15 +290,15 @@ class ModelRunner:
use_cuda_graph=False,
kv_cache_dtype=self.kv_cache_dtype,
)
return (input_tokens, input_positions, input_metadata, prompt_lens,
return (input_tokens, input_positions, attn_metadata, prompt_lens,
subquery_lens, lora_index_mapping, lora_prompt_mapping,
lora_requests)
def _prepare_decode(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int],
Set[LoRARequest]]:
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
List[int], Set[LoRARequest]]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
......@@ -401,7 +404,7 @@ class ModelRunner:
device=self.device,
)
input_metadata = InputMetadata(
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
slot_mapping=slot_mapping,
prompt_lens=None,
......@@ -410,7 +413,7 @@ class ModelRunner:
num_generation_tokens=len(input_tokens),
max_subquery_len=None,
max_context_len=max_context_len,
max_seq_len=None,
max_prompt_len=None,
subquery_start_loc=None,
seq_start_loc=None,
context_lens=context_lens,
......@@ -418,7 +421,7 @@ class ModelRunner:
use_cuda_graph=use_captured_graph,
kv_cache_dtype=self.kv_cache_dtype,
)
return (input_tokens, input_positions, input_metadata,
return (input_tokens, input_positions, attn_metadata,
lora_index_mapping, lora_prompt_mapping, lora_requests)
def _prepare_sample(
......@@ -522,7 +525,7 @@ class ModelRunner:
def prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata,
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Set[int], LoRAMapping]:
if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or
......@@ -530,11 +533,11 @@ class ModelRunner:
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_metadata, prompt_lens,
(input_tokens, input_positions, attn_metadata, prompt_lens,
subquery_lens, lora_index_mapping, lora_prompt_mapping,
lora_requests) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions, input_metadata,
(input_tokens, input_positions, attn_metadata,
lora_index_mapping, lora_prompt_mapping,
lora_requests) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = []
......@@ -560,7 +563,7 @@ class ModelRunner:
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
}
metadata_dict.update(input_metadata.asdict_zerocopy())
metadata_dict.update(attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0)
else:
metadata_dict = broadcast_tensor_dict(src=0)
......@@ -570,7 +573,7 @@ class ModelRunner:
"selected_token_indices")
lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests")
input_metadata = InputMetadata(**metadata_dict)
attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
sampling_metadata = SamplingMetadata(
seq_groups=None,
seq_data=None,
......@@ -581,16 +584,16 @@ class ModelRunner:
perform_sampling=False,
)
return (input_tokens, input_positions, input_metadata,
return (input_tokens, input_positions, attn_metadata,
sampling_metadata, lora_requests, lora_mapping)
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, input_metadata, sampling_metadata,
(input_tokens, input_positions, attn_metadata, sampling_metadata,
lora_requests,
lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list)
......@@ -598,7 +601,7 @@ class ModelRunner:
self.set_active_loras(lora_requests, lora_mapping)
# Execute the model.
if input_metadata.use_cuda_graph:
if attn_metadata.use_cuda_graph:
graph_batch_size = input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size]
else:
......@@ -607,7 +610,7 @@ class ModelRunner:
input_ids=input_tokens,
positions=input_positions,
kv_caches=kv_caches,
input_metadata=input_metadata,
attn_metadata=attn_metadata,
)
# Compute the logits.
......@@ -673,7 +676,7 @@ class ModelRunner:
# Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [(None, None)] * num_layers
kv_caches = [None] * num_layers
self.execute_model(seqs, kv_caches)
torch.cuda.synchronize()
return
......@@ -705,7 +708,7 @@ class ModelRunner:
return self.lora_manager.list_loras()
@torch.inference_mode()
def capture_model(self, kv_caches: List[KVCache]) -> None:
def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
"""Cuda graph capture a model.
Note that CUDA graph's performance gain is negligible if number
......@@ -759,8 +762,8 @@ class ModelRunner:
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list):
# Create dummy input_metadata.
input_metadata = InputMetadata(
# Create dummy attn_metadata.
attn_metadata = self.attn_backend.make_metadata(
is_prompt=False,
slot_mapping=slot_mapping[:batch_size],
prompt_lens=None,
......@@ -769,7 +772,7 @@ class ModelRunner:
num_generation_tokens=batch_size,
max_subquery_len=None,
max_context_len=self.max_context_len_to_capture,
max_seq_len=None,
max_prompt_len=None,
subquery_start_loc=None,
seq_start_loc=None,
context_lens=context_lens[:batch_size],
......@@ -790,7 +793,7 @@ class ModelRunner:
input_tokens[:batch_size],
input_positions[:batch_size],
kv_caches,
input_metadata,
attn_metadata,
memory_pool=self.graph_memory_pool,
)
self.graph_memory_pool = graph_runner.graph.pool()
......@@ -826,8 +829,8 @@ class CUDAGraphRunner:
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
memory_pool,
) -> None:
assert self.graph is None
......@@ -839,7 +842,7 @@ class CUDAGraphRunner:
input_ids,
positions,
kv_caches,
input_metadata,
attn_metadata,
)
torch.cuda.synchronize()
......@@ -853,7 +856,7 @@ class CUDAGraphRunner:
input_ids,
positions,
kv_caches,
input_metadata,
attn_metadata,
)
torch.cuda.synchronize()
......@@ -862,9 +865,9 @@ class CUDAGraphRunner:
"input_ids": input_ids,
"positions": positions,
"kv_caches": kv_caches,
"slot_mapping": input_metadata.slot_mapping,
"context_lens": input_metadata.context_lens,
"block_tables": input_metadata.block_tables,
"slot_mapping": attn_metadata.slot_mapping,
"context_lens": attn_metadata.context_lens,
"block_tables": attn_metadata.block_tables,
}
self.output_buffers = {"hidden_states": hidden_states}
return
......@@ -873,8 +876,8 @@ class CUDAGraphRunner:
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
input_metadata: InputMetadata,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# KV caches are fixed tensors, so we don't need to copy them.
del kv_caches
......@@ -882,11 +885,11 @@ class CUDAGraphRunner:
# Copy the input tensors to the input buffers.
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
self.input_buffers["positions"].copy_(positions, non_blocking=True)
self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping,
self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
non_blocking=True)
self.input_buffers["context_lens"].copy_(input_metadata.context_lens,
self.input_buffers["context_lens"].copy_(attn_metadata.context_lens,
non_blocking=True)
self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
self.input_buffers["block_tables"].copy_(attn_metadata.block_tables,
non_blocking=True)
# Run the graph.
self.graph.replay()
......
......@@ -128,6 +128,9 @@ class Worker:
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory
assert peak_memory > 0, (
"Error in memory profiling. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
cache_block_size = self.get_cache_block_size_bytes(
block_size, cache_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment