Commit 4fd5389b authored by zhuwenwen's avatar zhuwenwen
Browse files

update qwen2&2.5-vl prefill interface and fix pixtral run error

parent 504a12b8
......@@ -1045,7 +1045,7 @@ class PixtralHFAttention(nn.Module):
q, k, v, attn_mask=attention_mask)
out = out.transpose(1, 2)
out = out.view(batch, patches, self.n_heads * self.head_dim)
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
attn_output, _ = self.o_proj(out)
return attn_output, None
......
......@@ -73,6 +73,7 @@ import os
import re
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from vllm.platforms import current_platform
logger = init_logger(__name__)
......@@ -311,9 +312,11 @@ class Qwen2_5_VisionAttention(nn.Module):
use_flash_attn=use_flash_attn)
if self.attn_backend == _Backend.FLASH_ATTN:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
from flash_attn import flash_attn_varlen_func
if not current_platform.is_rocm():
from vllm_flash_attn.flash_attn_interface import (
flash_attn_varlen_func)
else:
from flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
......@@ -1192,4 +1195,4 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
language_model="language_model",
connector="visual.merger.",
tower_model="visual.",
)
)
\ No newline at end of file
......@@ -81,6 +81,7 @@ import os
import re
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from vllm.platforms import current_platform
logger = init_logger(__name__)
......@@ -330,9 +331,11 @@ class Qwen2VisionAttention(nn.Module):
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
if self.attn_backend == _Backend.FLASH_ATTN:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
from flash_attn import flash_attn_varlen_func
if not current_platform.is_rocm():
from vllm_flash_attn.flash_attn_interface import (
flash_attn_varlen_func)
else:
from flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
......@@ -1469,4 +1472,4 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
language_model="language_model",
connector="visual.merger.",
tower_model="visual.",
)
)
\ No newline at end of file
......@@ -83,11 +83,11 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
if current_platform.is_cuda():
if current_platform.is_cuda() or current_platform.is_rocm():
device_available = current_platform.has_device_capability(80)
if device_available and support_fa:
from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available():
if is_flash_attn_2_available() or current_platform.is_rocm():
selected_backend = _Backend.FLASH_ATTN
else:
logger.warning_once(
......@@ -146,4 +146,4 @@ def resolve_visual_encoder_outputs(
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
if post_layer_norm is not None and uses_last_layer:
hs_pool[-1] = post_layer_norm(encoder_outputs)
return torch.cat(hs_pool, dim=-1)
return torch.cat(hs_pool, dim=-1)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment