Unverified Commit 2410132b authored by TJian's avatar TJian Committed by GitHub
Browse files

[ROCm] [Bugfix] Fix torch sdpa hallucination (#30789)


Signed-off-by: default avatartjtanaa <tunjian.tan@embeddedllm.com>
parent 0a1ab1e5
...@@ -16,6 +16,7 @@ import einops ...@@ -16,6 +16,7 @@ import einops
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
...@@ -89,6 +90,13 @@ def torch_sdpa_wrapper( ...@@ -89,6 +90,13 @@ def torch_sdpa_wrapper(
v: torch.Tensor, v: torch.Tensor,
cu_seqlens: torch.Tensor, cu_seqlens: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# Never remove the contiguous logic for ROCm
# Without it, hallucinations occur with the backend
if current_platform.is_rocm():
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
outputs = [] outputs = []
lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() lens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment