"vscode:/vscode.git/clone" did not exist on "85526e34a18d806d9447f8f21142cab2e07a0229"
Unverified Commit 5a9170d9 authored by YAMY's avatar YAMY Committed by GitHub
Browse files

Optimize copy_kv_cache for spec decoding (#11126)


Co-authored-by: default avatarXinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
parent c4d77774
......@@ -415,6 +415,7 @@ class MHATokenToKVPool(KVCache):
enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
enable_kv_cache_copy: bool = False,
):
super().__init__(
size,
......@@ -446,8 +447,57 @@ class MHATokenToKVPool(KVCache):
self.device_module = torch.get_device_module(self.device)
self.alt_stream = self.device_module.Stream() if _is_cuda else None
if enable_kv_cache_copy:
self._init_kv_copy_and_warmup()
else:
self._kv_copy_config = None
self._finalize_allocation_log(size)
def _init_kv_copy_and_warmup(self):
# Heuristics for KV copy tiling
_KV_COPY_STRIDE_THRESHOLD_LARGE = 8192
_KV_COPY_STRIDE_THRESHOLD_MEDIUM = 4096
_KV_COPY_TILE_SIZE_LARGE = 512
_KV_COPY_TILE_SIZE_MEDIUM = 256
_KV_COPY_TILE_SIZE_SMALL = 128
_KV_COPY_NUM_WARPS_LARGE_TILE = 8
_KV_COPY_NUM_WARPS_SMALL_TILE = 4
stride_bytes = int(self.data_strides[0].item())
if stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_LARGE:
bytes_per_tile = _KV_COPY_TILE_SIZE_LARGE
elif stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_MEDIUM:
bytes_per_tile = _KV_COPY_TILE_SIZE_MEDIUM
else:
bytes_per_tile = _KV_COPY_TILE_SIZE_SMALL
self._kv_copy_config = {
"bytes_per_tile": bytes_per_tile,
"byte_tiles": (stride_bytes + bytes_per_tile - 1) // bytes_per_tile,
"num_warps": (
_KV_COPY_NUM_WARPS_SMALL_TILE
if bytes_per_tile <= _KV_COPY_TILE_SIZE_MEDIUM
else _KV_COPY_NUM_WARPS_LARGE_TILE
),
}
dummy_loc = torch.zeros(1, dtype=torch.int32, device=self.device)
grid = (self.data_ptrs.numel(), self._kv_copy_config["byte_tiles"])
copy_all_layer_kv_cache_tiled[grid](
self.data_ptrs,
self.data_strides,
dummy_loc,
dummy_loc,
1,
1,
BYTES_PER_TILE=self._kv_copy_config["bytes_per_tile"],
num_warps=self._kv_copy_config["num_warps"],
num_stages=2,
)
def _create_buffers(self):
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
with (
......@@ -642,13 +692,28 @@ class MHATokenToKVPool(KVCache):
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
copy_all_layer_kv_cache[(len(self.data_ptrs),)](
N = tgt_loc.numel()
if N == 0:
return
assert (
self._kv_copy_config is not None
), "KV copy not initialized. Set enable_kv_cache_copy=True in __init__"
cfg = self._kv_copy_config
N_upper = next_power_of_2(N)
grid = (self.data_ptrs.numel(), cfg["byte_tiles"])
copy_all_layer_kv_cache_tiled[grid](
self.data_ptrs,
self.data_strides,
tgt_loc,
src_loc,
len(tgt_loc),
next_power_of_2(len(tgt_loc)),
N,
N_upper,
BYTES_PER_TILE=cfg["bytes_per_tile"],
num_warps=cfg["num_warps"],
num_stages=2,
)
......@@ -1588,38 +1653,36 @@ class DoubleSparseTokenToKVPool(KVCache):
@triton.jit
def copy_all_layer_kv_cache(
def copy_all_layer_kv_cache_tiled(
data_ptrs,
strides,
tgt_loc_ptr,
src_loc_ptr,
num_locs,
num_locs_upper: tl.constexpr,
BYTES_PER_TILE: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 128
"""2D tiled kernel. Safe for in-place copy."""
bid = tl.program_id(0)
tid = tl.program_id(1)
stride = tl.load(strides + bid)
base_ptr = tl.load(data_ptrs + bid)
base_ptr = tl.cast(base_ptr, tl.pointer_type(tl.uint8))
data_ptr = tl.load(data_ptrs + bid)
data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8))
byte_off = tid * BYTES_PER_TILE + tl.arange(0, BYTES_PER_TILE)
mask_byte = byte_off < stride
tl.multiple_of(byte_off, 16)
num_locs_offset = tl.arange(0, num_locs_upper)
tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
loc_idx = tl.arange(0, num_locs_upper)
mask_loc = loc_idx < num_locs
# NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks
# because this copy is an inplace operation.
src = tl.load(src_loc_ptr + loc_idx, mask=mask_loc, other=0)
tgt = tl.load(tgt_loc_ptr + loc_idx, mask=mask_loc, other=0)
num_loop = tl.cdiv(stride, BLOCK_SIZE)
for i in range(num_loop):
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = (num_locs_offset < num_locs)[:, None] & (copy_offset < stride)[None, :]
value = tl.load(
data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
)
tl.store(
data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
value,
mask=mask,
)
src_ptr = base_ptr + src[:, None] * stride + byte_off[None, :]
tgt_ptr = base_ptr + tgt[:, None] * stride + byte_off[None, :]
mask = mask_loc[:, None] & mask_byte[None, :]
vals = tl.load(src_ptr, mask=mask)
tl.store(tgt_ptr, vals, mask=mask)
......@@ -1672,6 +1672,9 @@ class ModelRunner:
enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
end_layer=self.end_layer,
enable_kv_cache_copy=(
self.server_args.speculative_algorithm is not None
),
)
# Initialize token_to_kv_pool_allocator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment