"tests/vscode:/vscode.git/clone" did not exist on "c29fb540ff90da720490daae58bb4bfe31a91125"
Unverified Commit 98217b09 authored by ElizaWszola's avatar ElizaWszola Committed by GitHub
Browse files

[Performance] Extract KV cache update op from flashinfer forward (#35422)


Signed-off-by: default avatarElizaWszola <ewszola@redhat.com>
parent 967572dd
......@@ -381,6 +381,8 @@ class FlashInferBackend(AttentionBackend):
return "HND"
return None
forward_includes_kv_cache_update: bool = False
@dataclass
class FIPrefill:
......@@ -1330,32 +1332,15 @@ class FlashInferImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if self.kv_sharing_target_layer_name is None and self.kv_cache_dtype.startswith(
"fp8"
):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.kv_cache_dtype
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if self.kv_cache_dtype.startswith("fp8"):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.kv_cache_dtype
)
kv_cache = kv_cache.view(torch_dtype)
kv_cache = kv_cache.view(torch_dtype)
# Inputs and outputs may be padded for CUDA graphs
query = query[:num_actual_tokens]
......@@ -1599,6 +1584,33 @@ class FlashInferImpl(AttentionImpl):
)
return output_padded
def do_kv_cache_update(
self,
layer: torch.nn.Module,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
def fast_plan_decode(
self, # decode wrapper
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment