Unverified Commit 84fa28d2 authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

Fused RoPE with combined QKV input. (#2122)



* Fused RoPE with combined QKV input.

Initial commit for Dropout with 8-bit RNG

Fix documentation

Initial commit for Fused QKV RoPE

WIP

Initial tests passing

Enable rotary percent and margin

Enable CP2, start_positions, interleaved

Cleanup test

Revert "Fix documentation"

This reverts commit 53df10044e7769982bd4af2ae2628e6b7717e715.

Revert "Initial commit for Dropout with 8-bit RNG"

This reverts commit 301505e24031cbcd679069e1c2cd4d00eedf2dca.

Cleanup.

Minor cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Optimize kernels
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Misc. Cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Optimize kernel performance
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Move fused_qkv_rope test to test_fused_rope.py
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* apply shared memory optimization to separate fused rope kernels
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix lint
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 603dbf72
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
from typing import Callable, Tuple, Union from typing import Callable, Tuple, Union, List
import math import math
import torch import torch
import pytest import pytest
from transformer_engine.pytorch.attention.rope import ( from transformer_engine.pytorch.attention.rope import (
RotaryPositionEmbedding, RotaryPositionEmbedding,
apply_rotary_pos_emb, apply_rotary_pos_emb,
apply_fused_qkv_rotary_pos_emb,
) )
# Gradient is a broadcasted scalar # Gradient is a broadcasted scalar
def _overlapping_grad(output: torch.Tensor) -> torch.Tensor: def _overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor:
if isinstance(output, List):
return sum(t.sum() * 2 for t in output)
else:
return output.sum() * 2 return output.sum() * 2
# Gradient is a full tensor # Gradient is a full tensor
def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor: def _non_overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor:
if isinstance(output, List):
return sum(torch.sum(t * torch.ones_like(t)) for t in output)
else:
t = torch.ones_like(output) t = torch.ones_like(output)
return torch.sum(output * t) return torch.sum(output * t)
...@@ -238,3 +245,131 @@ def test_fused_rope_thd( ...@@ -238,3 +245,131 @@ def test_fused_rope_thd(
torch.testing.assert_close(grad_fused, grad_unfused) torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous() assert output_fused.is_contiguous()
@pytest.mark.parametrize("start_positions", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("seq_length", [2, 8, 2048, 4096])
@pytest.mark.parametrize("hidden_size", [64, 128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("margin", [0, 10])
@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
@pytest.mark.parametrize("cp_size", [1, 2])
@pytest.mark.parametrize("interleaved", [True, False])
def test_fused_qkv_rope(
dtype: torch.dtype,
seq_length: int,
hidden_size: int,
rotary_percent: float,
margin: int,
tensor_format: str,
loss_func: Callable,
cp_size: int,
interleaved: bool,
start_positions: bool,
) -> None:
if margin == 0 and start_positions == True:
# This makes sure that the `start_positions` offsets being applied
# are with the maximum length of the rope embeddings.
pytest.skip("Skipping test with margin=0 and start_positions=True")
if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")
if seq_length - margin < 0:
pytest.skip("Skipping test with seq_length - margin < 0")
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
t = torch.rand(
(seq_length - margin, batch_size, head_num, hidden_size * 6),
dtype=dtype,
device=device,
)
# Get arbitrary offsets to be used with RoPE for all the sequences
start_positions = (
torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device)
if start_positions
else None
)
if tensor_format == "bshd":
t = t.transpose(0, 1).contiguous()
t.requires_grad = True
rotary_pos_emb_q = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb_q = rotary_pos_emb_q(seq_length * cp_size)
rotary_pos_emb_k = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb_k = rotary_pos_emb_k(seq_length * cp_size)
for cp_rank in range(cp_size):
# unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison
t_clone = t.clone()
(query, key, value) = torch.split(
t_clone, [hidden_size * 4, hidden_size, hidden_size], dim=3
)
query = query.reshape(query.shape[0], query.shape[1], head_num * 4, hidden_size)
query_unfused = apply_rotary_pos_emb(
query,
emb_q,
tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved,
fused=True,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
key_unfused = apply_rotary_pos_emb(
key,
emb_k,
tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved,
fused=True,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
value_unfused = value
loss_unfused = loss_func([query_unfused, key_unfused, value_unfused])
if not isinstance(start_positions, torch.Tensor):
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None
# fused
query_fused, key_fused, value_fused = apply_fused_qkv_rotary_pos_emb(
t,
emb_q,
emb_k,
tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved,
cp_size=cp_size,
cp_rank=cp_rank,
qkv_split_arg_list=[hidden_size * 4, hidden_size, hidden_size],
)
loss_fused = loss_func([query_fused, key_fused, value_fused])
if not isinstance(start_positions, torch.Tensor):
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(query_fused, query_unfused)
torch.testing.assert_close(key_fused, key_unfused)
torch.testing.assert_close(value_fused, value_unfused)
if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)
...@@ -75,6 +75,69 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu ...@@ -75,6 +75,69 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu
const int stride_b, const int stride_h, const int stride_d, const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Apply rotary positional embedding to the combined QKV input tensor.
*
* \param[in] qkv_input Combined QKV input tensor for fused rope.
* \param[in] q_freqs The freqs tensor for Q.
* \param[in] k_freqs The freqs tensor for K.
* \param[in] start_positions The beginning offsets for applying RoPE embeddings.
* \param[out] q_out Output tensor for Q.
* \param[out] k_out Output tensor for K.
* \param[out] v_out Output tensor for V.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] s Length of the s dimension of input.
* \param[in] b Length of the b dimension of input.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] qkv_split_arg_list_0 The hidden size for Q.
* \param[in] qkv_split_arg_list_1 The hidden size for K.
* \param[in] qkv_split_arg_list_2 The hidden size for V.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs,
const NVTETensor k_freqs, const NVTETensor start_positions,
NVTETensor q_out, NVTETensor k_out, NVTETensor v_out,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2,
const int qkv_split_arg_list_0, const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream);
/*! \brief Compute the backward of the fused qkv rope.
*
* \param[in] q_grad_out Incoming gradient tensor for Q.
* \param[in] k_grad_out Incoming gradient tensor for K.
* \param[in] v_grad_out Incoming gradient tensor for V.
* \param[in] q_freqs The freqs tensor for Q.
* \param[in] k_freqs The freqs tensor for K.
* \param[out] qkv_grad_input Input gradient tensor to calculate.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] s Length of the s dimension of input.
* \param[in] b Length of the b dimension of input.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] qkv_split_arg_list_0 The hidden size for Q.
* \param[in] qkv_split_arg_list_1 The hidden size for K.
* \param[in] qkv_split_arg_list_2 The hidden size for V.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_qkv_rope_backward(const NVTETensor q_grad_out, const NVTETensor k_grad_out,
const NVTETensor v_grad_out, const NVTETensor q_freqs,
const NVTETensor k_freqs, NVTETensor qkv_grad_input,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2,
const int qkv_split_arg_list_0, const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
""" """
Rotary Position Embedding implementation of different types along with helper functions Rotary Position Embedding implementation of different types along with helper functions
""" """
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union, List
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat
__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb"] __all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"]
class RotaryPositionEmbedding(torch.nn.Module): class RotaryPositionEmbedding(torch.nn.Module):
...@@ -170,6 +170,86 @@ class FusedRoPEFunc(torch.autograd.Function): ...@@ -170,6 +170,86 @@ class FusedRoPEFunc(torch.autograd.Function):
return grad_input, None, None, None, None, None, None, None return grad_input, None, None, None, None, None, None, None
class FusedQKVRoPEFunc(torch.autograd.Function):
"""
Function for FusedQKVRoPE
This implementation accepts combined QKV tensor in `bshd` or `sbhd` format. Q and K RoPE tensors are the additional required inputs.
The RoPE tensors should be of shape (s, 1, 1, d). It produces 3 outputs: Q, K after RoPE, V is the same as input.
"""
@staticmethod
def forward(
ctx,
qkv: torch.Tensor,
q_freqs: torch.Tensor,
k_freqs: torch.Tensor,
qkv_split_arg_list: List[int],
start_positions: Union[torch.Tensor, None] = None,
tensor_format: str = "sbhd",
interleaved: bool = False,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""Fused RoPE forward."""
if q_freqs.dtype != torch.float32:
q_freqs = q_freqs.float()
if k_freqs.dtype != torch.float32:
k_freqs = k_freqs.float()
assert tensor_format in (
"sbhd",
"bshd",
), f"Unsupported tensor_format: {tensor_format}."
assert qkv.is_contiguous(), "QKV Tensor should be contiguous."
assert q_freqs.is_contiguous(), "q_freqs Tensor should be contiguous."
assert k_freqs.is_contiguous(), "k_freqs Tensor should be contiguous."
output = tex.fused_qkv_rope_forward(
qkv,
q_freqs,
k_freqs,
start_positions,
qkv_split_arg_list,
QKVFormat[tensor_format],
interleaved,
cp_size,
cp_rank,
)
ctx.save_for_backward(q_freqs, k_freqs)
ctx.tensor_format = tensor_format
ctx.qkv_split_arg_list = qkv_split_arg_list
ctx.cp_size = cp_size
ctx.cp_rank = cp_rank
ctx.interleaved = interleaved
return output
@staticmethod
def backward(
ctx, grad_output_q: torch.Tensor, grad_output_k: torch.Tensor, grad_output_v: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused RoPE backward."""
q_freqs, k_freqs = ctx.saved_tensors
grad_output_q = grad_output_q.contiguous()
grad_output_k = grad_output_k.contiguous()
grad_output_v = grad_output_v.contiguous()
grad_input = tex.fused_qkv_rope_backward(
grad_output_q,
grad_output_k,
grad_output_v,
q_freqs,
k_freqs,
ctx.qkv_split_arg_list,
QKVFormat[ctx.tensor_format],
ctx.interleaved,
ctx.cp_size,
ctx.cp_rank,
)
return grad_input, None, None, None, None, None, None, None, None
def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor: def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor:
"""Change sign so the last dimension becomes [-odd, +even] """Change sign so the last dimension becomes [-odd, +even]
...@@ -393,3 +473,82 @@ def apply_rotary_pos_emb( ...@@ -393,3 +473,82 @@ def apply_rotary_pos_emb(
tensor_format, tensor_format,
interleaved=interleaved, interleaved=interleaved,
) )
def apply_fused_qkv_rotary_pos_emb(
qkv: torch.Tensor,
q_freqs: torch.Tensor,
k_freqs: torch.Tensor,
qkv_split_arg_list: List[int],
tensor_format: str = "sbhd",
start_positions: Union[torch.Tensor, None] = None,
interleaved: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None, # pylint: disable=unused-argument
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""
Apply rotary positional embedding tensor to the input qkv tensor.
Support matrix:
Fused:
Training:
qkv_formats: "bshd", "sbhd"
context parallel: yes
start_positions: no
interleaving: yes
Inference:
qkv_formats: "bshd", "sbhd"
context parallelism: no
start_positions: yes
interleaving: yes
Parameters
----------
qkv: torch.Tensor
Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which
rotary positional embedding will be applied. This tensor has q, k, v concatenated
along the last dimension.
q_freqs: torch.Tensor
Rotary positional embedding Q tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
k_freqs: torch.Tensor
Rotary positional embedding K tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
qkv_split_arg_list: List[int]
List of integers that specify the split of the qkv tensor. The list should have 3 elements,
the first element is the number of elements in the q tensor, the second element is the number
of elements in the k tensor, and the third element is the number of elements in the v tensor.
The sum of the elements in the list should be equal to the last dimension of the qkv tensor.
start_positions: torch.Tensor, default = None.
Tokens in a sequence `i` should be applied with position encoding offset by
`start_positions[i]`. If `start_positions=None`, there's no offset.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
is `bshd` if `qkv` is of shape `[bs, seq, ...]`, or `sbhd` if `qkv` is
of shape `[seq, bs, ...]`.
interleaved: bool, default = False
Whether to use interleaved rotary position embedding.
cp_size: int, default = 1.
Context parallel world size.
cp_rank: int, default = 0.
Context parallel rank.
"""
# `start_positions` is only supported for `cp_size=1` and inference.
assert not (
cp_size > 1 and start_positions is not None
), """start_positions != None with CP SIZE > 1 is not supported!"""
assert tensor_format != "thd", "'thd' tensor_format not supported currently."
return FusedQKVRoPEFunc.apply(
qkv,
q_freqs,
k_freqs,
qkv_split_arg_list,
start_positions,
tensor_format,
interleaved,
cp_size,
cp_rank,
)
...@@ -338,6 +338,18 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor ...@@ -338,6 +338,18 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
const std::optional<at::Tensor> cu_seqlens, const int cp_size, const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank); const int cp_rank);
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_qkv_rope_forward(
const at::Tensor &qkv_input, const at::Tensor &q_freqs, const at::Tensor &k_freqs,
const std::optional<at::Tensor> start_positions, const std::vector<int> &qkv_split_arg_list,
const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank);
at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tensor &k_grad_out,
const at::Tensor &v_grad_out, const at::Tensor &q_freqs,
const at::Tensor &k_freqs,
const std::vector<int> &qkv_split_arg_list,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank);
/*************************************************************************************************** /***************************************************************************************************
* Miscellaneous * Miscellaneous
**************************************************************************************************/ **************************************************************************************************/
......
...@@ -102,6 +102,65 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, ...@@ -102,6 +102,65 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
return output; return output;
} }
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_qkv_rope_forward(
const at::Tensor &qkv_input, const at::Tensor &q_freqs, const at::Tensor &k_freqs,
const std::optional<at::Tensor> start_positions, const std::vector<int> &qkv_split_arg_list,
const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size,
const int cp_rank) {
TORCH_CHECK(q_freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(q_freqs.size(1) == 1 && q_freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(q_freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
TORCH_CHECK(k_freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(k_freqs.size(1) == 1 && k_freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(k_freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// output
auto act_options = at::TensorOptions().dtype(qkv_input.scalar_type()).device(qkv_input.device());
auto q_out_size = qkv_input.sizes().vec();
q_out_size[2] = q_out_size[2] * qkv_split_arg_list[0] / qkv_split_arg_list[1];
q_out_size[3] = qkv_split_arg_list[1];
auto q_out = at::empty(q_out_size, act_options);
auto k_out_size = qkv_input.sizes().vec();
k_out_size[3] = qkv_split_arg_list[1];
auto k_out = at::empty(k_out_size, act_options);
auto v_out_size = qkv_input.sizes().vec();
v_out_size[3] = qkv_split_arg_list[2];
auto v_out = at::empty(v_out_size, act_options);
auto qkv_cu = makeTransformerEngineTensor(qkv_input);
auto q_freqs_cu = makeTransformerEngineTensor(q_freqs);
auto k_freqs_cu = makeTransformerEngineTensor(k_freqs);
auto q_out_cu = makeTransformerEngineTensor(q_out);
auto k_out_cu = makeTransformerEngineTensor(k_out);
auto v_out_cu = makeTransformerEngineTensor(v_out);
auto start_positions_cu = TensorWrapper(); // empty cu_seqlens tensor
if (start_positions) {
start_positions_cu = makeTransformerEngineTensor(start_positions.value());
}
TORCH_CHECK(qkv_input.dim() == 4, "expected 4D input tensor");
TORCH_CHECK(qkv_input.is_contiguous(), "input tensor must be contiguous");
const bool is_sbhd = qkv_format == NVTE_QKV_Format::NVTE_SBHD;
const int s = is_sbhd ? qkv_input.size(0) : qkv_input.size(1);
const int b = is_sbhd ? qkv_input.size(1) : qkv_input.size(0);
const int h = qkv_input.size(2);
const int d = qkv_split_arg_list[2];
const int d2 = q_freqs.size(3);
nvte_fused_qkv_rope_forward(qkv_cu.data(), q_freqs_cu.data(), k_freqs_cu.data(),
start_positions_cu.data(), q_out_cu.data(), k_out_cu.data(),
v_out_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b, h,
d, d2, qkv_split_arg_list[0], qkv_split_arg_list[1],
qkv_split_arg_list[2], at::cuda::getCurrentCUDAStream());
return std::make_tuple(q_out, k_out, v_out);
}
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const NVTE_QKV_Format qkv_format, const bool interleaved, const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size, const std::optional<at::Tensor> cu_seqlens, const int cp_size,
...@@ -193,4 +252,42 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor ...@@ -193,4 +252,42 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
return input_grads; return input_grads;
} }
at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tensor &k_grad_out,
const at::Tensor &v_grad_out, const at::Tensor &q_freqs,
const at::Tensor &k_freqs,
const std::vector<int> &qkv_split_arg_list,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank) {
auto act_options =
at::TensorOptions().dtype(q_grad_out.scalar_type()).device(q_grad_out.device());
auto qkv_grad_size = q_grad_out.sizes().vec();
auto total_hd =
(q_grad_out.size(2) + k_grad_out.size(2) + v_grad_out.size(2)) * q_grad_out.size(3);
auto total_d = qkv_split_arg_list[0] + qkv_split_arg_list[1] + qkv_split_arg_list[2];
qkv_grad_size[2] = total_hd / total_d;
qkv_grad_size[3] = total_d;
auto qkv_grad_input = at::empty(qkv_grad_size, act_options);
const bool is_sbhd = qkv_format == NVTE_QKV_Format::NVTE_SBHD;
const int s = is_sbhd ? q_grad_out.size(0) : q_grad_out.size(1);
const int b = is_sbhd ? q_grad_out.size(1) : q_grad_out.size(0);
const int h = qkv_grad_input.size(2);
const int d = qkv_split_arg_list[2];
const int d2 = q_freqs.size(3);
auto q_grad_out_cu = makeTransformerEngineTensor(q_grad_out);
auto k_grad_out_cu = makeTransformerEngineTensor(k_grad_out);
auto v_grad_out_cu = makeTransformerEngineTensor(v_grad_out);
auto q_freqs_cu = makeTransformerEngineTensor(q_freqs);
auto k_freqs_cu = makeTransformerEngineTensor(k_freqs);
auto qkv_grad_cu = makeTransformerEngineTensor(qkv_grad_input);
nvte_fused_qkv_rope_backward(q_grad_out_cu.data(), k_grad_out_cu.data(), v_grad_out_cu.data(),
q_freqs_cu.data(), k_freqs_cu.data(), qkv_grad_cu.data(), qkv_format,
interleaved, cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list[0],
qkv_split_arg_list[1], qkv_split_arg_list[2],
at::cuda::getCurrentCUDAStream());
return qkv_grad_input;
}
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -278,6 +278,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -278,6 +278,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Fused Apply RoPE FWD", py::call_guard<py::gil_scoped_release>()); "Fused Apply RoPE FWD", py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_backward", &transformer_engine::pytorch::fused_rope_backward, m.def("fused_rope_backward", &transformer_engine::pytorch::fused_rope_backward,
"Fused Apply RoPE BWD", py::call_guard<py::gil_scoped_release>()); "Fused Apply RoPE BWD", py::call_guard<py::gil_scoped_release>());
m.def("fused_qkv_rope_forward", &transformer_engine::pytorch::fused_qkv_rope_forward,
"Fused Apply QKV RoPE FWD", py::call_guard<py::gil_scoped_release>());
m.def("fused_qkv_rope_backward", &transformer_engine::pytorch::fused_qkv_rope_backward,
"Fused Apply QKV RoPE BWD", py::call_guard<py::gil_scoped_release>());
// fused router // fused router
m.def("fused_topk_with_score_function_fwd", m.def("fused_topk_with_score_function_fwd",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment