Unverified Commit 9dbe1fe9 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Bugfix] Fix missing scale passing for encoder Triton Attention implementation (#32149)


Signed-off-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent a5f89ae2
...@@ -4,10 +4,7 @@ ...@@ -4,10 +4,7 @@
from argparse import Namespace from argparse import Namespace
from vllm import LLM, EngineArgs from vllm import LLM, EngineArgs
from vllm.config import AttentionConfig
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.attention.backends.registry import AttentionBackendEnum
def parse_args(): def parse_args():
...@@ -23,11 +20,6 @@ def parse_args(): ...@@ -23,11 +20,6 @@ def parse_args():
def main(args: Namespace): def main(args: Namespace):
if current_platform.is_rocm():
args.attention_config = AttentionConfig(
backend=AttentionBackendEnum.FLEX_ATTENTION
)
# Sample prompts. # Sample prompts.
prompts = [ prompts = [
"Hello, my name is", "Hello, my name is",
......
...@@ -4,10 +4,7 @@ ...@@ -4,10 +4,7 @@
from argparse import Namespace from argparse import Namespace
from vllm import LLM, EngineArgs from vllm import LLM, EngineArgs
from vllm.config import AttentionConfig
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.attention.backends.registry import AttentionBackendEnum
def parse_args(): def parse_args():
...@@ -23,11 +20,6 @@ def parse_args(): ...@@ -23,11 +20,6 @@ def parse_args():
def main(args: Namespace): def main(args: Namespace):
if current_platform.is_rocm():
args.attention_config = AttentionConfig(
backend=AttentionBackendEnum.FLEX_ATTENTION
)
# Sample prompts. # Sample prompts.
text_1 = "What is the capital of France?" text_1 = "What is the capital of France?"
texts_2 = [ texts_2 = [
......
...@@ -573,6 +573,7 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -573,6 +573,7 @@ class TritonAttentionImpl(AttentionImpl):
b_seq_len=seq_lens, b_seq_len=seq_lens,
max_input_len=max_query_len, max_input_len=max_query_len,
is_causal=False, # Encoder attention is bidirectional is_causal=False, # Encoder attention is bidirectional
softmax_scale=self.scale,
sliding_window_q=self.sliding_window[0], sliding_window_q=self.sliding_window[0],
sliding_window_k=self.sliding_window[1], sliding_window_k=self.sliding_window[1],
) )
......
...@@ -211,16 +211,17 @@ def get_block_size(dtype: torch.dtype) -> int: ...@@ -211,16 +211,17 @@ def get_block_size(dtype: torch.dtype) -> int:
def context_attention_fwd( def context_attention_fwd(
q, q: torch.Tensor,
k, k: torch.Tensor,
v, v: torch.Tensor,
o, o: torch.Tensor,
b_start_loc, b_start_loc: torch.Tensor,
b_seq_len, b_seq_len: torch.Tensor,
max_input_len, max_input_len: int,
is_causal=True, is_causal: bool = True,
sliding_window_q=None, softmax_scale: float | None = None,
sliding_window_k=None, sliding_window_q: int | None = None,
sliding_window_k: int | None = None,
): ):
""" """
q, k, v: [b * s, head, head_dim] q, k, v: [b * s, head, head_dim]
...@@ -232,7 +233,7 @@ def context_attention_fwd( ...@@ -232,7 +233,7 @@ def context_attention_fwd(
Lq, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1] Lq, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1]
sm_scale = 1.0 / (Lq**0.5) sm_scale = 1.0 / (Lq**0.5) if softmax_scale is None else softmax_scale
batch, head = b_seq_len.shape[0], q.shape[1] batch, head = b_seq_len.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k.shape[1] kv_group_num = q.shape[1] // k.shape[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment