Unverified Commit cdfd6871 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Misaligned params in TreeAttentionImpl (#22226)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 4b3e4474
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import ast import ast
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Optional
import torch import torch
...@@ -313,15 +313,11 @@ class TreeAttentionImpl(AttentionImpl): ...@@ -313,15 +313,11 @@ class TreeAttentionImpl(AttentionImpl):
alibi_slopes: Optional[list[float]], alibi_slopes: Optional[list[float]],
sliding_window: Optional[int], sliding_window: Optional[int],
kv_cache_dtype: str, kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None, kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False, use_irope: bool = False,
) -> None: ) -> None:
if blocksparse_params is not None:
raise ValueError(
"TreeAttention does not support block-sparse attention.")
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.scale = float(scale) self.scale = float(scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment