"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "c8e3b2a5925e7b7ed21662e86a7e9553170a5633"
Unverified Commit 5a9170d9 authored by YAMY's avatar YAMY Committed by GitHub
Browse files

Optimize copy_kv_cache for spec decoding (#11126)


Co-authored-by: default avatarXinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
parent c4d77774
...@@ -415,6 +415,7 @@ class MHATokenToKVPool(KVCache): ...@@ -415,6 +415,7 @@ class MHATokenToKVPool(KVCache):
enable_memory_saver: bool, enable_memory_saver: bool,
start_layer: Optional[int] = None, start_layer: Optional[int] = None,
end_layer: Optional[int] = None, end_layer: Optional[int] = None,
enable_kv_cache_copy: bool = False,
): ):
super().__init__( super().__init__(
size, size,
...@@ -446,8 +447,57 @@ class MHATokenToKVPool(KVCache): ...@@ -446,8 +447,57 @@ class MHATokenToKVPool(KVCache):
self.device_module = torch.get_device_module(self.device) self.device_module = torch.get_device_module(self.device)
self.alt_stream = self.device_module.Stream() if _is_cuda else None self.alt_stream = self.device_module.Stream() if _is_cuda else None
if enable_kv_cache_copy:
self._init_kv_copy_and_warmup()
else:
self._kv_copy_config = None
self._finalize_allocation_log(size) self._finalize_allocation_log(size)
def _init_kv_copy_and_warmup(self):
# Heuristics for KV copy tiling
_KV_COPY_STRIDE_THRESHOLD_LARGE = 8192
_KV_COPY_STRIDE_THRESHOLD_MEDIUM = 4096
_KV_COPY_TILE_SIZE_LARGE = 512
_KV_COPY_TILE_SIZE_MEDIUM = 256
_KV_COPY_TILE_SIZE_SMALL = 128
_KV_COPY_NUM_WARPS_LARGE_TILE = 8
_KV_COPY_NUM_WARPS_SMALL_TILE = 4
stride_bytes = int(self.data_strides[0].item())
if stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_LARGE:
bytes_per_tile = _KV_COPY_TILE_SIZE_LARGE
elif stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_MEDIUM:
bytes_per_tile = _KV_COPY_TILE_SIZE_MEDIUM
else:
bytes_per_tile = _KV_COPY_TILE_SIZE_SMALL
self._kv_copy_config = {
"bytes_per_tile": bytes_per_tile,
"byte_tiles": (stride_bytes + bytes_per_tile - 1) // bytes_per_tile,
"num_warps": (
_KV_COPY_NUM_WARPS_SMALL_TILE
if bytes_per_tile <= _KV_COPY_TILE_SIZE_MEDIUM
else _KV_COPY_NUM_WARPS_LARGE_TILE
),
}
dummy_loc = torch.zeros(1, dtype=torch.int32, device=self.device)
grid = (self.data_ptrs.numel(), self._kv_copy_config["byte_tiles"])
copy_all_layer_kv_cache_tiled[grid](
self.data_ptrs,
self.data_strides,
dummy_loc,
dummy_loc,
1,
1,
BYTES_PER_TILE=self._kv_copy_config["bytes_per_tile"],
num_warps=self._kv_copy_config["num_warps"],
num_stages=2,
)
def _create_buffers(self): def _create_buffers(self):
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
with ( with (
...@@ -642,13 +692,28 @@ class MHATokenToKVPool(KVCache): ...@@ -642,13 +692,28 @@ class MHATokenToKVPool(KVCache):
self.v_buffer[layer_id - self.start_layer][loc] = cache_v self.v_buffer[layer_id - self.start_layer][loc] = cache_v
def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor): def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
copy_all_layer_kv_cache[(len(self.data_ptrs),)]( N = tgt_loc.numel()
if N == 0:
return
assert (
self._kv_copy_config is not None
), "KV copy not initialized. Set enable_kv_cache_copy=True in __init__"
cfg = self._kv_copy_config
N_upper = next_power_of_2(N)
grid = (self.data_ptrs.numel(), cfg["byte_tiles"])
copy_all_layer_kv_cache_tiled[grid](
self.data_ptrs, self.data_ptrs,
self.data_strides, self.data_strides,
tgt_loc, tgt_loc,
src_loc, src_loc,
len(tgt_loc), N,
next_power_of_2(len(tgt_loc)), N_upper,
BYTES_PER_TILE=cfg["bytes_per_tile"],
num_warps=cfg["num_warps"],
num_stages=2,
) )
...@@ -1588,38 +1653,36 @@ class DoubleSparseTokenToKVPool(KVCache): ...@@ -1588,38 +1653,36 @@ class DoubleSparseTokenToKVPool(KVCache):
@triton.jit @triton.jit
def copy_all_layer_kv_cache( def copy_all_layer_kv_cache_tiled(
data_ptrs, data_ptrs,
strides, strides,
tgt_loc_ptr, tgt_loc_ptr,
src_loc_ptr, src_loc_ptr,
num_locs, num_locs,
num_locs_upper: tl.constexpr, num_locs_upper: tl.constexpr,
BYTES_PER_TILE: tl.constexpr,
): ):
BLOCK_SIZE: tl.constexpr = 128 """2D tiled kernel. Safe for in-place copy."""
bid = tl.program_id(0) bid = tl.program_id(0)
tid = tl.program_id(1)
stride = tl.load(strides + bid) stride = tl.load(strides + bid)
base_ptr = tl.load(data_ptrs + bid)
base_ptr = tl.cast(base_ptr, tl.pointer_type(tl.uint8))
data_ptr = tl.load(data_ptrs + bid) byte_off = tid * BYTES_PER_TILE + tl.arange(0, BYTES_PER_TILE)
data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8)) mask_byte = byte_off < stride
tl.multiple_of(byte_off, 16)
num_locs_offset = tl.arange(0, num_locs_upper) loc_idx = tl.arange(0, num_locs_upper)
tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs) mask_loc = loc_idx < num_locs
src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
# NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks src = tl.load(src_loc_ptr + loc_idx, mask=mask_loc, other=0)
# because this copy is an inplace operation. tgt = tl.load(tgt_loc_ptr + loc_idx, mask=mask_loc, other=0)
num_loop = tl.cdiv(stride, BLOCK_SIZE) src_ptr = base_ptr + src[:, None] * stride + byte_off[None, :]
for i in range(num_loop): tgt_ptr = base_ptr + tgt[:, None] * stride + byte_off[None, :]
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = (num_locs_offset < num_locs)[:, None] & (copy_offset < stride)[None, :] mask = mask_loc[:, None] & mask_byte[None, :]
value = tl.load( vals = tl.load(src_ptr, mask=mask)
data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask tl.store(tgt_ptr, vals, mask=mask)
)
tl.store(
data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
value,
mask=mask,
)
...@@ -1672,6 +1672,9 @@ class ModelRunner: ...@@ -1672,6 +1672,9 @@ class ModelRunner:
enable_memory_saver=self.server_args.enable_memory_saver, enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer, start_layer=self.start_layer,
end_layer=self.end_layer, end_layer=self.end_layer,
enable_kv_cache_copy=(
self.server_args.speculative_algorithm is not None
),
) )
# Initialize token_to_kv_pool_allocator # Initialize token_to_kv_pool_allocator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment