Unverified Commit 55dcbb4b authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Let Fused RoPE support CP with THD format (#1238)



* Let Fused RoPE support THD with CP
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* add comment
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarXiaowei Ren <103958965+xrennvidia@users.noreply.github.com>
parent b36bd0a4
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import math
import pytest
import torch
from typing import Callable, Dict, Tuple, Union
from typing import Callable, Tuple, Union
from transformer_engine.pytorch.attention import (
RotaryPositionEmbedding,
apply_rotary_pos_emb,
)
def _get_thd_freqs_on_this_cp_rank(
cp_rank: int, cp_size: int, x: torch.Tensor, freqs: torch.Tensor
) -> torch.Tensor:
if cp_size > 1:
cp_seg = x.size(0) // 2
full_seqlen = cp_size * x.size(0)
return torch.cat(
[
freqs[cp_rank * cp_seg : (cp_rank + 1) * cp_seg],
freqs[full_seqlen - (cp_rank + 1) * cp_seg : full_seqlen - cp_rank * cp_seg],
]
)
else:
return freqs[: x.size(0)]
def apply_rotary_pos_emb_thd(
t: torch.Tensor, cu_seqlens: torch.Tensor, freqs: torch.Tensor
t: torch.Tensor,
cu_seqlens: torch.Tensor,
freqs: torch.Tensor,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""A baseline implementation of applying RoPE for `thd` format.
......@@ -24,20 +45,18 @@ def apply_rotary_pos_emb_thd(
Returns:
Tensor: Shape [t, h, d]. The input tensor after applying RoPE.
"""
cu_seqlens = cu_seqlens // cp_size
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return torch.cat(
[apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)]) for x in torch.split(t, seqlens)]
[
apply_rotary_pos_emb(
x.unsqueeze(1), _get_thd_freqs_on_this_cp_rank(cp_rank, cp_size, x, freqs)
)
for x in torch.split(t, seqlens)
]
).squeeze(1)
def get_tol(dtype: torch.dtype) -> Dict:
if dtype == torch.bfloat16:
return dict(atol=1e-2, rtol=1e-2)
elif dtype == torch.float16:
return dict(atol=1e-3, rtol=1e-3)
return dict(atol=1e-5, rtol=1.3e-6)
# Gradient is a broadcasted scalar
def _overlapping_grad(output: torch.Tensor) -> torch.Tensor:
return output.sum() * 2
......@@ -84,7 +103,11 @@ def test_fused_rope(
emb = rotary_pos_emb(seq_length)
# unfused
output_unfused = apply_rotary_pos_emb(t, emb, tensor_format=tensor_format, fused=False)
# The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison
output_unfused = apply_rotary_pos_emb(
t.float(), emb, tensor_format=tensor_format, fused=False
).to(dtype)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
......@@ -102,8 +125,8 @@ def test_fused_rope(
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(output_fused, output_unfused, **get_tol(dtype))
torch.testing.assert_close(grad_fused, grad_unfused, **get_tol(dtype))
torch.testing.assert_close(output_fused, output_unfused)
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()
......@@ -112,22 +135,34 @@ def test_fused_rope(
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("transpose", [None, (1, 2)])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
@pytest.mark.parametrize("cp_size", [1, 2, 3])
def test_fused_rope_thd(
dtype: torch.dtype,
hidden_size: int,
rotary_percent: float,
transpose: Union[Tuple, None],
loss_func: Callable,
cp_size: int,
) -> None:
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
cu_seqlens = torch.tensor(
[0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048],
cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048]
if cp_size > 1:
cu_seqlens_padded = [0]
for i in range(1, len(cu_seqlens)):
cu_seqlens_padded.append(
cu_seqlens_padded[i - 1]
+ math.ceil((cu_seqlens[i] - cu_seqlens[i - 1]) / (cp_size * 2)) * (cp_size * 2)
)
else:
cu_seqlens_padded = cu_seqlens
cu_seqlens_padded = torch.tensor(
cu_seqlens_padded,
dtype=torch.int32,
device=device,
)
t = torch.rand(
(cu_seqlens[-1], head_num, hidden_size),
(cu_seqlens_padded[-1] // cp_size, head_num, hidden_size),
dtype=dtype,
device=device,
)
......@@ -136,23 +171,34 @@ def test_fused_rope_thd(
t.requires_grad = True
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent)
emb = rotary_pos_emb(cu_seqlens[-1])
# unfused
output_unfused = apply_rotary_pos_emb_thd(t, cu_seqlens, emb)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None
# fused
output_fused = apply_rotary_pos_emb(
t, emb, fused=True, tensor_format="thd", cu_seqlens=cu_seqlens
)
loss_fused = loss_func(output_fused)
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(output_fused, output_unfused, **get_tol(dtype))
torch.testing.assert_close(grad_fused, grad_unfused, **get_tol(dtype))
emb = rotary_pos_emb(cu_seqlens_padded[-1])
for cp_rank in range(cp_size):
# unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison
output_unfused = apply_rotary_pos_emb_thd(
t.float(), cu_seqlens_padded, emb, cp_size, cp_rank
).to(dtype)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None
# fused
output_fused = apply_rotary_pos_emb(
t,
emb,
fused=True,
tensor_format="thd",
cu_seqlens=cu_seqlens_padded,
cp_size=cp_size,
cp_rank=cp_rank,
)
loss_fused = loss_func(output_fused)
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(output_fused, output_unfused)
torch.testing.assert_close(grad_fused, grad_unfused)
......@@ -4,6 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include <assert.h>
#include <cuda_runtime.h>
#include <transformer_engine/fused_rope.h>
......@@ -15,11 +16,10 @@ namespace transformer_engine {
template <typename scalar_t>
__device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *dst,
const int offset_block, const int offset_block_dst,
const int h, const int d, const int d2, const int stride_h,
const int stride_d, const int o_stride_h,
const int o_stride_d) {
int s_id = blockIdx.x;
const int s_id, const int offset_block,
const int offset_block_dst, const int h, const int d,
const int d2, const int stride_h, const int stride_d,
const int o_stride_h, const int o_stride_d) {
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
float v_cos, v_sin;
......@@ -52,11 +52,10 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs
template <typename scalar_t>
__device__ void fused_rope_block_backward(const scalar_t *src, const float *freqs, scalar_t *dst,
const int offset_block, const int offset_block_dst,
const int h, const int d, const int d2,
const int stride_h, const int stride_d,
const int s_id, const int offset_block,
const int offset_block_dst, const int h, const int d,
const int d2, const int stride_h, const int stride_d,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x;
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
float v_cos = cosf(freqs[s_id * d2 + d_id]);
......@@ -97,8 +96,8 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const float *freq
int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = s_id * stride_s + b_id * stride_b;
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h,
stride_d, o_stride_h, o_stride_d);
fused_rope_block_forward(src, freqs, dst, s_id, offset_block, offset_block_dst, h, d, d2,
stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
......@@ -111,40 +110,72 @@ __global__ void fused_rope_backward_kernel(const scalar_t *src, const float *fre
int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = s_id * stride_s + b_id * stride_b;
int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b;
fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h,
stride_d, o_stride_h, o_stride_d);
fused_rope_block_backward(src, freqs, dst, s_id, offset_block, offset_block_dst, h, d, d2,
stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
__global__ void fused_rope_thd_forward_kernel(const scalar_t *src, const int *cu_seqlens,
const float *freqs, scalar_t *dst, const int h,
const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d,
const int o_stride_t, const int o_stride_h,
const int o_stride_d) {
const float *freqs, scalar_t *dst, const int cp_size,
const int cp_rank, const int h, const int d,
const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int t_id = s_id + cu_seqlens[b_id];
if (t_id >= cu_seqlens[b_id + 1]) return;
int start = cu_seqlens[b_id] / cp_size;
int end = cu_seqlens[b_id + 1] / cp_size;
int t_id = s_id + start;
if (t_id >= end) return;
int offset_block = t_id * stride_t;
int offset_block_dst = t_id * o_stride_t;
fused_rope_block_forward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h,
stride_d, o_stride_h, o_stride_d);
int s_id_for_freqs;
if (cp_size > 1) {
int cur_seqlens = end - start;
assert(cur_seqlens % 2 == 0);
if (s_id < cur_seqlens / 2) {
s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2;
} else {
s_id_for_freqs =
cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2;
}
} else {
s_id_for_freqs = s_id;
}
fused_rope_block_forward(src, freqs, dst, s_id_for_freqs, offset_block, offset_block_dst, h, d,
d2, stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
__global__ void fused_rope_thd_backward_kernel(const scalar_t *src, const int *cu_seqlens,
const float *freqs, scalar_t *dst, const int h,
const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d,
const int o_stride_t, const int o_stride_h,
const int o_stride_d) {
const float *freqs, scalar_t *dst, const int cp_size,
const int cp_rank, const int h, const int d,
const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int t_id = s_id + cu_seqlens[b_id];
if (t_id >= cu_seqlens[b_id + 1]) return;
int start = cu_seqlens[b_id] / cp_size;
int end = cu_seqlens[b_id + 1] / cp_size;
int t_id = s_id + start;
if (t_id >= end) return;
int offset_block = t_id * stride_t;
int offset_block_dst = t_id * o_stride_t;
fused_rope_block_backward(src, freqs, dst, offset_block, offset_block_dst, h, d, d2, stride_h,
stride_d, o_stride_h, o_stride_d);
int s_id_for_freqs;
if (cp_size > 1) {
int cur_seqlens = end - start;
assert(cur_seqlens % 2 == 0);
if (s_id < cur_seqlens / 2) {
s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2;
} else {
s_id_for_freqs =
cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2;
}
} else {
s_id_for_freqs = s_id;
}
fused_rope_block_backward(src, freqs, dst, s_id_for_freqs, offset_block, offset_block_dst, h, d,
d2, stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
......@@ -182,35 +213,37 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const float *fre
template <typename scalar_t>
void fused_rope_thd_forward_launcher(const scalar_t *input, const int *cu_seqlens,
const float *freqs, scalar_t *output, const int max_s,
const int b, const int h, const int d, const int d2,
const int stride_t, const int stride_h, const int stride_d,
const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
const float *freqs, scalar_t *output, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(max_s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_thd_forward_kernel<<<blocks, threads, 0, stream>>>(input, cu_seqlens, freqs, output, h,
d, d2, stride_t, stride_h, stride_d,
o_stride_t, o_stride_h, o_stride_d);
fused_rope_thd_forward_kernel<<<blocks, threads, 0, stream>>>(
input, cu_seqlens, freqs, output, cp_size, cp_rank, h, d, d2, stride_t, stride_h, stride_d,
o_stride_t, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename scalar_t>
void fused_rope_thd_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens,
const float *freqs, scalar_t *input_grads, const int max_s,
const int b, const int h, const int d, const int d2,
const int stride_t, const int stride_h, const int stride_d,
const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
const float *freqs, scalar_t *input_grads, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(max_s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
fused_rope_thd_backward_kernel<<<blocks, threads, 0, stream>>>(
output_grads, cu_seqlens, freqs, input_grads, h, d, d2, stride_t, stride_h, stride_d,
o_stride_t, o_stride_h, o_stride_d);
output_grads, cu_seqlens, freqs, input_grads, cp_size, cp_rank, h, d, d2, stride_t, stride_h,
stride_d, o_stride_t, o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
......@@ -243,33 +276,34 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &freqs, Tensor
}
void fused_rope_thd_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs,
Tensor *output, const int max_s, const int b, const int h, const int d,
const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
Tensor *output, const int cp_size, const int cp_rank, const int max_s,
const int b, const int h, const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, scalar_t,
fused_rope_thd_forward_launcher(reinterpret_cast<const scalar_t *>(input.data.dptr),
reinterpret_cast<const int *>(cu_seqlens.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(output->data.dptr), max_s, b, h,
d, d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h,
o_stride_d, stream););
reinterpret_cast<scalar_t *>(output->data.dptr), cp_size,
cp_rank, max_s, b, h, d, d2, stride_t, stride_h, stride_d,
o_stride_t, o_stride_h, o_stride_d, stream););
}
void fused_rope_thd_backward(const Tensor &output_grads, const Tensor &cu_seqlens,
const Tensor &freqs, Tensor *input_grads, const int max_s, const int b,
const int h, const int d, const int d2, const int stride_t,
const int stride_h, const int stride_d, const int o_stride_t,
const int o_stride_h, const int o_stride_d, cudaStream_t stream) {
const Tensor &freqs, Tensor *input_grads, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
output_grads.data.dtype, scalar_t,
fused_rope_thd_backward_launcher(reinterpret_cast<const scalar_t *>(output_grads.data.dptr),
reinterpret_cast<const int *>(cu_seqlens.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<scalar_t *>(input_grads->data.dptr), max_s,
b, h, d, d2, stride_t, stride_h, stride_d, o_stride_t,
o_stride_h, o_stride_d, stream););
reinterpret_cast<scalar_t *>(input_grads->data.dptr),
cp_size, cp_rank, max_s, b, h, d, d2, stride_t, stride_h,
stride_d, o_stride_t, o_stride_h, o_stride_d, stream););
}
} // end namespace transformer_engine
......@@ -302,30 +336,31 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor fr
}
void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output, const int max_s,
const int b, const int h, const int d, const int d2,
const int stride_t, const int stride_h, const int stride_d,
const int o_stride_t, const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
const NVTETensor freqs, NVTETensor output, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_thd_forward);
using namespace transformer_engine;
fused_rope_thd_forward(
*reinterpret_cast<const Tensor *>(input), *reinterpret_cast<const Tensor *>(cu_seqlens),
*reinterpret_cast<const Tensor *>(freqs), reinterpret_cast<Tensor *>(output), max_s, b, h, d,
d2, stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream);
fused_rope_thd_forward(*reinterpret_cast<const Tensor *>(input),
*reinterpret_cast<const Tensor *>(cu_seqlens),
*reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(output), cp_size, cp_rank, max_s, b, h, d, d2,
stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream);
}
void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads, const int max_s,
const int b, const int h, const int d, const int d2,
const int stride_t, const int stride_h, const int stride_d,
const int o_stride_t, const int o_stride_h, const int o_stride_d,
cudaStream_t stream) {
const NVTETensor freqs, NVTETensor input_grads, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_thd_backward);
using namespace transformer_engine;
fused_rope_thd_backward(*reinterpret_cast<const Tensor *>(output_grads),
*reinterpret_cast<const Tensor *>(cu_seqlens),
*reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(input_grads), max_s, b, h, d, d2, stride_t,
stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream);
fused_rope_thd_backward(
*reinterpret_cast<const Tensor *>(output_grads),
*reinterpret_cast<const Tensor *>(cu_seqlens), *reinterpret_cast<const Tensor *>(freqs),
reinterpret_cast<Tensor *>(input_grads), cp_size, cp_rank, max_s, b, h, d, d2, stride_t,
stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d, stream);
}
......@@ -72,6 +72,8 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor fr
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* \param[in] freqs The freqs tensor.
* \param[out] output Output tensor.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] max_s Max sequence length.
* \param[in] b Batch size.
* \param[in] h Length of the h dimension of input.
......@@ -86,11 +88,11 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor fr
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor output, const int max_s,
const int b, const int h, const int d, const int d2,
const int stride_t, const int stride_h, const int stride_d,
const int o_stride_t, const int o_stride_h, const int o_stride_d,
cudaStream_t stream);
const NVTETensor freqs, NVTETensor output, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream);
/*! \brief Compute the backward of the fused rope in thd format.
*
......@@ -98,6 +100,8 @@ void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seq
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* \param[in] freqs The freqs tensor.
* \param[out] input_grads Input gradient to calculate.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] max_s Max sequence length.
* \param[in] b Batch size.
* \param[in] h Length of the h dimension of output_grads.
......@@ -112,11 +116,11 @@ void nvte_fused_rope_thd_forward(const NVTETensor input, const NVTETensor cu_seq
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_thd_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads, const int max_s,
const int b, const int h, const int d, const int d2,
const int stride_t, const int stride_h, const int stride_d,
const int o_stride_t, const int o_stride_h, const int o_stride_d,
cudaStream_t stream);
const NVTETensor freqs, NVTETensor input_grads, const int cp_size,
const int cp_rank, const int max_s, const int b, const int h,
const int d, const int d2, const int stride_t, const int stride_h,
const int stride_d, const int o_stride_t, const int o_stride_h,
const int o_stride_d, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
......
......@@ -4305,6 +4305,8 @@ class FusedRoPEFunc(torch.autograd.Function):
freqs: torch.Tensor,
tensor_format: str = "sbhd",
cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
if freqs.dtype != torch.float32:
freqs = freqs.float()
......@@ -4313,11 +4315,13 @@ class FusedRoPEFunc(torch.autograd.Function):
elif tensor_format == "bshd":
output = tex.fused_rope_forward(t.transpose(0, 1), freqs, True).transpose(0, 1)
elif tensor_format == "thd":
output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs)
output = tex.fused_rope_thd_forward(t, cu_seqlens, freqs, cp_size, cp_rank)
else:
raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
ctx.save_for_backward(freqs, cu_seqlens)
ctx.tensor_format = tensor_format
ctx.cp_size = cp_size
ctx.cp_rank = cp_rank
return output
......@@ -4331,11 +4335,13 @@ class FusedRoPEFunc(torch.autograd.Function):
grad_output.transpose(0, 1), freqs, True
).transpose(0, 1)
elif ctx.tensor_format == "thd":
grad_input = tex.fused_rope_thd_backward(grad_output, cu_seqlens, freqs)
grad_input = tex.fused_rope_thd_backward(
grad_output, cu_seqlens, freqs, ctx.cp_size, ctx.cp_rank
)
else:
raise ValueError(f"Unsupported tensor_format: {ctx.tensor_format}.")
return grad_input, None, None, None, None
return grad_input, None, None, None, None, None
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
......@@ -4353,6 +4359,8 @@ def apply_rotary_pos_emb(
tensor_format: str = "sbhd",
fused: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""
Apply rotary positional embedding tensor to the input tensor.
......@@ -4373,12 +4381,17 @@ def apply_rotary_pos_emb(
cu_seqlens: torch.Tensor, default = None.
Cumulative sum of sequence lengths in a batch for `t`, with shape [b + 1] and
dtype torch.int32. Only valid when `tensor_format` is 'thd'.
Should be `cu_seqlens_padded` when cp_size > 1.
cp_size: int, default = 1.
Context parallel world size. Only valid when `tensor_format` is 'thd' and `fused` is True.
cp_rank: int, default = 0.
Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
"""
if fused:
assert (
tensor_format != "thd" or cu_seqlens is not None
), "cu_seqlens must not be None when tensor_format is 'thd'."
return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens)
return FusedRoPEFunc.apply(t, freqs, tensor_format, cu_seqlens, cp_size, cp_rank)
assert tensor_format in ("sbhd", "bshd"), (
"Only formats `sbhd` or `bshd` are supported for input tensor `t` "
......
......@@ -412,10 +412,10 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
const bool transpose_output_memory);
at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens,
const at::Tensor &freqs);
const at::Tensor &freqs, const int cp_size, const int cp_rank);
at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens,
const at::Tensor &freqs);
const at::Tensor &freqs, const int cp_size, const int cp_rank);
/***************************************************************************************************
* Miscellaneous
......
......@@ -121,7 +121,7 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
}
at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_seqlens,
const at::Tensor &freqs) {
const at::Tensor &freqs, const int cp_size, const int cp_rank) {
using namespace transformer_engine;
TORCH_CHECK(input.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor");
......@@ -165,14 +165,15 @@ at::Tensor fused_rope_thd_forward(const at::Tensor &input, const at::Tensor &cu_
auto output_cu = makeTransformerEngineTensor(output);
nvte_fused_rope_thd_forward(input_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
output_cu.data(), max_s, b, h, d, d2, stride_t, stride_h, stride_d,
o_stride_t, o_stride_h, o_stride_d, at::cuda::getCurrentCUDAStream());
output_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2, stride_t,
stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d,
at::cuda::getCurrentCUDAStream());
return output;
}
at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Tensor &cu_seqlens,
const at::Tensor &freqs) {
const at::Tensor &freqs, const int cp_size, const int cp_rank) {
using namespace transformer_engine;
TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.dim() == 1, "expected 1D tensor");
......@@ -214,8 +215,8 @@ at::Tensor fused_rope_thd_backward(const at::Tensor &output_grads, const at::Ten
auto input_grads_cu = makeTransformerEngineTensor(input_grads);
nvte_fused_rope_thd_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
input_grads_cu.data(), max_s, b, h, d, d2, stride_t, stride_h,
stride_d, o_stride_t, o_stride_h, o_stride_d,
input_grads_cu.data(), cp_size, cp_rank, max_s, b, h, d, d2,
stride_t, stride_h, stride_d, o_stride_t, o_stride_h, o_stride_d,
at::cuda::getCurrentCUDAStream());
return input_grads;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment