Commit 4fd5389b authored by zhuwenwen's avatar zhuwenwen
Browse files

update qwen2&2.5-vl prefill interface and fix pixtral run error

parent 504a12b8
...@@ -1045,7 +1045,7 @@ class PixtralHFAttention(nn.Module): ...@@ -1045,7 +1045,7 @@ class PixtralHFAttention(nn.Module):
q, k, v, attn_mask=attention_mask) q, k, v, attn_mask=attention_mask)
out = out.transpose(1, 2) out = out.transpose(1, 2)
out = out.view(batch, patches, self.n_heads * self.head_dim) out = out.reshape(batch, patches, self.n_heads * self.head_dim)
attn_output, _ = self.o_proj(out) attn_output, _ = self.o_proj(out)
return attn_output, None return attn_output, None
......
...@@ -73,6 +73,7 @@ import os ...@@ -73,6 +73,7 @@ import os
import re import re
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -311,9 +312,11 @@ class Qwen2_5_VisionAttention(nn.Module): ...@@ -311,9 +312,11 @@ class Qwen2_5_VisionAttention(nn.Module):
use_flash_attn=use_flash_attn) use_flash_attn=use_flash_attn)
if self.attn_backend == _Backend.FLASH_ATTN: if self.attn_backend == _Backend.FLASH_ATTN:
# from vllm_flash_attn.flash_attn_interface import ( if not current_platform.is_rocm():
# flash_attn_varlen_func) from vllm_flash_attn.flash_attn_interface import (
from flash_attn import flash_attn_varlen_func flash_attn_varlen_func)
else:
from flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
...@@ -1192,4 +1195,4 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1192,4 +1195,4 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
language_model="language_model", language_model="language_model",
connector="visual.merger.", connector="visual.merger.",
tower_model="visual.", tower_model="visual.",
) )
\ No newline at end of file
...@@ -81,6 +81,7 @@ import os ...@@ -81,6 +81,7 @@ import os
import re import re
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -330,9 +331,11 @@ class Qwen2VisionAttention(nn.Module): ...@@ -330,9 +331,11 @@ class Qwen2VisionAttention(nn.Module):
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
if self.attn_backend == _Backend.FLASH_ATTN: if self.attn_backend == _Backend.FLASH_ATTN:
# from vllm_flash_attn.flash_attn_interface import ( if not current_platform.is_rocm():
# flash_attn_varlen_func) from vllm_flash_attn.flash_attn_interface import (
from flash_attn import flash_attn_varlen_func flash_attn_varlen_func)
else:
from flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
...@@ -1469,4 +1472,4 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, ...@@ -1469,4 +1472,4 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal,
language_model="language_model", language_model="language_model",
connector="visual.merger.", connector="visual.merger.",
tower_model="visual.", tower_model="visual.",
) )
\ No newline at end of file
...@@ -83,11 +83,11 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend: ...@@ -83,11 +83,11 @@ def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
if backend_by_env_var is not None: if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var) selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None: if selected_backend is None:
if current_platform.is_cuda(): if current_platform.is_cuda() or current_platform.is_rocm():
device_available = current_platform.has_device_capability(80) device_available = current_platform.has_device_capability(80)
if device_available and support_fa: if device_available and support_fa:
from transformers.utils import is_flash_attn_2_available from transformers.utils import is_flash_attn_2_available
if is_flash_attn_2_available(): if is_flash_attn_2_available() or current_platform.is_rocm():
selected_backend = _Backend.FLASH_ATTN selected_backend = _Backend.FLASH_ATTN
else: else:
logger.warning_once( logger.warning_once(
...@@ -146,4 +146,4 @@ def resolve_visual_encoder_outputs( ...@@ -146,4 +146,4 @@ def resolve_visual_encoder_outputs(
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1) uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
if post_layer_norm is not None and uses_last_layer: if post_layer_norm is not None and uses_last_layer:
hs_pool[-1] = post_layer_norm(encoder_outputs) hs_pool[-1] = post_layer_norm(encoder_outputs)
return torch.cat(hs_pool, dim=-1) return torch.cat(hs_pool, dim=-1)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment