Unverified Commit 84fa28d2 authored by vasunvidia's avatar vasunvidia Committed by GitHub
Browse files

Fused RoPE with combined QKV input. (#2122)



* Fused RoPE with combined QKV input.

Initial commit for Dropout with 8-bit RNG

Fix documentation

Initial commit for Fused QKV RoPE

WIP

Initial tests passing

Enable rotary percent and margin

Enable CP2, start_positions, interleaved

Cleanup test

Revert "Fix documentation"

This reverts commit 53df10044e7769982bd4af2ae2628e6b7717e715.

Revert "Initial commit for Dropout with 8-bit RNG"

This reverts commit 301505e24031cbcd679069e1c2cd4d00eedf2dca.

Cleanup.

Minor cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Optimize kernels
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Misc. Cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Optimize kernel performance
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Move fused_qkv_rope test to test_fused_rope.py
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* apply shared memory optimization to separate fused rope kernels
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix lint
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 603dbf72
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
from typing import Callable, Tuple, Union from typing import Callable, Tuple, Union, List
import math import math
import torch import torch
import pytest import pytest
from transformer_engine.pytorch.attention.rope import ( from transformer_engine.pytorch.attention.rope import (
RotaryPositionEmbedding, RotaryPositionEmbedding,
apply_rotary_pos_emb, apply_rotary_pos_emb,
apply_fused_qkv_rotary_pos_emb,
) )
# Gradient is a broadcasted scalar # Gradient is a broadcasted scalar
def _overlapping_grad(output: torch.Tensor) -> torch.Tensor: def _overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor:
if isinstance(output, List):
return sum(t.sum() * 2 for t in output)
else:
return output.sum() * 2 return output.sum() * 2
# Gradient is a full tensor # Gradient is a full tensor
def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor: def _non_overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor:
if isinstance(output, List):
return sum(torch.sum(t * torch.ones_like(t)) for t in output)
else:
t = torch.ones_like(output) t = torch.ones_like(output)
return torch.sum(output * t) return torch.sum(output * t)
...@@ -238,3 +245,131 @@ def test_fused_rope_thd( ...@@ -238,3 +245,131 @@ def test_fused_rope_thd(
torch.testing.assert_close(grad_fused, grad_unfused) torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous() assert output_fused.is_contiguous()
@pytest.mark.parametrize("start_positions", [True, False])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16])
@pytest.mark.parametrize("seq_length", [2, 8, 2048, 4096])
@pytest.mark.parametrize("hidden_size", [64, 128, 256])
@pytest.mark.parametrize("rotary_percent", [0.5, 1.0])
@pytest.mark.parametrize("margin", [0, 10])
@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"])
@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad])
@pytest.mark.parametrize("cp_size", [1, 2])
@pytest.mark.parametrize("interleaved", [True, False])
def test_fused_qkv_rope(
dtype: torch.dtype,
seq_length: int,
hidden_size: int,
rotary_percent: float,
margin: int,
tensor_format: str,
loss_func: Callable,
cp_size: int,
interleaved: bool,
start_positions: bool,
) -> None:
if margin == 0 and start_positions == True:
# This makes sure that the `start_positions` offsets being applied
# are with the maximum length of the rope embeddings.
pytest.skip("Skipping test with margin=0 and start_positions=True")
if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")
if seq_length - margin < 0:
pytest.skip("Skipping test with seq_length - margin < 0")
device = torch.device("cuda:0")
batch_size, head_num = 2, 64
t = torch.rand(
(seq_length - margin, batch_size, head_num, hidden_size * 6),
dtype=dtype,
device=device,
)
# Get arbitrary offsets to be used with RoPE for all the sequences
start_positions = (
torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device)
if start_positions
else None
)
if tensor_format == "bshd":
t = t.transpose(0, 1).contiguous()
t.requires_grad = True
rotary_pos_emb_q = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb_q = rotary_pos_emb_q(seq_length * cp_size)
rotary_pos_emb_k = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb_k = rotary_pos_emb_k(seq_length * cp_size)
for cp_rank in range(cp_size):
# unfused
# The fused kernel computes in float32 internally, so we force the unfused func to use float32
# for more accurate comparison
t_clone = t.clone()
(query, key, value) = torch.split(
t_clone, [hidden_size * 4, hidden_size, hidden_size], dim=3
)
query = query.reshape(query.shape[0], query.shape[1], head_num * 4, hidden_size)
query_unfused = apply_rotary_pos_emb(
query,
emb_q,
tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved,
fused=True,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
key_unfused = apply_rotary_pos_emb(
key,
emb_k,
tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved,
fused=True,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
value_unfused = value
loss_unfused = loss_func([query_unfused, key_unfused, value_unfused])
if not isinstance(start_positions, torch.Tensor):
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None
# fused
query_fused, key_fused, value_fused = apply_fused_qkv_rotary_pos_emb(
t,
emb_q,
emb_k,
tensor_format=tensor_format,
start_positions=start_positions,
interleaved=interleaved,
cp_size=cp_size,
cp_rank=cp_rank,
qkv_split_arg_list=[hidden_size * 4, hidden_size, hidden_size],
)
loss_fused = loss_func([query_fused, key_fused, value_fused])
if not isinstance(start_positions, torch.Tensor):
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None
torch.testing.assert_close(query_fused, query_unfused)
torch.testing.assert_close(key_fused, key_unfused)
torch.testing.assert_close(value_fused, value_unfused)
if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)
...@@ -21,12 +21,21 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs ...@@ -21,12 +21,21 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs
const int h, const int d, const int d2, const int stride_h, const int h, const int d, const int d2, const int stride_h,
const int stride_d, const int o_stride_h, const int stride_d, const int o_stride_h,
const int o_stride_d) { const int o_stride_d) {
#pragma unroll extern __shared__ float shared_mem_cos_sin[];
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { float *shared_mem_cos = shared_mem_cos_sin;
float v_cos, v_sin; float *shared_mem_sin = shared_mem_cos_sin + d2;
sincosf(freqs[s_id * d2 + d_id], &v_sin, &v_cos); int tid = threadIdx.x * blockDim.y + threadIdx.y;
for (int i = tid; i < d2; i += blockDim.x * blockDim.y) {
sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]);
}
__syncthreads();
#pragma unroll #pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
float v_cos = shared_mem_cos[d_id];
float v_sin = shared_mem_sin[d_id];
int offset_src = offset_block + h_id * stride_h + d_id * stride_d; int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
float v_src = src[offset_src]; float v_src = src[offset_src];
...@@ -48,13 +57,13 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs ...@@ -48,13 +57,13 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs
// copy the rest // copy the rest
if (d > d2) { if (d > d2) {
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_head = offset_block + h_id * stride_h;
int offset_head_dst = offset_block_dst + h_id * o_stride_h;
#pragma unroll #pragma unroll
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d]; #pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
dst[offset_dst] = src[offset_src];
} }
} }
} }
...@@ -67,47 +76,54 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq ...@@ -67,47 +76,54 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq
const int h, const int d, const int d2, const int h, const int d, const int d2,
const int stride_h, const int stride_d, const int stride_h, const int stride_d,
const int o_stride_h, const int o_stride_d) { const int o_stride_h, const int o_stride_d) {
#pragma unroll extern __shared__ float shared_mem_cos_sin[];
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { float *shared_mem_cos = shared_mem_cos_sin;
float v_cos = cosf(freqs[s_id * d2 + d_id]); float *shared_mem_sin = shared_mem_cos_sin + d2;
float v_sin; int tid = threadIdx.x * blockDim.y + threadIdx.y;
if (!interleaved) { for (int i = tid; i < d2; i += blockDim.x * blockDim.y) {
v_sin = (d_id + d2 / 2 < d2) ? sinf(freqs[s_id * d2 + d_id + d2 / 2]) sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]);
: -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]);
} else {
v_sin =
(d_id % 2 == 0) ? sinf(freqs[s_id * d2 + d_id + 1]) : -sinf(freqs[s_id * d2 + d_id - 1]);
} }
__syncthreads();
#pragma unroll #pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d; int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
float v_src = src[offset_src]; float v_src = src[offset_src];
float v_src_rotate; float v_cos = shared_mem_cos[d_id];
float v_src_rotate, v_sin;
if (!interleaved) { if (!interleaved) {
v_src_rotate = (d_id + d2 / 2 < d2) if (d_id + d2 / 2 < d2) {
? static_cast<float>(src[offset_src + (d2 / 2) * stride_d]) v_src_rotate = static_cast<float>(src[offset_src + (d2 / 2) * stride_d]);
: static_cast<float>(src[offset_src + (d2 / 2 - d2) * stride_d]); v_sin = shared_mem_sin[d_id + d2 / 2];
} else { } else {
v_src_rotate = (d_id % 2 == 0) v_src_rotate = static_cast<float>(src[offset_src + (d2 / 2 - d2) * stride_d]);
// d_id + 1 v_sin = -shared_mem_sin[d_id + d2 / 2 - d2];
? static_cast<float>(src[offset_src + stride_d]) }
// d_id - 1 } else {
: static_cast<float>(src[offset_src - stride_d]); if (d_id % 2 == 0) {
v_src_rotate = static_cast<float>(src[offset_src + stride_d]);
v_sin = shared_mem_sin[d_id + 1];
} else {
v_src_rotate = static_cast<float>(src[offset_src - stride_d]);
v_sin = -shared_mem_sin[d_id - 1];
}
} }
dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
} }
} }
// handle the tail // copy the rest
if (d > d2) { if (d > d2) {
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_head = offset_block + h_id * stride_h;
int offset_head_dst = offset_block_dst + h_id * o_stride_h;
#pragma unroll #pragma unroll
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d]; #pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
int offset_src = offset_block + h_id * stride_h + d_id * stride_d;
int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d;
dst[offset_dst] = src[offset_src];
} }
} }
} }
...@@ -198,6 +214,251 @@ __global__ void fused_rope_backward_kernel( ...@@ -198,6 +214,251 @@ __global__ void fused_rope_backward_kernel(
offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d);
} }
template <typename scalar_t>
__device__ void fused_qkv_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *out,
const bool interleaved, const int s_id,
const int offset_block, const int offset_block_dst,
const int h, const int d, const int d2,
const int row_offset, const int in_row_length,
const int out_row_length) {
extern __shared__ float shared_mem_cos_sin_qk[];
// Split the shared memory into cos and sin parts for q or k
float *shared_mem_cos = nullptr;
float *shared_mem_sin = nullptr;
if (row_offset == 0) { // q
shared_mem_cos = shared_mem_cos_sin_qk;
shared_mem_sin = shared_mem_cos_sin_qk + d2;
} else { // k
shared_mem_cos = shared_mem_cos_sin_qk + 2 * d2;
shared_mem_sin = shared_mem_cos_sin_qk + 3 * d2;
}
if (freqs != nullptr) {
int tid = threadIdx.x * blockDim.y + threadIdx.y;
for (int i = tid; i < d2; i += blockDim.x * blockDim.y) {
sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]);
}
}
__syncthreads();
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
#pragma unroll
for (int i = 0; i < out_row_length; i += d) {
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
int offset_src = offset_block + h_id * in_row_length + (row_offset + i) + d_id;
int offset_dst = offset_block_dst + h_id * out_row_length + i + d_id;
if (freqs != nullptr) {
float v_cos, v_sin;
v_cos = shared_mem_cos[d_id];
v_sin = shared_mem_sin[d_id];
float v_src = src[offset_src];
float v_src_rotate;
if (!interleaved) {
v_src_rotate = (d_id + d2 / 2 < d2)
? -static_cast<float>(src[offset_src + (d2 / 2)])
: static_cast<float>(src[offset_src + (d2 / 2 - d2)]);
} else {
v_src_rotate = (d_id % 2 == 0) ? -static_cast<float>(src[offset_src + 1])
: static_cast<float>(src[offset_src - 1]);
}
out[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
} else {
out[offset_dst] = src[offset_src];
}
}
}
}
// copy the rest
if (d > d2) {
#pragma unroll
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
#pragma unroll
for (int i = 0; i < out_row_length; i += d) {
int offset_src = offset_block + h_id * in_row_length + (row_offset + i) + d_id;
int offset_dst = offset_block_dst + h_id * out_row_length + i + d_id;
out[offset_dst] = src[offset_src];
}
}
}
}
}
template <typename scalar_t>
__device__ void fused_qkv_rope_block_backward(const scalar_t *grad_out, const float *freqs,
scalar_t *out, const bool interleaved, const int s_id,
const int offset_block, const int offset_block_dst,
const int h, const int d, const int d2,
const int row_offset, const int in_row_length,
const int out_row_length) {
extern __shared__ float shared_mem_cos_sin_qk[];
float *shared_mem_cos = nullptr;
float *shared_mem_sin = nullptr;
// Split the shared memory into cos and sin parts for q or k
if (row_offset == 0) { // q
shared_mem_cos = shared_mem_cos_sin_qk;
shared_mem_sin = shared_mem_cos_sin_qk + d2;
} else { // k
shared_mem_cos = shared_mem_cos_sin_qk + 2 * d2;
shared_mem_sin = shared_mem_cos_sin_qk + 3 * d2;
}
if (freqs != nullptr) {
int tid = threadIdx.x * blockDim.y + threadIdx.y;
for (int i = tid; i < d2; i += blockDim.x * blockDim.y) {
sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]);
}
}
__syncthreads();
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
#pragma unroll
for (int i = 0; i < out_row_length; i += d) {
#pragma unroll
for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) {
int offset_dst = offset_block + h_id * in_row_length + (row_offset + i) + d_id;
int offset_src = offset_block_dst + h_id * out_row_length + i + d_id;
float v_src = grad_out[offset_src];
if (freqs != nullptr) {
float v_cos, v_sin;
v_cos = shared_mem_cos[d_id];
float v_src_rotate;
if (!interleaved) {
if (d_id + d2 / 2 < d2) {
v_src_rotate = static_cast<float>(grad_out[offset_src + (d2 / 2)]);
v_sin = shared_mem_sin[d_id + d2 / 2];
} else {
v_src_rotate = static_cast<float>(grad_out[offset_src + (d2 / 2 - d2)]);
v_sin = -shared_mem_sin[d_id + d2 / 2 - d2];
}
} else {
if (d_id % 2 == 0) {
v_src_rotate = static_cast<float>(grad_out[offset_src + 1]);
v_sin = shared_mem_sin[d_id + 1];
} else {
v_src_rotate = static_cast<float>(grad_out[offset_src - 1]);
v_sin = -shared_mem_sin[d_id - 1];
}
}
out[offset_dst] = v_src * v_cos + v_src_rotate * v_sin;
} else {
out[offset_dst] = grad_out[offset_src];
}
}
}
}
// copy the rest
if (d > d2) {
#pragma unroll
for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) {
#pragma unroll
for (int i = 0; i < out_row_length; i += d) {
#pragma unroll
for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) {
int offset_dst = offset_block + h_id * in_row_length + (row_offset + i) + d_id;
int offset_src = offset_block_dst + h_id * out_row_length + i + d_id;
out[offset_dst] = grad_out[offset_src];
}
}
}
}
}
template <typename scalar_t>
__global__ void fused_qkv_rope_forward_kernel(
const scalar_t *qkv_input, const float *q_freqs, const float *k_freqs,
const int *start_positions, scalar_t *q_out, scalar_t *k_out, scalar_t *v_out,
const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2, const int q_split_arg,
const int k_split_arg, const int v_split_arg) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int cur_seqlens = s;
int total_d = q_split_arg + k_split_arg + v_split_arg;
int offset_block, offset_block_dst_q, offset_block_dst_k, offset_block_dst_v;
if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) {
offset_block = s_id * b * h * total_d + b_id * h * total_d;
offset_block_dst_q = s_id * b * h * q_split_arg + b_id * h * q_split_arg;
offset_block_dst_k = s_id * b * h * k_split_arg + b_id * h * k_split_arg;
offset_block_dst_v = s_id * b * h * v_split_arg + b_id * h * v_split_arg;
} else {
offset_block = b_id * s * h * total_d + s_id * h * total_d;
offset_block_dst_q = b_id * s * h * q_split_arg + s_id * h * q_split_arg;
offset_block_dst_k = b_id * s * h * k_split_arg + s_id * h * k_split_arg;
offset_block_dst_v = b_id * s * h * v_split_arg + s_id * h * v_split_arg;
}
int q_limit = q_split_arg;
int k_limit = q_limit + k_split_arg;
int s_id_for_freqs;
if (cp_size > 1) {
assert(cur_seqlens % 2 == 0);
if (s_id < cur_seqlens / 2) {
s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2;
} else {
s_id_for_freqs =
cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2;
}
} else {
int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id];
s_id_for_freqs = s_id + begin_offset;
}
fused_qkv_rope_block_forward(qkv_input, q_freqs, q_out, interleaved, s_id_for_freqs, offset_block,
offset_block_dst_q, h, d, d2, 0, total_d, q_split_arg);
fused_qkv_rope_block_forward(qkv_input, k_freqs, k_out, interleaved, s_id_for_freqs, offset_block,
offset_block_dst_k, h, d, d2, q_limit, total_d, k_split_arg);
fused_qkv_rope_block_forward(qkv_input, nullptr, v_out, interleaved, s_id_for_freqs, offset_block,
offset_block_dst_v, h, d, d2, k_limit, total_d, v_split_arg);
}
template <typename scalar_t>
__global__ void fused_qkv_rope_backward_kernel(
const scalar_t *grad_out_q, const scalar_t *grad_out_k, const scalar_t *grad_out_v,
const float *q_freqs, const float *k_freqs, scalar_t *qkv_grad,
const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2, const int q_split_arg,
const int k_split_arg, const int v_split_arg) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int cur_seqlens = s;
int offset_block, offset_block_dst_q, offset_block_dst_k, offset_block_dst_v;
int total_d = q_split_arg + k_split_arg + v_split_arg;
if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) {
offset_block = s_id * b * h * total_d + b_id * h * total_d;
offset_block_dst_q = s_id * b * h * q_split_arg + b_id * h * q_split_arg;
offset_block_dst_k = s_id * b * h * k_split_arg + b_id * h * k_split_arg;
offset_block_dst_v = s_id * b * h * v_split_arg + b_id * h * v_split_arg;
} else {
offset_block = b_id * s * h * total_d + s_id * h * total_d;
offset_block_dst_q = b_id * s * h * q_split_arg + s_id * h * q_split_arg;
offset_block_dst_k = b_id * s * h * k_split_arg + s_id * h * k_split_arg;
offset_block_dst_v = b_id * s * h * v_split_arg + s_id * h * v_split_arg;
}
int q_limit = q_split_arg;
int k_limit = q_limit + k_split_arg;
int s_id_for_freqs;
if (cp_size > 1) {
assert(cur_seqlens % 2 == 0);
if (s_id < cur_seqlens / 2) {
s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2;
} else {
s_id_for_freqs =
cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2;
}
} else {
s_id_for_freqs = s_id;
}
fused_qkv_rope_block_backward(grad_out_q, q_freqs, qkv_grad, interleaved, s_id_for_freqs,
offset_block, offset_block_dst_q, h, d, d2, 0, total_d,
q_split_arg);
fused_qkv_rope_block_backward(grad_out_k, k_freqs, qkv_grad, interleaved, s_id_for_freqs,
offset_block, offset_block_dst_k, h, d, d2, q_limit, total_d,
k_split_arg);
fused_qkv_rope_block_backward(grad_out_v, nullptr, qkv_grad, interleaved, s_id_for_freqs,
offset_block, offset_block_dst_v, h, d, d2, k_limit, total_d,
v_split_arg);
}
template <typename scalar_t> template <typename scalar_t>
void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, const float *freqs, void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, const float *freqs,
const int *start_positions, scalar_t *output, const int *start_positions, scalar_t *output,
...@@ -209,6 +470,7 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c ...@@ -209,6 +470,7 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c
int warps_per_block = h < 16 ? 4 : 8; int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b); dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block); dim3 threads(THREADS_PER_WARP, warps_per_block);
const int shared_mem_size = 2 * d2 * sizeof(float); // cos, sin
int o_stride_s_or_t, o_stride_b; int o_stride_s_or_t, o_stride_b;
if (qkv_format == NVTE_QKV_Format::NVTE_THD) { if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format"); NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format");
...@@ -224,7 +486,7 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c ...@@ -224,7 +486,7 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c
const int o_stride_h = d; const int o_stride_h = d;
const int o_stride_d = 1; const int o_stride_d = 1;
fused_rope_forward_kernel<<<blocks, threads, 0, stream>>>( fused_rope_forward_kernel<<<blocks, threads, shared_mem_size, stream>>>(
input, cu_seqlens, freqs, start_positions, output, interleaved, cp_size, cp_rank, s, h, d, d2, input, cu_seqlens, freqs, start_positions, output, interleaved, cp_size, cp_rank, s, h, d, d2,
stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h,
o_stride_d); o_stride_d);
...@@ -242,6 +504,7 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se ...@@ -242,6 +504,7 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se
int warps_per_block = h < 16 ? 4 : 8; int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b); dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block); dim3 threads(THREADS_PER_WARP, warps_per_block);
const int shared_mem_size = 2 * d2 * sizeof(float); // cos, sin
int o_stride_s_or_t, o_stride_b; int o_stride_s_or_t, o_stride_b;
if (qkv_format == NVTE_QKV_Format::NVTE_THD) { if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format"); NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format");
...@@ -257,13 +520,58 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se ...@@ -257,13 +520,58 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se
const int o_stride_h = d; const int o_stride_h = d;
const int o_stride_d = 1; const int o_stride_d = 1;
fused_rope_backward_kernel<<<blocks, threads, 0, stream>>>( fused_rope_backward_kernel<<<blocks, threads, shared_mem_size, stream>>>(
output_grads, cu_seqlens, freqs, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2, output_grads, cu_seqlens, freqs, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2,
stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h,
o_stride_d); o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
template <typename scalar_t>
void fused_qkv_rope_forward_launcher(const scalar_t *qkv_input, const float *q_freqs,
const float *k_freqs, const int *start_positions,
scalar_t *q_out, scalar_t *k_out, scalar_t *v_out,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2,
const int qkv_split_arg_list_0, const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream) {
const int THREADS_PER_WARP = 32;
int warps_per_block = (h <= 8) ? h : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
const int shared_mem_size = 4 * d2 * sizeof(float); // cos, sin * q ,k
fused_qkv_rope_forward_kernel<<<blocks, threads, shared_mem_size, stream>>>(
qkv_input, q_freqs, k_freqs, start_positions, q_out, k_out, v_out, qkv_format, interleaved,
cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list_0, qkv_split_arg_list_1,
qkv_split_arg_list_2);
NVTE_CHECK_CUDA(cudaGetLastError());
}
template <typename scalar_t>
void fused_qkv_rope_backward_launcher(const scalar_t *q_grad_out, const scalar_t *k_grad_out,
const scalar_t *v_grad_out, const float *q_freqs,
const float *k_freqs, scalar_t *qkv_grad_input,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s,
const int b, const int h, const int d, const int d2,
const int qkv_split_arg_list_0,
const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream) {
const int THREADS_PER_WARP = 32;
const int warps_per_block = (h <= 8) ? h : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
const int shared_mem_size = 4 * d2 * sizeof(float); // cos, sin * q ,k
fused_qkv_rope_backward_kernel<<<blocks, threads, shared_mem_size, stream>>>(
q_grad_out, k_grad_out, v_grad_out, q_freqs, k_freqs, qkv_grad_input, qkv_format, interleaved,
cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list_0, qkv_split_arg_list_1,
qkv_split_arg_list_2);
NVTE_CHECK_CUDA(cudaGetLastError());
}
void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs, void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs,
const Tensor &start_positions, Tensor *output, const Tensor &start_positions, Tensor *output,
const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size,
...@@ -297,6 +605,46 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, c ...@@ -297,6 +605,46 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, c
stride_b, stride_h, stride_d, stream);); stride_b, stride_h, stride_d, stream););
} }
void fused_qkv_rope_forward(const Tensor &qkv_input, const Tensor &q_freqs, const Tensor &k_freqs,
const Tensor &start_positions, Tensor *q_out, Tensor *k_out,
Tensor *v_out, const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int qkv_split_arg_list_0,
const int qkv_split_arg_list_1, const int qkv_split_arg_list_2,
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
qkv_input.data.dtype, scalar_t,
fused_qkv_rope_forward_launcher(reinterpret_cast<const scalar_t *>(qkv_input.data.dptr),
reinterpret_cast<const float *>(q_freqs.data.dptr),
reinterpret_cast<const float *>(k_freqs.data.dptr),
reinterpret_cast<const int *>(start_positions.data.dptr),
reinterpret_cast<scalar_t *>(q_out->data.dptr),
reinterpret_cast<scalar_t *>(k_out->data.dptr),
reinterpret_cast<scalar_t *>(v_out->data.dptr), qkv_format,
interleaved, cp_size, cp_rank, s, b, h, d, d2,
qkv_split_arg_list_0, qkv_split_arg_list_1,
qkv_split_arg_list_2, stream););
}
void fused_qkv_rope_backward(const Tensor &q_grad_out, const Tensor &k_grad_out,
const Tensor &v_grad_out, const Tensor &q_freqs, const Tensor &k_freqs,
Tensor *qkv_grad_input, const NVTE_QKV_Format qkv_format,
const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2,
const int qkv_split_arg_list_0, const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
q_grad_out.data.dtype, scalar_t,
fused_qkv_rope_backward_launcher(reinterpret_cast<const scalar_t *>(q_grad_out.data.dptr),
reinterpret_cast<const scalar_t *>(k_grad_out.data.dptr),
reinterpret_cast<const scalar_t *>(v_grad_out.data.dptr),
reinterpret_cast<const float *>(q_freqs.data.dptr),
reinterpret_cast<const float *>(k_freqs.data.dptr),
reinterpret_cast<scalar_t *>(qkv_grad_input->data.dptr),
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2,
qkv_split_arg_list_0, qkv_split_arg_list_1,
qkv_split_arg_list_2, stream););
}
} // end namespace transformer_engine } // end namespace transformer_engine
void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens, void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens,
...@@ -328,3 +676,38 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu ...@@ -328,3 +676,38 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
stride_b, stride_h, stride_d, stream); stride_b, stride_h, stride_d, stream);
} }
void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs,
const NVTETensor k_freqs, const NVTETensor start_positions,
NVTETensor q_out, NVTETensor k_out, NVTETensor v_out,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2,
const int qkv_split_arg_list_0, const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_qkv_rope_forward);
using namespace transformer_engine;
fused_qkv_rope_forward(*convertNVTETensorCheck(qkv_input), *convertNVTETensorCheck(q_freqs),
*convertNVTETensorCheck(k_freqs), *convertNVTETensorCheck(start_positions),
convertNVTETensorCheck(q_out), convertNVTETensorCheck(k_out),
convertNVTETensorCheck(v_out), qkv_format, interleaved, cp_size, cp_rank,
s, b, h, d, d2, qkv_split_arg_list_0, qkv_split_arg_list_1,
qkv_split_arg_list_2, stream);
}
void nvte_fused_qkv_rope_backward(const NVTETensor q_grad_out, const NVTETensor k_grad_out,
const NVTETensor v_grad_out, const NVTETensor q_freqs,
const NVTETensor k_freqs, NVTETensor qkv_grad_input,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2,
const int qkv_split_arg_list_0, const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_qkv_rope_backward);
using namespace transformer_engine;
fused_qkv_rope_backward(*convertNVTETensorCheck(q_grad_out), *convertNVTETensorCheck(k_grad_out),
*convertNVTETensorCheck(v_grad_out), *convertNVTETensorCheck(q_freqs),
*convertNVTETensorCheck(k_freqs), convertNVTETensorCheck(qkv_grad_input),
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2,
qkv_split_arg_list_0, qkv_split_arg_list_1, qkv_split_arg_list_2, stream);
}
...@@ -75,6 +75,69 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu ...@@ -75,6 +75,69 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu
const int stride_b, const int stride_h, const int stride_d, const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Apply rotary positional embedding to the combined QKV input tensor.
*
* \param[in] qkv_input Combined QKV input tensor for fused rope.
* \param[in] q_freqs The freqs tensor for Q.
* \param[in] k_freqs The freqs tensor for K.
* \param[in] start_positions The beginning offsets for applying RoPE embeddings.
* \param[out] q_out Output tensor for Q.
* \param[out] k_out Output tensor for K.
* \param[out] v_out Output tensor for V.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] s Length of the s dimension of input.
* \param[in] b Length of the b dimension of input.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] qkv_split_arg_list_0 The hidden size for Q.
* \param[in] qkv_split_arg_list_1 The hidden size for K.
* \param[in] qkv_split_arg_list_2 The hidden size for V.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs,
const NVTETensor k_freqs, const NVTETensor start_positions,
NVTETensor q_out, NVTETensor k_out, NVTETensor v_out,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2,
const int qkv_split_arg_list_0, const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream);
/*! \brief Compute the backward of the fused qkv rope.
*
* \param[in] q_grad_out Incoming gradient tensor for Q.
* \param[in] k_grad_out Incoming gradient tensor for K.
* \param[in] v_grad_out Incoming gradient tensor for V.
* \param[in] q_freqs The freqs tensor for Q.
* \param[in] k_freqs The freqs tensor for K.
* \param[out] qkv_grad_input Input gradient tensor to calculate.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
* \param[in] cp_size Context parallel world size.
* \param[in] cp_rank Context parallel rank.
* \param[in] s Length of the s dimension of input.
* \param[in] b Length of the b dimension of input.
* \param[in] h Length of the h dimension of input.
* \param[in] d Length of the d dimension of input.
* \param[in] d2 Length of the d dimension of freqs.
* \param[in] qkv_split_arg_list_0 The hidden size for Q.
* \param[in] qkv_split_arg_list_1 The hidden size for K.
* \param[in] qkv_split_arg_list_2 The hidden size for V.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_qkv_rope_backward(const NVTETensor q_grad_out, const NVTETensor k_grad_out,
const NVTETensor v_grad_out, const NVTETensor q_freqs,
const NVTETensor k_freqs, NVTETensor qkv_grad_input,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2,
const int qkv_split_arg_list_0, const int qkv_split_arg_list_1,
const int qkv_split_arg_list_2, cudaStream_t stream);
#ifdef __cplusplus #ifdef __cplusplus
} // extern "C" } // extern "C"
#endif #endif
......
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
""" """
Rotary Position Embedding implementation of different types along with helper functions Rotary Position Embedding implementation of different types along with helper functions
""" """
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union, List
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat
__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb"] __all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"]
class RotaryPositionEmbedding(torch.nn.Module): class RotaryPositionEmbedding(torch.nn.Module):
...@@ -170,6 +170,86 @@ class FusedRoPEFunc(torch.autograd.Function): ...@@ -170,6 +170,86 @@ class FusedRoPEFunc(torch.autograd.Function):
return grad_input, None, None, None, None, None, None, None return grad_input, None, None, None, None, None, None, None
class FusedQKVRoPEFunc(torch.autograd.Function):
"""
Function for FusedQKVRoPE
This implementation accepts combined QKV tensor in `bshd` or `sbhd` format. Q and K RoPE tensors are the additional required inputs.
The RoPE tensors should be of shape (s, 1, 1, d). It produces 3 outputs: Q, K after RoPE, V is the same as input.
"""
@staticmethod
def forward(
ctx,
qkv: torch.Tensor,
q_freqs: torch.Tensor,
k_freqs: torch.Tensor,
qkv_split_arg_list: List[int],
start_positions: Union[torch.Tensor, None] = None,
tensor_format: str = "sbhd",
interleaved: bool = False,
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""Fused RoPE forward."""
if q_freqs.dtype != torch.float32:
q_freqs = q_freqs.float()
if k_freqs.dtype != torch.float32:
k_freqs = k_freqs.float()
assert tensor_format in (
"sbhd",
"bshd",
), f"Unsupported tensor_format: {tensor_format}."
assert qkv.is_contiguous(), "QKV Tensor should be contiguous."
assert q_freqs.is_contiguous(), "q_freqs Tensor should be contiguous."
assert k_freqs.is_contiguous(), "k_freqs Tensor should be contiguous."
output = tex.fused_qkv_rope_forward(
qkv,
q_freqs,
k_freqs,
start_positions,
qkv_split_arg_list,
QKVFormat[tensor_format],
interleaved,
cp_size,
cp_rank,
)
ctx.save_for_backward(q_freqs, k_freqs)
ctx.tensor_format = tensor_format
ctx.qkv_split_arg_list = qkv_split_arg_list
ctx.cp_size = cp_size
ctx.cp_rank = cp_rank
ctx.interleaved = interleaved
return output
@staticmethod
def backward(
ctx, grad_output_q: torch.Tensor, grad_output_k: torch.Tensor, grad_output_v: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Fused RoPE backward."""
q_freqs, k_freqs = ctx.saved_tensors
grad_output_q = grad_output_q.contiguous()
grad_output_k = grad_output_k.contiguous()
grad_output_v = grad_output_v.contiguous()
grad_input = tex.fused_qkv_rope_backward(
grad_output_q,
grad_output_k,
grad_output_v,
q_freqs,
k_freqs,
ctx.qkv_split_arg_list,
QKVFormat[ctx.tensor_format],
ctx.interleaved,
ctx.cp_size,
ctx.cp_rank,
)
return grad_input, None, None, None, None, None, None, None, None
def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor: def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor:
"""Change sign so the last dimension becomes [-odd, +even] """Change sign so the last dimension becomes [-odd, +even]
...@@ -393,3 +473,82 @@ def apply_rotary_pos_emb( ...@@ -393,3 +473,82 @@ def apply_rotary_pos_emb(
tensor_format, tensor_format,
interleaved=interleaved, interleaved=interleaved,
) )
def apply_fused_qkv_rotary_pos_emb(
qkv: torch.Tensor,
q_freqs: torch.Tensor,
k_freqs: torch.Tensor,
qkv_split_arg_list: List[int],
tensor_format: str = "sbhd",
start_positions: Union[torch.Tensor, None] = None,
interleaved: bool = False,
cu_seqlens: Union[torch.Tensor, None] = None, # pylint: disable=unused-argument
cp_size: int = 1,
cp_rank: int = 0,
) -> torch.Tensor:
"""
Apply rotary positional embedding tensor to the input qkv tensor.
Support matrix:
Fused:
Training:
qkv_formats: "bshd", "sbhd"
context parallel: yes
start_positions: no
interleaving: yes
Inference:
qkv_formats: "bshd", "sbhd"
context parallelism: no
start_positions: yes
interleaving: yes
Parameters
----------
qkv: torch.Tensor
Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which
rotary positional embedding will be applied. This tensor has q, k, v concatenated
along the last dimension.
q_freqs: torch.Tensor
Rotary positional embedding Q tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
k_freqs: torch.Tensor
Rotary positional embedding K tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
with `s2 >= s` and `d2 <= d`.
qkv_split_arg_list: List[int]
List of integers that specify the split of the qkv tensor. The list should have 3 elements,
the first element is the number of elements in the q tensor, the second element is the number
of elements in the k tensor, and the third element is the number of elements in the v tensor.
The sum of the elements in the list should be equal to the last dimension of the qkv tensor.
start_positions: torch.Tensor, default = None.
Tokens in a sequence `i` should be applied with position encoding offset by
`start_positions[i]`. If `start_positions=None`, there's no offset.
tensor_format: {'sbhd', 'bshd'}, default = 'sbhd'
is `bshd` if `qkv` is of shape `[bs, seq, ...]`, or `sbhd` if `qkv` is
of shape `[seq, bs, ...]`.
interleaved: bool, default = False
Whether to use interleaved rotary position embedding.
cp_size: int, default = 1.
Context parallel world size.
cp_rank: int, default = 0.
Context parallel rank.
"""
# `start_positions` is only supported for `cp_size=1` and inference.
assert not (
cp_size > 1 and start_positions is not None
), """start_positions != None with CP SIZE > 1 is not supported!"""
assert tensor_format != "thd", "'thd' tensor_format not supported currently."
return FusedQKVRoPEFunc.apply(
qkv,
q_freqs,
k_freqs,
qkv_split_arg_list,
start_positions,
tensor_format,
interleaved,
cp_size,
cp_rank,
)
...@@ -338,6 +338,18 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor ...@@ -338,6 +338,18 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
const std::optional<at::Tensor> cu_seqlens, const int cp_size, const std::optional<at::Tensor> cu_seqlens, const int cp_size,
const int cp_rank); const int cp_rank);
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_qkv_rope_forward(
const at::Tensor &qkv_input, const at::Tensor &q_freqs, const at::Tensor &k_freqs,
const std::optional<at::Tensor> start_positions, const std::vector<int> &qkv_split_arg_list,
const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank);
at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tensor &k_grad_out,
const at::Tensor &v_grad_out, const at::Tensor &q_freqs,
const at::Tensor &k_freqs,
const std::vector<int> &qkv_split_arg_list,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank);
/*************************************************************************************************** /***************************************************************************************************
* Miscellaneous * Miscellaneous
**************************************************************************************************/ **************************************************************************************************/
......
...@@ -102,6 +102,65 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, ...@@ -102,6 +102,65 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
return output; return output;
} }
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_qkv_rope_forward(
const at::Tensor &qkv_input, const at::Tensor &q_freqs, const at::Tensor &k_freqs,
const std::optional<at::Tensor> start_positions, const std::vector<int> &qkv_split_arg_list,
const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size,
const int cp_rank) {
TORCH_CHECK(q_freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(q_freqs.size(1) == 1 && q_freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(q_freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
TORCH_CHECK(k_freqs.dim() == 4, "expected 4D tensor");
TORCH_CHECK(k_freqs.size(1) == 1 && k_freqs.size(2) == 1,
"expected the second and third dims of the freqs tensor equal 1");
TORCH_CHECK(k_freqs.scalar_type() == at::ScalarType::Float,
"Dtype of the freqs tensor must be float");
// output
auto act_options = at::TensorOptions().dtype(qkv_input.scalar_type()).device(qkv_input.device());
auto q_out_size = qkv_input.sizes().vec();
q_out_size[2] = q_out_size[2] * qkv_split_arg_list[0] / qkv_split_arg_list[1];
q_out_size[3] = qkv_split_arg_list[1];
auto q_out = at::empty(q_out_size, act_options);
auto k_out_size = qkv_input.sizes().vec();
k_out_size[3] = qkv_split_arg_list[1];
auto k_out = at::empty(k_out_size, act_options);
auto v_out_size = qkv_input.sizes().vec();
v_out_size[3] = qkv_split_arg_list[2];
auto v_out = at::empty(v_out_size, act_options);
auto qkv_cu = makeTransformerEngineTensor(qkv_input);
auto q_freqs_cu = makeTransformerEngineTensor(q_freqs);
auto k_freqs_cu = makeTransformerEngineTensor(k_freqs);
auto q_out_cu = makeTransformerEngineTensor(q_out);
auto k_out_cu = makeTransformerEngineTensor(k_out);
auto v_out_cu = makeTransformerEngineTensor(v_out);
auto start_positions_cu = TensorWrapper(); // empty cu_seqlens tensor
if (start_positions) {
start_positions_cu = makeTransformerEngineTensor(start_positions.value());
}
TORCH_CHECK(qkv_input.dim() == 4, "expected 4D input tensor");
TORCH_CHECK(qkv_input.is_contiguous(), "input tensor must be contiguous");
const bool is_sbhd = qkv_format == NVTE_QKV_Format::NVTE_SBHD;
const int s = is_sbhd ? qkv_input.size(0) : qkv_input.size(1);
const int b = is_sbhd ? qkv_input.size(1) : qkv_input.size(0);
const int h = qkv_input.size(2);
const int d = qkv_split_arg_list[2];
const int d2 = q_freqs.size(3);
nvte_fused_qkv_rope_forward(qkv_cu.data(), q_freqs_cu.data(), k_freqs_cu.data(),
start_positions_cu.data(), q_out_cu.data(), k_out_cu.data(),
v_out_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b, h,
d, d2, qkv_split_arg_list[0], qkv_split_arg_list[1],
qkv_split_arg_list[2], at::cuda::getCurrentCUDAStream());
return std::make_tuple(q_out, k_out, v_out);
}
at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs,
const NVTE_QKV_Format qkv_format, const bool interleaved, const NVTE_QKV_Format qkv_format, const bool interleaved,
const std::optional<at::Tensor> cu_seqlens, const int cp_size, const std::optional<at::Tensor> cu_seqlens, const int cp_size,
...@@ -193,4 +252,42 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor ...@@ -193,4 +252,42 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor
return input_grads; return input_grads;
} }
at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tensor &k_grad_out,
const at::Tensor &v_grad_out, const at::Tensor &q_freqs,
const at::Tensor &k_freqs,
const std::vector<int> &qkv_split_arg_list,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank) {
auto act_options =
at::TensorOptions().dtype(q_grad_out.scalar_type()).device(q_grad_out.device());
auto qkv_grad_size = q_grad_out.sizes().vec();
auto total_hd =
(q_grad_out.size(2) + k_grad_out.size(2) + v_grad_out.size(2)) * q_grad_out.size(3);
auto total_d = qkv_split_arg_list[0] + qkv_split_arg_list[1] + qkv_split_arg_list[2];
qkv_grad_size[2] = total_hd / total_d;
qkv_grad_size[3] = total_d;
auto qkv_grad_input = at::empty(qkv_grad_size, act_options);
const bool is_sbhd = qkv_format == NVTE_QKV_Format::NVTE_SBHD;
const int s = is_sbhd ? q_grad_out.size(0) : q_grad_out.size(1);
const int b = is_sbhd ? q_grad_out.size(1) : q_grad_out.size(0);
const int h = qkv_grad_input.size(2);
const int d = qkv_split_arg_list[2];
const int d2 = q_freqs.size(3);
auto q_grad_out_cu = makeTransformerEngineTensor(q_grad_out);
auto k_grad_out_cu = makeTransformerEngineTensor(k_grad_out);
auto v_grad_out_cu = makeTransformerEngineTensor(v_grad_out);
auto q_freqs_cu = makeTransformerEngineTensor(q_freqs);
auto k_freqs_cu = makeTransformerEngineTensor(k_freqs);
auto qkv_grad_cu = makeTransformerEngineTensor(qkv_grad_input);
nvte_fused_qkv_rope_backward(q_grad_out_cu.data(), k_grad_out_cu.data(), v_grad_out_cu.data(),
q_freqs_cu.data(), k_freqs_cu.data(), qkv_grad_cu.data(), qkv_format,
interleaved, cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list[0],
qkv_split_arg_list[1], qkv_split_arg_list[2],
at::cuda::getCurrentCUDAStream());
return qkv_grad_input;
}
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
...@@ -278,6 +278,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -278,6 +278,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Fused Apply RoPE FWD", py::call_guard<py::gil_scoped_release>()); "Fused Apply RoPE FWD", py::call_guard<py::gil_scoped_release>());
m.def("fused_rope_backward", &transformer_engine::pytorch::fused_rope_backward, m.def("fused_rope_backward", &transformer_engine::pytorch::fused_rope_backward,
"Fused Apply RoPE BWD", py::call_guard<py::gil_scoped_release>()); "Fused Apply RoPE BWD", py::call_guard<py::gil_scoped_release>());
m.def("fused_qkv_rope_forward", &transformer_engine::pytorch::fused_qkv_rope_forward,
"Fused Apply QKV RoPE FWD", py::call_guard<py::gil_scoped_release>());
m.def("fused_qkv_rope_backward", &transformer_engine::pytorch::fused_qkv_rope_backward,
"Fused Apply QKV RoPE BWD", py::call_guard<py::gil_scoped_release>());
// fused router // fused router
m.def("fused_topk_with_score_function_fwd", m.def("fused_topk_with_score_function_fwd",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment