Unverified Commit 925f3332 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Core] Refactor Attention Take 2 (#3462)

parent b0dfa91d
...@@ -431,7 +431,7 @@ class SequenceGroup: ...@@ -431,7 +431,7 @@ class SequenceGroup:
class SequenceGroupMetadata: class SequenceGroupMetadata:
"""Metadata for a sequence group. Used to create `InputMetadata`. """Metadata for a sequence group. Used to create `AttentionMetadata`.
Args: Args:
request_id: The ID of the request. request_id: The ID of the request.
......
"""CacheEngine class for managing the KV cache.""" """CacheEngine class for managing the KV cache."""
from typing import Dict, List, Tuple from typing import Dict, List
import torch import torch
from vllm.attention import get_attn_backend
from vllm.config import CacheConfig, ModelConfig, ParallelConfig from vllm.config import CacheConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_pin_memory_available, STR_DTYPE_TO_TORCH_DTYPE from vllm.utils import is_pin_memory_available, STR_DTYPE_TO_TORCH_DTYPE
logger = init_logger(__name__) logger = init_logger(__name__)
KVCache = Tuple[torch.Tensor, torch.Tensor]
class CacheEngine: class CacheEngine:
"""Manages the KV cache. """Manages the KV cache.
...@@ -43,95 +42,43 @@ class CacheEngine: ...@@ -43,95 +42,43 @@ class CacheEngine:
else: else:
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
# Get attention backend.
self.attn_backend = get_attn_backend(model_config.dtype)
# Initialize the cache. # Initialize the cache.
self.gpu_cache = self.allocate_gpu_cache() self.gpu_cache = self._allocate_kv_cache(self.num_gpu_blocks, "cuda")
self.cpu_cache = self.allocate_cpu_cache() self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks, "cpu")
def get_key_block_shape(self) -> Tuple[int, int, int, int]:
element_size = torch.tensor([], dtype=self.dtype).element_size()
x = 16 // element_size
return (
self.num_heads,
self.head_size // x,
self.block_size,
x,
)
def get_value_block_shape(self) -> Tuple[int, int, int]:
return (
self.num_heads,
self.head_size,
self.block_size,
)
def allocate_gpu_cache(self) -> List[KVCache]:
gpu_cache: List[KVCache] = []
key_block_shape = self.get_key_block_shape()
value_block_shape = self.get_value_block_shape()
for _ in range(self.num_layers):
key_blocks = torch.empty(
size=(self.num_gpu_blocks, *key_block_shape),
dtype=self.dtype,
device="cuda",
)
value_blocks = torch.empty(
size=(self.num_gpu_blocks, *value_block_shape),
dtype=self.dtype,
device="cuda",
)
gpu_cache.append((key_blocks, value_blocks))
return gpu_cache
def allocate_cpu_cache(self) -> List[KVCache]:
cpu_cache: List[KVCache] = []
key_block_shape = self.get_key_block_shape()
value_block_shape = self.get_value_block_shape()
pin_memory = is_pin_memory_available()
for _ in range(self.num_layers):
key_blocks = torch.empty(
size=(self.num_cpu_blocks, *key_block_shape),
dtype=self.dtype,
pin_memory=pin_memory,
device="cpu",
)
value_blocks = torch.empty(
size=(self.num_cpu_blocks, *value_block_shape),
dtype=self.dtype,
pin_memory=pin_memory,
device="cpu",
)
cpu_cache.append((key_blocks, value_blocks))
return cpu_cache
def _swap(
self,
src: List[KVCache],
dst: List[KVCache],
src_to_dst: Dict[int, int],
) -> None:
from vllm._C import cache_ops
for i in range(self.num_layers): def _allocate_kv_cache(
src_key_cache, src_value_cache = src[i] self,
dst_key_cache, dst_value_cache = dst[i] num_blocks: int,
# Copy the key blocks. device: str,
cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) ) -> List[torch.Tensor]:
# Copy the value blocks. """Allocates KV cache on the specified device."""
cache_ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_heads, self.head_size)
pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[torch.Tensor] = []
for _ in range(self.num_layers):
kv_cache.append(
torch.empty(kv_cache_shape,
dtype=self.dtype,
pin_memory=pin_memory,
device=device))
return kv_cache
def swap_in(self, src_to_dst: Dict[int, int]) -> None: def swap_in(self, src_to_dst: Dict[int, int]) -> None:
self._swap(self.cpu_cache, self.gpu_cache, src_to_dst) for i in range(self.num_layers):
self.attn_backend.swap_blocks(self.cpu_cache[i], self.gpu_cache[i],
src_to_dst)
def swap_out(self, src_to_dst: Dict[int, int]) -> None: def swap_out(self, src_to_dst: Dict[int, int]) -> None:
self._swap(self.gpu_cache, self.cpu_cache, src_to_dst) for i in range(self.num_layers):
self.attn_backend.swap_blocks(self.gpu_cache[i], self.cpu_cache[i],
src_to_dst)
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
from vllm._C import cache_ops self.attn_backend.copy_blocks(self.gpu_cache, src_to_dsts)
key_caches = [key_cache for key_cache, _ in self.gpu_cache]
value_caches = [value_cache for _, value_cache in self.gpu_cache]
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts)
@staticmethod @staticmethod
def get_cache_block_size( def get_cache_block_size(
......
...@@ -6,10 +6,11 @@ import numpy as np ...@@ -6,10 +6,11 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig, from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import InputMetadata, SamplingMetadata from vllm.model_executor import SamplingMetadata
from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader import get_model
from vllm.model_executor.parallel_utils import cupy_utils from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
...@@ -28,7 +29,6 @@ from vllm.utils import (async_tensor_h2d, CudaMemoryProfiler, ...@@ -28,7 +29,6 @@ from vllm.utils import (async_tensor_h2d, CudaMemoryProfiler,
logger = init_logger(__name__) logger = init_logger(__name__)
KVCache = Tuple[torch.Tensor, torch.Tensor]
_PAD_SLOT_ID = -1 _PAD_SLOT_ID = -1
LORA_WARMUP_RANK = 8 LORA_WARMUP_RANK = 8
_BATCH_SIZE_ALIGNMENT = 8 _BATCH_SIZE_ALIGNMENT = 8
...@@ -85,6 +85,9 @@ class ModelRunner: ...@@ -85,6 +85,9 @@ class ModelRunner:
self.pin_memory = is_pin_memory_available() self.pin_memory = is_pin_memory_available()
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.attn_backend = get_attn_backend(
self.model_config.dtype if model_config is not None else None)
def load_model(self) -> None: def load_model(self) -> None:
with CudaMemoryProfiler() as m: with CudaMemoryProfiler() as m:
self.model = get_model(self.model_config, self.model = get_model(self.model_config,
...@@ -127,8 +130,8 @@ class ModelRunner: ...@@ -127,8 +130,8 @@ class ModelRunner:
def _prepare_prompt( def _prepare_prompt(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
List[int], List[int], Set[LoRARequest]]: List[int], List[int], List[int], Set[LoRARequest]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
...@@ -216,7 +219,7 @@ class ModelRunner: ...@@ -216,7 +219,7 @@ class ModelRunner:
slot_mapping.append(slot) slot_mapping.append(slot)
max_subquery_len = max(subquery_lens) max_subquery_len = max(subquery_lens)
max_seq_len = max(prompt_lens) max_prompt_len = max(prompt_lens)
num_prompt_tokens = len(input_tokens) num_prompt_tokens = len(input_tokens)
assert max_subquery_len > 0 assert max_subquery_len > 0
...@@ -270,7 +273,7 @@ class ModelRunner: ...@@ -270,7 +273,7 @@ class ModelRunner:
dtype=seq_start_loc.dtype, dtype=seq_start_loc.dtype,
out=seq_start_loc[1:]) out=seq_start_loc[1:])
input_metadata = InputMetadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=True, is_prompt=True,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
prompt_lens=prompt_lens, prompt_lens=prompt_lens,
...@@ -279,7 +282,7 @@ class ModelRunner: ...@@ -279,7 +282,7 @@ class ModelRunner:
num_generation_tokens=0, num_generation_tokens=0,
max_subquery_len=max_subquery_len, max_subquery_len=max_subquery_len,
max_context_len=None, max_context_len=None,
max_seq_len=max_seq_len, max_prompt_len=max_prompt_len,
subquery_start_loc=subquery_start_loc, subquery_start_loc=subquery_start_loc,
seq_start_loc=seq_start_loc, seq_start_loc=seq_start_loc,
context_lens=context_lens_tensor, context_lens=context_lens_tensor,
...@@ -287,15 +290,15 @@ class ModelRunner: ...@@ -287,15 +290,15 @@ class ModelRunner:
use_cuda_graph=False, use_cuda_graph=False,
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,
) )
return (input_tokens, input_positions, input_metadata, prompt_lens, return (input_tokens, input_positions, attn_metadata, prompt_lens,
subquery_lens, lora_index_mapping, lora_prompt_mapping, subquery_lens, lora_index_mapping, lora_prompt_mapping,
lora_requests) lora_requests)
def _prepare_decode( def _prepare_decode(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
Set[LoRARequest]]: List[int], Set[LoRARequest]]:
assert len(seq_group_metadata_list) > 0 assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
...@@ -401,7 +404,7 @@ class ModelRunner: ...@@ -401,7 +404,7 @@ class ModelRunner:
device=self.device, device=self.device,
) )
input_metadata = InputMetadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=False, is_prompt=False,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
prompt_lens=None, prompt_lens=None,
...@@ -410,7 +413,7 @@ class ModelRunner: ...@@ -410,7 +413,7 @@ class ModelRunner:
num_generation_tokens=len(input_tokens), num_generation_tokens=len(input_tokens),
max_subquery_len=None, max_subquery_len=None,
max_context_len=max_context_len, max_context_len=max_context_len,
max_seq_len=None, max_prompt_len=None,
subquery_start_loc=None, subquery_start_loc=None,
seq_start_loc=None, seq_start_loc=None,
context_lens=context_lens, context_lens=context_lens,
...@@ -418,7 +421,7 @@ class ModelRunner: ...@@ -418,7 +421,7 @@ class ModelRunner:
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,
) )
return (input_tokens, input_positions, input_metadata, return (input_tokens, input_positions, attn_metadata,
lora_index_mapping, lora_prompt_mapping, lora_requests) lora_index_mapping, lora_prompt_mapping, lora_requests)
def _prepare_sample( def _prepare_sample(
...@@ -522,7 +525,7 @@ class ModelRunner: ...@@ -522,7 +525,7 @@ class ModelRunner:
def prepare_input_tensors( def prepare_input_tensors(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata, ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata,
Set[int], LoRAMapping]: Set[int], LoRAMapping]:
if self.is_driver_worker: if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or # NOTE: We assume that all sequences in the group are all prompts or
...@@ -530,11 +533,11 @@ class ModelRunner: ...@@ -530,11 +533,11 @@ class ModelRunner:
is_prompt = seq_group_metadata_list[0].is_prompt is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors. # Prepare input tensors.
if is_prompt: if is_prompt:
(input_tokens, input_positions, input_metadata, prompt_lens, (input_tokens, input_positions, attn_metadata, prompt_lens,
subquery_lens, lora_index_mapping, lora_prompt_mapping, subquery_lens, lora_index_mapping, lora_prompt_mapping,
lora_requests) = self._prepare_prompt(seq_group_metadata_list) lora_requests) = self._prepare_prompt(seq_group_metadata_list)
else: else:
(input_tokens, input_positions, input_metadata, (input_tokens, input_positions, attn_metadata,
lora_index_mapping, lora_prompt_mapping, lora_index_mapping, lora_prompt_mapping,
lora_requests) = self._prepare_decode(seq_group_metadata_list) lora_requests) = self._prepare_decode(seq_group_metadata_list)
prompt_lens = [] prompt_lens = []
...@@ -560,7 +563,7 @@ class ModelRunner: ...@@ -560,7 +563,7 @@ class ModelRunner:
"lora_requests": lora_requests, "lora_requests": lora_requests,
"lora_mapping": lora_mapping, "lora_mapping": lora_mapping,
} }
metadata_dict.update(input_metadata.asdict_zerocopy()) metadata_dict.update(attn_metadata.asdict_zerocopy())
broadcast_tensor_dict(metadata_dict, src=0) broadcast_tensor_dict(metadata_dict, src=0)
else: else:
metadata_dict = broadcast_tensor_dict(src=0) metadata_dict = broadcast_tensor_dict(src=0)
...@@ -570,7 +573,7 @@ class ModelRunner: ...@@ -570,7 +573,7 @@ class ModelRunner:
"selected_token_indices") "selected_token_indices")
lora_mapping = metadata_dict.pop("lora_mapping") lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests") lora_requests = metadata_dict.pop("lora_requests")
input_metadata = InputMetadata(**metadata_dict) attn_metadata = self.attn_backend.make_metadata(**metadata_dict)
sampling_metadata = SamplingMetadata( sampling_metadata = SamplingMetadata(
seq_groups=None, seq_groups=None,
seq_data=None, seq_data=None,
...@@ -581,16 +584,16 @@ class ModelRunner: ...@@ -581,16 +584,16 @@ class ModelRunner:
perform_sampling=False, perform_sampling=False,
) )
return (input_tokens, input_positions, input_metadata, return (input_tokens, input_positions, attn_metadata,
sampling_metadata, lora_requests, lora_mapping) sampling_metadata, lora_requests, lora_mapping)
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
self, self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[torch.Tensor],
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
(input_tokens, input_positions, input_metadata, sampling_metadata, (input_tokens, input_positions, attn_metadata, sampling_metadata,
lora_requests, lora_requests,
lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list) lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list)
...@@ -598,7 +601,7 @@ class ModelRunner: ...@@ -598,7 +601,7 @@ class ModelRunner:
self.set_active_loras(lora_requests, lora_mapping) self.set_active_loras(lora_requests, lora_mapping)
# Execute the model. # Execute the model.
if input_metadata.use_cuda_graph: if attn_metadata.use_cuda_graph:
graph_batch_size = input_tokens.shape[0] graph_batch_size = input_tokens.shape[0]
model_executable = self.graph_runners[graph_batch_size] model_executable = self.graph_runners[graph_batch_size]
else: else:
...@@ -607,7 +610,7 @@ class ModelRunner: ...@@ -607,7 +610,7 @@ class ModelRunner:
input_ids=input_tokens, input_ids=input_tokens,
positions=input_positions, positions=input_positions,
kv_caches=kv_caches, kv_caches=kv_caches,
input_metadata=input_metadata, attn_metadata=attn_metadata,
) )
# Compute the logits. # Compute the logits.
...@@ -673,7 +676,7 @@ class ModelRunner: ...@@ -673,7 +676,7 @@ class ModelRunner:
# Run the model with the dummy inputs. # Run the model with the dummy inputs.
num_layers = self.model_config.get_num_layers(self.parallel_config) num_layers = self.model_config.get_num_layers(self.parallel_config)
kv_caches = [(None, None)] * num_layers kv_caches = [None] * num_layers
self.execute_model(seqs, kv_caches) self.execute_model(seqs, kv_caches)
torch.cuda.synchronize() torch.cuda.synchronize()
return return
...@@ -705,7 +708,7 @@ class ModelRunner: ...@@ -705,7 +708,7 @@ class ModelRunner:
return self.lora_manager.list_loras() return self.lora_manager.list_loras()
@torch.inference_mode() @torch.inference_mode()
def capture_model(self, kv_caches: List[KVCache]) -> None: def capture_model(self, kv_caches: List[torch.Tensor]) -> None:
"""Cuda graph capture a model. """Cuda graph capture a model.
Note that CUDA graph's performance gain is negligible if number Note that CUDA graph's performance gain is negligible if number
...@@ -759,8 +762,8 @@ class ModelRunner: ...@@ -759,8 +762,8 @@ class ModelRunner:
# NOTE: Capturing the largest batch size first may help reduce the # NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph. # memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list): for batch_size in reversed(batch_size_capture_list):
# Create dummy input_metadata. # Create dummy attn_metadata.
input_metadata = InputMetadata( attn_metadata = self.attn_backend.make_metadata(
is_prompt=False, is_prompt=False,
slot_mapping=slot_mapping[:batch_size], slot_mapping=slot_mapping[:batch_size],
prompt_lens=None, prompt_lens=None,
...@@ -769,7 +772,7 @@ class ModelRunner: ...@@ -769,7 +772,7 @@ class ModelRunner:
num_generation_tokens=batch_size, num_generation_tokens=batch_size,
max_subquery_len=None, max_subquery_len=None,
max_context_len=self.max_context_len_to_capture, max_context_len=self.max_context_len_to_capture,
max_seq_len=None, max_prompt_len=None,
subquery_start_loc=None, subquery_start_loc=None,
seq_start_loc=None, seq_start_loc=None,
context_lens=context_lens[:batch_size], context_lens=context_lens[:batch_size],
...@@ -790,7 +793,7 @@ class ModelRunner: ...@@ -790,7 +793,7 @@ class ModelRunner:
input_tokens[:batch_size], input_tokens[:batch_size],
input_positions[:batch_size], input_positions[:batch_size],
kv_caches, kv_caches,
input_metadata, attn_metadata,
memory_pool=self.graph_memory_pool, memory_pool=self.graph_memory_pool,
) )
self.graph_memory_pool = graph_runner.graph.pool() self.graph_memory_pool = graph_runner.graph.pool()
...@@ -826,8 +829,8 @@ class CUDAGraphRunner: ...@@ -826,8 +829,8 @@ class CUDAGraphRunner:
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
memory_pool, memory_pool,
) -> None: ) -> None:
assert self.graph is None assert self.graph is None
...@@ -839,7 +842,7 @@ class CUDAGraphRunner: ...@@ -839,7 +842,7 @@ class CUDAGraphRunner:
input_ids, input_ids,
positions, positions,
kv_caches, kv_caches,
input_metadata, attn_metadata,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -853,7 +856,7 @@ class CUDAGraphRunner: ...@@ -853,7 +856,7 @@ class CUDAGraphRunner:
input_ids, input_ids,
positions, positions,
kv_caches, kv_caches,
input_metadata, attn_metadata,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -862,9 +865,9 @@ class CUDAGraphRunner: ...@@ -862,9 +865,9 @@ class CUDAGraphRunner:
"input_ids": input_ids, "input_ids": input_ids,
"positions": positions, "positions": positions,
"kv_caches": kv_caches, "kv_caches": kv_caches,
"slot_mapping": input_metadata.slot_mapping, "slot_mapping": attn_metadata.slot_mapping,
"context_lens": input_metadata.context_lens, "context_lens": attn_metadata.context_lens,
"block_tables": input_metadata.block_tables, "block_tables": attn_metadata.block_tables,
} }
self.output_buffers = {"hidden_states": hidden_states} self.output_buffers = {"hidden_states": hidden_states}
return return
...@@ -873,8 +876,8 @@ class CUDAGraphRunner: ...@@ -873,8 +876,8 @@ class CUDAGraphRunner:
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[torch.Tensor],
input_metadata: InputMetadata, attn_metadata: AttentionMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
# KV caches are fixed tensors, so we don't need to copy them. # KV caches are fixed tensors, so we don't need to copy them.
del kv_caches del kv_caches
...@@ -882,11 +885,11 @@ class CUDAGraphRunner: ...@@ -882,11 +885,11 @@ class CUDAGraphRunner:
# Copy the input tensors to the input buffers. # Copy the input tensors to the input buffers.
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
self.input_buffers["positions"].copy_(positions, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True)
self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping, self.input_buffers["slot_mapping"].copy_(attn_metadata.slot_mapping,
non_blocking=True) non_blocking=True)
self.input_buffers["context_lens"].copy_(input_metadata.context_lens, self.input_buffers["context_lens"].copy_(attn_metadata.context_lens,
non_blocking=True) non_blocking=True)
self.input_buffers["block_tables"].copy_(input_metadata.block_tables, self.input_buffers["block_tables"].copy_(attn_metadata.block_tables,
non_blocking=True) non_blocking=True)
# Run the graph. # Run the graph.
self.graph.replay() self.graph.replay()
......
...@@ -128,6 +128,9 @@ class Worker: ...@@ -128,6 +128,9 @@ class Worker:
# NOTE(woosuk): Here we assume that the other processes using the same # NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling. # GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory peak_memory = self.init_gpu_memory - free_gpu_memory
assert peak_memory > 0, (
"Error in memory profiling. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")
cache_block_size = self.get_cache_block_size_bytes( cache_block_size = self.get_cache_block_size_bytes(
block_size, cache_dtype) block_size, cache_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment