Commit 9b35bbfa authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev' into v0.9.2-step3v

parents 7f54652e 4ba4b755
...@@ -39,10 +39,11 @@ from vllm.model_executor.custom_op import CustomOp ...@@ -39,10 +39,11 @@ from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
import vllm.envs as envs import vllm.envs as envs
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
from vllm.utils import SUPPORT_TC
if current_platform.is_cuda(): if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
if current_platform.is_rocm(): if current_platform.is_rocm() and SUPPORT_TC:
from flash_attn.layers.rotary import apply_rotary_emb from flash_attn.layers.rotary import apply_rotary_emb
...@@ -91,8 +92,11 @@ def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ...@@ -91,8 +92,11 @@ def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
positional embeddings. positional embeddings.
""" """
if current_platform.is_cuda(): if current_platform.is_cuda():
return apply_rotary_emb(x.unsqueeze(0), cos, sin, if SUPPORT_TC:
not is_neox_style).squeeze(0) return apply_rotary_emb(x.unsqueeze(0), cos, sin,
not is_neox_style).squeeze(0)
else:
return _apply_rotary_emb_torch(x, cos, sin, is_neox_style)
else: else:
return _apply_rotary_emb_torch(x, cos, sin, is_neox_style) return _apply_rotary_emb_torch(x, cos, sin, is_neox_style)
......
...@@ -55,6 +55,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper, ...@@ -55,6 +55,7 @@ from .utils import (AutoWeightsLoader, WeightsMapper,
maybe_prefix, merge_multimodal_embeddings) maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vit_attn_backend from .vision import get_vit_attn_backend
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import SUPPORT_TC
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -331,10 +332,11 @@ def apply_rotary_pos_emb_flashatt( ...@@ -331,10 +332,11 @@ def apply_rotary_pos_emb_flashatt(
cos = cos.chunk(2, dim=-1)[0].contiguous() cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous() sin = sin.chunk(2, dim=-1)[0].contiguous()
if not current_platform.is_rocm(): if SUPPORT_TC:
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb if not current_platform.is_rocm():
else: from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
from flash_attn.layers.rotary import apply_rotary_emb else:
from flash_attn.layers.rotary import apply_rotary_emb
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q) q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k) k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
......
...@@ -85,6 +85,7 @@ import re ...@@ -85,6 +85,7 @@ import re
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import SUPPORT_TC
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -246,7 +247,7 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, ...@@ -246,7 +247,7 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor,
apply_rotary_emb = apply_rotary_emb_torch apply_rotary_emb = apply_rotary_emb_torch
if current_platform.is_cuda(): if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
if current_platform.is_rocm(): if current_platform.is_rocm() and SUPPORT_TC:
from flash_attn.layers.rotary import apply_rotary_emb from flash_attn.layers.rotary import apply_rotary_emb
output = apply_rotary_emb(t_, cos, sin).type_as(t) output = apply_rotary_emb(t_, cos, sin).type_as(t)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment