import torch from typing import Optional, Tuple from . import op def rms_rotary_embedding_fuse( positions: torch.Tensor, query: torch.Tensor, key: Optional[torch.Tensor], head_size: int, cos_sin_cache: torch.Tensor, is_neox: bool, weight_q: torch.Tensor, weight_k: torch.Tensor, residual_q: Optional[torch.Tensor], residual_k: Optional[torch.Tensor], epsilon: float = 1e-5, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: op.rms_rotary_embedding_fuse( positions, query, key, head_size, cos_sin_cache, is_neox, weight_q, weight_k, residual_q, residual_k, epsilon, ) return query, key