Commit 1441a435 authored by maxiao1's avatar maxiao1
Browse files

fix figure test bug.

parent 20eac49d
...@@ -29,6 +29,9 @@ _is_hip = is_hip() ...@@ -29,6 +29,9 @@ _is_hip = is_hip()
if _is_cuda: if _is_cuda:
from sgl_kernel.flash_attn import flash_attn_varlen_func from sgl_kernel.flash_attn import flash_attn_varlen_func
if _is_hip:
from sglang.srt.layers.attention.flashattention_interface import flash_attn_varlen_func
if _is_npu: if _is_npu:
import torch_npu import torch_npu
...@@ -299,8 +302,6 @@ class VisionFlash3Attention(nn.Module): ...@@ -299,8 +302,6 @@ class VisionFlash3Attention(nn.Module):
): ):
# if not _is_cuda: # if not _is_cuda:
# raise Exception("VisionFlash3Attention is only available for cuda") # raise Exception("VisionFlash3Attention is only available for cuda")
if _is_hip:
from sglang.srt.layers.attention.flashattention_interface import flash_attn_varlen_func
super().__init__() super().__init__()
def forward( def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment