Commit 4fd5389b authored by zhuwenwen's avatar zhuwenwen
Browse files

update qwen2&2.5-vl prefill interface and fix pixtral run error

parent 504a12b8
...@@ -1045,7 +1045,7 @@ class PixtralHFAttention(nn.Module): ...@@ -1045,7 +1045,7 @@ class PixtralHFAttention(nn.Module):
q, k, v, attn_mask=attention_mask) q, k, v, attn_mask=attention_mask)
out = out.transpose(1, 2) out = out.transpose(1, 2)
out = out.view(batch, patches, self.n_heads * self.head_dim) out = out.reshape(batch, patches, self.n_heads * self.head_dim)
attn_output, _ = self.o_proj(out) attn_output, _ = self.o_proj(out)
return attn_output, None return attn_output, None
......
...@@ -73,6 +73,7 @@ import os ...@@ -73,6 +73,7 @@ import os
import re import re
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -311,8 +312,10 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -311,8 +312,10 @@ class Qwen2_5_VisionAttention(nn.Module):
use_flash_attn=use_flash_attn) use_flash_attn=use_flash_attn)
if self.attn_backend == _Backend.FLASH_ATTN: if self.attn_backend == _Backend.FLASH_ATTN:
# from vllm_flash_attn.flash_attn_interface import ( if not current_platform.is_rocm():
# flash_attn_varlen_func) from vllm_flash_attn.flash_attn_interface import (
flash_attn_varlen_func)
else:
from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
......
...@@ -81,6 +81,7 @@ import os ...@@ -81,6 +81,7 @@ import os
import re import re
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -330,8 +331,10 @@ class Qwen2VisionAttention(nn.Module): ...@@ -330,8 +331,10 @@ class Qwen2VisionAttention(nn.Module):
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
if self.attn_backend == _Backend.FLASH_ATTN: if self.attn_backend == _Backend.FLASH_ATTN:
# from vllm_flash_attn.flash_attn_interface import ( if not current_platform.is_rocm():
# flash_attn_varlen_func) from vllm_flash_attn.flash_attn_interface import (
flash_attn_varlen_func)
else:
from flash_attn import flash_attn_varlen_func from flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
......
...@@ -83,11 +83,11 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend: ...@@ -83,11 +83,11 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
if backend_by_env_var is not None: if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var) selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None: if selected_backend is None:
if current_platform.is_cuda(): if current_platform.is_cuda() or current_platform.is_rocm():
device_available = current_platform.has_device_capability(80) device_available = current_platform.has_device_capability(80)
if device_available and support_fa: if device_available and support_fa:
from transformers.utils import is_flash_attn_2_available from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available(): if is_flash_attn_2_available() or current_platform.is_rocm():
selected_backend = _Backend.FLASH_ATTN selected_backend = _Backend.FLASH_ATTN
else: else:
logger.warning_once( logger.warning_once(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment