Unverified Commit 84fa28d2 authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

Fused RoPE with combined QKV input. (#2122)



* Fused RoPE with combined QKV input.

Initial commit for Dropout with 8-bit RNG

Fix documentation

Initial commit for Fused QKV RoPE

WIP

Initial tests passing

Enable rotary percent and margin

Enable CP2, start_positions, interleaved

Cleanup test

Revert "Fix documentation"

This reverts commit 53df10044e7769982bd4af2ae2628e6b7717e715.

Revert "Initial commit for Dropout with 8-bit RNG"

This reverts commit 301505e24031cbcd679069e1c2cd4d00eedf2dca.

Cleanup.

Minor cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Optimize kernels
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Misc. Cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Optimize kernel performance
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Move fused_qkv_rope test to test_fused_rope.py
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* apply shared memory optimization to separate fused rope kernels
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix lint
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 603dbf72
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Callable, Tuple, Union
from typing import Callable, Tuple, Union, List
import math
import torch
import pytest
from transformer_engine.pytorch.attention.rope import (
RotaryPositionEmbedding,
apply_rotary_pos_emb,
apply_fused_qkv_rotary_pos_emb,
)
# Gradient is a broadcasted scalar
def _overlapping_grad(output: torch.Tensor) -> torch.Tensor:
def _overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor:
if isinstance(output, List):
return sum(t.sum() * 2 for t in output)
else:
return output.sum() * 2
# Gradient is a full tensor
def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
def _non_overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor:
if isinstance(output, List):
return sum(torch.sum(t * torch.ones_like(t)) for t in output)
else:
t = torch.ones_like(output)
return torch.sum(output * t)
......@@ -238,3 +245,131 @@ def test_fused_rope_thd(
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()
@pytest.mark.parametrize("start_positions", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("seq_length", [2, 8, 2048, 4096])
@pytest.mark.parametrize("hidden_size", [64, 128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("margin", [0, 10])
@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
@pytest.mark.parametrize("cp_size", [1, 2])
@pytest.mark.parametrize("interleaved", [True, False])
def test_fused_qkv_rope(
dtype: torch.dtype,
seq_length: int,
hidden_size: int,
rotary_percent: float,
margin: int,
tensor_format: str,
loss_func: Callable,
cp_size: int,
interleaved: bool,
start_positions: bool,
) -> None:
if margin == 0 and start_positions == True:
# This makes sure that the `start_positions` offsets being applied
# are with the maximum length of the rope embeddings.
pytest.skip("Skipping test with margin=0 and start_positions=True")
if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")
if seq_length - margin < 0:
pytest.skip("Skipping test with seq_length - margin < 0")
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
t = torch.rand(
(seq_length - margin, batch_size, head_num, hidden_size * 6),
dtype=dtype,
device=device,
)
# Get arbitrary offsets to be used with RoPE for all the sequences
start_positions = (
torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device)
if start_positions
else None
)
if tensor_format == "bshd":
t = t.transpose(0, 1).contiguous()
t.requires_grad = True
rotary_pos_emb_q = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb_q = rotary_pos_emb_q(seq_length * cp_size)
rotary_pos_emb_k = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb_k = rotary_pos_emb_k(seq_length * cp_size)
for cp_rank in range(cp_size):
# unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison
t_clone = t.clone()
(query, key, value) = torch.split(
t_clone, [hidden_size * 4, hidden_size, hidden_size], dim=3
)
query = query.reshape(query.shape[0], query.shape[1], head_num * 4, hidden_size)
query_unfused = apply_rotary_pos_emb(
query,
emb_q,
tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved,
fused=True,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
key_unfused = apply_rotary_pos_emb(
key,
emb_k,
tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved,
fused=True,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
value_unfused = value
loss_unfused = loss_func([query_unfused, key_unfused, value_unfused])
if not isinstance(start_positions, torch.Tensor):
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None
# fused
query_fused, key_fused, value_fused = apply_fused_qkv_rotary_pos_emb(
t,
emb_q,
emb_k,
tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved,
cp_size=cp_size,
cp_rank=cp_rank,
qkv_split_arg_list=[hidden_size * 4, hidden_size, hidden_size],
)
loss_fused = loss_func([query_fused, key_fused, value_fused])
if not isinstance(start_positions, torch.Tensor):
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(query_fused, query_unfused)
torch.testing.assert_close(key_fused, key_unfused)
torch.testing.assert_close(value_fused, value_unfused)
if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)
......@@ -75,6 +75,69 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream);
/*! \brief Apply rotary positional embedding to the combined QKV input tensor.
*
* \param[in] qkv_input Combined QKV input tensor for fused rope.
* \param[in] q_freqs The freqs tensor for Q.
* \param[in] k_freqs The freqs tensor for K.
* \param[in] start_positions The beginning offsets for applying RoPE embeddings.
* \param[out] q_out Output tensor for Q.
* \param[out] k_out Output tensor for K.
* \param[out] v_out Output tensor for V.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] s Length of the s dimension of input.
* \param[in] b Length of the b dimension of input.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] qkv_split_arg_list_0 The hidden size for Q.
* \param[in] qkv_split_arg_list_1 The hidden size for K.
* \param[in] qkv_split_arg_list_2 The hidden size for V.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs,
const NVTETensor k_freqs, const NVTETensor start_positions,
NVTETensor q_out, NVTETensor k_out, NVTETensor v_out,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2,
const int qkv_split_arg_list_0, const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream);
/*! \brief Compute the backward of the fused qkv rope.
*
* \param[in] q_grad_out Incoming gradient tensor for Q.
* \param[in] k_grad_out Incoming gradient tensor for K.
* \param[in] v_grad_out Incoming gradient tensor for V.
* \param[in] q_freqs The freqs tensor for Q.
* \param[in] k_freqs The freqs tensor for K.
* \param[out] qkv_grad_input Input gradient tensor to calculate.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] s Length of the s dimension of input.
* \param[in] b Length of the b dimension of input.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] qkv_split_arg_list_0 The hidden size for Q.
* \param[in] qkv_split_arg_list_1 The hidden size for K.
* \param[in] qkv_split_arg_list_2 The hidden size for V.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_qkv_rope_backward(const NVTETensor q_grad_out, const NVTETensor k_grad_out,
const NVTETensor v_grad_out, const NVTETensor q_freqs,
const NVTETensor k_freqs, NVTETensor qkv_grad_input,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2,
const int qkv_split_arg_list_0, const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -5,14 +5,14 @@
"""
Rotary Position Embedding implementation of different types along with helper functions
"""
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, List
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat
__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb"]
__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"]
class RotaryPositionEmbedding(torch.nn.Module):
......@@ -170,6 +170,86 @@ class FusedRoPEFunc(torch.autograd.Function):
return grad_input, None, None, None, None, None, None, None
class FusedQKVRoPEFunc(torch.autograd.Function):
"""
Function for FusedQKVRoPE
This implementation accepts combined QKV tensor in `bshd` or `sbhd` format. Q and K RoPE tensors are the additional required inputs.
The RoPE tensors should be of shape (s, 1, 1, d). It produces 3 outputs: Q, K after RoPE, V is the same as input.
"""
@staticmethod
def forward(
ctx,
qkv: torch.Tensor,
q_freqs: torch.Tensor,
k_freqs: torch.Tensor,
qkv_split_arg_list: List[int],
start_positions: Union[torch.Tensor, None] = None,
tensor_format: str = "sbhd",
interleaved: bool = False,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""Fused RoPE forward."""
if q_freqs.dtype != torch.float32:
q_freqs = q_freqs.float()
if k_freqs.dtype != torch.float32:
k_freqs = k_freqs.float()
assert tensor_format in (
"sbhd",
"bshd",
), f"Unsupported tensor_format: {tensor_format}."
assert qkv.is_contiguous(), "QKV Tensor should be contiguous."
assert q_freqs.is_contiguous(), "q_freqs Tensor should be contiguous."
assert k_freqs.is_contiguous(), "k_freqs Tensor should be contiguous."
output = tex.fused_qkv_rope_forward(
qkv,
q_freqs,
k_freqs,
start_positions,
qkv_split_arg_list,
QKVFormat[tensor_format],
interleaved,
cp_size,
cp_rank,
)
ctx.save_for_backward(q_freqs, k_freqs)
ctx.tensor_format = tensor_format
ctx.qkv_split_arg_list = qkv_split_arg_list
ctx.cp_size = cp_size
ctx.cp_rank = cp_rank
ctx.interleaved = interleaved
return output
@staticmethod
def backward(
ctx, grad_output_q: torch.Tensor, grad_output_k: torch.Tensor, grad_output_v: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused RoPE backward."""
q_freqs, k_freqs = ctx.saved_tensors
grad_output_q = grad_output_q.contiguous()
grad_output_k = grad_output_k.contiguous()
grad_output_v = grad_output_v.contiguous()
grad_input = tex.fused_qkv_rope_backward(
grad_output_q,
grad_output_k,
grad_output_v,
q_freqs,
k_freqs,
ctx.qkv_split_arg_list,
QKVFormat[ctx.tensor_format],
ctx.interleaved,
ctx.cp_size,
ctx.cp_rank,
)
return grad_input, None, None, None, None, None, None, None, None
def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor:
"""Change sign so the last dimension becomes [-odd, +even]
......@@ -393,3 +473,82 @@ def apply_rotary_pos_emb(
tensor_format,
interleaved=interleaved,
)
def apply_fused_qkv_rotary_pos_emb(
qkv: torch.Tensor,
q_freqs: torch.Tensor,
k_freqs: torch.Tensor,
qkv_split_arg_list: List[int],
tensor_format: str = "sbhd",
start_positions: Union[torch.Tensor, None] = None,
interleaved: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None, # pylint: disable=unused-argument
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""
Apply rotary positional embedding tensor to the input qkv tensor.
Support matrix:
Fused:
Training:
qkv_formats: "bshd", "sbhd"
context parallel: yes
start_positions: no
interleaving: yes
Inference:
qkv_formats: "bshd", "sbhd"
context parallelism: no
start_positions: yes
interleaving: yes
Parameters
----------
qkv: torch.Tensor
Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which
rotary positional embedding will be applied. This tensor has q, k, v concatenated
along the last dimension.
q_freqs: torch.Tensor
Rotary positional embedding Q tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
k_freqs: torch.Tensor
Rotary positional embedding K tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
qkv_split_arg_list: List[int]
List of integers that specify the split of the qkv tensor. The list should have 3 elements,
the first element is the number of elements in the q tensor, the second element is the number
of elements in the k tensor, and the third element is the number of elements in the v tensor.
The sum of the elements in the list should be equal to the last dimension of the qkv tensor.
start_positions: torch.Tensor, default = None.
Tokens in a sequence `i` should be applied with position encoding offset by
`start_positions[i]`. If `start_positions=None`, there's no offset.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
is `bshd` if `qkv` is of shape `[bs, seq, ...]`, or `sbhd` if `qkv` is
of shape `[seq, bs, ...]`.
interleaved: bool, default = False
Whether to use interleaved rotary position embedding.
cp_size: int, default = 1.
Context parallel world size.
cp_rank: int, default = 0.
Context parallel rank.
"""
# `start_positions` is only supported for `cp_size=1` and inference.
assert not (
cp_size > 1 and start_positions is not None
), """start_positions != None with CP SIZE > 1 is not supported!"""
assert tensor_format != "thd", "'thd' tensor_format not supported currently."
return FusedQKVRoPEFunc.apply(
qkv,
q_freqs,
k_freqs,
qkv_split_arg_list,
start_positions,
tensor_format,
interleaved,
cp_size,
cp_rank,
)
......@@ -338,6 +338,18 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank);
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_qkv_rope_forward(
const at::Tensor &qkv_input, const at::Tensor &q_freqs, const at::Tensor &k_freqs,
const std::optional<at::Tensor> start_positions, const std::vector<int> &qkv_split_arg_list,
const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank);
at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tensor &k_grad_out,
const at::Tensor &v_grad_out, const at::Tensor &q_freqs,
const at::Tensor &k_freqs,
const std::vector<int> &qkv_split_arg_list,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank);
/***************************************************************************************************
* Miscellaneous
**************************************************************************************************/
......
......@@ -102,6 +102,65 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
return output;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_qkv_rope_forward(
const at::Tensor &qkv_input, const at::Tensor &q_freqs, const at::Tensor &k_freqs,
const std::optional<at::Tensor> start_positions, const std::vector<int> &qkv_split_arg_list,
const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size,
const int cp_rank) {
TORCH_CHECK(q_freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(q_freqs.size(1) == 1 && q_freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(q_freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
TORCH_CHECK(k_freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(k_freqs.size(1) == 1 && k_freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(k_freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// output
auto act_options = at::TensorOptions().dtype(qkv_input.scalar_type()).device(qkv_input.device());
auto q_out_size = qkv_input.sizes().vec();
q_out_size[2] = q_out_size[2] * qkv_split_arg_list[0] / qkv_split_arg_list[1];
q_out_size[3] = qkv_split_arg_list[1];
auto q_out = at::empty(q_out_size, act_options);
auto k_out_size = qkv_input.sizes().vec();
k_out_size[3] = qkv_split_arg_list[1];
auto k_out = at::empty(k_out_size, act_options);
auto v_out_size = qkv_input.sizes().vec();
v_out_size[3] = qkv_split_arg_list[2];
auto v_out = at::empty(v_out_size, act_options);
auto qkv_cu = makeTransformerEngineTensor(qkv_input);
auto q_freqs_cu = makeTransformerEngineTensor(q_freqs);
auto k_freqs_cu = makeTransformerEngineTensor(k_freqs);
auto q_out_cu = makeTransformerEngineTensor(q_out);
auto k_out_cu = makeTransformerEngineTensor(k_out);
auto v_out_cu = makeTransformerEngineTensor(v_out);
auto start_positions_cu = TensorWrapper(); // empty cu_seqlens tensor
if (start_positions) {
start_positions_cu = makeTransformerEngineTensor(start_positions.value());
}
TORCH_CHECK(qkv_input.dim() == 4, "expected 4D input tensor");
TORCH_CHECK(qkv_input.is_contiguous(), "input tensor must be contiguous");
const bool is_sbhd = qkv_format == NVTE_QKV_Format::NVTE_SBHD;
const int s = is_sbhd ? qkv_input.size(0) : qkv_input.size(1);
const int b = is_sbhd ? qkv_input.size(1) : qkv_input.size(0);
const int h = qkv_input.size(2);
const int d = qkv_split_arg_list[2];
const int d2 = q_freqs.size(3);
nvte_fused_qkv_rope_forward(qkv_cu.data(), q_freqs_cu.data(), k_freqs_cu.data(),
start_positions_cu.data(), q_out_cu.data(), k_out_cu.data(),
v_out_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b, h,
d, d2, qkv_split_arg_list[0], qkv_split_arg_list[1],
qkv_split_arg_list[2], at::cuda::getCurrentCUDAStream());
return std::make_tuple(q_out, k_out, v_out);
}
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size,
......@@ -193,4 +252,42 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
return input_grads;
}
at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tensor &k_grad_out,
const at::Tensor &v_grad_out, const at::Tensor &q_freqs,
const at::Tensor &k_freqs,
const std::vector<int> &qkv_split_arg_list,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank) {
auto act_options =
at::TensorOptions().dtype(q_grad_out.scalar_type()).device(q_grad_out.device());
auto qkv_grad_size = q_grad_out.sizes().vec();
auto total_hd =
(q_grad_out.size(2) + k_grad_out.size(2) + v_grad_out.size(2)) * q_grad_out.size(3);
auto total_d = qkv_split_arg_list[0] + qkv_split_arg_list[1] + qkv_split_arg_list[2];
qkv_grad_size[2] = total_hd / total_d;
qkv_grad_size[3] = total_d;
auto qkv_grad_input = at::empty(qkv_grad_size, act_options);
const bool is_sbhd = qkv_format == NVTE_QKV_Format::NVTE_SBHD;
const int s = is_sbhd ? q_grad_out.size(0) : q_grad_out.size(1);
const int b = is_sbhd ? q_grad_out.size(1) : q_grad_out.size(0);
const int h = qkv_grad_input.size(2);
const int d = qkv_split_arg_list[2];
const int d2 = q_freqs.size(3);
auto q_grad_out_cu = makeTransformerEngineTensor(q_grad_out);
auto k_grad_out_cu = makeTransformerEngineTensor(k_grad_out);
auto v_grad_out_cu = makeTransformerEngineTensor(v_grad_out);
auto q_freqs_cu = makeTransformerEngineTensor(q_freqs);
auto k_freqs_cu = makeTransformerEngineTensor(k_freqs);
auto qkv_grad_cu = makeTransformerEngineTensor(qkv_grad_input);
nvte_fused_qkv_rope_backward(q_grad_out_cu.data(), k_grad_out_cu.data(), v_grad_out_cu.data(),
q_freqs_cu.data(), k_freqs_cu.data(), qkv_grad_cu.data(), qkv_format,
interleaved, cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list[0],
qkv_split_arg_list[1], qkv_split_arg_list[2],
at::cuda::getCurrentCUDAStream());
return qkv_grad_input;
}
} // namespace transformer_engine::pytorch
......@@ -278,6 +278,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Fused Apply RoPE FWD", py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_backward", &transformer_engine::pytorch::fused_rope_backward,
"Fused Apply RoPE BWD", py::call_guard<py::gil_scoped_release>());
m.def("fused_qkv_rope_forward", &transformer_engine::pytorch::fused_qkv_rope_forward,
"Fused Apply QKV RoPE FWD", py::call_guard<py::gil_scoped_release>());
m.def("fused_qkv_rope_backward", &transformer_engine::pytorch::fused_qkv_rope_backward,
"Fused Apply QKV RoPE BWD", py::call_guard<py::gil_scoped_release>());
// fused router
m.def("fused_topk_with_score_function_fwd",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment