Unverified Commit 15ae8e07 authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[Bugfix][CI/Test][Spec Decode] Fix illegal memory access in...


[Bugfix][CI/Test][Spec Decode] Fix illegal memory access in offline_inference/spec_decode.py (Issue  27619) (#28432)
Signed-off-by: default avatarRandall Smith <ransmith@amd.com>
Co-authored-by: default avatarRandall Smith <ransmith@amd.com>
Co-authored-by: default avatarTJian <tunjian.tan@embeddedllm.com>
parent 0b254989
...@@ -97,7 +97,6 @@ def triton_reshape_and_cache_flash( ...@@ -97,7 +97,6 @@ def triton_reshape_and_cache_flash(
k_scale: torch.Tensor, # float32 k_scale: torch.Tensor, # float32
v_scale: torch.Tensor, # float32 v_scale: torch.Tensor, # float32
): ):
num_tokens = key.shape[0]
num_heads = key.shape[1] num_heads = key.shape[1]
head_size = key.shape[2] head_size = key.shape[2]
block_size = key_cache.shape[1] block_size = key_cache.shape[1]
...@@ -155,7 +154,10 @@ def triton_reshape_and_cache_flash( ...@@ -155,7 +154,10 @@ def triton_reshape_and_cache_flash(
# TODO(ngl): maybe replace with static launch grid to avoid overhead if # TODO(ngl): maybe replace with static launch grid to avoid overhead if
# using cudagraphs # using cudagraphs
grid = lambda meta: (int(num_tokens), triton.cdiv(n, meta["TILE_SIZE"])) grid = lambda meta: (
slot_mapping.shape[0],
triton.cdiv(n, meta["TILE_SIZE"]),
)
reshape_and_cache_kernel_flash[grid]( reshape_and_cache_kernel_flash[grid](
key_ptr=key, key_ptr=key,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment