Unverified Commit 85ac82d2 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

[Kernel] Make rotary_embedding ops more flexible with input shape (#12777)

parent 1e57b1ee
...@@ -124,18 +124,54 @@ __global__ void batched_rotary_embedding_kernel( ...@@ -124,18 +124,54 @@ __global__ void batched_rotary_embedding_kernel(
void rotary_embedding( void rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size] // [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] // [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t head_size, int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) { bool is_neox) {
int64_t num_tokens = query.numel() / query.size(-1); // num_tokens = batch_size * seq_len
int64_t num_tokens = positions.numel();
int positions_ndim = positions.dim();
// Make sure num_tokens dim is consistent across positions, query, and key.
TORCH_CHECK(
positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) {
TORCH_CHECK(
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
"query, key and positions must have the same number of tokens");
}
if (positions_ndim == 2) {
TORCH_CHECK(
query.size(0) == positions.size(0) &&
key.size(0) == positions.size(0) &&
query.size(1) == positions.size(1) &&
key.size(1) == positions.size(1),
"query, key and positions must have the same batch_size and seq_len");
}
// Make sure head_size is valid for query and key
// hidden_size = num_heads * head_size
int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.numel() / num_tokens;
TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0);
// Make sure query and key have consistent number of heads
int num_heads = query_hidden_size / head_size;
int num_kv_heads = key_hidden_size / head_size;
TORCH_CHECK(num_heads % num_kv_heads == 0);
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size; int seq_dim_idx = positions_ndim - 1;
int num_kv_heads = key.size(-1) / head_size; int64_t query_stride = query.stride(seq_dim_idx);
int64_t query_stride = query.stride(-2); int64_t key_stride = key.stride(seq_dim_idx);
int64_t key_stride = key.stride(-2);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512)); dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
...@@ -165,19 +201,58 @@ and process in batched manner. ...@@ -165,19 +201,58 @@ and process in batched manner.
void batched_rotary_embedding( void batched_rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size] // [num_tokens, num_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size] // [num_tokens, num_kv_heads * head_size] or
// [batch_size, seq_len, num_heads, head_size] or
// [num_tokens, num_heads, head_size]
int64_t head_size, int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox, int64_t rot_dim, bool is_neox, int64_t rot_dim,
torch::Tensor& cos_sin_cache_offsets // [num_tokens] torch::Tensor& cos_sin_cache_offsets // [num_tokens] or [batch_size]
) { ) {
// num_tokens = batch_size * seq_len
int64_t num_tokens = cos_sin_cache_offsets.size(0); int64_t num_tokens = cos_sin_cache_offsets.size(0);
int num_heads = query.size(-1) / head_size; TORCH_CHECK(
int num_kv_heads = key.size(-1) / head_size; positions.size(0) == num_tokens || positions.numel() == num_tokens,
int64_t query_stride = query.stride(-2); "positions must have the same num_tokens or batch_size as "
int64_t key_stride = key.stride(-2); "cos_sin_cache_offsets");
int positions_ndim = positions.dim();
// Make sure num_tokens dim is consistent across positions, query, and key.
TORCH_CHECK(
positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) {
TORCH_CHECK(
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
"query, key and positions must have the same number of tokens");
}
if (positions_ndim == 2) {
TORCH_CHECK(
query.size(0) == positions.size(0) &&
key.size(0) == positions.size(0) &&
query.size(1) == positions.size(1) &&
key.size(1) == positions.size(1),
"query, key and positions must have the same batch_size and seq_len");
}
// Make sure head_size is valid for query and key
int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.numel() / num_tokens;
TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0);
// Make sure query and key have concistent number of heads
int num_heads = query_hidden_size / head_size;
int num_kv_heads = key_hidden_size / head_size;
TORCH_CHECK(num_heads % num_kv_heads == 0);
int seq_dim_idx = positions_ndim - 1;
int64_t query_stride = query.stride(seq_dim_idx);
int64_t key_stride = key.stride(seq_dim_idx);
dim3 grid(num_tokens); dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512)); dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from itertools import accumulate, product from itertools import accumulate, product
from typing import Dict, List, Optional from typing import Callable, Dict, List, Optional
import pytest import pytest
import torch import torch
...@@ -24,7 +24,21 @@ CUDA_DEVICES = [ ...@@ -24,7 +24,21 @@ CUDA_DEVICES = [
] ]
def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
head_size: int) -> tuple[int, ...]:
return (batch_size, seq_len, num_heads * head_size)
def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
head_size: int) -> tuple[int, ...]:
return (batch_size, seq_len, num_heads, head_size)
TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape]
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
@pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS) @pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
...@@ -36,6 +50,7 @@ CUDA_DEVICES = [ ...@@ -36,6 +50,7 @@ CUDA_DEVICES = [
@torch.inference_mode() @torch.inference_mode()
def test_rotary_embedding( def test_rotary_embedding(
is_neox_style: bool, is_neox_style: bool,
tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
batch_size: int, batch_size: int,
seq_len: int, seq_len: int,
num_heads: int, num_heads: int,
...@@ -58,10 +73,8 @@ def test_rotary_embedding( ...@@ -58,10 +73,8 @@ def test_rotary_embedding(
rope = rope.to(dtype=dtype) rope = rope.to(dtype=dtype)
positions = torch.randint(0, max_position, (batch_size, seq_len)) positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size, query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
seq_len, query = torch.randn(query_shape, dtype=dtype)
num_heads * head_size,
dtype=dtype)
key = torch.randn_like(query) key = torch.randn_like(query)
# NOTE(woosuk): The reference implementation should be executed first # NOTE(woosuk): The reference implementation should be executed first
...@@ -80,6 +93,7 @@ def test_rotary_embedding( ...@@ -80,6 +93,7 @@ def test_rotary_embedding(
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("tensor_shape_fn", TENSORS_SHAPES_FN)
@pytest.mark.parametrize("batch_size", BATCH_SIZES) @pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS) @pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
...@@ -91,6 +105,7 @@ def test_rotary_embedding( ...@@ -91,6 +105,7 @@ def test_rotary_embedding(
@torch.inference_mode() @torch.inference_mode()
def test_batched_rotary_embedding( def test_batched_rotary_embedding(
is_neox_style: bool, is_neox_style: bool,
tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
batch_size: int, batch_size: int,
seq_len: int, seq_len: int,
num_heads: int, num_heads: int,
...@@ -113,10 +128,8 @@ def test_batched_rotary_embedding( ...@@ -113,10 +128,8 @@ def test_batched_rotary_embedding(
rope = rope.to(dtype=dtype) rope = rope.to(dtype=dtype)
positions = torch.randint(0, max_position, (batch_size, seq_len)) positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size, query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size)
seq_len, query = torch.randn(query_shape, dtype=dtype)
num_heads * head_size,
dtype=dtype)
key = torch.randn_like(query) key = torch.randn_like(query)
# NOTE(woosuk): The reference implementation should be executed first # NOTE(woosuk): The reference implementation should be executed first
......
...@@ -424,24 +424,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -424,24 +424,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
def apply_pure_rope(
self,
input_positions: torch.Tensor,
q_pe: torch.Tensor,
k_pe: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
seq_len = input_positions.size(0)
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
q_pe, k_pe = self.rotary_emb(
input_positions,
q_pe.reshape(seq_len, -1),
k_pe.reshape(seq_len, -1),
)
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)
return q_pe, k_pe
def forward( def forward(
self, self,
layer: AttentionLayer, layer: AttentionLayer,
...@@ -466,14 +448,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -466,14 +448,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# Restore head dim (for rotary embedding) # Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1) k_pe = k_pe.unsqueeze(1)
assert hasattr(attn_metadata, "input_positions") assert hasattr(attn_metadata, "input_positions")
rope_fn = (self.rotary_emb
if self.use_yarn_rope else self.apply_pure_rope)
if is_decode: if is_decode:
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c) q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\ q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
.view(-1, self.num_heads, self.qk_rope_head_dim) .view(-1, self.num_heads, self.qk_rope_head_dim)
q_pe, k_pe = rope_fn(attn_metadata.input_positions, q_pe, k_pe) q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe,
k_pe)
else: else:
assert is_prefill assert is_prefill
q = self.q_proj(hidden_states_or_q_c)[0]\ q = self.q_proj(hidden_states_or_q_c)[0]\
...@@ -481,7 +462,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -481,7 +462,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# TODO(lucas): there must be a nicer way to write this line # TODO(lucas): there must be a nicer way to write this line
q[..., self.qk_nope_head_dim:], k_pe = \ q[..., self.qk_nope_head_dim:], k_pe = \
rope_fn( self.rotary_emb(
attn_metadata.input_positions, attn_metadata.input_positions,
q[..., self.qk_nope_head_dim:], k_pe) q[..., self.qk_nope_head_dim:], k_pe)
......
...@@ -257,9 +257,7 @@ class DeepseekV2Attention(nn.Module): ...@@ -257,9 +257,7 @@ class DeepseekV2Attention(nn.Module):
prefix=f"{prefix}.o_proj") prefix=f"{prefix}.o_proj")
if rope_scaling: if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn' rope_scaling["rope_type"] = 'deepseek_yarn'
self.use_normal_rope = False
else:
self.use_normal_rope = True
self.rotary_emb = get_rope(qk_rope_head_dim, self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim, rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
...@@ -309,17 +307,8 @@ class DeepseekV2Attention(nn.Module): ...@@ -309,17 +307,8 @@ class DeepseekV2Attention(nn.Module):
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = latent_cache[:, :, self.kv_lora_rank:] k_pe = latent_cache[:, :, self.kv_lora_rank:]
if self.use_normal_rope:
seq_len = positions.size(0)
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape
q_pe = q_pe.reshape(seq_len, -1)
k_pe = k_pe.reshape(seq_len, -1)
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
if self.use_normal_rope:
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)
q[..., self.qk_nope_head_dim:] = q_pe q[..., self.qk_nope_head_dim:] = q_pe
k = torch.empty_like(q) k = torch.empty_like(q)
k[..., :self.qk_nope_head_dim] = k_nope k[..., :self.qk_nope_head_dim] = k_nope
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment