"examples/basics/kubernetes/vscode:/vscode.git/clone" did not exist on "b6f31b4114275acbc32b82640183d783710986cf"
Unverified Commit f2ad952f authored by Or Ozeri's avatar Or Ozeri Committed by GitHub
Browse files

[BugFix][kv_offload]: Fix kernel block size detection (#35125)


Signed-off-by: default avatarOr Ozeri <oro@il.ibm.com>
parent 9e2cabdf
...@@ -259,6 +259,11 @@ class CpuGpuOffloadingHandlers: ...@@ -259,6 +259,11 @@ class CpuGpuOffloadingHandlers:
assert gpu_shape[0] == 2 assert gpu_shape[0] == 2
split_k_and_v = True split_k_and_v = True
if has_layers_dim:
# in the cross layers case, the registered kv cache tensor
# shape matches the physical layout, whereas test_shape
# is the logical layout.
# To match them, we need to permute test_shape
try: try:
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order( kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=has_layers_dim include_num_layers_dimension=has_layers_dim
...@@ -267,7 +272,6 @@ class CpuGpuOffloadingHandlers: ...@@ -267,7 +272,6 @@ class CpuGpuOffloadingHandlers:
except (AttributeError, NotImplementedError): except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(gpu_shape))) kv_cache_stride_order = tuple(range(len(gpu_shape)))
# permute test_shape according to stride_order
test_shape = tuple(test_shape[i] for i in kv_cache_stride_order) test_shape = tuple(test_shape[i] for i in kv_cache_stride_order)
# find block_size (16) dimension index # find block_size (16) dimension index
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment