Unverified Commit e4bfa628 authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

[Feature] Enable rope application with offsets for training (#2188)



* enable applying rope offsets in backwared
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add tests for rope offsets for thd/bshd/sbhd formats
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* minor fixes
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent f8693d2b
...@@ -58,10 +58,6 @@ def test_fused_rope( ...@@ -58,10 +58,6 @@ def test_fused_rope(
# are with the maximum length of the rope embeddings. # are with the maximum length of the rope embeddings.
pytest.skip("Skipping test with margin=0 and start_positions=True") pytest.skip("Skipping test with margin=0 and start_positions=True")
if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")
device = torch.device("cuda:0") device = torch.device("cuda:0")
batch_size, head_num = 2, 64 batch_size, head_num = 2, 64
t = torch.rand( t = torch.rand(
...@@ -102,11 +98,8 @@ def test_fused_rope( ...@@ -102,11 +98,8 @@ def test_fused_rope(
cp_rank=cp_rank, cp_rank=cp_rank,
).to(dtype) ).to(dtype)
loss_unfused = loss_func(output_unfused) loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
if not isinstance(start_positions, torch.Tensor): grad_unfused = t.grad.detach().clone()
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None t.grad = None
# fused # fused
...@@ -121,17 +114,12 @@ def test_fused_rope( ...@@ -121,17 +114,12 @@ def test_fused_rope(
cp_rank=cp_rank, cp_rank=cp_rank,
) )
loss_fused = loss_func(output_fused) loss_fused = loss_func(output_fused)
loss_fused.backward()
if not isinstance(start_positions, torch.Tensor): grad_fused = t.grad.detach().clone()
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None t.grad = None
torch.testing.assert_close(output_fused, output_unfused) torch.testing.assert_close(output_fused, output_unfused)
torch.testing.assert_close(grad_fused, grad_unfused)
if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous() assert output_fused.is_contiguous()
...@@ -156,10 +144,6 @@ def test_fused_rope_thd( ...@@ -156,10 +144,6 @@ def test_fused_rope_thd(
margin: int, margin: int,
) -> None: ) -> None:
if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")
device = torch.device("cuda:0") device = torch.device("cuda:0")
batch_size, head_num = 2, 64 batch_size, head_num = 2, 64
cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048] cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048]
...@@ -214,10 +198,8 @@ def test_fused_rope_thd( ...@@ -214,10 +198,8 @@ def test_fused_rope_thd(
cp_rank=cp_rank, cp_rank=cp_rank,
).to(dtype) ).to(dtype)
loss_unfused = loss_func(output_unfused) loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
if not isinstance(start_positions, torch.Tensor): grad_unfused = t.grad.detach().clone()
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None t.grad = None
# fused # fused
...@@ -233,18 +215,142 @@ def test_fused_rope_thd( ...@@ -233,18 +215,142 @@ def test_fused_rope_thd(
cp_rank=cp_rank, cp_rank=cp_rank,
) )
loss_fused = loss_func(output_fused) loss_fused = loss_func(output_fused)
loss_fused.backward()
if not isinstance(start_positions, torch.Tensor): grad_fused = t.grad.detach().clone()
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None t.grad = None
torch.testing.assert_close(output_fused, output_unfused) torch.testing.assert_close(output_fused, output_unfused)
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()
if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous() @pytest.mark.parametrize("start_positions", [False, True])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [1.0])
@pytest.mark.parametrize("loss_func", [_overlapping_grad])
@pytest.mark.parametrize("cp_size", [2])
@pytest.mark.parametrize("interleaved", [False, True])
def test_unfused_rope_thd_vs_bshd(
dtype: torch.dtype,
hidden_size: int,
rotary_percent: float,
loss_func: Callable,
cp_size: int,
interleaved: bool,
start_positions: bool,
) -> None:
"""
This is just a sanity check to ensure that the unfused RoPE in THD/SBHD/BSHD
formats are the same.
"""
device = torch.device("cuda:0")
seqlen, max_seqlen = 16, 2048
batch_size, head_num = 4, 256
# NOTE: dtype=torch.int32 is important, otherwise the cumsum will be in int64 and
# that causes unexpected issues.
seq_lens = torch.tensor([seqlen for _ in range(batch_size)], dtype=torch.int32)
cu_seqlens = torch.cumsum(torch.cat([torch.zeros(1, dtype=torch.int32), seq_lens]), dim=0).to(
device=device, dtype=torch.int32
)
# Create a tensor in THD format
thd = torch.rand(
(cu_seqlens[-1] // cp_size, head_num, hidden_size),
dtype=dtype,
device=device,
)
thd.requires_grad = True
# Clone the tensor to create a tensor in BSHD format
bshd = thd.view(batch_size, -1, head_num, hidden_size).clone().detach()
bshd = bshd.to(dtype=dtype, device=device)
bshd.requires_grad = True
# Clone the tensor to create a tensor in SBHD format
sbhd = bshd.transpose(1, 0).clone().detach()
sbhd = sbhd.to(dtype=dtype, device=device)
sbhd.requires_grad = True
rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb = rotary_pos_emb(max_seqlen)
assert emb.is_contiguous()
start_positions = cu_seqlens[:-1] if start_positions else None
for cp_rank in range(cp_size):
# unfused bshd
output_unfused_bshd = apply_rotary_pos_emb(
bshd.float(),
emb,
start_positions=start_positions,
interleaved=interleaved,
fused=False,
tensor_format="bshd",
cu_seqlens=cu_seqlens,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
loss_unfused_bshd = loss_func(output_unfused_bshd)
loss_unfused_bshd.backward()
grad_unfused_bshd = bshd.grad.detach().clone()
bshd.grad = None
# unfused sbhd
output_unfused_sbhd = apply_rotary_pos_emb(
sbhd.float(),
emb,
start_positions=start_positions,
interleaved=interleaved,
fused=False,
tensor_format="sbhd",
cu_seqlens=cu_seqlens,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
loss_unfused_sbhd = loss_func(output_unfused_sbhd)
loss_unfused_sbhd.backward()
grad_unfused_sbhd = sbhd.grad.detach().clone()
sbhd.grad = None
# unfused thd
output_unfused_thd = apply_rotary_pos_emb(
thd.float(),
emb,
start_positions=start_positions,
tensor_format="thd",
interleaved=interleaved,
fused=False,
cu_seqlens=cu_seqlens,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
loss_unfused_thd = loss_func(output_unfused_thd)
loss_unfused_thd.backward()
grad_unfused_thd = thd.grad.detach().clone()
thd.grad = None
torch.testing.assert_close(
output_unfused_bshd.reshape(*output_unfused_thd.shape), output_unfused_thd
)
torch.testing.assert_close(
output_unfused_sbhd.transpose(1, 0).reshape(*output_unfused_thd.shape),
output_unfused_thd,
)
torch.testing.assert_close(
grad_unfused_bshd.reshape(*grad_unfused_thd.shape), grad_unfused_thd
)
torch.testing.assert_close(
grad_unfused_sbhd.transpose(1, 0).reshape(*grad_unfused_thd.shape), grad_unfused_thd
)
assert output_unfused_thd.is_contiguous()
assert output_unfused_bshd.is_contiguous()
assert output_unfused_sbhd.is_contiguous()
@pytest.mark.parametrize("start_positions", [True, False]) @pytest.mark.parametrize("start_positions", [True, False])
......
...@@ -155,18 +155,18 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq ...@@ -155,18 +155,18 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq
cur_seqlens = s; cur_seqlens = s;
} }
int s_id_for_freqs; // Offset the RoPE embedding by start_positions if provided.
int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id];
int s_id_for_freqs = s_id + begin_offset;
// If CP_SIZE > 1, offset the RoPE embedding by cp_rank based on the dual-chunk order.
if (cp_size > 1) { if (cp_size > 1) {
assert(cur_seqlens % 2 == 0); assert(cur_seqlens % 2 == 0);
if (s_id < cur_seqlens / 2) { if (s_id < cur_seqlens / 2) {
s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; s_id_for_freqs += cp_rank * cur_seqlens / 2;
} else { } else {
s_id_for_freqs = s_id_for_freqs += cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 - cur_seqlens / 2;
cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2;
} }
} else {
int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id];
s_id_for_freqs = s_id + begin_offset;
} }
fused_rope_block_forward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block, fused_rope_block_forward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block,
...@@ -175,11 +175,11 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq ...@@ -175,11 +175,11 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq
template <typename scalar_t> template <typename scalar_t>
__global__ void fused_rope_backward_kernel( __global__ void fused_rope_backward_kernel(
const scalar_t *src, const int *cu_seqlens, const float *freqs, scalar_t *dst, const scalar_t *src, const int *cu_seqlens, const float *freqs, const int *start_positions,
const bool interleaved, const int cp_size, const int cp_rank, const int s, const int h, scalar_t *dst, const bool interleaved, const int cp_size, const int cp_rank, const int s,
const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b,
const int stride_d, const int o_stride_s_or_t, const int o_stride_b, const int o_stride_h, const int stride_h, const int stride_d, const int o_stride_s_or_t, const int o_stride_b,
const int o_stride_d) { const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y; int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block, offset_block_dst; int offset_block, offset_block_dst;
int cur_seqlens; int cur_seqlens;
...@@ -197,17 +197,18 @@ __global__ void fused_rope_backward_kernel( ...@@ -197,17 +197,18 @@ __global__ void fused_rope_backward_kernel(
cur_seqlens = s; cur_seqlens = s;
} }
int s_id_for_freqs; // Offset the RoPE embedding by start_positions if provided.
int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id];
int s_id_for_freqs = s_id + begin_offset;
// If CP_SIZE > 1, offset the RoPE embedding by cp_rank based on the dual-chunk order.
if (cp_size > 1) { if (cp_size > 1) {
assert(cur_seqlens % 2 == 0); assert(cur_seqlens % 2 == 0);
if (s_id < cur_seqlens / 2) { if (s_id < cur_seqlens / 2) {
s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; s_id_for_freqs += cp_rank * cur_seqlens / 2;
} else { } else {
s_id_for_freqs = s_id_for_freqs += cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 - cur_seqlens / 2;
cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2;
} }
} else {
s_id_for_freqs = s_id;
} }
fused_rope_block_backward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block, fused_rope_block_backward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block,
...@@ -495,12 +496,12 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c ...@@ -495,12 +496,12 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c
template <typename scalar_t> template <typename scalar_t>
void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens, void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens,
const float *freqs, scalar_t *input_grads, const float *freqs, const int *start_positions,
const NVTE_QKV_Format qkv_format, const bool interleaved, scalar_t *input_grads, const NVTE_QKV_Format qkv_format,
const int cp_size, const int cp_rank, const int s, const int b, const bool interleaved, const int cp_size, const int cp_rank,
const int h, const int d, const int d2, const int stride_s_or_t, const int s, const int b, const int h, const int d, const int d2,
const int stride_b, const int stride_h, const int stride_d, const int stride_s_or_t, const int stride_b, const int stride_h,
cudaStream_t stream) { const int stride_d, cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8; int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b); dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block); dim3 threads(THREADS_PER_WARP, warps_per_block);
...@@ -521,9 +522,9 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se ...@@ -521,9 +522,9 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se
const int o_stride_d = 1; const int o_stride_d = 1;
fused_rope_backward_kernel<<<blocks, threads, shared_mem_size, stream>>>( fused_rope_backward_kernel<<<blocks, threads, shared_mem_size, stream>>>(
output_grads, cu_seqlens, freqs, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2, output_grads, cu_seqlens, freqs, start_positions, input_grads, interleaved, cp_size, cp_rank,
stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, s, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b,
o_stride_d); o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
...@@ -590,16 +591,18 @@ void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Ten ...@@ -590,16 +591,18 @@ void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Ten
} }
void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, const Tensor &freqs, void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, const Tensor &freqs,
Tensor *input_grads, const NVTE_QKV_Format qkv_format, const Tensor &start_positions, Tensor *input_grads,
const bool interleaved, const int cp_size, const int cp_rank, const int s, const NVTE_QKV_Format qkv_format, const bool interleaved,
const int b, const int h, const int d, const int d2, const int cp_size, const int cp_rank, const int s, const int b,
const int stride_s_or_t, const int stride_b, const int stride_h, const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_d, cudaStream_t stream) { const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
output_grads.data.dtype, scalar_t, output_grads.data.dtype, scalar_t,
fused_rope_backward_launcher(reinterpret_cast<const scalar_t *>(output_grads.data.dptr), fused_rope_backward_launcher(reinterpret_cast<const scalar_t *>(output_grads.data.dptr),
reinterpret_cast<const int *>(cu_seqlens.data.dptr), reinterpret_cast<const int *>(cu_seqlens.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr), reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<const int *>(start_positions.data.dptr),
reinterpret_cast<scalar_t *>(input_grads->data.dptr), qkv_format, reinterpret_cast<scalar_t *>(input_grads->data.dptr), qkv_format,
interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
stride_b, stride_h, stride_d, stream);); stride_b, stride_h, stride_d, stream););
...@@ -663,18 +666,18 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens ...@@ -663,18 +666,18 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens
} }
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads, const NVTETensor freqs, const NVTETensor start_positions,
const NVTE_QKV_Format qkv_format, const bool interleaved, NVTETensor input_grads, const NVTE_QKV_Format qkv_format,
const int cp_size, const int cp_rank, const int s, const int b, const bool interleaved, const int cp_size, const int cp_rank,
const int h, const int d, const int d2, const int stride_s_or_t, const int s, const int b, const int h, const int d, const int d2,
const int stride_b, const int stride_h, const int stride_d, const int stride_s_or_t, const int stride_b, const int stride_h,
cudaStream_t stream) { const int stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_backward); NVTE_API_CALL(nvte_fused_rope_backward);
using namespace transformer_engine; using namespace transformer_engine;
fused_rope_backward(*convertNVTETensorCheck(output_grads), *convertNVTETensorCheck(cu_seqlens), fused_rope_backward(*convertNVTETensorCheck(output_grads), *convertNVTETensorCheck(cu_seqlens),
*convertNVTETensorCheck(freqs), convertNVTETensorCheck(input_grads), *convertNVTETensorCheck(freqs), *convertNVTETensorCheck(start_positions),
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, convertNVTETensorCheck(input_grads), qkv_format, interleaved, cp_size,
stride_b, stride_h, stride_d, stream); cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream);
} }
void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs, void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs,
......
...@@ -51,6 +51,7 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens ...@@ -51,6 +51,7 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* (Required for the thd format, empty tensor for other formats) * (Required for the thd format, empty tensor for other formats)
* \param[in] freqs The freqs tensor. * \param[in] freqs The freqs tensor.
* \param[in] start_positions The beginning offsets for applying RoPE embeddings.
* \param[out] input_grads Input gradient tensor to calculate. * \param[out] input_grads Input gradient tensor to calculate.
* \param[in] qkv_format QKV format. * \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding. * \param[in] interleaved Whether to use interleaved rotary position embedding.
...@@ -68,12 +69,12 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens ...@@ -68,12 +69,12 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads, const NVTETensor freqs, const NVTETensor start_positions,
const NVTE_QKV_Format qkv_format, const bool interleaved, NVTETensor input_grads, const NVTE_QKV_Format qkv_format,
const int cp_size, const int cp_rank, const int s, const int b, const bool interleaved, const int cp_size, const int cp_rank,
const int h, const int d, const int d2, const int stride_s_or_t, const int s, const int b, const int h, const int d, const int d2,
const int stride_b, const int stride_h, const int stride_d, const int stride_s_or_t, const int stride_b, const int stride_h,
cudaStream_t stream); const int stride_d, cudaStream_t stream);
/*! \brief Apply rotary positional embedding to the combined QKV input tensor. /*! \brief Apply rotary positional embedding to the combined QKV input tensor.
* *
......
...@@ -149,7 +149,7 @@ class FusedRoPEFunc(torch.autograd.Function): ...@@ -149,7 +149,7 @@ class FusedRoPEFunc(torch.autograd.Function):
cp_size, cp_size,
cp_rank, cp_rank,
) )
ctx.save_for_backward(freqs, cu_seqlens) ctx.save_for_backward(freqs, cu_seqlens, start_positions)
ctx.tensor_format = tensor_format ctx.tensor_format = tensor_format
ctx.cp_size = cp_size ctx.cp_size = cp_size
ctx.cp_rank = cp_rank ctx.cp_rank = cp_rank
...@@ -160,10 +160,11 @@ class FusedRoPEFunc(torch.autograd.Function): ...@@ -160,10 +160,11 @@ class FusedRoPEFunc(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused RoPE backward.""" """Fused RoPE backward."""
freqs, cu_seqlens = ctx.saved_tensors freqs, cu_seqlens, start_positions = ctx.saved_tensors
grad_input = tex.fused_rope_backward( grad_input = tex.fused_rope_backward(
grad_output, grad_output,
freqs, freqs,
start_positions,
QKVFormat[ctx.tensor_format], QKVFormat[ctx.tensor_format],
ctx.interleaved, ctx.interleaved,
cu_seqlens, cu_seqlens,
...@@ -171,7 +172,7 @@ class FusedRoPEFunc(torch.autograd.Function): ...@@ -171,7 +172,7 @@ class FusedRoPEFunc(torch.autograd.Function):
ctx.cp_rank, ctx.cp_rank,
) )
return grad_input, None, None, None, None, None, None, None return grad_input, None, None, None, None, None, None, None, None
class FusedQKVRoPEFunc(torch.autograd.Function): class FusedQKVRoPEFunc(torch.autograd.Function):
...@@ -278,7 +279,6 @@ def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor: ...@@ -278,7 +279,6 @@ def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor:
def _apply_rotary_pos_emb_base( def _apply_rotary_pos_emb_base(
t: torch.Tensor, t: torch.Tensor,
freqs: torch.Tensor, freqs: torch.Tensor,
start_positions: torch.Tensor = None,
tensor_format: str = "sbhd", tensor_format: str = "sbhd",
interleaved: bool = False, interleaved: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -291,45 +291,19 @@ def _apply_rotary_pos_emb_base( ...@@ -291,45 +291,19 @@ def _apply_rotary_pos_emb_base(
Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional
embedding will be applied. embedding will be applied.
freqs: torch.Tensor freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` or `[s2, b, 1, d2]`
with `s2 >= s` and `d2 <= d`. and dtype 'float', with `s2 >= s` and `d2 <= d`.
start_positions: torch.Tensor, default = None.
Tokens in a sequence `i` should be applied with position encoding offset by
`start_positions[i]`. If `start_positions=None`, there's no offset.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd' tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape
`[seq, bs, ...]`. `[seq, bs, ...]`.
interleaved: bool, default = False interleaved: bool, default = False
Whether to use interleaved rotary position embedding. Whether to use interleaved rotary position embedding.
""" """
max_seq_len = freqs.shape[0]
cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0]
# In case `start_positions` are provided, create a staggered `freqs` tensor
# offset by the values in `start_positions`.
# `start_positions` is only supported for `cp_size=1` and inference.
if start_positions is not None:
max_offset = torch.max(start_positions)
assert (
max_offset + cur_seq_len <= max_seq_len
), f"Rotary Embeddings only suppported up to {max_seq_len} sequence length!"
# Stack staggered rope embeddings along the batch dimension
freqs = torch.concatenate([freqs[i : i + cur_seq_len] for i in start_positions], dim=1)
# Note that from this point, `freqs` has a shape `(s,b,1,d)`.
# Only apply the rotary embeddings up to the sequence length of the running
# input.
assert (
cur_seq_len <= max_seq_len
), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!"
freqs = freqs[:cur_seq_len]
# [seq, 1, 1, dim] -> [1, seq, 1, dim] or # [seq, 1, 1, dim] -> [1, seq, 1, dim] or
# [seq, b, 1, dim] -> [b, seq, 1, dim] # [seq, b, 1, dim] -> [b, seq, 1, dim]
if tensor_format == "bshd": if tensor_format == "bshd":
freqs = freqs.transpose(0, 1) freqs = freqs.transpose(0, 1)
# cos/sin first then dtype conversion for better precision # cos/sin first then dtype conversion for better precision
cos_ = torch.cos(freqs).to(t.dtype) cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype) sin_ = torch.sin(freqs).to(t.dtype)
...@@ -366,7 +340,7 @@ def _get_freqs_on_this_cp_rank( ...@@ -366,7 +340,7 @@ def _get_freqs_on_this_cp_rank(
) )
# cp_size == 1 # cp_size == 1
return freqs return freqs[:seqlen]
def apply_rotary_pos_emb( def apply_rotary_pos_emb(
...@@ -388,13 +362,13 @@ def apply_rotary_pos_emb( ...@@ -388,13 +362,13 @@ def apply_rotary_pos_emb(
Training: Training:
qkv_formats: "thd", "bshd", "sbhd" qkv_formats: "thd", "bshd", "sbhd"
context parallel: yes context parallel: yes
start_positions: no start_positions: yes
interleaving: yes interleaving: yes
Inference: Inference:
qkv_formats: "thd", "bshd", "sbhd" qkv_formats: "thd", "bshd", "sbhd"
context parallelism: no context parallelism: no
start_positions: yes start_positions: yes
interleaving: yes interleaving: yes
Parameters Parameters
---------- ----------
...@@ -423,22 +397,17 @@ def apply_rotary_pos_emb( ...@@ -423,22 +397,17 @@ def apply_rotary_pos_emb(
cp_rank: int, default = 0. cp_rank: int, default = 0.
Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True. Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True.
""" """
# `start_positions` is only supported for `cp_size=1` and inference.
assert not (
cp_size > 1 and start_positions is not None
), """start_positions != None with CP SIZE > 1 is not supported!"""
assert ( assert (
tensor_format != "thd" or cu_seqlens is not None tensor_format != "thd" or cu_seqlens is not None
), "cu_seqlens must not be None when tensor_format is 'thd'." ), "cu_seqlens must not be None when tensor_format is 'thd'."
# Fused apply rope logic for THD/BSHD/SBHD formats
if fused: if fused:
return FusedRoPEFunc.apply( return FusedRoPEFunc.apply(
t, freqs, start_positions, tensor_format, interleaved, cu_seqlens, cp_size, cp_rank t, freqs, start_positions, tensor_format, interleaved, cu_seqlens, cp_size, cp_rank
) )
# Unfused THD format # Unfused apply rope logic for THD format
if tensor_format == "thd": if tensor_format == "thd":
cu_seqlens = cu_seqlens // cp_size cu_seqlens = cu_seqlens // cp_size
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
...@@ -447,15 +416,18 @@ def apply_rotary_pos_emb( ...@@ -447,15 +416,18 @@ def apply_rotary_pos_emb(
# `s1hd` tensors (for each sequence) and applies rotary embedding to # `s1hd` tensors (for each sequence) and applies rotary embedding to
# those sequences individually. # those sequences individually.
# Note that if `start_positions` is not `None`, then for each sequence, # Note that if `start_positions` is not `None`, then for each sequence,
# it's corresponding rope offset is also supplied from `start_positions` # the freqs supplied are offset by the corresponding `start_positions` value.
# individually.
return torch.cat( return torch.cat(
[ [
_apply_rotary_pos_emb_base( _apply_rotary_pos_emb_base(
x.unsqueeze(1), x.unsqueeze(1),
_get_freqs_on_this_cp_rank(freqs, x.size(0), cp_size, cp_rank), _get_freqs_on_this_cp_rank(
start_positions=( (
start_positions[idx : idx + 1] if start_positions is not None else None freqs[start_positions[idx] :] if start_positions is not None else freqs
), # offset the freqs
x.size(0),
cp_size,
cp_rank,
), ),
interleaved=interleaved, interleaved=interleaved,
) )
...@@ -463,17 +435,28 @@ def apply_rotary_pos_emb( ...@@ -463,17 +435,28 @@ def apply_rotary_pos_emb(
] ]
).squeeze(1) ).squeeze(1)
# Unfused SBHD/BSHD format # Unfused apply rope logic for SBHD/BSHD format follows ...
if tensor_format == "sbhd": if tensor_format == "sbhd":
seqlen = t.size(0) seqlen = t.size(0)
elif tensor_format == "bshd": elif tensor_format == "bshd":
seqlen = t.size(1) seqlen = t.size(1)
else: else:
raise ValueError(f"Unsupported tensor_format: {tensor_format}.") raise ValueError(f"Unsupported tensor_format: {tensor_format}.")
if start_positions is not None:
max_offset = torch.max(start_positions)
assert (
max_offset + seqlen * cp_size <= freqs.shape[0]
), f"Rotary Embeddings only suppported up to {freqs.shape[0]} sequence length!"
# Stack staggered rope embeddings along the batch dimension
freqs = torch.concatenate([freqs[i : i + seqlen * cp_size] for i in start_positions], dim=1)
# Note that from this point, `freqs` has a shape `(s,b,1,d)`.
return _apply_rotary_pos_emb_base( return _apply_rotary_pos_emb_base(
t, t,
_get_freqs_on_this_cp_rank(freqs, seqlen, cp_size, cp_rank), _get_freqs_on_this_cp_rank(freqs, seqlen, cp_size, cp_rank),
start_positions,
tensor_format, tensor_format,
interleaved=interleaved, interleaved=interleaved,
) )
...@@ -505,7 +488,7 @@ def apply_fused_qkv_rotary_pos_emb( ...@@ -505,7 +488,7 @@ def apply_fused_qkv_rotary_pos_emb(
qkv_formats: "bshd", "sbhd" qkv_formats: "bshd", "sbhd"
context parallelism: no context parallelism: no
start_positions: yes start_positions: yes
interleaving: yes interleaving: yes
Parameters Parameters
---------- ----------
......
...@@ -346,6 +346,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, ...@@ -346,6 +346,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
const int cp_rank); const int cp_rank);
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const std::optional<at::Tensor> start_positions,
const NVTE_QKV_Format qkv_format, const bool interleaved, const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size, const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank); const int cp_rank);
......
...@@ -163,6 +163,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_qkv_rope_forward( ...@@ -163,6 +163,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_qkv_rope_forward(
} }
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const std::optional<at::Tensor> start_positions,
const NVTE_QKV_Format qkv_format, const bool interleaved, const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size, const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank) { const int cp_rank) {
...@@ -180,6 +181,12 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor ...@@ -180,6 +181,12 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
auto freqs_cu = makeTransformerEngineTensor(freqs); auto freqs_cu = makeTransformerEngineTensor(freqs);
auto input_grads_cu = makeTransformerEngineTensor(input_grads); auto input_grads_cu = makeTransformerEngineTensor(input_grads);
auto start_positions_cu = TensorWrapper(); // empty start_positions tensor
if (start_positions) {
start_positions_cu = makeTransformerEngineTensor(start_positions.value());
TORCH_CHECK(start_positions_cu.ndim() == 1, "expected 1D tensor");
}
if (qkv_format == NVTE_QKV_Format::NVTE_THD) { if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor");
TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor"); TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor");
...@@ -208,8 +215,8 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor ...@@ -208,8 +215,8 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value()); auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value());
nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, start_positions_cu.data(), input_grads_cu.data(), qkv_format,
max_s, b, h, d, d2, stride_t, interleaved, cp_size, cp_rank, max_s, b, h, d, d2, stride_t,
/*stride_b=*/0, stride_h, stride_d, at::cuda::getCurrentCUDAStream()); /*stride_b=*/0, stride_h, stride_d, at::cuda::getCurrentCUDAStream());
return input_grads; return input_grads;
...@@ -246,9 +253,9 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor ...@@ -246,9 +253,9 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
auto cu_seqlens_cu = TensorWrapper(); // empty cu_seqlens tensor auto cu_seqlens_cu = TensorWrapper(); // empty cu_seqlens tensor
nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(),
input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b, start_positions_cu.data(), input_grads_cu.data(), qkv_format,
h, d, d2, stride_s, stride_b, stride_h, stride_d, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s, stride_b,
at::cuda::getCurrentCUDAStream()); stride_h, stride_d, at::cuda::getCurrentCUDAStream());
return input_grads; return input_grads;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment