Unverified Commit 74743104 authored by cong-or's avatar cong-or Committed by GitHub
Browse files

feat(attention): extract KV-cache update from FlexAttention backend (#36263)


Signed-off-by: default avatarcong-or <conchubhar.gannon@gmail.com>
parent d62856b9
......@@ -82,6 +82,8 @@ class FlexAttentionBackend(AttentionBackend):
]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "bfloat16"]
forward_includes_kv_cache_update: bool = False
@staticmethod
def get_name() -> str:
return "FLEX_ATTENTION"
......@@ -827,6 +829,29 @@ class FlexAttentionImpl(AttentionImpl):
assert tensor.ndim == 3
return tensor[None, :, :, :]
def do_kv_cache_update(
self,
layer: torch.nn.Module,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
if self.attn_type == AttentionType.ENCODER_ONLY:
return
key_cache, value_cache = kv_cache.unbind(0)
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
def forward(
self,
layer: torch.nn.Module,
......@@ -908,17 +933,6 @@ class FlexAttentionImpl(AttentionImpl):
assert self.attn_type == AttentionType.DECODER
key_cache, value_cache = kv_cache.unbind(0)
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
# View out the block_size dim
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment