Unverified Commit 6cc7abdc authored by Kfir Toledo's avatar Kfir Toledo Committed by GitHub
Browse files

[kv_offload+HMA] Fix num_blocks with different per-layer page sizes and...


[kv_offload+HMA] Fix num_blocks with different per-layer page sizes and improve assert message (#38554)
Signed-off-by: default avatarKfir Toledo <kfir.toledo@ibm.com>
Co-authored-by: default avatarOr Ozeri <oro@il.ibm.com>
parent d53cb9cb
...@@ -83,6 +83,8 @@ class OffloadingConnectorWorker: ...@@ -83,6 +83,8 @@ class OffloadingConnectorWorker:
if layer_name in layers if layer_name in layers
} }
num_blocks = self.spec.kv_cache_config.num_blocks
# layer_name -> list of matching KV cache tensors # layer_name -> list of matching KV cache tensors
# such that each tensor starts with the num_blocks dimension. # such that each tensor starts with the num_blocks dimension.
# FlashAttention layers which use the (2, num_blocks, ...) layout # FlashAttention layers which use the (2, num_blocks, ...) layout
...@@ -132,7 +134,6 @@ class OffloadingConnectorWorker: ...@@ -132,7 +134,6 @@ class OffloadingConnectorWorker:
num_blocks_logical_dim num_blocks_logical_dim
) )
if num_blocks_physical_dim == 0: if num_blocks_physical_dim == 0:
num_blocks = layer_kv_cache.shape[num_blocks_logical_dim]
storage = layer_kv_cache.untyped_storage() storage = layer_kv_cache.untyped_storage()
page = layer_kv_cache_spec.page_size_bytes page = layer_kv_cache_spec.page_size_bytes
tensors_per_block[layer_name] = ( tensors_per_block[layer_name] = (
...@@ -154,7 +155,6 @@ class OffloadingConnectorWorker: ...@@ -154,7 +155,6 @@ class OffloadingConnectorWorker:
assert num_blocks_physical_dim == 1 assert num_blocks_physical_dim == 1
# unbind the tensor to separate K and V tensors # unbind the tensor to separate K and V tensors
num_blocks = layer_kv_cache.shape[num_blocks_logical_dim]
half_page_size = layer_kv_cache_spec.page_size_bytes // 2 half_page_size = layer_kv_cache_spec.page_size_bytes // 2
storage = layer_kv_cache.untyped_storage() storage = layer_kv_cache.untyped_storage()
raw = ( raw = (
...@@ -181,7 +181,6 @@ class OffloadingConnectorWorker: ...@@ -181,7 +181,6 @@ class OffloadingConnectorWorker:
assert len(state_tensors) > 0 assert len(state_tensors) > 0
first_state_tensor = state_tensors[0] first_state_tensor = state_tensors[0]
assert first_state_tensor.storage_offset() == 0 assert first_state_tensor.storage_offset() == 0
num_blocks = first_state_tensor.shape[0]
tensor = ( tensor = (
torch.tensor( torch.tensor(
[], [],
......
...@@ -93,7 +93,12 @@ class OffloadingSpec(ABC): ...@@ -93,7 +93,12 @@ class OffloadingSpec(ABC):
) )
for block_size in self.gpu_block_size: for block_size in self.gpu_block_size:
assert block_size % self.hash_block_size == 0 assert block_size % self.hash_block_size == 0, (
f"gpu_block_size={block_size} not divisible by "
f"hash_block_size={self.hash_block_size}. "
f"Hybrid models (e.g. Mamba+Attention) need "
f"--enable-prefix-caching to align block sizes."
)
# offloaded_block_size / gpu_block_size # offloaded_block_size / gpu_block_size
self.block_size_factor: int = 1 self.block_size_factor: int = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment