Unverified Commit 9e96f56e authored by Shu Wang's avatar Shu Wang Committed by GitHub
Browse files

Allocate kv_cache with stride order (#16605)


Signed-off-by: default avatarshuw <shuw@nvidia.com>
parent b2789112
...@@ -270,9 +270,10 @@ __global__ void reshape_and_cache_flash_kernel( ...@@ -270,9 +270,10 @@ __global__ void reshape_and_cache_flash_kernel(
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
// head_size] // head_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens] const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, const int key_stride, const int value_stride, const int64_t block_stride, const int64_t page_stride,
const int num_heads, const int head_size, const int block_size, const int64_t head_stride, const int64_t key_stride,
const float* k_scale, const float* v_scale) { const int64_t value_stride, const int num_heads, const int head_size,
const int block_size, const float* k_scale, const float* v_scale) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx]; const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded // NOTE: slot_idx can be -1 if the token is padded
...@@ -288,8 +289,8 @@ __global__ void reshape_and_cache_flash_kernel( ...@@ -288,8 +289,8 @@ __global__ void reshape_and_cache_flash_kernel(
const int head_idx = i / head_size; const int head_idx = i / head_size;
const int head_offset = i % head_size; const int head_offset = i % head_size;
const int64_t tgt_key_value_idx = block_idx * block_stride + const int64_t tgt_key_value_idx = block_idx * block_stride +
block_offset * num_heads * head_size + block_offset * page_stride +
head_idx * head_size + head_offset; head_idx * head_stride + head_offset;
scalar_t tgt_key = key[src_key_idx]; scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx]; scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
...@@ -396,16 +397,16 @@ void reshape_and_cache( ...@@ -396,16 +397,16 @@ void reshape_and_cache(
// KV_T is the data type of key and value tensors. // KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache. // CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache. // KV_DTYPE is the real data type of kv-cache.
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \ #define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \ vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \ <<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \ reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \ reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \ reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \ reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \ slot_mapping.data_ptr<int64_t>(), block_stride, page_stride, \
value_stride, num_heads, head_size, block_size, \ head_stride, key_stride, value_stride, num_heads, head_size, \
reinterpret_cast<const float*>(k_scale.data_ptr()), \ block_size, reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr())); reinterpret_cast<const float*>(v_scale.data_ptr()));
void reshape_and_cache_flash( void reshape_and_cache_flash(
...@@ -432,9 +433,11 @@ void reshape_and_cache_flash( ...@@ -432,9 +433,11 @@ void reshape_and_cache_flash(
int head_size = key.size(2); int head_size = key.size(2);
int block_size = key_cache.size(1); int block_size = key_cache.size(1);
int key_stride = key.stride(0); int64_t key_stride = key.stride(0);
int value_stride = value.stride(0); int64_t value_stride = value.stride(0);
int block_stride = key_cache.stride(0); int64_t block_stride = key_cache.stride(0);
int64_t page_stride = key_cache.stride(1);
int64_t head_stride = key_cache.stride(2);
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0)); TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
dim3 grid(num_tokens); dim3 grid(num_tokens);
......
...@@ -16,6 +16,7 @@ NUM_LAYERS = [1] # Arbitrary values for testing ...@@ -16,6 +16,7 @@ NUM_LAYERS = [1] # Arbitrary values for testing
NUM_HEADS = [8] # Arbitrary values for testing NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 120, 256] HEAD_SIZES = [64, 80, 120, 256]
BLOCK_SIZES = [8, 16, 32] BLOCK_SIZES = [8, 16, 32]
CACHE_LAYOUTS = ["NHD", "HND"]
# Parameters for MLA tests. # Parameters for MLA tests.
KV_LORA_RANKS = [512] KV_LORA_RANKS = [512]
...@@ -220,6 +221,7 @@ def test_reshape_and_cache( ...@@ -220,6 +221,7 @@ def test_reshape_and_cache(
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("kv_cache_layout", CACHE_LAYOUTS)
@torch.inference_mode() @torch.inference_mode()
def test_reshape_and_cache_flash( def test_reshape_and_cache_flash(
kv_cache_factory_flashinfer, kv_cache_factory_flashinfer,
...@@ -232,17 +234,21 @@ def test_reshape_and_cache_flash( ...@@ -232,17 +234,21 @@ def test_reshape_and_cache_flash(
seed: int, seed: int,
device: str, device: str,
kv_cache_dtype: str, kv_cache_dtype: str,
kv_cache_layout: str,
) -> None: ) -> None:
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch.set_default_device(device) torch.set_default_device(device)
# fp8 conversion requires continugous memory buffer. Reduce the number of
# blocks and tokens to consume less memory.
num_tokens = num_tokens // 2
num_blocks = num_blocks // 2
# Create a random slot mapping. # Create a random slot mapping.
num_slots = block_size * num_blocks num_slots = block_size * num_blocks
slot_mapping_lst = random.sample(range(num_slots), num_tokens) slot_mapping_lst = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping_lst, slot_mapping = torch.tensor(slot_mapping_lst,
dtype=torch.long, dtype=torch.long,
device=device) device=device)
qkv = torch.randn(num_tokens, qkv = torch.randn(num_tokens,
3, 3,
num_heads, num_heads,
...@@ -261,27 +267,35 @@ def test_reshape_and_cache_flash( ...@@ -261,27 +267,35 @@ def test_reshape_and_cache_flash(
kv_cache_dtype, kv_cache_dtype,
dtype, dtype,
device=device, device=device,
cache_layout=kv_cache_layout,
) )
key_cache, value_cache = key_caches[0].contiguous( key_cache, value_cache = key_caches[0], value_caches[0]
), value_caches[0].contiguous()
del key_caches del key_caches
del value_caches del value_caches
k_scale = (key.amax() / 64.0).to(torch.float32) k_scale = (key.amax() / 64.0).to(torch.float32)
v_scale = (value.amax() / 64.0).to(torch.float32) v_scale = (value.amax() / 64.0).to(torch.float32)
def permute_and_compact(x):
y = x if kv_cache_layout == "NHD" else x.permute(0, 2, 1, 3)
return y.contiguous()
key_cache_compact = permute_and_compact(key_cache)
value_cache_compact = permute_and_compact(value_cache)
# Clone the KV caches. # Clone the KV caches.
if kv_cache_dtype == "fp8": if kv_cache_dtype == "fp8":
cloned_key_cache = torch.empty_like(key_cache, dtype=torch.float16) cloned_key_cache = torch.empty_like(key_cache_compact,
ops.convert_fp8(cloned_key_cache, key_cache, k_scale.item(), dtype=torch.float16)
kv_cache_dtype) ops.convert_fp8(cloned_key_cache, key_cache_compact, k_scale.item(),
cloned_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(cloned_value_cache, value_cache, v_scale.item(),
kv_cache_dtype) kv_cache_dtype)
cloned_value_cache = torch.empty_like(value_cache_compact,
dtype=torch.float16)
ops.convert_fp8(cloned_value_cache, value_cache_compact,
v_scale.item(), kv_cache_dtype)
else: else:
cloned_key_cache = key_cache.clone() cloned_key_cache = key_cache_compact.clone()
cloned_value_cache = value_cache.clone() cloned_value_cache = value_cache_compact.clone()
# Call the reshape_and_cache kernel. # Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash, opcheck(torch.ops._C_cache_ops.reshape_and_cache_flash,
(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, (key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype,
...@@ -289,16 +303,20 @@ def test_reshape_and_cache_flash( ...@@ -289,16 +303,20 @@ def test_reshape_and_cache_flash(
cond=(head_size == HEAD_SIZES[0])) cond=(head_size == HEAD_SIZES[0]))
ops.reshape_and_cache_flash(key, value, key_cache, value_cache, ops.reshape_and_cache_flash(key, value, key_cache, value_cache,
slot_mapping, kv_cache_dtype, k_scale, v_scale) slot_mapping, kv_cache_dtype, k_scale, v_scale)
key_cache_compact = permute_and_compact(key_cache)
value_cache_compact = permute_and_compact(value_cache)
if kv_cache_dtype == "fp8": if kv_cache_dtype == "fp8":
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) result_key_cache = torch.empty_like(key_cache_compact,
dtype=torch.float16)
ops.convert_fp8(result_key_cache, ops.convert_fp8(result_key_cache,
key_cache, key_cache_compact,
k_scale.item(), k_scale.item(),
kv_dtype=kv_cache_dtype) kv_dtype=kv_cache_dtype)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) result_value_cache = torch.empty_like(value_cache_compact,
dtype=torch.float16)
ops.convert_fp8(result_value_cache, ops.convert_fp8(result_value_cache,
value_cache, value_cache_compact,
v_scale.item(), v_scale.item(),
kv_dtype=kv_cache_dtype) kv_dtype=kv_cache_dtype)
...@@ -310,8 +328,12 @@ def test_reshape_and_cache_flash( ...@@ -310,8 +328,12 @@ def test_reshape_and_cache_flash(
for i in range(num_tokens): for i in range(num_tokens):
block_idx = block_indicies_lst[i] block_idx = block_indicies_lst[i]
block_offset = block_offsets_lst[i] block_offset = block_offsets_lst[i]
cloned_key_cache[block_idx, block_offset, :, :] = key[i] if kv_cache_layout == "NHD":
cloned_value_cache[block_idx, block_offset, :, :] = value[i] cloned_key_cache[block_idx, block_offset, :, :] = key[i]
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
else:
cloned_key_cache[block_idx, :, block_offset, :] = key[i]
cloned_value_cache[block_idx, :, block_offset, :] = value[i]
if kv_cache_dtype == "fp8": if kv_cache_dtype == "fp8":
torch.testing.assert_close(result_key_cache, torch.testing.assert_close(result_key_cache,
...@@ -323,8 +345,8 @@ def test_reshape_and_cache_flash( ...@@ -323,8 +345,8 @@ def test_reshape_and_cache_flash(
atol=0.001, atol=0.001,
rtol=0.1) rtol=0.1)
else: else:
torch.testing.assert_close(key_cache, cloned_key_cache) torch.testing.assert_close(key_cache_compact, cloned_key_cache)
torch.testing.assert_close(value_cache, cloned_value_cache) torch.testing.assert_close(value_cache_compact, cloned_value_cache)
@pytest.mark.parametrize("direction", COPYING_DIRECTION) @pytest.mark.parametrize("direction", COPYING_DIRECTION)
......
...@@ -77,6 +77,10 @@ class AttentionBackend(ABC): ...@@ -77,6 +77,10 @@ class AttentionBackend(ABC):
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
raise NotImplementedError raise NotImplementedError
@staticmethod
def get_kv_cache_stride_order() -> Tuple[int, ...]:
raise NotImplementedError
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def swap_blocks( def swap_blocks(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import dataclasses import dataclasses
import os
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
...@@ -48,6 +49,9 @@ if TYPE_CHECKING: ...@@ -48,6 +49,9 @@ if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT",
"NHD").upper()
class FlashInferBackend(AttentionBackend): class FlashInferBackend(AttentionBackend):
...@@ -80,6 +84,14 @@ class FlashInferBackend(AttentionBackend): ...@@ -80,6 +84,14 @@ class FlashInferBackend(AttentionBackend):
) -> Tuple[int, ...]: ) -> Tuple[int, ...]:
return (num_blocks, 2, block_size, num_kv_heads, head_size) return (num_blocks, 2, block_size, num_kv_heads, head_size)
@staticmethod
def get_kv_cache_stride_order() -> Tuple[int, ...]:
cache_layout = FLASHINFER_KV_CACHE_LAYOUT
assert (cache_layout in ("NHD", "HND"))
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3,
2, 4)
return stride_order
@staticmethod @staticmethod
def swap_blocks( def swap_blocks(
src_kv_cache: torch.Tensor, src_kv_cache: torch.Tensor,
...@@ -188,6 +200,7 @@ class FlashInferState(AttentionState): ...@@ -188,6 +200,7 @@ class FlashInferState(AttentionState):
self.global_hyperparameters: Optional[PerLayerParameters] = None self.global_hyperparameters: Optional[PerLayerParameters] = None
self.vllm_config = self.runner.vllm_config self.vllm_config = self.runner.vllm_config
self._kv_cache_layout = None
def _get_workspace_buffer(self): def _get_workspace_buffer(self):
if self._workspace_buffer is None: if self._workspace_buffer is None:
...@@ -197,10 +210,15 @@ class FlashInferState(AttentionState): ...@@ -197,10 +210,15 @@ class FlashInferState(AttentionState):
device=self.runner.device) device=self.runner.device)
return self._workspace_buffer return self._workspace_buffer
def get_kv_cache_layout(self):
if self._kv_cache_layout is None:
self._kv_cache_layout = FLASHINFER_KV_CACHE_LAYOUT
return self._kv_cache_layout
def _get_prefill_wrapper(self): def _get_prefill_wrapper(self):
if self._prefill_wrapper is None: if self._prefill_wrapper is None:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(), "NHD") self._get_workspace_buffer(), self.get_kv_cache_layout())
return self._prefill_wrapper return self._prefill_wrapper
def _get_decode_wrapper(self): def _get_decode_wrapper(self):
...@@ -213,7 +231,7 @@ class FlashInferState(AttentionState): ...@@ -213,7 +231,7 @@ class FlashInferState(AttentionState):
num_qo_heads // num_kv_heads > 4) num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(), self._get_workspace_buffer(),
"NHD", self.get_kv_cache_layout(),
use_tensor_cores=use_tensor_cores) use_tensor_cores=use_tensor_cores)
return self._decode_wrapper return self._decode_wrapper
...@@ -274,7 +292,8 @@ class FlashInferState(AttentionState): ...@@ -274,7 +292,8 @@ class FlashInferState(AttentionState):
self._graph_decode_wrapper = \ self._graph_decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper( CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self._graph_decode_workspace_buffer, _indptr_buffer, self._graph_decode_workspace_buffer, _indptr_buffer,
self._graph_indices_buffer, _last_page_len_buffer, "NHD", self._graph_indices_buffer, _last_page_len_buffer,
self.get_kv_cache_layout(),
use_tensor_cores) use_tensor_cores)
if self.runner.kv_cache_dtype.startswith("fp8"): if self.runner.kv_cache_dtype.startswith("fp8"):
kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
...@@ -1005,6 +1024,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -1005,6 +1024,7 @@ class FlashInferImpl(AttentionImpl):
prefill_output: Optional[torch.Tensor] = None prefill_output: Optional[torch.Tensor] = None
decode_output: Optional[torch.Tensor] = None decode_output: Optional[torch.Tensor] = None
stride_order = FlashInferBackend.get_kv_cache_stride_order()
if prefill_meta := attn_metadata.prefill_metadata: if prefill_meta := attn_metadata.prefill_metadata:
# We will use flash attention for prefill # We will use flash attention for prefill
# when kv_cache is not provided. # when kv_cache is not provided.
...@@ -1036,7 +1056,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -1036,7 +1056,7 @@ class FlashInferImpl(AttentionImpl):
prefill_output = prefill_meta.prefill_wrapper.run( prefill_output = prefill_meta.prefill_wrapper.run(
query, query,
kv_cache, kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float, k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float, v_scale=layer._v_scale_float,
) )
...@@ -1051,7 +1071,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -1051,7 +1071,7 @@ class FlashInferImpl(AttentionImpl):
decode_output = decode_meta.decode_wrapper.run( decode_output = decode_meta.decode_wrapper.run(
decode_query, decode_query,
kv_cache, kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float, k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float, v_scale=layer._v_scale_float,
) )
......
...@@ -765,21 +765,28 @@ def create_kv_caches_with_random_flash( ...@@ -765,21 +765,28 @@ def create_kv_caches_with_random_flash(
model_dtype: Optional[Union[str, torch.dtype]] = None, model_dtype: Optional[Union[str, torch.dtype]] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
device: Optional[str] = "cuda", device: Optional[str] = "cuda",
cache_layout: Optional[str] = "NHD",
) -> tuple[list[torch.Tensor], list[torch.Tensor]]: ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
from vllm.platforms import current_platform from vllm.platforms import current_platform
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype)
key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) generic_kv_cache_shape = (num_blocks, 2, block_size, num_heads, head_size)
assert cache_layout in ("NHD", "HND")
stride_order = (0, 1, 2, 3, 4) if cache_layout == "NHD" else (0, 1, 3, 2,
4)
kv_cache_allocation_shape = tuple(generic_kv_cache_shape[i]
for i in stride_order)
scale = head_size**-0.5 scale = head_size**-0.5
key_caches: list[torch.Tensor] = [] key_caches: list[torch.Tensor] = []
value_caches: list[torch.Tensor] = [] value_caches: list[torch.Tensor] = []
for _ in range(num_layers): for _ in range(num_layers):
key_value_cache = torch.empty(size=key_value_cache_shape, key_value_cache = torch.empty(size=kv_cache_allocation_shape,
dtype=torch_dtype, dtype=torch_dtype,
device=device) device=device).permute(*stride_order)
if cache_dtype in ["auto", "half", "bfloat16", "float"]: if cache_dtype in ["auto", "half", "bfloat16", "float"]:
key_value_cache.uniform_(-scale, scale) key_value_cache.uniform_(-scale, scale)
elif cache_dtype == 'fp8': elif cache_dtype == 'fp8':
......
...@@ -71,19 +71,32 @@ class CacheEngine: ...@@ -71,19 +71,32 @@ class CacheEngine:
device: str, device: str,
) -> List[torch.Tensor]: ) -> List[torch.Tensor]:
"""Allocates KV cache on the specified device.""" """Allocates KV cache on the specified device."""
kv_cache_shape = self.attn_backend.get_kv_cache_shape( kv_cache_generic_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size) num_blocks, self.block_size, self.num_kv_heads, self.head_size)
pin_memory = is_pin_memory_available() if device == "cpu" else False pin_memory = is_pin_memory_available() if device == "cpu" else False
kv_cache: List[torch.Tensor] = [] kv_cache: List[torch.Tensor] = []
try:
kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(kv_cache_generic_shape)))
# The allocation respects the backend-defined stride order to ensure
# the semantic remains consistent for each backend. We first obtain the
# generic kv cache shape and then permute it according to the stride
# order which could result in a non-contiguous tensor.
kv_cache_allocation_shape = tuple(kv_cache_generic_shape[i]
for i in kv_cache_stride_order)
for _ in range(self.num_attention_layers): for _ in range(self.num_attention_layers):
# null block in CpuGpuBlockAllocator requires at least that # null block in CpuGpuBlockAllocator requires at least that
# block to be zeroed-out. # block to be zeroed-out.
# We zero-out everything for simplicity. # We zero-out everything for simplicity.
layer_kv_cache = torch.zeros(kv_cache_shape, layer_kv_cache = torch.zeros(
dtype=self.dtype, kv_cache_allocation_shape,
pin_memory=pin_memory, dtype=self.dtype,
device=device) pin_memory=pin_memory,
device=device).permute(*kv_cache_stride_order)
# view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases # view back to (TOTAL_PAGES, PAGE_SIZE, entry_shape...) for cases
# when entry_shape is higher than 1D # when entry_shape is higher than 1D
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment