import torch import triton import triton.language as tl @triton.jit def _triton_qwen2vl_mrope( q_ptr, k_ptr, cos, sin, sl, bs: tl.constexpr, n_qh: tl.constexpr, n_kh: tl.constexpr, hd: tl.constexpr, pad_n_qh: tl.constexpr, pad_n_kh: tl.constexpr, pad_hd: tl.constexpr, mrope_section_t: tl.constexpr, mrope_section_h: tl.constexpr, BLOCK_SIZE: tl.constexpr, BACKWARD_PASS: tl.constexpr = False, ): pid = tl.program_id(0) # locate start address q_ptr = q_ptr + pid * (n_qh * hd) k_ptr = k_ptr + pid * (n_kh * hd) # #################################################################### # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position # m of this program instance # #################################################################### # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index # and pid % sl to get the sequence index. # 2. We only need the left half of cos and sin matrix because the right half is just # a clone of the left half. t_end = mrope_section_t h_end = t_end + mrope_section_h t_cos = cos + pid * hd h_cos = t_cos + bs * sl * hd w_cos = h_cos + bs * sl * hd t_sin = sin + pid * hd h_sin = t_sin + bs * sl * hd w_sin = h_sin + bs * sl * hd cos_offsets = tl.arange(0, pad_hd // 2) t_mask = cos_offsets < t_end h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end) w_mask = (h_end <= cos_offsets) & (cos_offsets < hd // 2) t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0) h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0) w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0) t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0) h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0) w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0) cos_row = t_cos_row + h_cos_row + w_cos_row sin_row = t_sin_row + h_sin_row + w_sin_row # #################################################################### # Load the left and right half of q and k for the current # program instance (i.e. for the current token) separately # #################################################################### # left half of the head first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2) q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype) k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype) # right half of the head second_half_q_offsets = first_half_q_offsets + (hd // 2) second_half_k_offsets = first_half_k_offsets + (hd // 2) second_q_mask = first_q_mask second_k_mask = first_k_mask q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype) k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype) if not BACKWARD_PASS: # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) else: # with some math, we can get: # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin] new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section): # transpose it back to the physical shape because Triton looks at the physical storage # note: q and k are incontiguous before the transformation and will become contiguous after transpose q = q.transpose(1, 2) k = k.transpose(1, 2) batch_size, seq_len, n_q_head, head_dim = q.shape n_kv_head = k.shape[2] pad_hd = triton.next_power_of_2(head_dim) pad_n_q_head = triton.next_power_of_2(n_q_head) pad_n_kv_head = triton.next_power_of_2(n_kv_head) BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) n_row = batch_size * seq_len # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous q = q.contiguous() k = k.contiguous() cos = cos.contiguous() sin = sin.contiguous() _triton_qwen2vl_mrope[(n_row,)]( q, k, cos, sin, seq_len, batch_size, n_q_head, n_kv_head, head_dim, pad_n_q_head, pad_n_kv_head, pad_hd, mrope_section[0], mrope_section[1], BLOCK_SIZE=BLOCK_SIZE, BACKWARD_PASS=False, ) return q.transpose(1, 2), k.transpose(1, 2), cos, sin def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section): dq = dq.transpose(1, 2) dk = dk.transpose(1, 2) batch_size, seq_len, n_q_head, head_dim = dq.shape n_kv_head = dk.shape[2] pad_hd = triton.next_power_of_2(head_dim) pad_n_q_head = triton.next_power_of_2(n_q_head) pad_n_kv_head = triton.next_power_of_2(n_kv_head) BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) n_row = batch_size * seq_len # ensure dq and dk are contiguous dq = dq.contiguous() dk = dk.contiguous() # backward is similar to forward except swapping few ops _triton_qwen2vl_mrope[(n_row,)]( dq, dk, cos, sin, seq_len, batch_size, n_q_head, n_kv_head, head_dim, pad_n_q_head, pad_n_kv_head, pad_hd, mrope_section[0], mrope_section[1], BLOCK_SIZE=BLOCK_SIZE, BACKWARD_PASS=True, ) return dq.transpose(1, 2), dk.transpose(1, 2) class LigerQwen2VLMRopeFunction(torch.autograd.Function): """ Triton implementation of the Qwen2VL Multimodal Rotary Positional Embedding (M-RoPE) operation. Please find the corresponding HuggingFace implementation here: https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py """ @staticmethod def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1): """ q size: (bsz, n_q_head, seq_len, head_dim) k size: (bsz, n_kv_head, seq_len, head_dim) cos size: (3, bsz, seq_len, head_dim) sin size: (3, bsz, seq_len, head_dim) """ q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section) ctx.save_for_backward(cos, sin) ctx.mrope_section = mrope_section return q, k def backward(ctx, dq, dk): """ dq size: (bsz, n_q_head, seq_len, head_dim) dk size: (bsz, n_kv_head, seq_len, head_dim) cos size: (3, bsz, seq_len, head_dim) sin size: (3, bsz, seq_len, head_dim) """ cos, sin = ctx.saved_tensors mrope_section = ctx.mrope_section dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section) return dq, dk, None, None, None, None