Unverified Commit ece5a8b0 authored by Lu Fang's avatar Lu Fang Committed by GitHub
Browse files

Make the _apply_rotary_emb compatible with dynamo (#17435)

parent 54072f31
...@@ -32,6 +32,9 @@ from transformers import PretrainedConfig ...@@ -32,6 +32,9 @@ from transformers import PretrainedConfig
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
if current_platform.is_cuda_alike():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
def _rotate_neox(x: torch.Tensor) -> torch.Tensor: def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., :x.shape[-1] // 2] x1 = x[..., :x.shape[-1] // 2]
...@@ -78,7 +81,6 @@ def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ...@@ -78,7 +81,6 @@ def _apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor,
positional embeddings. positional embeddings.
""" """
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
return apply_rotary_emb(x.unsqueeze(0), cos, sin, return apply_rotary_emb(x.unsqueeze(0), cos, sin,
not is_neox_style).squeeze(0) not is_neox_style).squeeze(0)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment