Unverified Commit ba605f18 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch][Common] Refactor RoPE (#1626)



* refactor to add cp support for sbhd/bshd
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* support interleaved
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* format
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* add interleaved to RotaryPositionEmbedding in test
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* update
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* merge sbhd/bshd and thd functions
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
parent ff884e20
...@@ -11,52 +11,6 @@ from transformer_engine.pytorch.dot_product_attention.rope import ( ...@@ -11,52 +11,6 @@ from transformer_engine.pytorch.dot_product_attention.rope import (
) )
def _get_thd_freqs_on_this_cp_rank(
cp_rank: int, cp_size: int, x: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
if cp_size > 1:
cp_seg = x.size(0) // 2
full_seqlen = cp_size * x.size(0)
return torch.cat(
[
freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg],
freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg],
]
)
else:
return freqs[: x.size(0)]
def apply_rotary_pos_emb_thd(
t: torch.Tensor,
cu_seqlens: torch.Tensor,
freqs: torch.Tensor,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""A baseline implementation of applying RoPE for `thd` format.
Args:
t (Tensor): Input tensor T is of shape [t, h, d]
cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`,
with shape [b + 1] and dtype torch.int32.
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d]
Returns:
Tensor: Shape [t, h, d]. The input tensor after applying RoPE.
"""
cu_seqlens = cu_seqlens // cp_size
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return torch.cat(
[
apply_rotary_pos_emb(
x.unsqueeze(1), _get_thd_freqs_on_this_cp_rank(cp_rank, cp_size, x, freqs)
)
for x in torch.split(t, seqlens)
]
).squeeze(1)
# Gradient is a broadcasted scalar # Gradient is a broadcasted scalar
def _overlapping_grad(output: torch.Tensor) -> torch.Tensor: def _overlapping_grad(output: torch.Tensor) -> torch.Tensor:
return output.sum() * 2 return output.sum() * 2
...@@ -76,6 +30,8 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor: ...@@ -76,6 +30,8 @@ def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
@pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)]) @pytest.mark.parametrize("transpose", [None, (0, 1), (2, 3)])
@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"]) @pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) @pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
@pytest.mark.parametrize("cp_size", [1, 2])
@pytest.mark.parametrize("interleaved", [True, False])
def test_fused_rope( def test_fused_rope(
dtype: torch.dtype, dtype: torch.dtype,
seq_length: int, seq_length: int,
...@@ -85,6 +41,8 @@ def test_fused_rope( ...@@ -85,6 +41,8 @@ def test_fused_rope(
transpose: Union[Tuple, None], transpose: Union[Tuple, None],
tensor_format: str, tensor_format: str,
loss_func: Callable, loss_func: Callable,
cp_size: int,
interleaved: bool,
) -> None: ) -> None:
device = torch.device("cuda:0") device = torch.device("cuda:0")
batch_size, head_num = 2, 64 batch_size, head_num = 2, 64
...@@ -99,14 +57,22 @@ def test_fused_rope( ...@@ -99,14 +57,22 @@ def test_fused_rope(
t = t.transpose(*transpose).contiguous().transpose(*transpose) t = t.transpose(*transpose).contiguous().transpose(*transpose)
t.requires_grad = True t.requires_grad = True
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb = rotary_pos_emb(seq_length) emb = rotary_pos_emb(seq_length * cp_size)
assert emb.is_contiguous()
for cp_rank in range(cp_size):
# unfused # unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32 # The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison # for more accurate comparison
output_unfused = apply_rotary_pos_emb( output_unfused = apply_rotary_pos_emb(
t.float(), emb, tensor_format=tensor_format, fused=False t.float(),
emb,
tensor_format=tensor_format,
interleaved=interleaved,
fused=False,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype) ).to(dtype)
loss_unfused = loss_func(output_unfused) loss_unfused = loss_func(output_unfused)
loss_unfused.backward() loss_unfused.backward()
...@@ -118,7 +84,10 @@ def test_fused_rope( ...@@ -118,7 +84,10 @@ def test_fused_rope(
t, t,
emb, emb,
tensor_format=tensor_format, tensor_format=tensor_format,
interleaved=interleaved,
fused=True, fused=True,
cp_size=cp_size,
cp_rank=cp_rank,
) )
loss_fused = loss_func(output_fused) loss_fused = loss_func(output_fused)
loss_fused.backward() loss_fused.backward()
...@@ -135,7 +104,8 @@ def test_fused_rope( ...@@ -135,7 +104,8 @@ def test_fused_rope(
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) @pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("transpose", [None, (1, 2)]) @pytest.mark.parametrize("transpose", [None, (1, 2)])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) @pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
@pytest.mark.parametrize("cp_size", [1, 2, 3]) @pytest.mark.parametrize("cp_size", [1, 2])
@pytest.mark.parametrize("interleaved", [True, False])
def test_fused_rope_thd( def test_fused_rope_thd(
dtype: torch.dtype, dtype: torch.dtype,
hidden_size: int, hidden_size: int,
...@@ -143,6 +113,7 @@ def test_fused_rope_thd( ...@@ -143,6 +113,7 @@ def test_fused_rope_thd(
transpose: Union[Tuple, None], transpose: Union[Tuple, None],
loss_func: Callable, loss_func: Callable,
cp_size: int, cp_size: int,
interleaved: bool,
) -> None: ) -> None:
device = torch.device("cuda:0") device = torch.device("cuda:0")
batch_size, head_num = 2, 64 batch_size, head_num = 2, 64
...@@ -170,15 +141,23 @@ def test_fused_rope_thd( ...@@ -170,15 +141,23 @@ def test_fused_rope_thd(
t = t.transpose(*transpose).contiguous().transpose(*transpose) t = t.transpose(*transpose).contiguous().transpose(*transpose)
t.requires_grad = True t.requires_grad = True
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent) rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb = rotary_pos_emb(cu_seqlens_padded[-1]) emb = rotary_pos_emb(cu_seqlens_padded[-1])
assert emb.is_contiguous()
for cp_rank in range(cp_size): for cp_rank in range(cp_size):
# unfused # unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32 # The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison # for more accurate comparison
output_unfused = apply_rotary_pos_emb_thd( output_unfused = apply_rotary_pos_emb(
t.float(), cu_seqlens_padded, emb, cp_size, cp_rank t.float(),
emb,
tensor_format="thd",
interleaved=interleaved,
fused=False,
cu_seqlens=cu_seqlens_padded,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype) ).to(dtype)
loss_unfused = loss_func(output_unfused) loss_unfused = loss_func(output_unfused)
loss_unfused.backward() loss_unfused.backward()
...@@ -189,6 +168,7 @@ def test_fused_rope_thd( ...@@ -189,6 +168,7 @@ def test_fused_rope_thd(
output_fused = apply_rotary_pos_emb( output_fused = apply_rotary_pos_emb(
t, t,
emb, emb,
interleaved=interleaved,
fused=True, fused=True,
tensor_format="thd", tensor_format="thd",
cu_seqlens=cu_seqlens_padded, cu_seqlens=cu_seqlens_padded,
......
...@@ -16,10 +16,11 @@ namespace transformer_engine { ...@@ -16,10 +16,11 @@ namespace transformer_engine {
template <typename scalar_t> template <typename scalar_t>
__device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *dst, __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *dst,
const int s_id, const int offset_block, const bool interleaved, const int s_id,
const int offset_block_dst, const int h, const int d, const int offset_block, const int offset_block_dst,
const int d2, const int stride_h, const int stride_d, const int h, const int d, const int d2, const int stride_h,
const int o_stride_h, const int o_stride_d) { const int stride_d, const int o_stride_h,
const int o_stride_d) {
#pragma unroll #pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
float v_cos, v_sin; float v_cos, v_sin;
...@@ -29,9 +30,18 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs ...@@ -29,9 +30,18 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs
int offset_src = offset_block + h_id * stride_h + d_id * stride_d; int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
float v_src = src[offset_src]; float v_src = src[offset_src];
float v_src_rotate = (d_id + d2 / 2 < d2) float v_src_rotate;
if (!interleaved) {
v_src_rotate = (d_id + d2 / 2 < d2)
? -static_cast<float>(src[offset_src + (d2 / 2) * stride_d]) ? -static_cast<float>(src[offset_src + (d2 / 2) * stride_d])
: static_cast<float>(src[offset_src + (d2 / 2 - d2) * stride_d]); : static_cast<float>(src[offset_src + (d2 / 2 - d2) * stride_d]);
} else {
v_src_rotate = (d_id % 2 == 0)
// d_id + 1
? -static_cast<float>(src[offset_src + stride_d])
// d_id - 1
: static_cast<float>(src[offset_src - stride_d]);
}
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
} }
} }
...@@ -52,22 +62,39 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs ...@@ -52,22 +62,39 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs
template <typename scalar_t> template <typename scalar_t>
__device__ void fused_rope_block_backward(const scalar_t *src, const float *freqs, scalar_t *dst, __device__ void fused_rope_block_backward(const scalar_t *src, const float *freqs, scalar_t *dst,
const int s_id, const int offset_block, const bool interleaved, const int s_id,
const int offset_block_dst, const int h, const int d, const int offset_block, const int offset_block_dst,
const int d2, const int stride_h, const int stride_d, const int h, const int d, const int d2,
const int stride_h, const int stride_d,
const int o_stride_h, const int o_stride_d) { const int o_stride_h, const int o_stride_d) {
#pragma unroll #pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
float v_cos = cosf(freqs[s_id * d2 + d_id]); float v_cos = cosf(freqs[s_id * d2 + d_id]);
float v_sin = (d_id + d2 / 2 < d2) ? sinf(freqs[s_id * d2 + d_id + d2 / 2]) float v_sin;
if (!interleaved) {
v_sin = (d_id + d2 / 2 < d2) ? sinf(freqs[s_id * d2 + d_id + d2 / 2])
: -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]); : -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]);
} else {
v_sin =
(d_id % 2 == 0) ? sinf(freqs[s_id * d2 + d_id + 1]) : -sinf(freqs[s_id * d2 + d_id - 1]);
}
#pragma unroll #pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d; int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
float v_src = src[offset_src]; float v_src = src[offset_src];
float v_src_rotate = (d_id + d2 / 2 < d2) ? src[offset_src + (d2 / 2) * stride_d] float v_src_rotate;
: src[offset_src + (d2 / 2 - d2) * stride_d]; if (!interleaved) {
v_src_rotate = (d_id + d2 / 2 < d2)
? static_cast<float>(src[offset_src + (d2 / 2) * stride_d])
: static_cast<float>(src[offset_src + (d2 / 2 - d2) * stride_d]);
} else {
v_src_rotate = (d_id % 2 == 0)
// d_id + 1
? static_cast<float>(src[offset_src + stride_d])
// d_id - 1
: static_cast<float>(src[offset_src - stride_d]);
}
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
} }
} }
...@@ -87,51 +114,33 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq ...@@ -87,51 +114,33 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq
} }
template <typename scalar_t> template <typename scalar_t>
__global__ void fused_rope_forward_kernel(const scalar_t *src, const float *freqs, scalar_t *dst, __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seqlens,
const float *freqs, scalar_t *dst, const bool interleaved,
const int cp_size, const int cp_rank, const int s,
const int h, const int d, const int d2, const int h, const int d, const int d2,
const int stride_s, const int stride_b, const int stride_s_or_t, const int stride_b,
const int stride_h, const int stride_d, const int stride_h, const int stride_d,
const int o_stride_s, const int o_stride_b, const int o_stride_s_or_t, const int o_stride_b,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = s_id * stride_s + b_id * stride_b;
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
fused_rope_block_forward(src, freqs, dst, s_id, offset_block, offset_block_dst, h, d, d2,
stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
__global__ void fused_rope_backward_kernel(const scalar_t *src, const float *freqs, scalar_t *dst,
const int h, const int d, const int d2,
const int stride_s, const int stride_b,
const int stride_h, const int stride_d,
const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = s_id * stride_s + b_id * stride_b;
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
fused_rope_block_backward(src, freqs, dst, s_id, offset_block, offset_block_dst, h, d, d2,
stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
__global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu_seqlens,
const float *freqs, scalar_t *dst, const int cp_size,
const int cp_rank, const int h, const int d,
const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d) { const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y; int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block, offset_block_dst;
int cur_seqlens;
if (cu_seqlens != nullptr) { // THD
int start = cu_seqlens[b_id] / cp_size; int start = cu_seqlens[b_id] / cp_size;
int end = cu_seqlens[b_id + 1] / cp_size; int end = cu_seqlens[b_id + 1] / cp_size;
int t_id = s_id + start; int t_id = s_id + start;
if (t_id >= end) return; if (t_id >= end) return;
int offset_block = t_id * stride_t; offset_block = t_id * stride_s_or_t;
int offset_block_dst = t_id * o_stride_t; offset_block_dst = t_id * o_stride_s_or_t;
cur_seqlens = end - start;
} else { // SBHD/BSHD
offset_block = s_id * stride_s_or_t + b_id * stride_b;
offset_block_dst = s_id * o_stride_s_or_t + b_id * o_stride_b;
cur_seqlens = s;
}
int s_id_for_freqs; int s_id_for_freqs;
if (cp_size > 1) { if (cp_size > 1) {
int cur_seqlens = end - start;
assert(cur_seqlens % 2 == 0); assert(cur_seqlens % 2 == 0);
if (s_id < cur_seqlens / 2) { if (s_id < cur_seqlens / 2) {
s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2;
...@@ -142,28 +151,37 @@ __global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu ...@@ -142,28 +151,37 @@ __global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu
} else { } else {
s_id_for_freqs = s_id; s_id_for_freqs = s_id;
} }
fused_rope_block_forward(src, freqs, dst, s_id_for_freqs, offset_block, offset_block_dst, h, d,
d2, stride_h, stride_d, o_stride_h, o_stride_d); fused_rope_block_forward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block,
offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
} }
template <typename scalar_t> template <typename scalar_t>
__global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *cu_seqlens, __global__ void fused_rope_backward_kernel(
const float *freqs, scalar_t *dst, const int cp_size, const scalar_t *src, const int *cu_seqlens, const float *freqs, scalar_t *dst,
const int cp_rank, const int h, const int d, const bool interleaved, const int cp_size, const int cp_rank, const int s, const int h,
const int d2, const int stride_t, const int stride_h, const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, const int o_stride_t, const int stride_d, const int o_stride_s_or_t, const int o_stride_b, const int o_stride_h,
const int o_stride_h, const int o_stride_d) { const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y; int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block, offset_block_dst;
int cur_seqlens;
if (cu_seqlens != nullptr) { // THD
int start = cu_seqlens[b_id] / cp_size; int start = cu_seqlens[b_id] / cp_size;
int end = cu_seqlens[b_id + 1] / cp_size; int end = cu_seqlens[b_id + 1] / cp_size;
int t_id = s_id + start; int t_id = s_id + start;
if (t_id >= end) return; if (t_id >= end) return;
int offset_block = t_id * stride_t; offset_block = t_id * stride_s_or_t;
int offset_block_dst = t_id * o_stride_t; offset_block_dst = t_id * o_stride_s_or_t;
cur_seqlens = end - start;
} else { // SBHD/BSHD
offset_block = s_id * stride_s_or_t + b_id * stride_b;
offset_block_dst = s_id * o_stride_s_or_t + b_id * o_stride_b;
cur_seqlens = s;
}
int s_id_for_freqs; int s_id_for_freqs;
if (cp_size > 1) { if (cp_size > 1) {
int cur_seqlens = end - start;
assert(cur_seqlens % 2 == 0); assert(cur_seqlens % 2 == 0);
if (s_id < cur_seqlens / 2) { if (s_id < cur_seqlens / 2) {
s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2;
...@@ -174,193 +192,136 @@ __global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *c ...@@ -174,193 +192,136 @@ __global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *c
} else { } else {
s_id_for_freqs = s_id; s_id_for_freqs = s_id;
} }
fused_rope_block_backward(src, freqs, dst, s_id_for_freqs, offset_block, offset_block_dst, h, d,
d2, stride_h, stride_d, o_stride_h, o_stride_d); fused_rope_block_backward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block,
offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
} }
template <typename scalar_t> template <typename scalar_t>
void fused_rope_forward_launcher(const scalar_t *input, const float *freqs, scalar_t *output, void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, const float *freqs,
scalar_t *output, const NVTE_QKV_Format qkv_format,
const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2, const int s, const int b, const int h, const int d, const int d2,
const int stride_s, const int stride_b, const int stride_h, const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s, const int o_stride_b, const int stride_d, cudaStream_t stream) {
const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8; int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b); dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block); dim3 threads(THREADS_PER_WARP, warps_per_block);
int o_stride_s_or_t, o_stride_b;
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format");
o_stride_s_or_t = h * d;
o_stride_b = 0;
} else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) {
o_stride_s_or_t = b * h * d;
o_stride_b = h * d;
} else {
o_stride_s_or_t = h * d;
o_stride_b = s * h * d;
}
const int o_stride_h = d;
const int o_stride_d = 1;
fused_rope_forward_kernel<<<blocks, threads, 0, stream>>>( fused_rope_forward_kernel<<<blocks, threads, 0, stream>>>(
input, freqs, output, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, input, cu_seqlens, freqs, output, interleaved, cp_size, cp_rank, s, h, d, d2, stride_s_or_t,
o_stride_b, o_stride_h, o_stride_d); stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
template <typename scalar_t> template <typename scalar_t>
void fused_rope_backward_launcher(const scalar_t *output_grads, const float *freqs, void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens,
scalar_t *input_grads, const int s, const int b, const int h, const float *freqs, scalar_t *input_grads,
const int d, const int d2, const int stride_s, const int stride_b, const NVTE_QKV_Format qkv_format, const bool interleaved,
const int stride_h, const int stride_d, const int o_stride_s, const int cp_size, const int cp_rank, const int s, const int b,
const int o_stride_b, const int o_stride_h, const int o_stride_d, const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream) { cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8; int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b); dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block); dim3 threads(THREADS_PER_WARP, warps_per_block);
int o_stride_s_or_t, o_stride_b;
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format");
o_stride_s_or_t = h * d;
o_stride_b = 0;
} else if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) {
o_stride_s_or_t = b * h * d;
o_stride_b = h * d;
} else {
o_stride_s_or_t = h * d;
o_stride_b = s * h * d;
}
const int o_stride_h = d;
const int o_stride_d = 1;
fused_rope_backward_kernel<<<blocks, threads, 0, stream>>>( fused_rope_backward_kernel<<<blocks, threads, 0, stream>>>(
output_grads, freqs, input_grads, h, d, d2, stride_s, stride_b, stride_h, stride_d, output_grads, cu_seqlens, freqs, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2,
o_stride_s, o_stride_b, o_stride_h, o_stride_d); stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h,
NVTE_CHECK_CUDA(cudaGetLastError()); o_stride_d);
}
template <typename scalar_t>
void fused_rope_thd_forward_launcher(const scalar_t *input, const int *cu_seqlens,
const float *freqs, scalar_t *output, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(max_s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_thd_forward_kernel<<<blocks, threads, 0, stream>>>(
input, cu_seqlens, freqs, output, cp_size, cp_rank, h, d, d2, stride_t, stride_h, stride_d,
o_stride_t, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
template <typename scalar_t> void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs,
void fused_rope_thd_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens, Tensor *output, const NVTE_QKV_Format qkv_format, const bool interleaved,
const float *freqs, scalar_t *input_grads, const int cp_size, const int cp_size, const int cp_rank, const int s, const int b, const int h,
const int cp_rank, const int max_s, const int b, const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b,
const int d, const int d2, const int stride_t, const int stride_h, const int stride_d, cudaStream_t stream) {
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(max_s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_thd_backward_kernel<<<blocks, threads, 0, stream>>>(
output_grads, cu_seqlens, freqs, input_grads, cp_size, cp_rank, h, d, d2, stride_t, stride_h,
stride_d, o_stride_t, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
void fused_rope_forward(const Tensor &input, const Tensor &freqs, Tensor *output, const int s,
const int b, const int h, const int d, const int d2, const int stride_s,
const int stride_b, const int stride_h, const int stride_d,
const int o_stride_s, const int o_stride_b, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, scalar_t, input.data.dtype, scalar_t,
fused_rope_forward_launcher(reinterpret_cast<const scalar_t *>(input.data.dptr), fused_rope_forward_launcher(reinterpret_cast<const scalar_t *>(input.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(output->data.dptr), s, b, h, d, d2,
stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, stream););
}
void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs, Tensor *input_grads,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s, const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s, const int o_stride_b,
const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
output_grads.data.dtype, scalar_t,
fused_rope_backward_launcher(reinterpret_cast<const scalar_t *>(output_grads.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(input_grads->data.dptr), s, b, h, d,
d2, stride_s, stride_b, stride_h, stride_d, o_stride_s,
o_stride_b, o_stride_h, o_stride_d, stream););
}
void fused_rope_thd_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs,
Tensor *output, const int cp_size, const int cp_rank, const int max_s,
const int b, const int h, const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, scalar_t,
fused_rope_thd_forward_launcher(reinterpret_cast<const scalar_t *>(input.data.dptr),
reinterpret_cast<const int *>(cu_seqlens.data.dptr), reinterpret_cast<const int *>(cu_seqlens.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr), reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(output->data.dptr), cp_size, reinterpret_cast<scalar_t *>(output->data.dptr), qkv_format,
cp_rank, max_s, b, h, d, d2, stride_t, stride_h, stride_d, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
o_stride_t, o_stride_h, o_stride_d, stream);); stride_b, stride_h, stride_d, stream););
} }
void fused_rope_thd_backward(const Tensor &output_grads, const Tensor &cu_seqlens, void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, const Tensor &freqs,
const Tensor &freqs, Tensor *input_grads, const int cp_size, Tensor *input_grads, const NVTE_QKV_Format qkv_format,
const int cp_rank, const int max_s, const int b, const int h, const bool interleaved, const int cp_size, const int cp_rank, const int s,
const int d, const int d2, const int stride_t, const int stride_h, const int b, const int h, const int d, const int d2,
const int stride_d, const int o_stride_t, const int o_stride_h, const int stride_s_or_t, const int stride_b, const int stride_h,
const int o_stride_d, cudaStream_t stream) { const int stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
output_grads.data.dtype, scalar_t, output_grads.data.dtype, scalar_t,
fused_rope_thd_backward_launcher(reinterpret_cast<const scalar_t *>(output_grads.data.dptr), fused_rope_backward_launcher(reinterpret_cast<const scalar_t *>(output_grads.data.dptr),
reinterpret_cast<const int *>(cu_seqlens.data.dptr), reinterpret_cast<const int *>(cu_seqlens.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr), reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(input_grads->data.dptr), reinterpret_cast<scalar_t *>(input_grads->data.dptr), qkv_format,
cp_size, cp_rank, max_s, b, h, d, d2, stride_t, stride_h, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
stride_d, o_stride_t, o_stride_h, o_stride_d, stream);); stride_b, stride_h, stride_d, stream););
} }
} // end namespace transformer_engine } // end namespace transformer_engine
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output, void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const int s, const int b, const int h, const int d, const int d2, const NVTETensor freqs, NVTETensor output,
const int stride_s, const int stride_b, const int stride_h, const NVTE_QKV_Format qkv_format, const bool interleaved,
const int stride_d, const int o_stride_s, const int o_stride_b, const int cp_size, const int cp_rank, const int s, const int b,
const int o_stride_h, const int o_stride_d, cudaStream_t stream) { const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_forward); NVTE_API_CALL(nvte_fused_rope_forward);
using namespace transformer_engine; using namespace transformer_engine;
fused_rope_forward(*reinterpret_cast<const Tensor *>(input), fused_rope_forward(*reinterpret_cast<const Tensor *>(input),
*reinterpret_cast<const Tensor *>(cu_seqlens),
*reinterpret_cast<const Tensor *>(freqs), reinterpret_cast<Tensor *>(output), *reinterpret_cast<const Tensor *>(freqs), reinterpret_cast<Tensor *>(output),
s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
o_stride_h, o_stride_d, stream); stride_b, stride_h, stride_d, stream);
} }
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs, void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
NVTETensor input_grads, const int s, const int b, const int h, const NVTETensor freqs, NVTETensor input_grads,
const int d, const int d2, const int stride_s, const int stride_b, const NVTE_QKV_Format qkv_format, const bool interleaved,
const int stride_h, const int stride_d, const int o_stride_s, const int cp_size, const int cp_rank, const int s, const int b,
const int o_stride_b, const int o_stride_h, const int o_stride_d, const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_backward); NVTE_API_CALL(nvte_fused_rope_backward);
using namespace transformer_engine; using namespace transformer_engine;
fused_rope_backward(*reinterpret_cast<const Tensor *>(output_grads), fused_rope_backward(*reinterpret_cast<const Tensor *>(output_grads),
*reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(input_grads), s, b, h, d, d2, stride_s, stride_b,
stride_h, stride_d, o_stride_s, o_stride_b, o_stride_h, o_stride_d, stream);
}
void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_thd_forward);
using namespace transformer_engine;
fused_rope_thd_forward(*reinterpret_cast<const Tensor *>(input),
*reinterpret_cast<const Tensor *>(cu_seqlens), *reinterpret_cast<const Tensor *>(cu_seqlens),
*reinterpret_cast<const Tensor *>(freqs), *reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(output), cp_size, cp_rank, max_s, b, h, d, d2, reinterpret_cast<Tensor *>(input_grads), qkv_format, interleaved, cp_size,
stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream); cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream);
}
void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_thd_backward);
using namespace transformer_engine;
fused_rope_thd_backward(
*reinterpret_cast<const Tensor *>(output_grads),
*reinterpret_cast<const Tensor *>(cu_seqlens), *reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(input_grads), cp_size, cp_rank, max_s, b, h, d, d2, stride_t,
stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream);
} }
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#ifndef TRANSFORMER_ENGINE_FUSED_ROPE_H_ #ifndef TRANSFORMER_ENGINE_FUSED_ROPE_H_
#define TRANSFORMER_ENGINE_FUSED_ROPE_H_ #define TRANSFORMER_ENGINE_FUSED_ROPE_H_
#include "fused_attn.h"
#include "transformer_engine.h" #include "transformer_engine.h"
#ifdef __cplusplus #ifdef __cplusplus
...@@ -16,112 +17,63 @@ extern "C" { ...@@ -16,112 +17,63 @@ extern "C" {
/*! \brief Apply rotary positional embedding to the input tensor. /*! \brief Apply rotary positional embedding to the input tensor.
* *
* \param[in] input Input tensor for fused rope. * \param[in] input Input tensor for fused rope.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* (Required for the thd format, empty tensor for other formats)
* \param[in] freqs The freqs tensor. * \param[in] freqs The freqs tensor.
* \param[out] output Output tensor. * \param[out] output Output tensor.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] s Length of the s dimension of input. * \param[in] s Length of the s dimension of input.
* \param[in] b Length of the b dimension of input. * \param[in] b Length of the b dimension of input.
* \param[in] h Length of the h dimension of input. * \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input. * \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs. * \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_s Stride of the s dimension of input. * \param[in] stride_s_or_t Stride of the s (sbhd/bshd)/t (thd) dimension of input.
* \param[in] stride_b Stride of the b dimension of input. * \param[in] stride_b Stride of the b dimension of input. (0 for thd).
* \param[in] stride_h Stride of the h dimension of input. * \param[in] stride_h Stride of the h dimension of input.
* \param[in] stride_d Stride of the d dimension of input. * \param[in] stride_d Stride of the d dimension of input.
* \param[in] o_stride_s Stride of the s dimension of output.
* \param[in] o_stride_b Stride of the b dimension of output.
* \param[in] o_stride_h Stride of the h dimension of output.
* \param[in] o_stride_d Stride of the d dimension of output.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor freqs, NVTETensor output, void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const int s, const int b, const int h, const int d, const int d2, const NVTETensor freqs, NVTETensor output,
const int stride_s, const int stride_b, const int stride_h, const NVTE_QKV_Format qkv_format, const bool interleaved,
const int stride_d, const int o_stride_s, const int o_stride_b, const int cp_size, const int cp_rank, const int s, const int b,
const int o_stride_h, const int o_stride_d, cudaStream_t stream); const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream);
/*! \brief Compute the backward of the fused rope. /*! \brief Compute the backward of the fused rope.
* *
* \param[in] output_grads Incoming gradient tensor for backward. * \param[in] output_grads Incoming gradient tensor for backward.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* (Required for the thd format, empty tensor for other formats)
* \param[in] freqs The freqs tensor. * \param[in] freqs The freqs tensor.
* \param[out] input_grads Input gradient tensor to calculate. * \param[out] input_grads Input gradient tensor to calculate.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] s Length of the s dimension of output_grads. * \param[in] s Length of the s dimension of output_grads.
* \param[in] b Length of the b dimension of output_grads. * \param[in] b Length of the b dimension of output_grads.
* \param[in] h Length of the h dimension of output_grads. * \param[in] h Length of the h dimension of output_grads.
* \param[in] d Length of the d dimension of output_grads. * \param[in] d Length of the d dimension of output_grads.
* \param[in] d2 Length of the d dimension of freqs. * \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_s Stride of the s dimension of output_grads. * \param[in] stride_s_or_t Stride of the s (sbhd/bshd)/t (thd) dimension of output_grads.
* \param[in] stride_b Stride of the b dimension of output_grads. * \param[in] stride_b Stride of the b dimension of output_grads. (0 for thd).
* \param[in] stride_h Stride of the h dimension of output_grads. * \param[in] stride_h Stride of the h dimension of output_grads.
* \param[in] stride_d Stride of the d dimension of output_grads. * \param[in] stride_d Stride of the d dimension of output_grads.
* \param[in] o_stride_s Stride of the s dimension of input_grads.
* \param[in] o_stride_b Stride of the b dimension of input_grads.
* \param[in] o_stride_h Stride of the h dimension of input_grads.
* \param[in] o_stride_d Stride of the d dimension of input_grads.
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor freqs, void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
NVTETensor input_grads, const int s, const int b, const int h, const NVTETensor freqs, NVTETensor input_grads,
const int d, const int d2, const int stride_s, const int stride_b, const NVTE_QKV_Format qkv_format, const bool interleaved,
const int stride_h, const int stride_d, const int o_stride_s, const int cp_size, const int cp_rank, const int s, const int b,
const int o_stride_b, const int o_stride_h, const int o_stride_d, const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Apply rotary positional embedding to the input tensor in thd format.
*
* \param[in] input Input tensor for fused rope.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* \param[in] freqs The freqs tensor.
* \param[out] output Output tensor.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] max_s Max sequence length.
* \param[in] b Batch size.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_t Stride of the t dimension of input.
* \param[in] stride_h Stride of the h dimension of input.
* \param[in] stride_d Stride of the d dimension of input.
* \param[in] o_stride_t Stride of the t dimension of output.
* \param[in] o_stride_h Stride of the h dimension of output.
* \param[in] o_stride_d Stride of the d dimension of output.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream);
/*! \brief Compute the backward of the fused rope in thd format.
*
* \param[in] output_grads Incoming gradient tensor for backward.
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* \param[in] freqs The freqs tensor.
* \param[out] input_grads Input gradient to calculate.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] max_s Max sequence length.
* \param[in] b Batch size.
* \param[in] h Length of the h dimension of output_grads.
* \param[in] d Length of the d dimension of output_grads.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] stride_t Stride of the t dimension of output_grads.
* \param[in] stride_h Stride of the h dimension of output_grads.
* \param[in] stride_d Stride of the d dimension of output_grads.
* \param[in] o_stride_t Stride of the t dimension of input_grads.
* \param[in] o_stride_h Stride of the h dimension of input_grads.
* \param[in] o_stride_d Stride of the d dimension of input_grads.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -265,16 +265,14 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio ...@@ -265,16 +265,14 @@ void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reductio
**************************************************************************************************/ **************************************************************************************************/
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const bool transpose_output_memory); const NVTE_QKV_Format qkv_format, const bool interleaved,
const c10::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank);
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const bool transpose_output_memory); const NVTE_QKV_Format qkv_format, const bool interleaved,
const c10::optional<at::Tensor> cu_seqlens, const int cp_size,
at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens, const int cp_rank);
const at::Tensor &freqs, const int cp_size, const int cp_rank);
at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens,
const at::Tensor &freqs, const int cp_size, const int cp_rank);
/*************************************************************************************************** /***************************************************************************************************
* Miscellaneous * Miscellaneous
......
...@@ -7,138 +7,38 @@ ...@@ -7,138 +7,38 @@
#include "extensions.h" #include "extensions.h"
at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const bool transpose_output_memory) { const NVTE_QKV_Format qkv_format, const bool interleaved,
const c10::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank) {
using namespace transformer_engine::pytorch; using namespace transformer_engine::pytorch;
TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(input.size(0) <= freqs.size(0),
"expected freqs tensor has a longer sequence length than input");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1"); "expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(input.size(3) >= freqs.size(3),
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float, TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float"); "Dtype of the freqs tensor must be float");
// input sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = input.size(0);
const int b = input.size(1);
const int h = input.size(2);
const int d = input.size(3);
// input strides
const int stride_s = input.stride(0);
const int stride_b = input.stride(1);
const int stride_h = input.stride(2);
const int stride_d = input.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
// output // output
auto act_options = input.options().requires_grad(false); auto act_options = at::TensorOptions().dtype(input.scalar_type()).device(input.device());
at::Tensor output; auto output = at::empty(input.sizes(), act_options);
if (transpose_output_memory) {
output = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
output = torch::empty({s, b, h, d}, act_options);
}
// output strides
const int o_stride_s = output.stride(0);
const int o_stride_b = output.stride(1);
const int o_stride_h = output.stride(2);
const int o_stride_d = output.stride(3);
auto input_cu = makeTransformerEngineTensor(input); auto input_cu = makeTransformerEngineTensor(input);
auto freqs_cu = makeTransformerEngineTensor(freqs); auto freqs_cu = makeTransformerEngineTensor(freqs);
auto output_cu = makeTransformerEngineTensor(output); auto output_cu = makeTransformerEngineTensor(output);
nvte_fused_rope_forward(input_cu.data(), freqs_cu.data(), output_cu.data(), s, b, h, d, d2, if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const bool transpose_output_memory) {
using namespace transformer_engine::pytorch;
TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(output_grads.size(0) <= freqs.size(0),
"expected freqs tensor has a longer sequence length than output_grads");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(output_grads.size(3) >= freqs.size(3),
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// output_grads sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = output_grads.size(0);
const int b = output_grads.size(1);
const int h = output_grads.size(2);
const int d = output_grads.size(3);
// output_grads strides
const int stride_s = output_grads.stride(0);
const int stride_b = output_grads.stride(1);
const int stride_h = output_grads.stride(2);
const int stride_d = output_grads.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
auto act_options = output_grads.options().requires_grad(false);
at::Tensor input_grads;
if (transpose_output_memory) {
input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1);
} else {
input_grads = torch::empty({s, b, h, d}, act_options);
}
const int o_stride_s = input_grads.stride(0);
const int o_stride_b = input_grads.stride(1);
const int o_stride_h = input_grads.stride(2);
const int o_stride_d = input_grads.stride(3);
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto input_grads_cu = makeTransformerEngineTensor(input_grads);
nvte_fused_rope_backward(output_grads_cu.data(), freqs_cu.data(), input_grads_cu.data(), s, b, h,
d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b,
o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream());
return input_grads;
}
at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens,
const at::Tensor &freqs, const int cp_size, const int cp_rank) {
using namespace transformer_engine::pytorch;
TORCH_CHECK(input.dim() == 3, "expected 3D tensor"); TORCH_CHECK(input.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor"); TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); TORCH_CHECK(cu_seqlens.value().dim() == 1, "expected 1D tensor");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(input.size(2) >= freqs.size(3), TORCH_CHECK(input.size(2) >= freqs.size(3),
"expected the last dim of the input tensor equals or is " "expected the last dim of the input tensor equals or is "
"greater than the freqs tensor"); "greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// input sizes: (t, h, d) // input sizes: (t, h, d)
// t: cumulative sum of sequence lengths // t: cumulative sum of sequence lengths
// h: head num // h: head num
// d: dim of each head // d: dim of each head
const int t = input.size(0); // const int t = input.size(0);
const int h = input.size(1); const int h = input.size(1);
const int d = input.size(2); const int d = input.size(2);
// input strides // input strides
...@@ -146,51 +46,86 @@ at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_ ...@@ -146,51 +46,86 @@ at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_
const int stride_h = input.stride(1); const int stride_h = input.stride(1);
const int stride_d = input.stride(2); const int stride_d = input.stride(2);
// batch size // batch size
const int b = cu_seqlens.size(0) - 1; const int b = cu_seqlens.value().size(0) - 1;
// freqs' shape is (max_s, 1, 1, d2) // freqs' shape is (max_s, 1, 1, d2)
const int max_s = freqs.size(0); const int max_s = freqs.size(0);
const int d2 = freqs.size(3); const int d2 = freqs.size(3);
// output auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value());
auto act_options = input.options().requires_grad(false);
auto output = torch::empty({t, h, d}, act_options);
// output strides
const int o_stride_t = output.stride(0);
const int o_stride_h = output.stride(1);
const int o_stride_d = output.stride(2);
auto input_cu = makeTransformerEngineTensor(input); nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens); output_cu.data(), qkv_format, interleaved, cp_size, cp_rank, max_s, b,
auto freqs_cu = makeTransformerEngineTensor(freqs); h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d,
auto output_cu = makeTransformerEngineTensor(output);
nvte_fused_rope_thd_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
output_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, stride_t,
stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
return output;
}
TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
// input sizes: (s, b, h, d) or (b, s, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.size(0) : input.size(1);
const int b = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.size(1) : input.size(0);
const int h = input.size(2);
const int d = input.size(3);
// input strides
const int stride_s = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.stride(0) : input.stride(1);
const int stride_b = qkv_format == NVTE_QKV_Format::NVTE_SBHD ? input.stride(1) : input.stride(0);
const int stride_h = input.stride(2);
const int stride_d = input.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
TORCH_CHECK(s * cp_size <= freqs.size(0),
"expected freqs tensor has a longer sequence length than input");
TORCH_CHECK(d >= d2,
"expected the last dim of the input tensor equals or is "
"greater than the freqs tensor");
auto cu_seqlens_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor
nvte_fused_rope_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), output_cu.data(),
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s,
stride_b, stride_h, stride_d, at::cuda::getCurrentCUDAStream());
return output; return output;
} }
at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens, at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const at::Tensor &freqs, const int cp_size, const int cp_rank) { const NVTE_QKV_Format qkv_format, const bool interleaved,
const c10::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank) {
using namespace transformer_engine::pytorch; using namespace transformer_engine::pytorch;
TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor");
TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor"); TORCH_CHECK(freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1, TORCH_CHECK(freqs.size(1) == 1 && freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1"); "expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
auto act_options =
at::TensorOptions().dtype(output_grads.scalar_type()).device(output_grads.device());
auto input_grads = at::empty(output_grads.sizes(), act_options);
auto output_grads_cu = makeTransformerEngineTensor(output_grads);
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto input_grads_cu = makeTransformerEngineTensor(input_grads);
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor");
TORCH_CHECK(cu_seqlens.value().dim() == 1, "expected 1D tensor");
TORCH_CHECK(output_grads.size(2) >= freqs.size(3), TORCH_CHECK(output_grads.size(2) >= freqs.size(3),
"expected the last dim of the output_grads tensor equals or is " "expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor"); "greater than the freqs tensor");
TORCH_CHECK(freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// output_grads sizes: (t, h, d) // output_grads sizes: (t, h, d)
// t: cumulative sum of sequence lengths // t: cumulative sum of sequence lengths
// h: head num // h: head num
// d: dim of each head // d: dim of each head
const int t = output_grads.size(0); // const int t = output_grads.size(0);
const int h = output_grads.size(1); const int h = output_grads.size(1);
const int d = output_grads.size(2); const int d = output_grads.size(2);
// output_grads strides // output_grads strides
...@@ -198,25 +133,54 @@ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Ten ...@@ -198,25 +133,54 @@ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Ten
const int stride_h = output_grads.stride(1); const int stride_h = output_grads.stride(1);
const int stride_d = output_grads.stride(2); const int stride_d = output_grads.stride(2);
// batch size // batch size
const int b = cu_seqlens.size(0) - 1; const int b = cu_seqlens.value().size(0) - 1;
// freqs' shape is (max_s, 1, 1, d2) // freqs' shape is (max_s, 1, 1, d2)
const int max_s = freqs.size(0); const int max_s = freqs.size(0);
const int d2 = freqs.size(3); const int d2 = freqs.size(3);
auto act_options = output_grads.options().requires_grad(false); auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value());
auto input_grads = torch::empty({t, h, d}, act_options);
const int o_stride_t = input_grads.stride(0);
const int o_stride_h = input_grads.stride(1);
const int o_stride_d = input_grads.stride(2);
auto output_grads_cu = makeTransformerEngineTensor(output_grads); nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens); input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank,
auto freqs_cu = makeTransformerEngineTensor(freqs); max_s, b, h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d,
auto input_grads_cu = makeTransformerEngineTensor(input_grads); at::cuda::getCurrentCUDAStream());
return input_grads;
}
TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
// output_grads sizes: (s, b, h, d)
// s: sequence length
// b: batch size
// h: head num
// d: dim of each head
const int s =
qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.size(0) : output_grads.size(1);
const int b =
qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.size(1) : output_grads.size(0);
const int h = output_grads.size(2);
const int d = output_grads.size(3);
// output_grads strides
const int stride_s =
qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.stride(0) : output_grads.stride(1);
const int stride_b =
qkv_format == NVTE_QKV_Format::NVTE_SBHD ? output_grads.stride(1) : output_grads.stride(0);
const int stride_h = output_grads.stride(2);
const int stride_d = output_grads.stride(3);
// freqs' shape is always (s, 1, 1, d2), so the strides are same under
// different memory formats
const int d2 = freqs.size(3);
TORCH_CHECK(s * cp_size <= freqs.size(0),
"expected freqs tensor has a longer sequence length than output_grads");
TORCH_CHECK(d >= d2,
"expected the last dim of the output_grads tensor equals or is "
"greater than the freqs tensor");
nvte_fused_rope_thd_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), auto cu_seqlens_cu = transformer_engine::TensorWrapper(); // empty cu_seqlens tensor
input_grads_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b,
h, d, d2, stride_s, stride_b, stride_h, stride_d,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
return input_grads; return input_grads;
......
...@@ -229,10 +229,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -229,10 +229,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_backward", &fused_rope_backward, "Fused Apply RoPE BWD", m.def("fused_rope_backward", &fused_rope_backward, "Fused Apply RoPE BWD",
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_thd_forward", &fused_rope_thd_forward, "Fused Apply RoPE FWD for thd format",
py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_thd_backward", &fused_rope_thd_backward, "Fused Apply RoPE BWD for thd format",
py::call_guard<py::gil_scoped_release>());
// Misc // Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version", m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version",
......
...@@ -7,7 +7,12 @@ Rotary Position Embedding implementation of different types along with helper fu ...@@ -7,7 +7,12 @@ Rotary Position Embedding implementation of different types along with helper fu
""" """
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat
__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb"]
class RotaryPositionEmbedding(torch.nn.Module): class RotaryPositionEmbedding(torch.nn.Module):
...@@ -22,19 +27,24 @@ class RotaryPositionEmbedding(torch.nn.Module): ...@@ -22,19 +27,24 @@ class RotaryPositionEmbedding(torch.nn.Module):
seq_len_interpolation_factor: Optional[int] = None, seq_len_interpolation_factor: Optional[int] = None,
pretrained_max_position_embeddings: Optional[int] = None, pretrained_max_position_embeddings: Optional[int] = None,
rotary_base: float = 10000.0, rotary_base: float = 10000.0,
interleaved: bool = False,
): ):
""" """
Parameters Parameters
---------- ----------
dim: int dim: int
rotary embedding dimension Rotary embedding dimension.
rotary_percent: float rotary_percent: float, default = 1.0
Percent of rotary dimension to use for rotary position embeddings. Percent of rotary dimension to use for rotary position embeddings.
seq_len_interpolation_factor: int seq_len_interpolation_factor: int, default = None
if not None, discrete positions will be interpolated by this factor via the trick in If not None, discrete positions will be interpolated by this factor via the trick in
https://arxiv.org/abs/2306.15595 https://arxiv.org/abs/2306.15595
pretrained_max_position_embeddings: int pretrained_max_position_embeddings: int, default = None
pre-trained max_position_embeddings before position interpolation Pre-trained max_position_embeddings before position interpolation.
rotary_base: float, default = 10000.0
Base of the rotary position embedding.
interleaved: bool, default = False
Whether to use interleaved rotary position embedding.
""" """
super().__init__() super().__init__()
if rotary_percent < 1.0: if rotary_percent < 1.0:
...@@ -50,17 +60,18 @@ class RotaryPositionEmbedding(torch.nn.Module): ...@@ -50,17 +60,18 @@ class RotaryPositionEmbedding(torch.nn.Module):
) )
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq)
self.pretrained_max_position_embeddings = pretrained_max_position_embeddings self.pretrained_max_position_embeddings = pretrained_max_position_embeddings
self.interleaved = interleaved
def forward(self, max_seq_len: int, offset: int = 0): def forward(self, max_seq_len: int, offset: int = 0):
""" """
Create rotary position embedding frequencies Create rotary position embedding frequencies.
Parameters Parameters
---------- ----------
max_seq_len: int max_seq_len: int
sequence length of a sample Sequence length of a sample.
offset: int, default = 0 offset: int, default = 0
fixed offset for freqencies Fixed offset for frequencies.
""" """
seq = ( seq = (
torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
...@@ -84,7 +95,12 @@ class RotaryPositionEmbedding(torch.nn.Module): ...@@ -84,7 +95,12 @@ class RotaryPositionEmbedding(torch.nn.Module):
freqs = torch.einsum("i , j -> i j", seq, self.inv_freq) freqs = torch.einsum("i , j -> i j", seq, self.inv_freq)
# first part even vector components, second part odd vector components, # first part even vector components, second part odd vector components,
# 2 * dim in dimension size # 2 * dim in dimension size
if not self.interleaved:
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
else:
emb = torch.stack((freqs.view(-1, 1), freqs.view(-1, 1)), dim=-1).view(
freqs.shape[0], -1
)
# emb [seq_length, .., dim] # emb [seq_length, .., dim]
return emb.reshape(emb.size(0), 1, 1, emb.size(1)) return emb.reshape(emb.size(0), 1, 1, emb.size(1))
...@@ -104,61 +120,146 @@ class FusedRoPEFunc(torch.autograd.Function): ...@@ -104,61 +120,146 @@ class FusedRoPEFunc(torch.autograd.Function):
t: torch.Tensor, t: torch.Tensor,
freqs: torch.Tensor, freqs: torch.Tensor,
tensor_format: str = "sbhd", tensor_format: str = "sbhd",
interleaved: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None, cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1, cp_size: int = 1,
cp_rank: int = 0, cp_rank: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
# pylint: disable=missing-function-docstring """Fused RoPE forward."""
if freqs.dtype != torch.float32: if freqs.dtype != torch.float32:
freqs = freqs.float() freqs = freqs.float()
if tensor_format == "sbhd": assert tensor_format in (
output = tex.fused_rope_forward(t, freqs, False) "sbhd",
elif tensor_format == "bshd": "bshd",
output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1) "thd",
elif tensor_format == "thd": ), f"Unsupported tensor_format: {tensor_format}."
output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, cp_size, cp_rank) output = tex.fused_rope_forward(
else: t, freqs, QKVFormat[tensor_format], interleaved, cu_seqlens, cp_size, cp_rank
raise ValueError(f"Unsupported tensor_format: {tensor_format}.") )
ctx.save_for_backward(freqs, cu_seqlens) ctx.save_for_backward(freqs, cu_seqlens)
ctx.tensor_format = tensor_format ctx.tensor_format = tensor_format
ctx.cp_size = cp_size ctx.cp_size = cp_size
ctx.cp_rank = cp_rank ctx.cp_rank = cp_rank
ctx.interleaved = interleaved
return output return output
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring """Fused RoPE backward."""
freqs, cu_seqlens = ctx.saved_tensors freqs, cu_seqlens = ctx.saved_tensors
if ctx.tensor_format == "sbhd":
grad_input = tex.fused_rope_backward(grad_output, freqs, False)
elif ctx.tensor_format == "bshd":
grad_input = tex.fused_rope_backward( grad_input = tex.fused_rope_backward(
grad_output.transpose(0, 1), freqs, True grad_output,
).transpose(0, 1) freqs,
elif ctx.tensor_format == "thd": QKVFormat[ctx.tensor_format],
grad_input = tex.fused_rope_thd_backward( ctx.interleaved,
grad_output, cu_seqlens, freqs, ctx.cp_size, ctx.cp_rank cu_seqlens,
ctx.cp_size,
ctx.cp_rank,
) )
else:
raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")
return grad_input, None, None, None, None, None return grad_input, None, None, None, None, None, None
def _rotate_half(x: torch.Tensor) -> torch.Tensor: def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor:
""" """Change sign so the last dimension becomes [-odd, +even]
change sign so the last dimension becomes [-odd, +even]
Args:
x: torch.Tensor. Input tensor.
interleaved: bool. Whether to use interleaved rotary position embedding.
Returns:
Tensor: Tensor rotated half.
""" """
x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2))) if not interleaved:
x1, x2 = x.unbind(dim=-2) x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# interleaved
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x_new = torch.stack((-x2, x1), dim=-1)
return x_new.view(x_new.shape[0], x_new.shape[1], x_new.shape[2], -1)
def _apply_rotary_pos_emb_base(
t: torch.Tensor,
freqs: torch.Tensor,
tensor_format: str = "sbhd",
interleaved: bool = False,
) -> torch.Tensor:
"""
Base implementation of applying rotary positional embedding tensor to the input tensor.
Parameters
----------
t: torch.Tensor
Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional
embedding will be applied.
freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape
`[seq, bs, ...]`.
interleaved: bool, default = False
Whether to use interleaved rotary position embedding.
"""
max_seq_len = freqs.shape[0]
cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]
# Only apply the rotary embeddings up to the sequence length of the running
# input.
assert (
cur_seq_len <= max_seq_len
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = freqs[:cur_seq_len]
if tensor_format == "bshd":
freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim]
# cos/sin first then dtype conversion for better precision
cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype)
rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
t = (t * cos_) + (_rotate_half(t, interleaved) * sin_)
return torch.cat((t, t_pass), dim=-1)
def _get_freqs_on_this_cp_rank(
freqs: torch.Tensor, seqlen: int, cp_size: int, cp_rank: int
) -> torch.Tensor:
"""Get the position embedding on the current context parallel rank.
Args:
freqs: torch.Tensor. Positional embedding tensor in shape `[s2, 1, 1, d2]`.
seqlen: int. Length of the current sequence.
cp_size: int. Context parallel world size.
cp_rank: int. Context parallel rank.
"""
if cp_size > 1:
cp_seg = seqlen // 2
full_seqlen = cp_size * seqlen
return torch.cat(
[
freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg],
freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg],
]
)
# cp_size == 1
return freqs[:seqlen]
def apply_rotary_pos_emb( def apply_rotary_pos_emb(
t: torch.Tensor, t: torch.Tensor,
freqs: torch.Tensor, freqs: torch.Tensor,
tensor_format: str = "sbhd", tensor_format: str = "sbhd",
interleaved: bool = False,
fused: bool = False, fused: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None, cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1, cp_size: int = 1,
...@@ -175,11 +276,13 @@ def apply_rotary_pos_emb( ...@@ -175,11 +276,13 @@ def apply_rotary_pos_emb(
freqs: torch.Tensor freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`. with `s2 >= s` and `d2 <= d`.
fused: bool, default = False
Whether to use a fused applying RoPE implementation.
tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd' tensor_format: {'sbhd', 'bshd', 'thd'}, default = 'sbhd'
is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is is `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is
of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True. of shape `[seq, bs, ...]`. 'thd' is only supported when `fused` is True.
interleaved: bool, default = False
Whether to use interleaved rotary position embedding.
fused: bool, default = False
Whether to use a fused applying RoPE implementation.
cu_seqlens: torch.Tensor, default = None. cu_seqlens: torch.Tensor, default = None.
Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
dtype torch.int32. Only valid when `tensor_format` is 'thd'. dtype torch.int32. Only valid when `tensor_format` is 'thd'.
...@@ -189,37 +292,40 @@ def apply_rotary_pos_emb( ...@@ -189,37 +292,40 @@ def apply_rotary_pos_emb(
cp_rank: int, default = 0. cp_rank: int, default = 0.
Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True. Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
""" """
if fused:
assert ( assert (
tensor_format != "thd" or cu_seqlens is not None tensor_format != "thd" or cu_seqlens is not None
), "cu_seqlens must not be None when tensor_format is 'thd'." ), "cu_seqlens must not be None when tensor_format is 'thd'."
return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, cp_size, cp_rank)
assert tensor_format in ("sbhd", "bshd"), ( if fused:
"Only formats `sbhd` or `bshd` are supported for input tensor `t` " return FusedRoPEFunc.apply(
f"when fused is False, got {tensor_format}." t, freqs, tensor_format, interleaved, cu_seqlens, cp_size, cp_rank
) )
max_seq_len = freqs.shape[0] # Unfused THD format
cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] if tensor_format == "thd":
cu_seqlens = cu_seqlens // cp_size
# Only apply the rotary embeddings up to the sequence length of the running seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
# input. return torch.cat(
assert ( [
cur_seq_len <= max_seq_len _apply_rotary_pos_emb_base(
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" x.unsqueeze(1),
freqs = freqs[:cur_seq_len] _get_freqs_on_this_cp_rank(freqs, x.size(0), cp_size, cp_rank),
if tensor_format == "bshd": interleaved=interleaved,
freqs = freqs.transpose(0, 1) # [seq, 1, 1, dim] -> [1, seq, 1, dim] )
# cos/sin first then dtype conversion for better precision for x in torch.split(t, seqlens)
cos_ = torch.cos(freqs).to(t.dtype) ]
sin_ = torch.sin(freqs).to(t.dtype) ).squeeze(1)
rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
# first part is cosine component # Unfused SBHD/BSHD format
# second part is sine component, need to change signs with _rotate_half method if tensor_format == "sbhd":
t = (t * cos_) + (_rotate_half(t) * sin_) seqlen = t.size(0)
return torch.cat((t, t_pass), dim=-1) elif tensor_format == "bshd":
seqlen = t.size(1)
else:
raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
return _apply_rotary_pos_emb_base(
t,
_get_freqs_on_this_cp_rank(freqs, seqlen, cp_size, cp_rank),
tensor_format,
interleaved=interleaved,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment