Unverified Commit cdfd6871 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Misaligned params in TreeAttentionImpl (#22226)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 4b3e4474
......@@ -4,7 +4,7 @@
import ast
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Optional
import torch
......@@ -313,15 +313,11 @@ class TreeAttentionImpl(AttentionImpl):
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
) -> None:
if blocksparse_params is not None:
raise ValueError(
"TreeAttention does not support block-sparse attention.")
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment