qwen2vl_mrope.py 1.01 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from liger_kernel.ops import LigerQwen2VLMRopeFunction


def liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
    """
    Applies Multimodal Rotary Positional Embedding (M-RoPE) operation to query and key states.

    Args:
        q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
        k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
        cos (torch.Tensor): The cosine tensor of shape (3, bsz, seq_len, head_dim).
        sin (torch.Tensor): The sine tensor of shape (3, bsz, seq_len, head_dim).
        mrope_section (List[int]): The multimodal rope section for channel dimension of temporal, height and width in rope calculation.
        unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the M-RoPE operation.
    """

    return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)