Unverified Commit dd96465f authored by Chendi.Xue's avatar Chendi.Xue Committed by GitHub
Browse files

[BugFix][QWEN-VL]fix wrong apply_rotary_emb_torch selection introduced by #24642 (#26123)


Signed-off-by: default avatarChendi Xue <Chendi.Xue@intel.com>
Signed-off-by: default avatarChendi.Xue <chendi.xue@intel.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
parent 4f8f47e8
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import math import math
from functools import cache from functools import cache
from importlib.util import find_spec from importlib.util import find_spec
from typing import Callable from typing import Callable, Optional
import torch import torch
...@@ -72,7 +72,9 @@ def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor, ...@@ -72,7 +72,9 @@ def apply_rotary_emb_dispatch(x: torch.Tensor, cos: torch.Tensor,
@cache @cache
def dispatch_rotary_emb_function() -> Callable[..., torch.Tensor]: def dispatch_rotary_emb_function(
default: Optional[Callable[..., torch.Tensor]] = None
) -> Callable[..., torch.Tensor]:
if current_platform.is_cuda(): if current_platform.is_cuda():
return apply_rotary_emb return apply_rotary_emb
...@@ -85,6 +87,9 @@ def dispatch_rotary_emb_function() -> Callable[..., torch.Tensor]: ...@@ -85,6 +87,9 @@ def dispatch_rotary_emb_function() -> Callable[..., torch.Tensor]:
"flash_attn is not installed. Falling back to PyTorch " "flash_attn is not installed. Falling back to PyTorch "
"implementation for rotary embeddings.") "implementation for rotary embeddings.")
if default is not None:
return default
else:
return apply_rotary_emb_torch return apply_rotary_emb_torch
......
...@@ -276,7 +276,8 @@ def apply_rotary_emb_torch(x: torch.Tensor, ...@@ -276,7 +276,8 @@ def apply_rotary_emb_torch(x: torch.Tensor,
def apply_rotary_pos_emb_vision(t: torch.Tensor, def apply_rotary_pos_emb_vision(t: torch.Tensor,
freqs: torch.Tensor) -> torch.Tensor: freqs: torch.Tensor) -> torch.Tensor:
rotary_emb_function = dispatch_rotary_emb_function() rotary_emb_function = dispatch_rotary_emb_function(
default=apply_rotary_emb_torch)
t_ = t.float() t_ = t.float()
cos = freqs.cos() cos = freqs.cos()
sin = freqs.sin() sin = freqs.sin()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment