"docs/vscode:/vscode.git/clone" did not exist on "b2605a8e64bdbc8ecd9933259caaff7b78307c7d"
Unverified Commit f2ad952f authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

[BugFix][kv_offload]: Fix kernel block size detection (#35125)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent 9e2cabdf
...@@ -259,16 +259,20 @@ class CpuGpuOffloadingHandlers: ...@@ -259,16 +259,20 @@ class CpuGpuOffloadingHandlers:
assert gpu_shape[0] == 2 assert gpu_shape[0] == 2
split_k_and_v = True split_k_and_v = True
try: if has_layers_dim:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( # in the cross layers case, the registered kv cache tensor
include_num_layers_dimension=has_layers_dim # shape matches the physical layout, whereas test_shape
) # is the logical layout.
assert len(kv_cache_stride_order) == len(gpu_shape) # To match them, we need to permute test_shape
except (AttributeError, NotImplementedError): try:
kv_cache_stride_order = tuple(range(len(gpu_shape))) kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=has_layers_dim
# permute test_shape according to stride_order )
test_shape = tuple(test_shape[i] for i in kv_cache_stride_order) assert len(kv_cache_stride_order) == len(gpu_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(gpu_shape)))
test_shape = tuple(test_shape[i] for i in kv_cache_stride_order)
# find block_size (16) dimension index # find block_size (16) dimension index
block_size_idx = test_shape.index(16) block_size_idx = test_shape.index(16)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment