Unverified Commit 75082432 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Model Runner V2] Simplify BlockTables with UVA (#31965)


Signed-off-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
parent 83e1c76d
...@@ -6,8 +6,9 @@ import torch ...@@ -6,8 +6,9 @@ import torch
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_uva_available
from vllm.utils.torch_utils import get_cuda_view_from_cpu_tensor
from vllm.v1.attention.backends.utils import PAD_SLOT_ID from vllm.v1.attention.backends.utils import PAD_SLOT_ID
from vllm.v1.utils import CpuGpuBuffer
class BlockTables: class BlockTables:
...@@ -18,51 +19,53 @@ class BlockTables: ...@@ -18,51 +19,53 @@ class BlockTables:
max_num_batched_tokens: int, max_num_batched_tokens: int,
max_model_len: int, max_model_len: int,
device: torch.device, device: torch.device,
pin_memory: bool,
): ):
self.block_sizes = block_sizes self.block_sizes = block_sizes
self.max_num_reqs = max_num_reqs self.max_num_reqs = max_num_reqs
self.max_num_batched_tokens = max_num_batched_tokens self.max_num_batched_tokens = max_num_batched_tokens
self.max_model_len = max_model_len self.max_model_len = max_model_len
self.device = device self.device = device
self.pin_memory = pin_memory
if not is_uva_available():
raise RuntimeError("UVA is not available")
self.num_kv_cache_groups = len(self.block_sizes) self.num_kv_cache_groups = len(self.block_sizes)
# num_kv_cache_groups x [max_num_reqs, max_num_blocks] # num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.block_tables: list[torch.Tensor] = [] self.block_tables: list[UvaBuffer] = []
for i in range(self.num_kv_cache_groups): for i in range(self.num_kv_cache_groups):
block_size = self.block_sizes[i] block_size = self.block_sizes[i]
max_num_blocks = cdiv(self.max_model_len, block_size) max_num_blocks = cdiv(self.max_model_len, block_size)
block_table = torch.zeros( block_table = UvaBuffer(
self.max_num_reqs, self.max_num_reqs,
max_num_blocks, max_num_blocks,
dtype=torch.int32, dtype=torch.int32,
device=self.device,
) )
self.block_tables.append(block_table) self.block_tables.append(block_table)
self.block_table_ptrs = self._make_ptr_tensor(self.block_tables) self.block_table_ptrs = self._make_ptr_tensor(
[b.gpu for b in self.block_tables]
# Block tables used for model's forward pass. )
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.input_block_tables: list[torch.Tensor] = [
torch.zeros_like(block_table) for block_table in self.block_tables
]
self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)
self.block_table_strides = torch.tensor( self.block_table_strides = torch.tensor(
[b.stride(0) for b in self.block_tables], [b.gpu.stride(0) for b in self.block_tables],
dtype=torch.int64, dtype=torch.int64,
device=self.device, device=self.device,
) )
self.block_sizes_tensor = torch.tensor( self.block_sizes_tensor = torch.tensor(
self.block_sizes, dtype=torch.int32, device=self.device self.block_sizes, dtype=torch.int32, device=self.device
) )
self.num_blocks = torch.zeros( self.num_blocks = UvaBuffer(
self.num_kv_cache_groups, self.num_kv_cache_groups,
self.max_num_reqs, self.max_num_reqs,
dtype=torch.int32, dtype=torch.int32,
device=self.device,
) )
# Block tables used for model's forward pass.
# num_kv_cache_groups x [max_num_reqs, max_num_blocks]
self.input_block_tables: list[torch.Tensor] = [
torch.zeros_like(b.gpu) for b in self.block_tables
]
self.input_block_table_ptrs = self._make_ptr_tensor(self.input_block_tables)
self.slot_mappings = torch.zeros( self.slot_mappings = torch.zeros(
self.num_kv_cache_groups, self.num_kv_cache_groups,
self.max_num_batched_tokens, self.max_num_batched_tokens,
...@@ -70,74 +73,36 @@ class BlockTables: ...@@ -70,74 +73,36 @@ class BlockTables:
device=self.device, device=self.device,
) )
# Misc buffers.
self.req_indices = self._make_buffer(self.max_num_reqs, dtype=torch.int32)
self.overwrite = self._make_buffer(self.max_num_reqs, dtype=torch.bool)
self.cu_num_new_blocks = self._make_buffer(
self.num_kv_cache_groups, self.max_num_reqs + 1, dtype=torch.int32
)
def _make_buffer(self, *args, dtype: torch.dtype) -> CpuGpuBuffer:
return CpuGpuBuffer(
*args, dtype=dtype, pin_memory=self.pin_memory, device=self.device
)
def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor: def _make_ptr_tensor(self, x: Iterable[torch.Tensor]) -> torch.Tensor:
# NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses. # NOTE(woosuk): Use uint64 instead of int64 to cover all possible addresses.
ptrs_tensor_cpu = torch.tensor( ptrs_tensor_cpu = torch.tensor(
[t.data_ptr() for t in x], [t.data_ptr() for t in x],
dtype=torch.uint64, dtype=torch.uint64,
device="cpu", device="cpu",
pin_memory=self.pin_memory, pin_memory=True,
) )
return ptrs_tensor_cpu.to(self.device, non_blocking=True) return ptrs_tensor_cpu.to(self.device, non_blocking=True)
def append_block_ids( def append_block_ids(
self, self,
# [num_reqs] req_index: int,
req_indices: list[int],
# [num_kv_cache_groups, num_reqs + 1]
cu_num_new_blocks: tuple[list[int], ...],
# [num_kv_cache_groups, num_new_blocks]
new_block_ids: tuple[list[int], ...], new_block_ids: tuple[list[int], ...],
# [num_reqs] overwrite: bool,
overwrite: list[bool],
) -> None: ) -> None:
num_reqs = len(req_indices)
self.req_indices.np[:num_reqs] = req_indices
self.overwrite.np[:num_reqs] = overwrite
for i in range(self.num_kv_cache_groups):
self.cu_num_new_blocks.np[i, : num_reqs + 1] = cu_num_new_blocks[i]
# NOTE(woosuk): Here, we cannot use a fixed-size buffer because there's
# no clear upper bound to the number of new blocks in a single step.
# NOTE(woosuk): The buffer has to be cached, because otherwise we cannot
# guarantee that the buffer is not freed before the copy is completed.
self.new_block_ids_cpu = torch.empty(
self.num_kv_cache_groups,
max(len(x) for x in new_block_ids),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
new_block_ids_np = self.new_block_ids_cpu.numpy()
for i in range(self.num_kv_cache_groups): for i in range(self.num_kv_cache_groups):
new_block_ids_np[i, : len(new_block_ids[i])] = new_block_ids[i] block_ids = new_block_ids[i]
new_block_ids_gpu = self.new_block_ids_cpu.to(self.device, non_blocking=True) num_new_blocks = len(block_ids)
if num_new_blocks == 0:
_append_block_ids_kernel[(self.num_kv_cache_groups, num_reqs)]( continue
self.req_indices.copy_to_gpu(num_reqs),
self.cu_num_new_blocks.copy_to_gpu(), # TODO(woosuk): Too many Numpy invocations. Optimize this.
self.cu_num_new_blocks.gpu.stride(0), start = self.num_blocks.np[i, req_index] if not overwrite else 0
new_block_ids_gpu, end = start + num_new_blocks
new_block_ids_gpu.stride(0), if num_new_blocks == 1:
self.overwrite.copy_to_gpu(num_reqs), self.block_tables[i].np[req_index, start] = block_ids[0]
self.block_table_strides, else:
self.block_table_ptrs, self.block_tables[i].np[req_index, start:end] = block_ids
self.num_blocks, self.num_blocks.np[i, req_index] = end
self.num_blocks.stride(0),
BLOCK_SIZE=1024, # type: ignore
)
def gather_block_tables( def gather_block_tables(
self, self,
...@@ -149,8 +114,8 @@ class BlockTables: ...@@ -149,8 +114,8 @@ class BlockTables:
self.block_table_ptrs, self.block_table_ptrs,
self.input_block_table_ptrs, self.input_block_table_ptrs,
self.block_table_strides, self.block_table_strides,
self.num_blocks, self.num_blocks.gpu,
self.num_blocks.stride(0), self.num_blocks.gpu.stride(0),
BLOCK_SIZE=1024, # type: ignore BLOCK_SIZE=1024, # type: ignore
) )
return tuple(block_table[:num_reqs] for block_table in self.input_block_tables) return tuple(block_table[:num_reqs] for block_table in self.input_block_tables)
...@@ -186,54 +151,6 @@ class BlockTables: ...@@ -186,54 +151,6 @@ class BlockTables:
return self.slot_mappings[:, :num_tokens] return self.slot_mappings[:, :num_tokens]
@triton.jit
def _append_block_ids_kernel(
# Inputs
req_indices, # [num_reqs]
cu_num_new_blocks_ptr, # [num_kv_cache_groups, num_reqs + 1]
cu_num_new_blocks_stride,
new_block_ids_ptr, # [num_kv_cache_groups, num_new_blocks]
new_block_ids_stride,
overwrite, # [num_reqs]
block_table_strides, # [num_kv_cache_groups]
# Outputs
block_table_ptrs, # [num_kv_cache_groups]
num_blocks_ptr, # [num_kv_cache_groups, max_num_reqs]
num_blocks_stride,
# Constants
BLOCK_SIZE: tl.constexpr,
):
group_id = tl.program_id(0)
batch_idx = tl.program_id(1)
req_idx = tl.load(req_indices + batch_idx)
do_overwrite = tl.load(overwrite + batch_idx)
group_new_blocks_ptr = cu_num_new_blocks_ptr + group_id * cu_num_new_blocks_stride
start_idx = tl.load(group_new_blocks_ptr + batch_idx)
end_idx = tl.load(group_new_blocks_ptr + batch_idx + 1)
num_new_blocks = end_idx - start_idx
group_num_blocks_ptr = num_blocks_ptr + group_id * num_blocks_stride
dst_start_idx = tl.load(group_num_blocks_ptr + req_idx) if not do_overwrite else 0
dst_end_idx = dst_start_idx + num_new_blocks
tl.store(group_num_blocks_ptr + req_idx, dst_end_idx)
# Destination
block_table_ptr = _load_ptr(block_table_ptrs + group_id, tl.int32)
block_table_stride = tl.load(block_table_strides + group_id)
row_ptr = block_table_ptr + req_idx * block_table_stride
group_new_block_ids_ptr = new_block_ids_ptr + group_id * new_block_ids_stride
for i in range(0, num_new_blocks, BLOCK_SIZE):
offset = i + tl.arange(0, BLOCK_SIZE)
block_ids = tl.load(
group_new_block_ids_ptr + start_idx + offset, mask=offset < num_new_blocks
)
tl.store(
row_ptr + dst_start_idx + offset, block_ids, mask=offset < num_new_blocks
)
@triton.jit @triton.jit
def _gather_block_tables_kernel( def _gather_block_tables_kernel(
batch_idx_to_req_idx, # [batch_size] batch_idx_to_req_idx, # [batch_size]
...@@ -312,3 +229,10 @@ def _load_ptr(ptr_to_ptr, elem_dtype): ...@@ -312,3 +229,10 @@ def _load_ptr(ptr_to_ptr, elem_dtype):
ptr = tl.load(ptr_to_ptr) ptr = tl.load(ptr_to_ptr)
ptr = tl.cast(ptr, tl.pointer_type(elem_dtype)) ptr = tl.cast(ptr, tl.pointer_type(elem_dtype))
return tl.multiple_of(ptr, 16) return tl.multiple_of(ptr, 16)
class UvaBuffer:
def __init__(self, *size, dtype: torch.dtype):
self.cpu = torch.zeros(*size, dtype=dtype, device="cpu", pin_memory=True)
self.np = self.cpu.numpy()
self.gpu = get_cuda_view_from_cpu_tensor(self.cpu)
...@@ -193,7 +193,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -193,7 +193,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
max_num_batched_tokens=self.max_num_tokens, max_num_batched_tokens=self.max_num_tokens,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
device=self.device, device=self.device,
pin_memory=self.pin_memory,
) )
self.attn_backends, self.attn_metadata_builders = init_attn_backend( self.attn_backends, self.attn_metadata_builders = init_attn_backend(
...@@ -382,16 +381,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -382,16 +381,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for req_id in scheduler_output.finished_req_ids: for req_id in scheduler_output.finished_req_ids:
self.req_states.remove_request(req_id) self.req_states.remove_request(req_id)
# TODO(woosuk): Change SchedulerOutput.
req_indices: list[int] = []
cu_num_new_blocks = tuple(
[0] for _ in range(self.block_tables.num_kv_cache_groups)
)
new_block_ids: tuple[list[int], ...] = tuple(
[] for _ in range(self.block_tables.num_kv_cache_groups)
)
overwrite: list[bool] = []
# Add new requests. # Add new requests.
for new_req_data in scheduler_output.scheduled_new_reqs: for new_req_data in scheduler_output.scheduled_new_reqs:
assert new_req_data.prompt_token_ids is not None assert new_req_data.prompt_token_ids is not None
...@@ -408,12 +397,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -408,12 +397,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
) )
req_index = self.req_states.req_id_to_index[req_id] req_index = self.req_states.req_id_to_index[req_id]
req_indices.append(req_index) self.block_tables.append_block_ids(
for i, block_ids in enumerate(new_req_data.block_ids): req_index, new_req_data.block_ids, overwrite=True
x = cu_num_new_blocks[i][-1] )
cu_num_new_blocks[i].append(x + len(block_ids))
new_block_ids[i].extend(block_ids)
overwrite.append(True)
if scheduler_output.scheduled_new_reqs: if scheduler_output.scheduled_new_reqs:
self.req_states.prefill_len.copy_to_gpu() self.req_states.prefill_len.copy_to_gpu()
...@@ -421,22 +407,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): ...@@ -421,22 +407,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cached_reqs = scheduler_output.scheduled_cached_reqs cached_reqs = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(cached_reqs.req_ids): for i, req_id in enumerate(cached_reqs.req_ids):
req_index = self.req_states.req_id_to_index[req_id] req_index = self.req_states.req_id_to_index[req_id]
req_new_block_ids = cached_reqs.new_block_ids[i] req_new_block_ids = cached_reqs.new_block_ids[i]
if req_new_block_ids is not None: if req_new_block_ids is not None:
req_indices.append(req_index)
for group_id, block_ids in enumerate(req_new_block_ids):
x = cu_num_new_blocks[group_id][-1]
cu_num_new_blocks[group_id].append(x + len(block_ids))
new_block_ids[group_id].extend(block_ids)
overwrite.append(False)
if req_indices:
self.block_tables.append_block_ids( self.block_tables.append_block_ids(
req_indices=req_indices, req_index, req_new_block_ids, overwrite=False
cu_num_new_blocks=cu_num_new_blocks,
new_block_ids=new_block_ids,
overwrite=overwrite,
) )
def prepare_inputs( def prepare_inputs(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment