Unverified Commit 421012b6 authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

OffloadingConnector: Support kernel_block_size != block_size (#30692)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent 841d53aa
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <vector> #include <vector>
void swap_blocks(torch::Tensor& src, torch::Tensor& dst, void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
int64_t block_size_in_bytes,
const torch::Tensor& block_mapping); const torch::Tensor& block_mapping);
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
......
...@@ -25,6 +25,7 @@ typedef __hip_bfloat16 __nv_bfloat16; ...@@ -25,6 +25,7 @@ typedef __hip_bfloat16 __nv_bfloat16;
#endif #endif
void swap_blocks(torch::Tensor& src, torch::Tensor& dst, void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
int64_t block_size_in_bytes,
const torch::Tensor& block_mapping) { const torch::Tensor& block_mapping) {
torch::Device src_device = src.device(); torch::Device src_device = src.device();
torch::Device dst_device = dst.device(); torch::Device dst_device = dst.device();
...@@ -49,10 +50,6 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, ...@@ -49,10 +50,6 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
char* src_ptr = static_cast<char*>(src.data_ptr()); char* src_ptr = static_cast<char*>(src.data_ptr());
char* dst_ptr = static_cast<char*>(dst.data_ptr()); char* dst_ptr = static_cast<char*>(dst.data_ptr());
// We use the stride instead of numel in case the cache is padded for memory
// alignment reasons, we assume the blocks data (inclusive of any padding)
// is contiguous in memory
const int64_t block_size_in_bytes = src.element_size() * src.stride(0);
const at::cuda::OptionalCUDAGuard device_guard( const at::cuda::OptionalCUDAGuard device_guard(
src_device.is_cuda() ? src_device : dst_device); src_device.is_cuda() ? src_device : dst_device);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
......
...@@ -692,7 +692,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { ...@@ -692,7 +692,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
// Cache ops // Cache ops
// Swap in (out) the cache blocks from src to dst. // Swap in (out) the cache blocks from src to dst.
cache_ops.def( cache_ops.def(
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"); "swap_blocks(Tensor src, Tensor! dst,"
" int block_size_in_bytes, Tensor block_mapping) -> ()");
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks); cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
// Reshape the key and value tensors and cache them. // Reshape the key and value tensors and cache them.
......
...@@ -405,19 +405,41 @@ def test_swap_blocks( ...@@ -405,19 +405,41 @@ def test_swap_blocks(
# Call the swap_blocks kernel. # Call the swap_blocks kernel.
do_opcheck = head_size == HEAD_SIZES[0] do_opcheck = head_size == HEAD_SIZES[0]
src_cache = src_key_caches[0]
block_size_in_bytes = src_cache.element_size() * src_cache.stride(0)
opcheck( opcheck(
torch.ops._C_cache_ops.swap_blocks, torch.ops._C_cache_ops.swap_blocks,
(src_key_caches[0], dist_key_caches[0], block_mapping_tensor), (
src_key_caches[0],
dist_key_caches[0],
block_size_in_bytes,
block_mapping_tensor,
),
cond=do_opcheck, cond=do_opcheck,
) )
opcheck( opcheck(
torch.ops._C_cache_ops.swap_blocks, torch.ops._C_cache_ops.swap_blocks,
(src_value_caches[0], dist_value_caches[0], block_mapping_tensor), (
src_value_caches[0],
dist_value_caches[0],
block_size_in_bytes,
block_mapping_tensor,
),
cond=do_opcheck, cond=do_opcheck,
) )
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping_tensor) ops.swap_blocks(
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping_tensor) src_key_caches[0],
dist_key_caches[0],
block_size_in_bytes,
block_mapping_tensor,
)
ops.swap_blocks(
src_value_caches[0],
dist_value_caches[0],
block_size_in_bytes,
block_mapping_tensor,
)
for src, dst in block_mapping: for src, dst in block_mapping:
torch.testing.assert_close( torch.testing.assert_close(
...@@ -723,13 +745,14 @@ def test_swap_blocks_mla( ...@@ -723,13 +745,14 @@ def test_swap_blocks_mla(
block_mapping, dtype=torch.int64, device="cpu" block_mapping, dtype=torch.int64, device="cpu"
).view(-1, 2) ).view(-1, 2)
block_size_in_bytes = src_cache.element_size() * src_cache.stride(0)
opcheck( opcheck(
torch.ops._C_cache_ops.swap_blocks, torch.ops._C_cache_ops.swap_blocks,
(src_cache, dst_cache, block_mapping_tensor), (src_cache, dst_cache, block_size_in_bytes, block_mapping_tensor),
test_utils=DEFAULT_OPCHECK_TEST_UTILS, test_utils=DEFAULT_OPCHECK_TEST_UTILS,
) )
ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor) ops.swap_blocks(src_cache, dst_cache, block_size_in_bytes, block_mapping_tensor)
for src, dst in block_mapping: for src, dst in block_mapping:
torch.testing.assert_close( torch.testing.assert_close(
......
...@@ -25,8 +25,9 @@ if not current_platform.is_rocm(): ...@@ -25,8 +25,9 @@ if not current_platform.is_rocm():
NUM_GPU_BLOCKS = [64] NUM_GPU_BLOCKS = [64]
NUM_CPU_BLOCKS = [256] NUM_CPU_BLOCKS = [256]
GPU_BLOCK_SIZES = [16] KERNEL_BLOCK_SIZES = [16]
GPU_BLOCKS_PER_CPU_BLOCK = [1, 3] LOGICAL_BLOCK_SIZES = [16, 32]
LOGICAL_BLOCKS_PER_CPU_BLOCK = [1, 3]
HEAD_SIZES = [64] HEAD_SIZES = [64]
NUM_HEADS = [8] NUM_HEADS = [8]
NUM_LAYERS = [4] NUM_LAYERS = [4]
...@@ -40,8 +41,9 @@ NUM_MAPPINGS = [3] ...@@ -40,8 +41,9 @@ NUM_MAPPINGS = [3]
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("gpu_block_size", GPU_BLOCK_SIZES) @pytest.mark.parametrize("kernel_block_size", KERNEL_BLOCK_SIZES)
@pytest.mark.parametrize("gpu_blocks_per_cpu_block", GPU_BLOCKS_PER_CPU_BLOCK) @pytest.mark.parametrize("logical_block_size", LOGICAL_BLOCK_SIZES)
@pytest.mark.parametrize("logical_blocks_per_cpu_block", LOGICAL_BLOCKS_PER_CPU_BLOCK)
@pytest.mark.parametrize("num_gpu_blocks", NUM_GPU_BLOCKS) @pytest.mark.parametrize("num_gpu_blocks", NUM_GPU_BLOCKS)
@pytest.mark.parametrize("num_cpu_blocks", NUM_CPU_BLOCKS) @pytest.mark.parametrize("num_cpu_blocks", NUM_CPU_BLOCKS)
@pytest.mark.parametrize("num_layers", NUM_LAYERS) @pytest.mark.parametrize("num_layers", NUM_LAYERS)
...@@ -55,8 +57,9 @@ def test_transfer( ...@@ -55,8 +57,9 @@ def test_transfer(
num_mappings: int, num_mappings: int,
head_size: int, head_size: int,
num_heads: int, num_heads: int,
gpu_block_size: int, kernel_block_size: int,
gpu_blocks_per_cpu_block: int, logical_block_size: int,
logical_blocks_per_cpu_block: int,
num_gpu_blocks: int, num_gpu_blocks: int,
num_cpu_blocks: int, num_cpu_blocks: int,
num_layers: int, num_layers: int,
...@@ -69,6 +72,10 @@ def test_transfer( ...@@ -69,6 +72,10 @@ def test_transfer(
# create per-layer GPU KV caches based on available attn_backends # create per-layer GPU KV caches based on available attn_backends
attn_backends_list = BACKENDS_TO_TEST attn_backends_list = BACKENDS_TO_TEST
assert logical_block_size % kernel_block_size == 0
kernel_blocks_per_gpu_block = logical_block_size // kernel_block_size
num_gpu_kernel_blocks = num_gpu_blocks * kernel_blocks_per_gpu_block
gpu_caches = {} gpu_caches = {}
attn_backends = {} attn_backends = {}
for i in range(num_layers): for i in range(num_layers):
...@@ -78,15 +85,16 @@ def test_transfer( ...@@ -78,15 +85,16 @@ def test_transfer(
attn_backends[layer_name] = attn_backend attn_backends[layer_name] = attn_backend
gpu_cache_shape = attn_backend.get_kv_cache_shape( gpu_cache_shape = attn_backend.get_kv_cache_shape(
num_gpu_blocks, gpu_block_size, num_heads, head_size num_gpu_kernel_blocks, kernel_block_size, num_heads, head_size
) )
gpu_caches[layer_name] = torch.rand(gpu_cache_shape, dtype=dtype, device=device) gpu_caches[layer_name] = torch.rand(gpu_cache_shape, dtype=dtype, device=device)
# create handler # create handler
cpu_block_size = gpu_blocks_per_cpu_block * gpu_block_size cpu_block_size = logical_blocks_per_cpu_block * logical_block_size
kernel_blocks_per_cpu_block = cpu_block_size // kernel_block_size
handlers = CpuGpuOffloadingHandlers( handlers = CpuGpuOffloadingHandlers(
attn_backends=attn_backends, attn_backends=attn_backends,
gpu_block_size=gpu_block_size, gpu_block_size=logical_block_size,
cpu_block_size=cpu_block_size, cpu_block_size=cpu_block_size,
num_cpu_blocks=num_cpu_blocks, num_cpu_blocks=num_cpu_blocks,
gpu_caches=gpu_caches, gpu_caches=gpu_caches,
...@@ -94,22 +102,34 @@ def test_transfer( ...@@ -94,22 +102,34 @@ def test_transfer(
# select block mappings # select block mappings
gpu_blocks = random.sample( gpu_blocks = random.sample(
range(num_gpu_blocks), num_mappings * gpu_blocks_per_cpu_block range(num_gpu_blocks), num_mappings * logical_blocks_per_cpu_block
) )
cpu_blocks = random.sample(range(num_cpu_blocks), num_mappings) cpu_blocks = random.sample(range(num_cpu_blocks), num_mappings)
# convert gpu blocks to kernel block size
gpu_blocks_in_kernel_block_size = []
for gpu_block in gpu_blocks:
base_block_id = gpu_block * kernel_blocks_per_gpu_block
for i in range(kernel_blocks_per_gpu_block):
gpu_blocks_in_kernel_block_size.append(i + base_block_id)
# convert cpu blocks to gpu block size # convert cpu blocks to gpu block size
cpu_blocks_in_gpu_block_size = [] cpu_blocks_in_kernel_block_size = []
for cpu_block in cpu_blocks: for cpu_block in cpu_blocks:
base_block_id = cpu_block * gpu_blocks_per_cpu_block base_block_id = cpu_block * kernel_blocks_per_cpu_block
for i in range(gpu_blocks_per_cpu_block): for i in range(kernel_blocks_per_cpu_block):
cpu_blocks_in_gpu_block_size.append(i + base_block_id) cpu_blocks_in_kernel_block_size.append(i + base_block_id)
# maybe skip a GPU block to test reading from the middle of a CPU block # maybe skip some GPU block to test reading from the middle of a CPU block
if not gpu_to_cpu: if not gpu_to_cpu:
gpu_blocks = gpu_blocks[gpu_blocks_per_cpu_block - 1 :] gpu_blocks_to_skip = logical_blocks_per_cpu_block - 1
cpu_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size[ gpu_blocks = gpu_blocks[gpu_blocks_to_skip:]
gpu_blocks_per_cpu_block - 1 : kernel_blocks_to_skip = gpu_blocks_to_skip * kernel_blocks_per_gpu_block
gpu_blocks_in_kernel_block_size = gpu_blocks_in_kernel_block_size[
kernel_blocks_to_skip:
]
cpu_blocks_in_kernel_block_size = cpu_blocks_in_kernel_block_size[
kernel_blocks_to_skip:
] ]
# set transfer direction # set transfer direction
...@@ -119,23 +139,23 @@ def test_transfer( ...@@ -119,23 +139,23 @@ def test_transfer(
dst_spec_class = CPULoadStoreSpec dst_spec_class = CPULoadStoreSpec
src_blocks = gpu_blocks src_blocks = gpu_blocks
dst_blocks = cpu_blocks dst_blocks = cpu_blocks
src_blocks_in_gpu_block_size = gpu_blocks src_blocks_in_kernel_block_size = gpu_blocks_in_kernel_block_size
dst_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size dst_blocks_in_kernel_block_size = cpu_blocks_in_kernel_block_size
dst_size_in_gpu_blocks = num_cpu_blocks * gpu_blocks_per_cpu_block dst_size_in_kernel_blocks = num_cpu_blocks * kernel_blocks_per_cpu_block
else: else:
handler = handlers.cpu_to_gpu_handler handler = handlers.cpu_to_gpu_handler
src_spec_class = CPULoadStoreSpec src_spec_class = CPULoadStoreSpec
dst_spec_class = GPULoadStoreSpec dst_spec_class = GPULoadStoreSpec
src_blocks = cpu_blocks src_blocks = cpu_blocks
dst_blocks = gpu_blocks dst_blocks = gpu_blocks
src_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size src_blocks_in_kernel_block_size = cpu_blocks_in_kernel_block_size
dst_blocks_in_gpu_block_size = gpu_blocks dst_blocks_in_kernel_block_size = gpu_blocks_in_kernel_block_size
dst_size_in_gpu_blocks = num_gpu_blocks dst_size_in_kernel_blocks = num_gpu_blocks * kernel_blocks_per_gpu_block
# build dst -> src mapping # build dst -> src mapping
dst_to_src = {} dst_to_src = {}
for src_block, dst_block in zip( for src_block, dst_block in zip(
src_blocks_in_gpu_block_size, dst_blocks_in_gpu_block_size src_blocks_in_kernel_block_size, dst_blocks_in_kernel_block_size
): ):
dst_to_src[dst_block] = src_block dst_to_src[dst_block] = src_block
...@@ -165,29 +185,15 @@ def test_transfer( ...@@ -165,29 +185,15 @@ def test_transfer(
assert torch.equal(orig_tensor, tensor) assert torch.equal(orig_tensor, tensor)
# verify dst tensors # verify dst tensors
for dst_block in range(dst_size_in_gpu_blocks): for dst_block in range(dst_size_in_kernel_blocks):
src_block_candidate = dst_to_src.get(dst_block) src_block_candidate = dst_to_src.get(dst_block)
for src_cache, dst_cache, orig_dst_cache, kv_dim in zip( for src_cache, dst_cache, orig_dst_cache in zip(
handler.src_tensors, handler.src_tensors,
handler.dst_tensors, handler.dst_tensors,
orig_dst_caches, orig_dst_caches,
handler.kv_dim_before_num_blocks,
): ):
if kv_dim:
# iterate over key, value
for i in range(2):
if src_block_candidate is not None:
expected_value = src_cache[i][src_block_candidate]
else:
expected_value = orig_dst_cache[i][dst_block]
torch.testing.assert_close(
dst_cache[i][dst_block].cpu(), expected_value.cpu()
)
else:
if src_block_candidate is not None: if src_block_candidate is not None:
expected_value = src_cache[src_block_candidate] expected_value = src_cache[src_block_candidate]
else: else:
expected_value = orig_dst_cache[dst_block] expected_value = orig_dst_cache[dst_block]
torch.testing.assert_close( torch.testing.assert_close(dst_cache[dst_block].cpu(), expected_value.cpu())
dst_cache[dst_block].cpu(), expected_value.cpu()
)
...@@ -2455,9 +2455,32 @@ def concat_and_cache_mla_rope_fused( ...@@ -2455,9 +2455,32 @@ def concat_and_cache_mla_rope_fused(
def swap_blocks( def swap_blocks(
src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor src: torch.Tensor,
dst: torch.Tensor,
block_size_in_bytes: int,
block_mapping: torch.Tensor,
) -> None: ) -> None:
torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping) """
Copy specific blocks from one tensor to another.
This method assumes each of the two input tensors is composed of
consecutive contiguous blocks, of size block_size_in_bytes.
i.e. the memory layout for each tensor is:
[block0] [block1] ... [block N]
block_mapping determines the subset of blocks to copy of the source tensor,
and their matching destination block number on the destination tensor.
block_mapping is expected to be a tensor of shape (num_blocks_to_copy, 2)
where each block_mapping[i] represents a single copy operation, copying
block #block_mapping[i][0] from the source tensor
to block #block_mapping[i][1] on the destination tensor.
block_mapping should have dtype int64.
The source and the destination tensors can be either on cpu or gpu,
but not both on cpu.
the block mapping tensor must on cpu.
"""
torch.ops._C_cache_ops.swap_blocks(src, dst, block_size_in_bytes, block_mapping)
def convert_fp8( def convert_fp8(
......
...@@ -65,7 +65,6 @@ class SingleDirectionOffloadingHandler(OffloadingHandler): ...@@ -65,7 +65,6 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
self, self,
src_tensors: list[torch.Tensor], src_tensors: list[torch.Tensor],
dst_tensors: list[torch.Tensor], dst_tensors: list[torch.Tensor],
kv_dim_before_num_blocks: list[bool],
src_block_size_factor: int, src_block_size_factor: int,
dst_block_size_factor: int, dst_block_size_factor: int,
): ):
...@@ -76,22 +75,23 @@ class SingleDirectionOffloadingHandler(OffloadingHandler): ...@@ -76,22 +75,23 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
src_tensors: list of KV cache tensors to copy from. src_tensors: list of KV cache tensors to copy from.
dst_tensors: list of KV cache tensors to copy to. dst_tensors: list of KV cache tensors to copy to.
Order should match src_tensors. Order should match src_tensors.
kv_dim_before_num_blocks: list of bools, indicating
whether the respective KV cache tensor has a KV
dimension before its num_blocks dimension.
e.g. (2, num_blocks, ...)
src_block_size_factor: The number of kernel blocks src_block_size_factor: The number of kernel blocks
per KV block in a source tensor. per KV block in a source tensor.
dst_block_size_factor: The number of kernel blocks dst_block_size_factor: The number of kernel blocks
per KV block in a destination tensor. per KV block in a destination tensor.
""" """
assert len(src_tensors) == len(dst_tensors) == len(kv_dim_before_num_blocks) assert len(src_tensors) == len(dst_tensors)
self.src_tensors: list[torch.Tensor] = src_tensors self.src_tensors: list[torch.Tensor] = src_tensors
self.dst_tensors: list[torch.Tensor] = dst_tensors self.dst_tensors: list[torch.Tensor] = dst_tensors
self.kv_dim_before_num_blocks: list[bool] = kv_dim_before_num_blocks min_block_size_factor = min(src_block_size_factor, dst_block_size_factor)
self.src_block_size_factor: int = src_block_size_factor self.src_block_size_factor: int = src_block_size_factor // min_block_size_factor
self.dst_block_size_factor: int = dst_block_size_factor self.dst_block_size_factor: int = dst_block_size_factor // min_block_size_factor
self.block_size_in_bytes = [
tensor.element_size() * tensor.stride(0) * min_block_size_factor
for tensor in src_tensors
]
assert len(src_tensors) > 0 assert len(src_tensors) > 0
self.gpu_to_cpu: bool = self.src_tensors[0].is_cuda self.gpu_to_cpu: bool = self.src_tensors[0].is_cuda
...@@ -142,16 +142,17 @@ class SingleDirectionOffloadingHandler(OffloadingHandler): ...@@ -142,16 +142,17 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
# assure job will start only after the previous one completes # assure job will start only after the previous one completes
stream.wait_event(last_event) stream.wait_event(last_event)
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
for src_tensor, dst_tensor, kv_dim in zip( for src_tensor, dst_tensor, block_size_in_bytes in zip(
self.src_tensors, self.dst_tensors, self.kv_dim_before_num_blocks self.src_tensors,
self.dst_tensors,
self.block_size_in_bytes,
): ):
if kv_dim: ops.swap_blocks(
src_key_cache, src_value_cache = src_tensor src_tensor,
dst_key_cache, dst_value_cache = dst_tensor dst_tensor,
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst_tensor) block_size_in_bytes,
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst_tensor) src_to_dst_tensor,
else: )
ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor)
event.record(stream) event.record(stream)
self._transfer_events[job_id] = event self._transfer_events[job_id] = event
...@@ -188,19 +189,12 @@ class CpuGpuOffloadingHandlers: ...@@ -188,19 +189,12 @@ class CpuGpuOffloadingHandlers:
): ):
assert gpu_caches assert gpu_caches
assert cpu_block_size % gpu_block_size == 0 assert cpu_block_size % gpu_block_size == 0
block_size_factor = cpu_block_size // gpu_block_size
pin_memory = is_pin_memory_available() # find kernel block size and determine layout per each gpu tensor
# allocate cpu tensors
logger.info("Allocating %d CPU tensors...", len(gpu_caches))
gpu_tensors: list[torch.Tensor] = []
cpu_tensors: list[torch.Tensor] = []
kv_dim_before_num_blocks: list[bool] = []
kernel_block_size: int | None = None kernel_block_size: int | None = None
# list of (gpu_tensor, split_k_and_v)
parsed_gpu_tensors: list[tuple[torch.Tensor, bool]] = []
for layer_name, gpu_tensor in gpu_caches.items(): for layer_name, gpu_tensor in gpu_caches.items():
gpu_tensors.append(gpu_tensor)
gpu_shape = gpu_tensor.shape gpu_shape = gpu_tensor.shape
attn_backend = attn_backends[layer_name] attn_backend = attn_backends[layer_name]
test_shape = attn_backend.get_kv_cache_shape( test_shape = attn_backend.get_kv_cache_shape(
...@@ -208,28 +202,20 @@ class CpuGpuOffloadingHandlers: ...@@ -208,28 +202,20 @@ class CpuGpuOffloadingHandlers:
) )
has_layers_dim = False has_layers_dim = False
split_k_and_v = False
if len(gpu_shape) != len(test_shape): if len(gpu_shape) != len(test_shape):
# cross-layers tensor # cross-layers tensor
# shape is (num_blocks, ...) # shape is (num_blocks, ...)
assert len(gpu_shape) == len(test_shape) + 1 assert len(gpu_shape) == len(test_shape) + 1
num_blocks_idx = 0
has_layers_dim = True has_layers_dim = True
kv_dim_before_num_blocks.append(False)
# prepend a dummy num_layers=80 to test_shape # prepend a dummy num_layers=80 to test_shape
test_shape = (80,) + test_shape test_shape = (80,) + test_shape
elif test_shape[0] == 1234: elif test_shape[0] != 1234:
# shape is (num_blocks, ...)
num_blocks_idx = 0
kv_dim_before_num_blocks.append(False)
else:
# shape should be (2, num_blocks, ...) # shape should be (2, num_blocks, ...)
assert test_shape[0] == 2 assert test_shape[0] == 2
assert test_shape[1] == 1234 assert test_shape[1] == 1234
assert gpu_shape[0] == 2 assert gpu_shape[0] == 2
split_k_and_v = True
num_blocks_idx = 1
kv_dim_before_num_blocks.append(True)
try: try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
...@@ -250,30 +236,36 @@ class CpuGpuOffloadingHandlers: ...@@ -250,30 +236,36 @@ class CpuGpuOffloadingHandlers:
kernel_block_size = gpu_shape[block_size_idx] kernel_block_size = gpu_shape[block_size_idx]
assert gpu_block_size % kernel_block_size == 0 assert gpu_block_size % kernel_block_size == 0
cpu_shape = list(gpu_shape) parsed_gpu_tensors.append((gpu_tensor, split_k_and_v))
cpu_shape[num_blocks_idx] = num_cpu_blocks * block_size_factor
assert kernel_block_size is not None
cpu_block_size_factor = cpu_block_size // kernel_block_size
gpu_block_size_factor = gpu_block_size // kernel_block_size
num_cpu_kernel_blocks = num_cpu_blocks * cpu_block_size_factor
# allocate cpu tensors
pin_memory = is_pin_memory_available()
logger.info("Allocating %d CPU tensors...", len(parsed_gpu_tensors))
gpu_tensors: list[torch.Tensor] = []
cpu_tensors: list[torch.Tensor] = []
for gpu_tensor, split_k_and_v in parsed_gpu_tensors:
cpu_shape = list(gpu_tensor.shape)
cpu_shape[1 if split_k_and_v else 0] = num_cpu_kernel_blocks
logger.debug("Allocating CPU tensor of shape %r", cpu_shape) logger.debug("Allocating CPU tensor of shape %r", cpu_shape)
cpu_tensors.append( cpu_tensor = torch.zeros(
torch.zeros(
cpu_shape, cpu_shape,
dtype=gpu_tensor.dtype, dtype=gpu_tensor.dtype,
device="cpu", device="cpu",
pin_memory=pin_memory, pin_memory=pin_memory,
) )
)
assert kernel_block_size is not None
gpu_block_size_factor = gpu_block_size // kernel_block_size
cpu_block_size_factor = cpu_block_size // kernel_block_size
# TODO (orozery): adapt swap_blocks to support gpu_block_size_factor gpu_tensors.extend(gpu_tensor.unbind(0) if split_k_and_v else [gpu_tensor])
assert gpu_block_size_factor == 1 cpu_tensors.extend(cpu_tensor.unbind(0) if split_k_and_v else [cpu_tensor])
self.gpu_to_cpu_handler = SingleDirectionOffloadingHandler( self.gpu_to_cpu_handler = SingleDirectionOffloadingHandler(
src_tensors=gpu_tensors, src_tensors=gpu_tensors,
dst_tensors=cpu_tensors, dst_tensors=cpu_tensors,
kv_dim_before_num_blocks=kv_dim_before_num_blocks,
src_block_size_factor=gpu_block_size_factor, src_block_size_factor=gpu_block_size_factor,
dst_block_size_factor=cpu_block_size_factor, dst_block_size_factor=cpu_block_size_factor,
) )
...@@ -281,7 +273,6 @@ class CpuGpuOffloadingHandlers: ...@@ -281,7 +273,6 @@ class CpuGpuOffloadingHandlers:
self.cpu_to_gpu_handler = SingleDirectionOffloadingHandler( self.cpu_to_gpu_handler = SingleDirectionOffloadingHandler(
src_tensors=cpu_tensors, src_tensors=cpu_tensors,
dst_tensors=gpu_tensors, dst_tensors=gpu_tensors,
kv_dim_before_num_blocks=kv_dim_before_num_blocks,
src_block_size_factor=cpu_block_size_factor, src_block_size_factor=cpu_block_size_factor,
dst_block_size_factor=gpu_block_size_factor, dst_block_size_factor=gpu_block_size_factor,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment