Unverified Commit 84fa28d2 authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

Fused RoPE with combined QKV input. (#2122)



* Fused RoPE with combined QKV input.

Initial commit for Dropout with 8-bit RNG

Fix documentation

Initial commit for Fused QKV RoPE

WIP

Initial tests passing

Enable rotary percent and margin

Enable CP2, start_positions, interleaved

Cleanup test

Revert "Fix documentation"

This reverts commit 53df10044e7769982bd4af2ae2628e6b7717e715.

Revert "Initial commit for Dropout with 8-bit RNG"

This reverts commit 301505e24031cbcd679069e1c2cd4d00eedf2dca.

Cleanup.

Minor cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Optimize kernels
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Misc. Cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Optimize kernel performance
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Move fused_qkv_rope test to test_fused_rope.py
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* apply shared memory optimization to separate fused rope kernels
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix lint
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 603dbf72
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from typing import Callable, Tuple, Union
from typing import Callable, Tuple, Union, List
import math
import torch
import pytest
from transformer_engine.pytorch.attention.rope import (
RotaryPositionEmbedding,
apply_rotary_pos_emb,
apply_fused_qkv_rotary_pos_emb,
)
# Gradient is a broadcasted scalar
def _overlapping_grad(output: torch.Tensor) -> torch.Tensor:
def _overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor:
if isinstance(output, List):
return sum(t.sum() * 2 for t in output)
else:
return output.sum() * 2
# Gradient is a full tensor
def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
def _non_overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor:
if isinstance(output, List):
return sum(torch.sum(t * torch.ones_like(t)) for t in output)
else:
t = torch.ones_like(output)
return torch.sum(output * t)
......@@ -238,3 +245,131 @@ def test_fused_rope_thd(
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()
@pytest.mark.parametrize("start_positions", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("seq_length", [2, 8, 2048, 4096])
@pytest.mark.parametrize("hidden_size", [64, 128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("margin", [0, 10])
@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
@pytest.mark.parametrize("cp_size", [1, 2])
@pytest.mark.parametrize("interleaved", [True, False])
def test_fused_qkv_rope(
dtype: torch.dtype,
seq_length: int,
hidden_size: int,
rotary_percent: float,
margin: int,
tensor_format: str,
loss_func: Callable,
cp_size: int,
interleaved: bool,
start_positions: bool,
) -> None:
if margin == 0 and start_positions == True:
# This makes sure that the `start_positions` offsets being applied
# are with the maximum length of the rope embeddings.
pytest.skip("Skipping test with margin=0 and start_positions=True")
if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")
if seq_length - margin < 0:
pytest.skip("Skipping test with seq_length - margin < 0")
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
t = torch.rand(
(seq_length - margin, batch_size, head_num, hidden_size * 6),
dtype=dtype,
device=device,
)
# Get arbitrary offsets to be used with RoPE for all the sequences
start_positions = (
torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device)
if start_positions
else None
)
if tensor_format == "bshd":
t = t.transpose(0, 1).contiguous()
t.requires_grad = True
rotary_pos_emb_q = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb_q = rotary_pos_emb_q(seq_length * cp_size)
rotary_pos_emb_k = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb_k = rotary_pos_emb_k(seq_length * cp_size)
for cp_rank in range(cp_size):
# unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison
t_clone = t.clone()
(query, key, value) = torch.split(
t_clone, [hidden_size * 4, hidden_size, hidden_size], dim=3
)
query = query.reshape(query.shape[0], query.shape[1], head_num * 4, hidden_size)
query_unfused = apply_rotary_pos_emb(
query,
emb_q,
tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved,
fused=True,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
key_unfused = apply_rotary_pos_emb(
key,
emb_k,
tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved,
fused=True,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
value_unfused = value
loss_unfused = loss_func([query_unfused, key_unfused, value_unfused])
if not isinstance(start_positions, torch.Tensor):
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None
# fused
query_fused, key_fused, value_fused = apply_fused_qkv_rotary_pos_emb(
t,
emb_q,
emb_k,
tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved,
cp_size=cp_size,
cp_rank=cp_rank,
qkv_split_arg_list=[hidden_size * 4, hidden_size, hidden_size],
)
loss_fused = loss_func([query_fused, key_fused, value_fused])
if not isinstance(start_positions, torch.Tensor):
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(query_fused, query_unfused)
torch.testing.assert_close(key_fused, key_unfused)
torch.testing.assert_close(value_fused, value_unfused)
if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)
......@@ -21,12 +21,21 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs
const int h, const int d, const int d2, const int stride_h,
const int stride_d, const int o_stride_h,
const int o_stride_d) {
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
float v_cos, v_sin;
sincosf(freqs[s_id * d2 + d_id], &v_sin, &v_cos);
extern __shared__ float shared_mem_cos_sin[];
float *shared_mem_cos = shared_mem_cos_sin;
float *shared_mem_sin = shared_mem_cos_sin + d2;
int tid = threadIdx.x * blockDim.y + threadIdx.y;
for (int i = tid; i < d2; i += blockDim.x * blockDim.y) {
sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]);
}
__syncthreads();
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
float v_cos = shared_mem_cos[d_id];
float v_sin = shared_mem_sin[d_id];
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
float v_src = src[offset_src];
......@@ -48,13 +57,13 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs
// copy the rest
if (d > d2) {
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_head = offset_block + h_id * stride_h;
int offset_head_dst = offset_block_dst + h_id * o_stride_h;
#pragma unroll
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d];
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
dst[offset_dst] = src[offset_src];
}
}
}
......@@ -67,47 +76,54 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq
const int h, const int d, const int d2,
const int stride_h, const int stride_d,
const int o_stride_h, const int o_stride_d) {
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
float v_cos = cosf(freqs[s_id * d2 + d_id]);
float v_sin;
if (!interleaved) {
v_sin = (d_id + d2 / 2 < d2) ? sinf(freqs[s_id * d2 + d_id + d2 / 2])
: -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]);
} else {
v_sin =
(d_id % 2 == 0) ? sinf(freqs[s_id * d2 + d_id + 1]) : -sinf(freqs[s_id * d2 + d_id - 1]);
extern __shared__ float shared_mem_cos_sin[];
float *shared_mem_cos = shared_mem_cos_sin;
float *shared_mem_sin = shared_mem_cos_sin + d2;
int tid = threadIdx.x * blockDim.y + threadIdx.y;
for (int i = tid; i < d2; i += blockDim.x * blockDim.y) {
sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]);
}
__syncthreads();
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
float v_src = src[offset_src];
float v_src_rotate;
float v_cos = shared_mem_cos[d_id];
float v_src_rotate, v_sin;
if (!interleaved) {
v_src_rotate = (d_id + d2 / 2 < d2)
? static_cast<float>(src[offset_src + (d2 / 2) * stride_d])
: static_cast<float>(src[offset_src + (d2 / 2 - d2) * stride_d]);
if (d_id + d2 / 2 < d2) {
v_src_rotate = static_cast<float>(src[offset_src + (d2 / 2) * stride_d]);
v_sin = shared_mem_sin[d_id + d2 / 2];
} else {
v_src_rotate = (d_id % 2 == 0)
// d_id + 1
? static_cast<float>(src[offset_src + stride_d])
// d_id - 1
: static_cast<float>(src[offset_src - stride_d]);
v_src_rotate = static_cast<float>(src[offset_src + (d2 / 2 - d2) * stride_d]);
v_sin = -shared_mem_sin[d_id + d2 / 2 - d2];
}
} else {
if (d_id % 2 == 0) {
v_src_rotate = static_cast<float>(src[offset_src + stride_d]);
v_sin = shared_mem_sin[d_id + 1];
} else {
v_src_rotate = static_cast<float>(src[offset_src - stride_d]);
v_sin = -shared_mem_sin[d_id - 1];
}
}
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
}
}
// handle the tail
// copy the rest
if (d > d2) {
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_head = offset_block + h_id * stride_h;
int offset_head_dst = offset_block_dst + h_id * o_stride_h;
#pragma unroll
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d];
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
dst[offset_dst] = src[offset_src];
}
}
}
......@@ -198,6 +214,251 @@ __global__ void fused_rope_backward_kernel(
offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
}
template <typename scalar_t>
__device__ void fused_qkv_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *out,
const bool interleaved, const int s_id,
const int offset_block, const int offset_block_dst,
const int h, const int d, const int d2,
const int row_offset, const int in_row_length,
const int out_row_length) {
extern __shared__ float shared_mem_cos_sin_qk[];
// Split the shared memory into cos and sin parts for q or k
float *shared_mem_cos = nullptr;
float *shared_mem_sin = nullptr;
if (row_offset == 0) { // q
shared_mem_cos = shared_mem_cos_sin_qk;
shared_mem_sin = shared_mem_cos_sin_qk + d2;
} else { // k
shared_mem_cos = shared_mem_cos_sin_qk + 2 * d2;
shared_mem_sin = shared_mem_cos_sin_qk + 3 * d2;
}
if (freqs != nullptr) {
int tid = threadIdx.x * blockDim.y + threadIdx.y;
for (int i = tid; i < d2; i += blockDim.x * blockDim.y) {
sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]);
}
}
__syncthreads();
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
#pragma unroll
for (int i = 0; i < out_row_length; i += d) {
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
int offset_src = offset_block + h_id * in_row_length + (row_offset + i) + d_id;
int offset_dst = offset_block_dst + h_id * out_row_length + i + d_id;
if (freqs != nullptr) {
float v_cos, v_sin;
v_cos = shared_mem_cos[d_id];
v_sin = shared_mem_sin[d_id];
float v_src = src[offset_src];
float v_src_rotate;
if (!interleaved) {
v_src_rotate = (d_id + d2 / 2 < d2)
? -static_cast<float>(src[offset_src + (d2 / 2)])
: static_cast<float>(src[offset_src + (d2 / 2 - d2)]);
} else {
v_src_rotate = (d_id % 2 == 0) ? -static_cast<float>(src[offset_src + 1])
: static_cast<float>(src[offset_src - 1]);
}
out[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
} else {
out[offset_dst] = src[offset_src];
}
}
}
}
// copy the rest
if (d > d2) {
#pragma unroll
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
#pragma unroll
for (int i = 0; i < out_row_length; i += d) {
int offset_src = offset_block + h_id * in_row_length + (row_offset + i) + d_id;
int offset_dst = offset_block_dst + h_id * out_row_length + i + d_id;
out[offset_dst] = src[offset_src];
}
}
}
}
}
template <typename scalar_t>
__device__ void fused_qkv_rope_block_backward(const scalar_t *grad_out, const float *freqs,
scalar_t *out, const bool interleaved, const int s_id,
const int offset_block, const int offset_block_dst,
const int h, const int d, const int d2,
const int row_offset, const int in_row_length,
const int out_row_length) {
extern __shared__ float shared_mem_cos_sin_qk[];
float *shared_mem_cos = nullptr;
float *shared_mem_sin = nullptr;
// Split the shared memory into cos and sin parts for q or k
if (row_offset == 0) { // q
shared_mem_cos = shared_mem_cos_sin_qk;
shared_mem_sin = shared_mem_cos_sin_qk + d2;
} else { // k
shared_mem_cos = shared_mem_cos_sin_qk + 2 * d2;
shared_mem_sin = shared_mem_cos_sin_qk + 3 * d2;
}
if (freqs != nullptr) {
int tid = threadIdx.x * blockDim.y + threadIdx.y;
for (int i = tid; i < d2; i += blockDim.x * blockDim.y) {
sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]);
}
}
__syncthreads();
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
#pragma unroll
for (int i = 0; i < out_row_length; i += d) {
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
int offset_dst = offset_block + h_id * in_row_length + (row_offset + i) + d_id;
int offset_src = offset_block_dst + h_id * out_row_length + i + d_id;
float v_src = grad_out[offset_src];
if (freqs != nullptr) {
float v_cos, v_sin;
v_cos = shared_mem_cos[d_id];
float v_src_rotate;
if (!interleaved) {
if (d_id + d2 / 2 < d2) {
v_src_rotate = static_cast<float>(grad_out[offset_src + (d2 / 2)]);
v_sin = shared_mem_sin[d_id + d2 / 2];
} else {
v_src_rotate = static_cast<float>(grad_out[offset_src + (d2 / 2 - d2)]);
v_sin = -shared_mem_sin[d_id + d2 / 2 - d2];
}
} else {
if (d_id % 2 == 0) {
v_src_rotate = static_cast<float>(grad_out[offset_src + 1]);
v_sin = shared_mem_sin[d_id + 1];
} else {
v_src_rotate = static_cast<float>(grad_out[offset_src - 1]);
v_sin = -shared_mem_sin[d_id - 1];
}
}
out[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
} else {
out[offset_dst] = grad_out[offset_src];
}
}
}
}
// copy the rest
if (d > d2) {
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
#pragma unroll
for (int i = 0; i < out_row_length; i += d) {
#pragma unroll
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
int offset_dst = offset_block + h_id * in_row_length + (row_offset + i) + d_id;
int offset_src = offset_block_dst + h_id * out_row_length + i + d_id;
out[offset_dst] = grad_out[offset_src];
}
}
}
}
}
template <typename scalar_t>
__global__ void fused_qkv_rope_forward_kernel(
const scalar_t *qkv_input, const float *q_freqs, const float *k_freqs,
const int *start_positions, scalar_t *q_out, scalar_t *k_out, scalar_t *v_out,
const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2, const int q_split_arg,
const int k_split_arg, const int v_split_arg) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int cur_seqlens = s;
int total_d = q_split_arg + k_split_arg + v_split_arg;
int offset_block, offset_block_dst_q, offset_block_dst_k, offset_block_dst_v;
if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) {
offset_block = s_id * b * h * total_d + b_id * h * total_d;
offset_block_dst_q = s_id * b * h * q_split_arg + b_id * h * q_split_arg;
offset_block_dst_k = s_id * b * h * k_split_arg + b_id * h * k_split_arg;
offset_block_dst_v = s_id * b * h * v_split_arg + b_id * h * v_split_arg;
} else {
offset_block = b_id * s * h * total_d + s_id * h * total_d;
offset_block_dst_q = b_id * s * h * q_split_arg + s_id * h * q_split_arg;
offset_block_dst_k = b_id * s * h * k_split_arg + s_id * h * k_split_arg;
offset_block_dst_v = b_id * s * h * v_split_arg + s_id * h * v_split_arg;
}
int q_limit = q_split_arg;
int k_limit = q_limit + k_split_arg;
int s_id_for_freqs;
if (cp_size > 1) {
assert(cur_seqlens % 2 == 0);
if (s_id < cur_seqlens / 2) {
s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2;
} else {
s_id_for_freqs =
cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2;
}
} else {
int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id];
s_id_for_freqs = s_id + begin_offset;
}
fused_qkv_rope_block_forward(qkv_input, q_freqs, q_out, interleaved, s_id_for_freqs, offset_block,
offset_block_dst_q, h, d, d2, 0, total_d, q_split_arg);
fused_qkv_rope_block_forward(qkv_input, k_freqs, k_out, interleaved, s_id_for_freqs, offset_block,
offset_block_dst_k, h, d, d2, q_limit, total_d, k_split_arg);
fused_qkv_rope_block_forward(qkv_input, nullptr, v_out, interleaved, s_id_for_freqs, offset_block,
offset_block_dst_v, h, d, d2, k_limit, total_d, v_split_arg);
}
template <typename scalar_t>
__global__ void fused_qkv_rope_backward_kernel(
const scalar_t *grad_out_q, const scalar_t *grad_out_k, const scalar_t *grad_out_v,
const float *q_freqs, const float *k_freqs, scalar_t *qkv_grad,
const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2, const int q_split_arg,
const int k_split_arg, const int v_split_arg) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int cur_seqlens = s;
int offset_block, offset_block_dst_q, offset_block_dst_k, offset_block_dst_v;
int total_d = q_split_arg + k_split_arg + v_split_arg;
if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) {
offset_block = s_id * b * h * total_d + b_id * h * total_d;
offset_block_dst_q = s_id * b * h * q_split_arg + b_id * h * q_split_arg;
offset_block_dst_k = s_id * b * h * k_split_arg + b_id * h * k_split_arg;
offset_block_dst_v = s_id * b * h * v_split_arg + b_id * h * v_split_arg;
} else {
offset_block = b_id * s * h * total_d + s_id * h * total_d;
offset_block_dst_q = b_id * s * h * q_split_arg + s_id * h * q_split_arg;
offset_block_dst_k = b_id * s * h * k_split_arg + s_id * h * k_split_arg;
offset_block_dst_v = b_id * s * h * v_split_arg + s_id * h * v_split_arg;
}
int q_limit = q_split_arg;
int k_limit = q_limit + k_split_arg;
int s_id_for_freqs;
if (cp_size > 1) {
assert(cur_seqlens % 2 == 0);
if (s_id < cur_seqlens / 2) {
s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2;
} else {
s_id_for_freqs =
cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2;
}
} else {
s_id_for_freqs = s_id;
}
fused_qkv_rope_block_backward(grad_out_q, q_freqs, qkv_grad, interleaved, s_id_for_freqs,
offset_block, offset_block_dst_q, h, d, d2, 0, total_d,
q_split_arg);
fused_qkv_rope_block_backward(grad_out_k, k_freqs, qkv_grad, interleaved, s_id_for_freqs,
offset_block, offset_block_dst_k, h, d, d2, q_limit, total_d,
k_split_arg);
fused_qkv_rope_block_backward(grad_out_v, nullptr, qkv_grad, interleaved, s_id_for_freqs,
offset_block, offset_block_dst_v, h, d, d2, k_limit, total_d,
v_split_arg);
}
template <typename scalar_t>
void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, const float *freqs,
const int *start_positions, scalar_t *output,
......@@ -209,6 +470,7 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
const int shared_mem_size = 2 * d2 * sizeof(float); // cos, sin
int o_stride_s_or_t, o_stride_b;
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format");
......@@ -224,7 +486,7 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c
const int o_stride_h = d;
const int o_stride_d = 1;
fused_rope_forward_kernel<<<blocks, threads, 0, stream>>>(
fused_rope_forward_kernel<<<blocks, threads, shared_mem_size, stream>>>(
input, cu_seqlens, freqs, start_positions, output, interleaved, cp_size, cp_rank, s, h, d, d2,
stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h,
o_stride_d);
......@@ -242,6 +504,7 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
const int shared_mem_size = 2 * d2 * sizeof(float); // cos, sin
int o_stride_s_or_t, o_stride_b;
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format");
......@@ -257,13 +520,58 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se
const int o_stride_h = d;
const int o_stride_d = 1;
fused_rope_backward_kernel<<<blocks, threads, 0, stream>>>(
fused_rope_backward_kernel<<<blocks, threads, shared_mem_size, stream>>>(
output_grads, cu_seqlens, freqs, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2,
stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h,
o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename scalar_t>
void fused_qkv_rope_forward_launcher(const scalar_t *qkv_input, const float *q_freqs,
const float *k_freqs, const int *start_positions,
scalar_t *q_out, scalar_t *k_out, scalar_t *v_out,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2,
const int qkv_split_arg_list_0, const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream) {
const int THREADS_PER_WARP = 32;
int warps_per_block = (h <= 8) ? h : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
const int shared_mem_size = 4 * d2 * sizeof(float); // cos, sin * q ,k
fused_qkv_rope_forward_kernel<<<blocks, threads, shared_mem_size, stream>>>(
qkv_input, q_freqs, k_freqs, start_positions, q_out, k_out, v_out, qkv_format, interleaved,
cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list_0, qkv_split_arg_list_1,
qkv_split_arg_list_2);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename scalar_t>
void fused_qkv_rope_backward_launcher(const scalar_t *q_grad_out, const scalar_t *k_grad_out,
const scalar_t *v_grad_out, const float *q_freqs,
const float *k_freqs, scalar_t *qkv_grad_input,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s,
const int b, const int h, const int d, const int d2,
const int qkv_split_arg_list_0,
const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream) {
const int THREADS_PER_WARP = 32;
const int warps_per_block = (h <= 8) ? h : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
const int shared_mem_size = 4 * d2 * sizeof(float); // cos, sin * q ,k
fused_qkv_rope_backward_kernel<<<blocks, threads, shared_mem_size, stream>>>(
q_grad_out, k_grad_out, v_grad_out, q_freqs, k_freqs, qkv_grad_input, qkv_format, interleaved,
cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list_0, qkv_split_arg_list_1,
qkv_split_arg_list_2);
NVTE_CHECK_CUDA(cudaGetLastError());
}
void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs,
const Tensor &start_positions, Tensor *output,
const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size,
......@@ -297,6 +605,46 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, c
stride_b, stride_h, stride_d, stream););
}
void fused_qkv_rope_forward(const Tensor &qkv_input, const Tensor &q_freqs, const Tensor &k_freqs,
const Tensor &start_positions, Tensor *q_out, Tensor *k_out,
Tensor *v_out, const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int qkv_split_arg_list_0,
const int qkv_split_arg_list_1, const int qkv_split_arg_list_2,
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
qkv_input.data.dtype, scalar_t,
fused_qkv_rope_forward_launcher(reinterpret_cast<const scalar_t *>(qkv_input.data.dptr),
reinterpret_cast<const float *>(q_freqs.data.dptr),
reinterpret_cast<const float *>(k_freqs.data.dptr),
reinterpret_cast<const int *>(start_positions.data.dptr),
reinterpret_cast<scalar_t *>(q_out->data.dptr),
reinterpret_cast<scalar_t *>(k_out->data.dptr),
reinterpret_cast<scalar_t *>(v_out->data.dptr), qkv_format,
interleaved, cp_size, cp_rank, s, b, h, d, d2,
qkv_split_arg_list_0, qkv_split_arg_list_1,
qkv_split_arg_list_2, stream););
}
void fused_qkv_rope_backward(const Tensor &q_grad_out, const Tensor &k_grad_out,
const Tensor &v_grad_out, const Tensor &q_freqs, const Tensor &k_freqs,
Tensor *qkv_grad_input, const NVTE_QKV_Format qkv_format,
const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2,
const int qkv_split_arg_list_0, const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
q_grad_out.data.dtype, scalar_t,
fused_qkv_rope_backward_launcher(reinterpret_cast<const scalar_t *>(q_grad_out.data.dptr),
reinterpret_cast<const scalar_t *>(k_grad_out.data.dptr),
reinterpret_cast<const scalar_t *>(v_grad_out.data.dptr),
reinterpret_cast<const float *>(q_freqs.data.dptr),
reinterpret_cast<const float *>(k_freqs.data.dptr),
reinterpret_cast<scalar_t *>(qkv_grad_input->data.dptr),
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2,
qkv_split_arg_list_0, qkv_split_arg_list_1,
qkv_split_arg_list_2, stream););
}
} // end namespace transformer_engine
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens,
......@@ -328,3 +676,38 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
stride_b, stride_h, stride_d, stream);
}
void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs,
const NVTETensor k_freqs, const NVTETensor start_positions,
NVTETensor q_out, NVTETensor k_out, NVTETensor v_out,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2,
const int qkv_split_arg_list_0, const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_qkv_rope_forward);
using namespace transformer_engine;
fused_qkv_rope_forward(*convertNVTETensorCheck(qkv_input), *convertNVTETensorCheck(q_freqs),
*convertNVTETensorCheck(k_freqs), *convertNVTETensorCheck(start_positions),
convertNVTETensorCheck(q_out), convertNVTETensorCheck(k_out),
convertNVTETensorCheck(v_out), qkv_format, interleaved, cp_size, cp_rank,
s, b, h, d, d2, qkv_split_arg_list_0, qkv_split_arg_list_1,
qkv_split_arg_list_2, stream);
}
void nvte_fused_qkv_rope_backward(const NVTETensor q_grad_out, const NVTETensor k_grad_out,
const NVTETensor v_grad_out, const NVTETensor q_freqs,
const NVTETensor k_freqs, NVTETensor qkv_grad_input,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2,
const int qkv_split_arg_list_0, const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_qkv_rope_backward);
using namespace transformer_engine;
fused_qkv_rope_backward(*convertNVTETensorCheck(q_grad_out), *convertNVTETensorCheck(k_grad_out),
*convertNVTETensorCheck(v_grad_out), *convertNVTETensorCheck(q_freqs),
*convertNVTETensorCheck(k_freqs), convertNVTETensorCheck(qkv_grad_input),
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2,
qkv_split_arg_list_0, qkv_split_arg_list_1, qkv_split_arg_list_2, stream);
}
......@@ -75,6 +75,69 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream);
/*! \brief Apply rotary positional embedding to the combined QKV input tensor.
*
* \param[in] qkv_input Combined QKV input tensor for fused rope.
* \param[in] q_freqs The freqs tensor for Q.
* \param[in] k_freqs The freqs tensor for K.
* \param[in] start_positions The beginning offsets for applying RoPE embeddings.
* \param[out] q_out Output tensor for Q.
* \param[out] k_out Output tensor for K.
* \param[out] v_out Output tensor for V.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] s Length of the s dimension of input.
* \param[in] b Length of the b dimension of input.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] qkv_split_arg_list_0 The hidden size for Q.
* \param[in] qkv_split_arg_list_1 The hidden size for K.
* \param[in] qkv_split_arg_list_2 The hidden size for V.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs,
const NVTETensor k_freqs, const NVTETensor start_positions,
NVTETensor q_out, NVTETensor k_out, NVTETensor v_out,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2,
const int qkv_split_arg_list_0, const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream);
/*! \brief Compute the backward of the fused qkv rope.
*
* \param[in] q_grad_out Incoming gradient tensor for Q.
* \param[in] k_grad_out Incoming gradient tensor for K.
* \param[in] v_grad_out Incoming gradient tensor for V.
* \param[in] q_freqs The freqs tensor for Q.
* \param[in] k_freqs The freqs tensor for K.
* \param[out] qkv_grad_input Input gradient tensor to calculate.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] s Length of the s dimension of input.
* \param[in] b Length of the b dimension of input.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] qkv_split_arg_list_0 The hidden size for Q.
* \param[in] qkv_split_arg_list_1 The hidden size for K.
* \param[in] qkv_split_arg_list_2 The hidden size for V.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_qkv_rope_backward(const NVTETensor q_grad_out, const NVTETensor k_grad_out,
const NVTETensor v_grad_out, const NVTETensor q_freqs,
const NVTETensor k_freqs, NVTETensor qkv_grad_input,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2,
const int qkv_split_arg_list_0, const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
......
......@@ -5,14 +5,14 @@
"""
Rotary Position Embedding implementation of different types along with helper functions
"""
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, List
import torch
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat
__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb"]
__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"]
class RotaryPositionEmbedding(torch.nn.Module):
......@@ -170,6 +170,86 @@ class FusedRoPEFunc(torch.autograd.Function):
return grad_input, None, None, None, None, None, None, None
class FusedQKVRoPEFunc(torch.autograd.Function):
"""
Function for FusedQKVRoPE
This implementation accepts combined QKV tensor in `bshd` or `sbhd` format. Q and K RoPE tensors are the additional required inputs.
The RoPE tensors should be of shape (s, 1, 1, d). It produces 3 outputs: Q, K after RoPE, V is the same as input.
"""
@staticmethod
def forward(
ctx,
qkv: torch.Tensor,
q_freqs: torch.Tensor,
k_freqs: torch.Tensor,
qkv_split_arg_list: List[int],
start_positions: Union[torch.Tensor, None] = None,
tensor_format: str = "sbhd",
interleaved: bool = False,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""Fused RoPE forward."""
if q_freqs.dtype != torch.float32:
q_freqs = q_freqs.float()
if k_freqs.dtype != torch.float32:
k_freqs = k_freqs.float()
assert tensor_format in (
"sbhd",
"bshd",
), f"Unsupported tensor_format: {tensor_format}."
assert qkv.is_contiguous(), "QKV Tensor should be contiguous."
assert q_freqs.is_contiguous(), "q_freqs Tensor should be contiguous."
assert k_freqs.is_contiguous(), "k_freqs Tensor should be contiguous."
output = tex.fused_qkv_rope_forward(
qkv,
q_freqs,
k_freqs,
start_positions,
qkv_split_arg_list,
QKVFormat[tensor_format],
interleaved,
cp_size,
cp_rank,
)
ctx.save_for_backward(q_freqs, k_freqs)
ctx.tensor_format = tensor_format
ctx.qkv_split_arg_list = qkv_split_arg_list
ctx.cp_size = cp_size
ctx.cp_rank = cp_rank
ctx.interleaved = interleaved
return output
@staticmethod
def backward(
ctx, grad_output_q: torch.Tensor, grad_output_k: torch.Tensor, grad_output_v: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused RoPE backward."""
q_freqs, k_freqs = ctx.saved_tensors
grad_output_q = grad_output_q.contiguous()
grad_output_k = grad_output_k.contiguous()
grad_output_v = grad_output_v.contiguous()
grad_input = tex.fused_qkv_rope_backward(
grad_output_q,
grad_output_k,
grad_output_v,
q_freqs,
k_freqs,
ctx.qkv_split_arg_list,
QKVFormat[ctx.tensor_format],
ctx.interleaved,
ctx.cp_size,
ctx.cp_rank,
)
return grad_input, None, None, None, None, None, None, None, None
def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor:
"""Change sign so the last dimension becomes [-odd, +even]
......@@ -393,3 +473,82 @@ def apply_rotary_pos_emb(
tensor_format,
interleaved=interleaved,
)
def apply_fused_qkv_rotary_pos_emb(
qkv: torch.Tensor,
q_freqs: torch.Tensor,
k_freqs: torch.Tensor,
qkv_split_arg_list: List[int],
tensor_format: str = "sbhd",
start_positions: Union[torch.Tensor, None] = None,
interleaved: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None, # pylint: disable=unused-argument
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""
Apply rotary positional embedding tensor to the input qkv tensor.
Support matrix:
Fused:
Training:
qkv_formats: "bshd", "sbhd"
context parallel: yes
start_positions: no
interleaving: yes
Inference:
qkv_formats: "bshd", "sbhd"
context parallelism: no
start_positions: yes
interleaving: yes
Parameters
----------
qkv: torch.Tensor
Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which
rotary positional embedding will be applied. This tensor has q, k, v concatenated
along the last dimension.
q_freqs: torch.Tensor
Rotary positional embedding Q tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
k_freqs: torch.Tensor
Rotary positional embedding K tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
qkv_split_arg_list: List[int]
List of integers that specify the split of the qkv tensor. The list should have 3 elements,
the first element is the number of elements in the q tensor, the second element is the number
of elements in the k tensor, and the third element is the number of elements in the v tensor.
The sum of the elements in the list should be equal to the last dimension of the qkv tensor.
start_positions: torch.Tensor, default = None.
Tokens in a sequence `i` should be applied with position encoding offset by
`start_positions[i]`. If `start_positions=None`, there's no offset.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
is `bshd` if `qkv` is of shape `[bs, seq, ...]`, or `sbhd` if `qkv` is
of shape `[seq, bs, ...]`.
interleaved: bool, default = False
Whether to use interleaved rotary position embedding.
cp_size: int, default = 1.
Context parallel world size.
cp_rank: int, default = 0.
Context parallel rank.
"""
# `start_positions` is only supported for `cp_size=1` and inference.
assert not (
cp_size > 1 and start_positions is not None
), """start_positions != None with CP SIZE > 1 is not supported!"""
assert tensor_format != "thd", "'thd' tensor_format not supported currently."
return FusedQKVRoPEFunc.apply(
qkv,
q_freqs,
k_freqs,
qkv_split_arg_list,
start_positions,
tensor_format,
interleaved,
cp_size,
cp_rank,
)
......@@ -338,6 +338,18 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank);
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_qkv_rope_forward(
const at::Tensor &qkv_input, const at::Tensor &q_freqs, const at::Tensor &k_freqs,
const std::optional<at::Tensor> start_positions, const std::vector<int> &qkv_split_arg_list,
const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank);
at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tensor &k_grad_out,
const at::Tensor &v_grad_out, const at::Tensor &q_freqs,
const at::Tensor &k_freqs,
const std::vector<int> &qkv_split_arg_list,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank);
/***************************************************************************************************
* Miscellaneous
**************************************************************************************************/
......
......@@ -102,6 +102,65 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
return output;
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_qkv_rope_forward(
const at::Tensor &qkv_input, const at::Tensor &q_freqs, const at::Tensor &k_freqs,
const std::optional<at::Tensor> start_positions, const std::vector<int> &qkv_split_arg_list,
const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size,
const int cp_rank) {
TORCH_CHECK(q_freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(q_freqs.size(1) == 1 && q_freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(q_freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
TORCH_CHECK(k_freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(k_freqs.size(1) == 1 && k_freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(k_freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// output
auto act_options = at::TensorOptions().dtype(qkv_input.scalar_type()).device(qkv_input.device());
auto q_out_size = qkv_input.sizes().vec();
q_out_size[2] = q_out_size[2] * qkv_split_arg_list[0] / qkv_split_arg_list[1];
q_out_size[3] = qkv_split_arg_list[1];
auto q_out = at::empty(q_out_size, act_options);
auto k_out_size = qkv_input.sizes().vec();
k_out_size[3] = qkv_split_arg_list[1];
auto k_out = at::empty(k_out_size, act_options);
auto v_out_size = qkv_input.sizes().vec();
v_out_size[3] = qkv_split_arg_list[2];
auto v_out = at::empty(v_out_size, act_options);
auto qkv_cu = makeTransformerEngineTensor(qkv_input);
auto q_freqs_cu = makeTransformerEngineTensor(q_freqs);
auto k_freqs_cu = makeTransformerEngineTensor(k_freqs);
auto q_out_cu = makeTransformerEngineTensor(q_out);
auto k_out_cu = makeTransformerEngineTensor(k_out);
auto v_out_cu = makeTransformerEngineTensor(v_out);
auto start_positions_cu = TensorWrapper(); // empty cu_seqlens tensor
if (start_positions) {
start_positions_cu = makeTransformerEngineTensor(start_positions.value());
}
TORCH_CHECK(qkv_input.dim() == 4, "expected 4D input tensor");
TORCH_CHECK(qkv_input.is_contiguous(), "input tensor must be contiguous");
const bool is_sbhd = qkv_format == NVTE_QKV_Format::NVTE_SBHD;
const int s = is_sbhd ? qkv_input.size(0) : qkv_input.size(1);
const int b = is_sbhd ? qkv_input.size(1) : qkv_input.size(0);
const int h = qkv_input.size(2);
const int d = qkv_split_arg_list[2];
const int d2 = q_freqs.size(3);
nvte_fused_qkv_rope_forward(qkv_cu.data(), q_freqs_cu.data(), k_freqs_cu.data(),
start_positions_cu.data(), q_out_cu.data(), k_out_cu.data(),
v_out_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b, h,
d, d2, qkv_split_arg_list[0], qkv_split_arg_list[1],
qkv_split_arg_list[2], at::cuda::getCurrentCUDAStream());
return std::make_tuple(q_out, k_out, v_out);
}
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size,
......@@ -193,4 +252,42 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
return input_grads;
}
at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tensor &k_grad_out,
const at::Tensor &v_grad_out, const at::Tensor &q_freqs,
const at::Tensor &k_freqs,
const std::vector<int> &qkv_split_arg_list,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank) {
auto act_options =
at::TensorOptions().dtype(q_grad_out.scalar_type()).device(q_grad_out.device());
auto qkv_grad_size = q_grad_out.sizes().vec();
auto total_hd =
(q_grad_out.size(2) + k_grad_out.size(2) + v_grad_out.size(2)) * q_grad_out.size(3);
auto total_d = qkv_split_arg_list[0] + qkv_split_arg_list[1] + qkv_split_arg_list[2];
qkv_grad_size[2] = total_hd / total_d;
qkv_grad_size[3] = total_d;
auto qkv_grad_input = at::empty(qkv_grad_size, act_options);
const bool is_sbhd = qkv_format == NVTE_QKV_Format::NVTE_SBHD;
const int s = is_sbhd ? q_grad_out.size(0) : q_grad_out.size(1);
const int b = is_sbhd ? q_grad_out.size(1) : q_grad_out.size(0);
const int h = qkv_grad_input.size(2);
const int d = qkv_split_arg_list[2];
const int d2 = q_freqs.size(3);
auto q_grad_out_cu = makeTransformerEngineTensor(q_grad_out);
auto k_grad_out_cu = makeTransformerEngineTensor(k_grad_out);
auto v_grad_out_cu = makeTransformerEngineTensor(v_grad_out);
auto q_freqs_cu = makeTransformerEngineTensor(q_freqs);
auto k_freqs_cu = makeTransformerEngineTensor(k_freqs);
auto qkv_grad_cu = makeTransformerEngineTensor(qkv_grad_input);
nvte_fused_qkv_rope_backward(q_grad_out_cu.data(), k_grad_out_cu.data(), v_grad_out_cu.data(),
q_freqs_cu.data(), k_freqs_cu.data(), qkv_grad_cu.data(), qkv_format,
interleaved, cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list[0],
qkv_split_arg_list[1], qkv_split_arg_list[2],
at::cuda::getCurrentCUDAStream());
return qkv_grad_input;
}
} // namespace transformer_engine::pytorch
......@@ -278,6 +278,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Fused Apply RoPE FWD", py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_backward", &transformer_engine::pytorch::fused_rope_backward,
"Fused Apply RoPE BWD", py::call_guard<py::gil_scoped_release>());
m.def("fused_qkv_rope_forward", &transformer_engine::pytorch::fused_qkv_rope_forward,
"Fused Apply QKV RoPE FWD", py::call_guard<py::gil_scoped_release>());
m.def("fused_qkv_rope_backward", &transformer_engine::pytorch::fused_qkv_rope_backward,
"Fused Apply QKV RoPE BWD", py::call_guard<py::gil_scoped_release>());
// fused router
m.def("fused_topk_with_score_function_fwd",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment