Unverified Commit 9a31a817 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Bugfix] Fix FP8 KV cache support (#4869)

parent 2060e936
...@@ -200,15 +200,15 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -200,15 +200,15 @@ class FlashAttentionImpl(AttentionImpl):
num_heads: int, num_heads: int,
head_size: int, head_size: int,
scale: float, scale: float,
num_kv_heads: Optional[int] = None, num_kv_heads: int,
alibi_slopes: Optional[List[float]] = None, alibi_slopes: Optional[List[float]],
sliding_window: Optional[int] = None, sliding_window: Optional[int],
kv_cache_dtype: str = "auto", kv_cache_dtype: str,
) -> None: ) -> None:
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.num_kv_heads = num_kv_heads
if alibi_slopes is not None: if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes self.alibi_slopes = alibi_slopes
......
...@@ -164,15 +164,15 @@ class FlashInferImpl(AttentionImpl): ...@@ -164,15 +164,15 @@ class FlashInferImpl(AttentionImpl):
num_heads: int, num_heads: int,
head_size: int, head_size: int,
scale: float, scale: float,
num_kv_heads: Optional[int] = None, num_kv_heads: int,
alibi_slopes: Optional[List[float]] = None, alibi_slopes: Optional[List[float]],
sliding_window: Optional[int] = None, sliding_window: Optional[int],
kv_cache_dtype: str = "auto", kv_cache_dtype: str,
) -> None: ) -> None:
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.num_kv_heads = num_kv_heads
if alibi_slopes is not None: if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes self.alibi_slopes = alibi_slopes
......
...@@ -197,15 +197,15 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -197,15 +197,15 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_heads: int, num_heads: int,
head_size: int, head_size: int,
scale: float, scale: float,
num_kv_heads: Optional[int] = None, num_kv_heads: int,
alibi_slopes: Optional[List[float]] = None, alibi_slopes: Optional[List[float]],
sliding_window: Optional[int] = None, sliding_window: Optional[int],
kv_cache_dtype: str = "auto", kv_cache_dtype: str,
) -> None: ) -> None:
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.num_kv_heads = num_kv_heads
if alibi_slopes is not None: if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes self.alibi_slopes = alibi_slopes
......
...@@ -96,15 +96,15 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -96,15 +96,15 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
num_heads: int, num_heads: int,
head_size: int, head_size: int,
scale: float, scale: float,
num_kv_heads: Optional[int] = None, num_kv_heads: int,
alibi_slopes: Optional[List[float]] = None, alibi_slopes: Optional[List[float]],
sliding_window: Optional[int] = None, sliding_window: Optional[int],
kv_cache_dtype: str = "auto", kv_cache_dtype: str,
) -> None: ) -> None:
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.num_kv_heads = num_kv_heads
if alibi_slopes is not None: if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes self.alibi_slopes = alibi_slopes
......
...@@ -208,15 +208,15 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): ...@@ -208,15 +208,15 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
num_heads: int, num_heads: int,
head_size: int, head_size: int,
scale: float, scale: float,
num_kv_heads: Optional[int] = None, num_kv_heads: int,
alibi_slopes: Optional[List[float]] = None, alibi_slopes: Optional[List[float]],
sliding_window: Optional[int] = None, sliding_window: Optional[int],
kv_cache_dtype: str = "auto", kv_cache_dtype: str,
) -> None: ) -> None:
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads self.num_kv_heads = num_kv_heads
if alibi_slopes is not None: if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes self.alibi_slopes = alibi_slopes
......
...@@ -48,7 +48,7 @@ class Attention(nn.Module): ...@@ -48,7 +48,7 @@ class Attention(nn.Module):
block_size) block_size)
impl_cls = attn_backend.get_impl_cls() impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window) alibi_slopes, sliding_window, kv_cache_dtype)
def forward( def forward(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment